decomposeRL-7b / example.py
dipta007's picture
Update
54b21c8 verified
"""Run DecomposeRL-7B on a (claim, evidence_doc) pair and pretty-print the trace.
Usage:
python example.py
"""
import re
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_NAME = "dipta007/decomposeRL-7b"
PROMPT_TEMPLATE = """You are tasked with systematically verifying the accuracy of a claim. You will be provided with a claim to verify and an evidence document to consult.
Here is the evidence document you should consult:
<evidence_document>
{evidence_doc}
</evidence_document>
Here is the claim you need to verify:
<claim>
{claim}
</claim>
Your task is to verify whether this claim is Supported or Refuted through an iterative process of asking questions and gathering information.
# Verification Process
Begin by analyzing the claim in <think> tags, then enter an iterative cycle of <question>/<answer> pairs answered ONLY from the evidence document. When every sub-claim is addressed, output your final label inside <verification> tags. The label must be exactly one of: Supported, Refuted.
Stop immediately after the closing </verification> tag.
Begin your verification process now."""
TAG_RE = re.compile(r"<(think|question|answer|verification)>(.*?)</\1>", re.DOTALL)
def build_prompt(claim: str, evidence_doc: str) -> str:
"""Wrap a claim and evidence document in the DecomposeRL verification prompt."""
return PROMPT_TEMPLATE.format(claim=claim, evidence_doc=evidence_doc)
def verify(
model,
tokenizer,
claim: str,
evidence_doc: str,
max_new_tokens: int = 4500,
temperature: float = 0.7,
) -> str:
"""Run the model end-to-end on a (claim, evidence_doc) pair and return the raw trace."""
messages = [{"role": "user", "content": build_prompt(claim, evidence_doc)}]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer([text], return_tensors="pt").to(model.device)
out = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=True,
)
return tokenizer.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
def parse_trace(text: str):
"""Return a list of (tag, content) tuples in the order they appear."""
return [(tag, body.strip()) for tag, body in TAG_RE.findall(text)]
def pretty_print(text: str) -> None:
"""Print the trace as a readable conversation. Falls back to raw output if degenerate."""
parsed = parse_trace(text)
tags = {tag for tag, _ in parsed}
if not parsed or "verification" not in tags:
print("⚠️ Could not parse output into the expected think/question/answer/verification structure.")
print("Raw output:")
print("─" * 78)
print(text)
print("─" * 78)
return
cycle_idx = 0
pending_q = None
for tag, body in parsed:
if tag == "think":
print("─" * 78)
print("🧠 THINK")
print("─" * 78)
print(body)
print()
elif tag == "question":
cycle_idx += 1
pending_q = body
elif tag == "answer":
print(f"🔸 Q{cycle_idx}: {pending_q}")
print(f"💬 A{cycle_idx}: {body}")
print()
pending_q = None
elif tag == "verification":
print("=" * 78)
print(f"✅ VERIFICATION: {body}")
print("=" * 78)
def extract_label(text: str):
"""Return 'Supported', 'Refuted', or None."""
match = re.search(r"<verification>\s*(Supported|Refuted)\s*</verification>", text)
return match.group(1) if match else None
def main():
print(f"Loading {MODEL_NAME} ...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype="auto",
device_map="auto",
)
evidence_doc = (
"The Eiffel Tower is a wrought-iron lattice tower on the Champ de Mars in Paris, "
"France. It is named after the engineer Gustave Eiffel, whose company designed and "
"built the tower from 1887 to 1889. Locally nicknamed 'La dame de fer', it was "
"constructed as the centerpiece of the 1889 World's Fair. The tower is 330 metres "
"(1,083 ft) tall."
)
claim = "The Eiffel Tower was completed in 1887 and stands 330 metres tall."
print(f"\nClaim: {claim}\n")
response = verify(model, tokenizer, claim, evidence_doc)
pretty_print(response)
print(f"\nFinal label: {extract_label(response)}")
if __name__ == "__main__":
main()