| import logging |
| import re |
| import time |
| from typing import List, Dict, Any, Optional |
| from langgraph.graph import StateGraph, END |
| from langgraph.checkpoint.memory import MemorySaver |
|
|
| from pydantic import BaseModel, Field |
|
|
| from langchain_core.messages import HumanMessage, AIMessage, SystemMessage |
| from langchain_core.output_parsers import StrOutputParser, JsonOutputParser |
|
|
| from .config import settings |
| from .schemas import PlannerState, KeyIssue, GraphConfig |
| from .prompts import get_initial_planner_prompt, KEY_ISSUE_STRUCTURING_PROMPT |
| from .llm_interface import get_llm, invoke_llm |
| from .graph_operations import ( |
| generate_cypher_auto, generate_cypher_guided, |
| retrieve_documents, evaluate_documents |
| ) |
| from .processing import process_documents |
|
|
| logger = logging.getLogger(__name__) |
|
|
| |
|
|
| def start_planning(state: PlannerState) -> Dict[str, Any]: |
| """Generates the initial plan based on the user query.""" |
| logger.info("Node: start_planning") |
| user_query = state['user_query'] |
| if not user_query: |
| return {"error": "User query is empty."} |
|
|
| initial_prompt = get_initial_planner_prompt(settings.plan_method, user_query) |
| llm = get_llm(settings.main_llm_model) |
| chain = initial_prompt | llm | StrOutputParser() |
|
|
| try: |
| plan_text = invoke_llm(chain,{}) |
| logger.debug(f"Raw plan text: {plan_text}") |
|
|
| |
| plan_match = re.search(r"Plan:(.*?)<END_OF_PLAN>", plan_text, re.DOTALL | re.IGNORECASE) |
| if plan_match: |
| plan_steps = [step.strip() for step in re.split(r"\n\s*\d+\.\s*", plan_match.group(1)) if step.strip()] |
| logger.info(f"Extracted plan: {plan_steps}") |
| return { |
| "plan": plan_steps, |
| "current_plan_step_index": 0, |
| "messages": [AIMessage(content=plan_text)], |
| "step_outputs": {} |
| } |
| else: |
| logger.error("Could not parse plan from LLM response.") |
| return {"error": "Failed to parse plan from LLM response.", "messages": [AIMessage(content=plan_text)]} |
| except Exception as e: |
| logger.error(f"Error during plan generation: {e}", exc_info=True) |
| return {"error": f"LLM error during plan generation: {e}"} |
|
|
|
|
| def execute_plan_step(state: PlannerState) -> Dict[str, Any]: |
| """Executes the current step of the plan (retrieval, processing).""" |
| current_index = state['current_plan_step_index'] |
| plan = state['plan'] |
| user_query = state['user_query'] |
|
|
| if current_index >= len(plan): |
| logger.warning("Plan step index out of bounds, attempting to finalize.") |
| |
| return {"error": "Plan execution finished unexpectedly."} |
|
|
| step_description = plan[current_index] |
| logger.info(f"Node: execute_plan_step - Step {current_index + 1}/{len(plan)}: {step_description}") |
|
|
| |
| |
| |
| query_for_retrieval = f"Regarding the query '{user_query}', focus on: {step_description}" |
| logger.info(f"Query for retrieval: {query_for_retrieval}") |
|
|
| |
| cypher_query = "" |
| if settings.cypher_gen_method == 'auto': |
| cypher_query = generate_cypher_auto(query_for_retrieval) |
| elif settings.cypher_gen_method == 'guided': |
| cypher_query = generate_cypher_guided(query_for_retrieval, current_index) |
| |
|
|
| |
| retrieved_docs = retrieve_documents(cypher_query) |
|
|
| |
| evaluated_docs = evaluate_documents(retrieved_docs, query_for_retrieval) |
|
|
| |
| |
| processed_docs_content = process_documents(evaluated_docs, settings.process_steps) |
|
|
| |
| |
| step_output = "\n\n".join(processed_docs_content) if processed_docs_content else "No relevant information found for this step." |
| current_step_outputs = state.get('step_outputs', {}) |
| current_step_outputs[current_index] = step_output |
|
|
| logger.info(f"Finished executing plan step {current_index + 1}. Stored output.") |
|
|
| return { |
| "current_plan_step_index": current_index + 1, |
| "messages": [SystemMessage(content=f"Completed plan step {current_index + 1}. Context gathered:\n{step_output[:500]}...")], |
| "step_outputs": current_step_outputs |
| } |
|
|
| class KeyIssue(BaseModel): |
| |
| id: int |
| description: str |
|
|
| class KeyIssueList(BaseModel): |
| key_issues: List[KeyIssue] = Field(description="List of key issues") |
|
|
| class KeyIssueInvoke(BaseModel): |
| id: int |
| title: str |
| description: str |
| challenges: List[str] |
| potential_impact: Optional[str] = None |
|
|
| def generate_structured_issues(state: PlannerState) -> Dict[str, Any]: |
| """Generates the final structured Key Issues based on all gathered context.""" |
| logger.info("Node: generate_structured_issues") |
|
|
| user_query = state['user_query'] |
| step_outputs = state.get('step_outputs', {}) |
|
|
| |
| full_context = f"Original User Query: {user_query}\n\n" |
| full_context += "Context gathered during planning:\n" |
| for i, output in sorted(step_outputs.items()): |
| full_context += f"--- Context from Step {i+1} ---\n{output}\n\n" |
|
|
| if not step_outputs: |
| full_context += "No context was gathered during the planning steps.\n" |
|
|
| logger.info(f"Generating key issues using combined context (length: {len(full_context)} chars).") |
| |
|
|
| |
| issue_llm = get_llm(settings.main_llm_model) |
| |
| output_parser = JsonOutputParser(pydantic_object=KeyIssueList) |
|
|
| |
| prompt = KEY_ISSUE_STRUCTURING_PROMPT.partial( |
| schema=output_parser.get_format_instructions(), |
| ) |
|
|
| chain = prompt | issue_llm | output_parser |
|
|
| try: |
| structured_issues_obj = invoke_llm(chain, { |
| "user_query": user_query, |
| "context": full_context |
| }) |
| print(f"structured_issues_obj => type : {type(structured_issues_obj)}, value : {structured_issues_obj}") |
| |
| |
| if isinstance(structured_issues_obj, dict) and 'key_issues' in structured_issues_obj: |
| issues_data = structured_issues_obj['key_issues'] |
| else: |
| issues_data = structured_issues_obj |
| |
| |
| key_issues_list = [KeyIssueInvoke(**issue_dict) for issue_dict in issues_data] |
| |
| |
| for i, issue in enumerate(key_issues_list): |
| issue.id = i + 1 |
| |
| logger.info(f"Successfully generated {len(key_issues_list)} structured key issues.") |
| final_message = f"Generated {len(key_issues_list)} Key Issues based on the query '{user_query}'." |
| return { |
| "key_issues": key_issues_list, |
| "messages": [AIMessage(content=final_message)], |
| "error": None |
| } |
| except Exception as e: |
| logger.error(f"Failed to generate or parse structured key issues: {e}", exc_info=True) |
| |
| raw_output = "Could not retrieve raw output." |
| try: |
| raw_chain = prompt | issue_llm | StrOutputParser() |
| raw_output = invoke_llm(raw_chain, {"user_query": user_query, "context": full_context}) |
| logger.debug(f"Raw output from failed JSON parsing:\n{raw_output}") |
| except Exception as raw_e: |
| logger.error(f"Could not even get raw output: {raw_e}") |
| |
| return {"error": f"Failed to generate structured key issues: {e}. Raw output hint: {raw_output[:500]}..."} |
|
|
|
|
| |
|
|
| def should_continue_planning(state: PlannerState) -> str: |
| """Determines if there are more plan steps to execute.""" |
| logger.debug("Edge: should_continue_planning") |
| if state.get("error"): |
| logger.error(f"Error state detected: {state['error']}. Ending execution.") |
| return "error_state" |
|
|
| current_index = state['current_plan_step_index'] |
| plan_length = len(state.get('plan', [])) |
|
|
| if current_index < plan_length: |
| logger.debug(f"Continuing plan execution. Next step index: {current_index}") |
| return "continue_execution" |
| else: |
| logger.debug("Plan finished. Proceeding to final generation.") |
| return "finalize" |
|
|
|
|
| |
| def build_graph(): |
| """Builds the LangGraph workflow.""" |
| workflow = StateGraph(PlannerState) |
|
|
| |
| workflow.add_node("start_planning", start_planning) |
| workflow.add_node("execute_plan_step", execute_plan_step) |
| workflow.add_node("generate_issues", generate_structured_issues) |
| |
| workflow.add_node("error_node", lambda state: {"messages": [SystemMessage(content=f"Execution failed: {state.get('error', 'Unknown error')}") ]}) |
|
|
|
|
| |
| workflow.set_entry_point("start_planning") |
| workflow.add_edge("start_planning", "execute_plan_step") |
|
|
| workflow.add_conditional_edges( |
| "execute_plan_step", |
| should_continue_planning, |
| { |
| "continue_execution": "execute_plan_step", |
| "finalize": "generate_issues", |
| "error_state": "error_node" |
| } |
| ) |
|
|
| workflow.add_edge("generate_issues", END) |
| workflow.add_edge("error_node", END) |
|
|
| |
| |
| |
| app_graph = workflow.compile() |
| return app_graph |