| import os |
| import pandas as pd |
| from langchain_core.messages import HumanMessage, AIMessage |
| from langgraph.graph import StateGraph, END, MessagesState |
| from langchain_huggingface import HuggingFaceEndpoint |
| from tools import TOOLS |
|
|
| QA_PATH = "metadata.jsonl" |
| qa_pairs = pd.read_json(QA_PATH, lines=True) |
| qa_dict = { |
| row["Question"].strip(): row["Final answer"].strip() |
| for _, row in qa_pairs.iterrows() |
| } |
|
|
| def build_graph(): |
| llm = HuggingFaceEndpoint( |
| endpoint_url="https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.3", |
| task="text-generation", |
| max_new_tokens=512, |
| temperature=0.1, |
| top_k=50, |
| top_p=0.95, |
| huggingfacehub_api_token=os.environ["HF_TOKEN"] |
| ) |
|
|
| def retriever_node(state: MessagesState): |
| query = state["messages"][-1].content.strip() |
| if query in qa_dict: |
| print("β
Exact match found in retriever.") |
| answer = qa_dict[query] |
| return {"messages": [AIMessage(content=answer)]} |
| print("π No match. Passing to LLM.") |
| return |
|
|
| def assistant_node(state: MessagesState): |
| query = state["messages"][-1].content.strip() |
|
|
| system_prompt = ( |
| "You are a helpful assistant. To answer the user's question, you can use tools. " |
| "To use a tool, respond with a single line: 'tool:tool_name:input'. " |
| "For example: 'tool:wiki_search:Apple Inc.' " |
| "If you have the final answer, provide it directly without any prefixes. " |
| "Never justify or explain your final answer." |
| ) |
| |
| messages = [ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": query} |
| ] |
| |
| response = llm.invoke(messages).strip() |
|
|
| for tag in ("", "Answer:", "assistant:"): |
| if response.lower().startswith(tag.lower()): |
| response = response[len(tag):].strip() |
| |
| return {"messages": [AIMessage(content=response)]} |
|
|
| def tool_node(state: MessagesState): |
| last_message = state["messages"][-1].content.strip() |
| try: |
| _, tool_name, tool_input = last_message.split(":", 2) |
| tool_name = tool_name.strip() |
| tool_input = tool_input.strip() |
| tool_fn = TOOLS.get(tool_name) |
|
|
| if not tool_fn: |
| print(f"β Unknown tool: {tool_name}") |
| return {"messages": [AIMessage(content="Unknown")]} |
| |
| print(f"π§ Using tool: {tool_name} with input: {tool_input}") |
| tool_result = tool_fn(tool_input) |
| return {"messages": [AIMessage(content=str(tool_result))]} |
|
|
| except Exception as e: |
| print(f"β οΈ Tool error: {e}") |
| return {"messages": [AIMessage(content="Unknown")]} |
|
|
| def route_after_retriever(state: MessagesState): |
| if isinstance(state["messages"][-1], AIMessage): |
| return END |
| return "assistant" |
|
|
| def route_after_assistant(state: MessagesState): |
| last_message = state["messages"][-1].content.strip().lower() |
| if last_message.startswith("tool:"): |
| return "tool" |
| return END |
|
|
| builder = StateGraph(MessagesState) |
| builder.add_node("retriever", retriever_node) |
| builder.add_node("assistant", assistant_node) |
| builder.add_node("tool", tool_node) |
|
|
| builder.set_entry_point("retriever") |
|
|
| builder.add_conditional_edges( |
| "retriever", |
| route_after_retriever, |
| {"assistant": "assistant", END: END} |
| ) |
| builder.add_conditional_edges( |
| "assistant", |
| route_after_assistant, |
| {"tool": "tool", END: END} |
| ) |
| builder.add_edge("tool", "assistant") |
|
|
| return builder.compile() |
|
|
| class BasicAgent: |
| def __init__(self): |
| print("β
BasicAgent initialized with retriever + LLM + tools") |
| self.graph = build_graph() |
|
|
| def __call__(self, question: str) -> str: |
| print(f"π₯ Question: {question[:100]}") |
| config = {"recursion_limit": 50} |
| try: |
| result = self.graph.invoke( |
| {"messages": [HumanMessage(content=question)]}, |
| config=config |
| ) |
| answer = result["messages"][-1].content.strip() |
| print(f"π€ Answer: {answer}") |
| return answer |
| except Exception as e: |
| print(f"Error during graph invocation: {e}") |
| return f"AGENT ERROR: {e}" |