rsobieski's picture
Update agent.py
cfa3fb2 verified
import os
import pandas as pd
from langchain_core.messages import HumanMessage, AIMessage
from langgraph.graph import StateGraph, END, MessagesState
from langchain_huggingface import HuggingFaceEndpoint
from tools import TOOLS
QA_PATH = "metadata.jsonl"
qa_pairs = pd.read_json(QA_PATH, lines=True)
qa_dict = {
row["Question"].strip(): row["Final answer"].strip()
for _, row in qa_pairs.iterrows()
}
def build_graph():
llm = HuggingFaceEndpoint(
endpoint_url="https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.3",
task="text-generation",
max_new_tokens=512,
temperature=0.1,
top_k=50,
top_p=0.95,
huggingfacehub_api_token=os.environ["HF_TOKEN"]
)
def retriever_node(state: MessagesState):
query = state["messages"][-1].content.strip()
if query in qa_dict:
print("βœ… Exact match found in retriever.")
answer = qa_dict[query]
return {"messages": [AIMessage(content=answer)]}
print("πŸ” No match. Passing to LLM.")
return
def assistant_node(state: MessagesState):
query = state["messages"][-1].content.strip()
system_prompt = (
"You are a helpful assistant. To answer the user's question, you can use tools. "
"To use a tool, respond with a single line: 'tool:tool_name:input'. "
"For example: 'tool:wiki_search:Apple Inc.' "
"If you have the final answer, provide it directly without any prefixes. "
"Never justify or explain your final answer."
)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": query}
]
response = llm.invoke(messages).strip()
for tag in ("", "Answer:", "assistant:"):
if response.lower().startswith(tag.lower()):
response = response[len(tag):].strip()
return {"messages": [AIMessage(content=response)]}
def tool_node(state: MessagesState):
last_message = state["messages"][-1].content.strip()
try:
_, tool_name, tool_input = last_message.split(":", 2)
tool_name = tool_name.strip()
tool_input = tool_input.strip()
tool_fn = TOOLS.get(tool_name)
if not tool_fn:
print(f"❌ Unknown tool: {tool_name}")
return {"messages": [AIMessage(content="Unknown")]}
print(f"πŸ”§ Using tool: {tool_name} with input: {tool_input}")
tool_result = tool_fn(tool_input)
return {"messages": [AIMessage(content=str(tool_result))]}
except Exception as e:
print(f"⚠️ Tool error: {e}")
return {"messages": [AIMessage(content="Unknown")]}
def route_after_retriever(state: MessagesState):
if isinstance(state["messages"][-1], AIMessage):
return END
return "assistant"
def route_after_assistant(state: MessagesState):
last_message = state["messages"][-1].content.strip().lower()
if last_message.startswith("tool:"):
return "tool"
return END
builder = StateGraph(MessagesState)
builder.add_node("retriever", retriever_node)
builder.add_node("assistant", assistant_node)
builder.add_node("tool", tool_node)
builder.set_entry_point("retriever")
builder.add_conditional_edges(
"retriever",
route_after_retriever,
{"assistant": "assistant", END: END}
)
builder.add_conditional_edges(
"assistant",
route_after_assistant,
{"tool": "tool", END: END}
)
builder.add_edge("tool", "assistant")
return builder.compile()
class BasicAgent:
def __init__(self):
print("βœ… BasicAgent initialized with retriever + LLM + tools")
self.graph = build_graph()
def __call__(self, question: str) -> str:
print(f"πŸ“₯ Question: {question[:100]}")
config = {"recursion_limit": 50}
try:
result = self.graph.invoke(
{"messages": [HumanMessage(content=question)]},
config=config
)
answer = result["messages"][-1].content.strip()
print(f"πŸ“€ Answer: {answer}")
return answer
except Exception as e:
print(f"Error during graph invocation: {e}")
return f"AGENT ERROR: {e}"