| from time import time |
| from pprint import pprint |
| import huggingface_hub |
| import streamlit as st |
| from typing import Literal, Dict |
| from typing_extensions import TypedDict |
| import langchain |
| from langgraph.graph import END, StateGraph |
| from langchain_community.chat_models import ChatOllama |
| from logger import logger |
|
|
| from config import config |
| from agents import get_agents, tools_dict |
|
|
|
|
| class GraphState(TypedDict): |
| """Represents the state of the graph.""" |
| question: str |
| rephrased_question: str |
| function_agent_output: str |
| generation: str |
|
|
|
|
| @st.cache_resource(show_spinner="Loading model..") |
| def init_agents() -> dict[str, langchain.agents.AgentExecutor]: |
| huggingface_hub.login(token=config.hf_token, new_session=False) |
| llm = ChatOllama(model = config.ollama_model, temperature = 0.8) |
| return get_agents(llm) |
|
|
|
|
| |
|
|
| def question_node(state: GraphState) -> Dict[str, str]: |
| """ |
| Generate a question for the function agent. |
| """ |
| logger.info("Generating question for function agent") |
| |
| question = state["question"] |
| logger.info(f"Original question: {question}") |
| rephrased_question = agents["rephrase_agent"].invoke({"question": question}) |
| logger.info(f"Rephrased question: {rephrased_question}") |
| return {"rephrased_question": rephrased_question} |
|
|
| def function_agent_node(state: GraphState) -> Literal["finished"]: |
| """ |
| Call the function agent |
| """ |
| logger.info("Calling function agent") |
| question = state["rephrased_question"] |
| response = agents["function_agent"].invoke({"input": question, "tools": tools_dict}).get("output") |
| |
| logger.info(f"Function agent output: {response}") |
| return {"function_agent_output": response} |
|
|
| def output_node(state: GraphState) -> Dict[str, str]: |
| """ |
| Generate the final output |
| """ |
| logger.info("Generating output") |
| |
| generation = agents["output_agent"].invoke({"context": state["function_agent_output"], |
| "question": state["rephrased_question"]}) |
| return {"generation": generation} |
|
|
| |
|
|
| def route_question(state: GraphState) -> Literal["vectorstore", "websearch"]: |
| """ |
| Route quesition to web search or RAG |
| """ |
| logger.info("Routing question") |
| |
| question = state["question"] |
| logger.info(f"Question: {question}") |
| source = agents["router_agent"].invoke({"question": question}) |
| logger.info(source) |
| logger.info(source["datasource"]) |
| if source["datasource"] == "vectorstore": |
| return "vectorstore" |
| elif source["datasource"] == "websearch": |
| return "websearch" |
|
|
|
|
| |
|
|
| workflow = StateGraph(GraphState) |
| workflow.add_node("question_rephrase", question_node) |
| workflow.add_node("function_agent", function_agent_node) |
| workflow.add_node("output_node", output_node) |
|
|
| workflow.set_entry_point("question_rephrase") |
| workflow.add_edge("question_rephrase", "function_agent") |
| workflow.add_edge("function_agent", "output_node") |
| workflow.set_finish_point("output_node") |
|
|
| flow = workflow.compile() |
|
|
| progress_map = { |
| "question_rephrase": ":mag: Collecting data", |
| "function_agent": ":bulb: Preparing response", |
| "output_node": ":bulb: Done!", |
| } |
|
|
| def main(): |
| st.title("LLM-ADE 9B Demo") |
|
|
| input_text = st.text_area("Enter your text here:", value="", height=200) |
| |
| def get_response(input_text: str, depth: int = 1) -> str: |
| try: |
| for output in flow.stream({"question": input_text}): |
| for key, value in output.items(): |
| config.status.update(label=progress_map[key]) |
| pprint(f"Finished running: {key}") |
| return value["generation"] |
| except Exception as e: |
| logger.error(e) |
| logger.info("Retrying..") |
| if depth < 5: |
| return get_response(input_text, depth + 1) |
|
|
| if st.button("Generate") or input_text: |
| start = time() |
| if input_text: |
| with st.status("Generating response...") as status: |
| config.status = status |
| config.status.update(label=":question: Breaking down question") |
| response = get_response(input_text) |
| response = response.replace("$", "\$") |
| st.info(response) |
| config.status.update(label=f"Finished! ({time() - start:.2f}s)", state="complete", expanded=True) |
| else: |
| st.warning("Please enter some text to generate a response.") |
|
|
|
|
| def main_headless(prompt: str): |
| start = time() |
| for output in flow.stream({"question": prompt}): |
| for key, value in output.items(): |
| pprint(f"Finished running: {key}") |
| print("\033[94m" + value["generation"] + "\033[0m") |
| print(f"Time taken: {time() - start:.2f}s\n" + "-" * 20) |
|
|
|
|
| agents = init_agents() |
|
|
| if __name__ == "__main__": |
| if config.headless: |
| import fire |
| fire.Fire(main_headless) |
| else: |
| main() |
|
|