| | """ |
| | LangGraph Workflow for SPARKNET |
| | Implements cyclic multi-agent workflows with StateGraph |
| | """ |
| |
|
| | from typing import Literal, Dict, Any, Optional |
| | from datetime import datetime |
| | from loguru import logger |
| |
|
| | from langgraph.graph import StateGraph, END |
| | from langgraph.checkpoint.memory import MemorySaver |
| | from langchain_core.messages import HumanMessage, AIMessage, SystemMessage |
| |
|
| | from .langgraph_state import ( |
| | AgentState, |
| | ScenarioType, |
| | TaskStatus, |
| | WorkflowOutput, |
| | create_initial_state, |
| | state_to_output, |
| | ) |
| | from ..llm.langchain_ollama_client import LangChainOllamaClient |
| |
|
| |
|
| | class SparknetWorkflow: |
| | """ |
| | LangGraph-powered workflow orchestrator for SPARKNET. |
| | |
| | Implements cyclic workflow with conditional routing: |
| | START → PLANNER → ROUTER → [scenario executors] → CRITIC |
| | ↑ ↓ |
| | └────────── REFINE ←──────────────────────┘ |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | llm_client: LangChainOllamaClient, |
| | planner_agent: Optional[Any] = None, |
| | critic_agent: Optional[Any] = None, |
| | memory_agent: Optional[Any] = None, |
| | vision_ocr_agent: Optional[Any] = None, |
| | quality_threshold: float = 0.85, |
| | max_iterations: int = 3, |
| | ): |
| | self.llm_client = llm_client |
| | self.planner_agent = planner_agent |
| | self.critic_agent = critic_agent |
| | self.memory_agent = memory_agent |
| | self.vision_ocr_agent = vision_ocr_agent |
| | self.quality_threshold = quality_threshold |
| | self.max_iterations = max_iterations |
| |
|
| | self.graph = self._build_graph() |
| | self.checkpointer = MemorySaver() |
| | self.app = self.graph.compile(checkpointer=self.checkpointer) |
| |
|
| | if vision_ocr_agent: |
| | logger.info("Initialized SparknetWorkflow with LangGraph StateGraph and VisionOCR support") |
| | else: |
| | logger.info("Initialized SparknetWorkflow with LangGraph StateGraph") |
| |
|
| | def _build_graph(self) -> StateGraph: |
| | workflow = StateGraph(AgentState) |
| |
|
| | workflow.add_node("planner", self._planner_node) |
| | workflow.add_node("router", self._router_node) |
| | workflow.add_node("executor", self._executor_node) |
| | workflow.add_node("critic", self._critic_node) |
| | workflow.add_node("refine", self._refine_node) |
| | workflow.add_node("finish", self._finish_node) |
| |
|
| | workflow.set_entry_point("planner") |
| | workflow.add_edge("planner", "router") |
| | workflow.add_edge("router", "executor") |
| | workflow.add_edge("executor", "critic") |
| |
|
| | workflow.add_conditional_edges( |
| | "critic", |
| | self._should_refine, |
| | { |
| | "refine": "refine", |
| | "finish": "finish", |
| | } |
| | ) |
| |
|
| | workflow.add_edge("refine", "planner") |
| | workflow.add_edge("finish", END) |
| |
|
| | return workflow |
| |
|
| | async def _planner_node(self, state: AgentState) -> AgentState: |
| | logger.info(f"PLANNER node processing task: {state['task_id']}") |
| | state["status"] = TaskStatus.PLANNING |
| | state["current_agent"] = "PlannerAgent" |
| |
|
| | |
| | context_docs = [] |
| | if self.memory_agent: |
| | try: |
| | logger.info("Retrieving relevant context from memory...") |
| | context_docs = await self.memory_agent.retrieve_relevant_context( |
| | query=state["task_description"], |
| | context_type="all", |
| | top_k=3, |
| | scenario_filter=state["scenario"], |
| | min_quality_score=0.8 |
| | ) |
| | if context_docs: |
| | logger.info(f"Retrieved {len(context_docs)} relevant memories") |
| | |
| | state["agent_outputs"]["memory_context"] = [ |
| | {"content": doc.page_content, "metadata": doc.metadata} |
| | for doc in context_docs |
| | ] |
| | except Exception as e: |
| | logger.warning(f"Memory retrieval failed: {e}") |
| |
|
| | system_msg = SystemMessage(content="Decompose the task into executable subtasks.") |
| |
|
| | |
| | context_text = "" |
| | if context_docs: |
| | context_text = "\n\nRelevant past experiences:\n" |
| | for i, doc in enumerate(context_docs, 1): |
| | context_text += f"\n{i}. {doc.page_content[:200]}..." |
| |
|
| | user_msg = HumanMessage( |
| | content=f"Task: {state['task_description']}\nScenario: {state['scenario']}{context_text}" |
| | ) |
| |
|
| | llm = self.llm_client.get_llm(complexity="complex") |
| |
|
| | if self.planner_agent: |
| | from ..agents.base_agent import Task |
| | task = Task( |
| | id=state["task_id"], |
| | description=state["task_description"], |
| | metadata={"scenario": state["scenario"].value} |
| | ) |
| | result_task = await self.planner_agent.process_task(task) |
| |
|
| | if result_task.status == "completed": |
| | state["subtasks"] = [ |
| | { |
| | "id": st.id, |
| | "description": st.description, |
| | "agent_type": st.agent_type, |
| | "dependencies": st.dependencies, |
| | } |
| | for st in result_task.result["task_graph"].subtasks.values() |
| | ] |
| | state["execution_order"] = result_task.result["execution_order"] |
| | response_msg = AIMessage(content=f"Created plan with {len(state['subtasks'])} subtasks") |
| | state["messages"].append(response_msg) |
| | else: |
| | response = await llm.ainvoke([system_msg, user_msg]) |
| | state["messages"].append(response) |
| | state["subtasks"] = [ |
| | {"id": "subtask_1", "description": "Execute primary task", "agent_type": "ExecutorAgent", "dependencies": []} |
| | ] |
| | state["execution_order"] = [["subtask_1"]] |
| |
|
| | logger.info(f"Planning completed: {len(state.get('subtasks', []))} subtasks created") |
| | return state |
| |
|
| | async def _router_node(self, state: AgentState) -> AgentState: |
| | logger.info(f"ROUTER node routing for scenario: {state['scenario']}") |
| | state["current_agent"] = "Router" |
| |
|
| | scenario = state["scenario"] |
| | routing_msg = AIMessage(content=f"Routing to {scenario.value} workflow agents") |
| | state["messages"].append(routing_msg) |
| |
|
| | state["agent_outputs"]["router"] = { |
| | "scenario": scenario.value, |
| | "agents_to_use": self._get_scenario_agents(scenario) |
| | } |
| |
|
| | return state |
| |
|
| | async def _executor_node(self, state: AgentState) -> AgentState: |
| | logger.info(f"EXECUTOR node executing for scenario: {state['scenario']}") |
| | state["status"] = TaskStatus.EXECUTING |
| | state["current_agent"] = "Executor" |
| |
|
| | scenario = state["scenario"] |
| |
|
| | |
| | if scenario == ScenarioType.PATENT_WAKEUP: |
| | logger.info("🎯 Routing to Patent Wake-Up pipeline") |
| | return await self._execute_patent_wakeup(state) |
| |
|
| | |
| | agents = self._get_scenario_agents(scenario) |
| |
|
| | |
| | from ..tools.langchain_tools import get_vista_tools |
| | tools = get_vista_tools(scenario.value) |
| | logger.info(f"Loaded {len(tools)} tools for scenario: {scenario.value}") |
| |
|
| | |
| | llm = self.llm_client.get_llm(complexity="standard") |
| | llm_with_tools = llm.bind_tools(tools) |
| |
|
| | |
| | tool_descriptions = "\n".join([f"- {tool.name}: {tool.description}" for tool in tools]) |
| | execution_prompt = HumanMessage( |
| | content=f"""Execute the following task using the available tools when needed: |
| | |
| | Task: {state['task_description']} |
| | Scenario: {scenario.value} |
| | |
| | Available tools: |
| | {tool_descriptions} |
| | |
| | Provide detailed results.""" |
| | ) |
| |
|
| | |
| | response = await llm_with_tools.ainvoke([execution_prompt]) |
| | state["messages"].append(response) |
| |
|
| | |
| | tool_calls = [] |
| | if hasattr(response, 'tool_calls') and response.tool_calls: |
| | logger.info(f"LLM requested {len(response.tool_calls)} tool calls") |
| | for tool_call in response.tool_calls: |
| | tool_name = tool_call.get('name', 'unknown') |
| | tool_calls.append(tool_name) |
| | logger.info(f"Tool called: {tool_name}") |
| |
|
| | state["agent_outputs"]["executor"] = { |
| | "result": response.content, |
| | "agents_used": agents, |
| | "tools_available": [tool.name for tool in tools], |
| | "tools_called": tool_calls, |
| | } |
| | state["final_output"] = response.content |
| |
|
| | logger.info("Execution completed") |
| | return state |
| |
|
| | async def _execute_patent_wakeup(self, state: AgentState) -> AgentState: |
| | """ |
| | Execute Patent Wake-Up scenario pipeline. |
| | Sequential execution: Document → Market → Matchmaking → Outreach |
| | """ |
| | logger.info("🚀 Executing Patent Wake-Up pipeline") |
| |
|
| | |
| | from ..agents.scenario1 import ( |
| | DocumentAnalysisAgent, |
| | MarketAnalysisAgent, |
| | MatchmakingAgent, |
| | OutreachAgent |
| | ) |
| |
|
| | |
| | |
| | patent_path = state.get("input_data", {}).get("patent_path", "mock_patent.txt") |
| |
|
| | try: |
| | |
| | logger.info("📄 Step 1/4: Analyzing patent document...") |
| | doc_agent = DocumentAnalysisAgent( |
| | llm_client=self.llm_client, |
| | memory_agent=self.memory_agent, |
| | vision_ocr_agent=self.vision_ocr_agent |
| | ) |
| | patent_analysis = await doc_agent.analyze_patent(patent_path) |
| | state["agent_outputs"]["document_analysis"] = patent_analysis.model_dump() |
| | logger.success(f"✅ Patent analyzed: {patent_analysis.title}") |
| |
|
| | |
| | logger.info("📊 Step 2/4: Analyzing market opportunities...") |
| | market_agent = MarketAnalysisAgent( |
| | llm_client=self.llm_client, |
| | memory_agent=self.memory_agent |
| | ) |
| | market_analysis = await market_agent.analyze_market(patent_analysis) |
| | state["agent_outputs"]["market_analysis"] = market_analysis.model_dump() |
| | logger.success(f"✅ Market analyzed: {len(market_analysis.opportunities)} opportunities") |
| |
|
| | |
| | logger.info("🤝 Step 3/4: Finding potential partners...") |
| | matching_agent = MatchmakingAgent( |
| | llm_client=self.llm_client, |
| | memory_agent=self.memory_agent |
| | ) |
| | matches = await matching_agent.find_matches( |
| | patent_analysis, |
| | market_analysis, |
| | max_matches=10 |
| | ) |
| | state["agent_outputs"]["matches"] = [m.model_dump() for m in matches] |
| | logger.success(f"✅ Found {len(matches)} potential partners") |
| |
|
| | |
| | logger.info("📝 Step 4/4: Creating valorization brief...") |
| | outreach_agent = OutreachAgent( |
| | llm_client=self.llm_client, |
| | memory_agent=self.memory_agent |
| | ) |
| | brief = await outreach_agent.create_valorization_brief( |
| | patent_analysis, |
| | market_analysis, |
| | matches |
| | ) |
| | state["agent_outputs"]["brief"] = brief.model_dump() |
| | state["final_output"] = brief.content |
| | logger.success(f"✅ Brief created: {brief.pdf_path}") |
| |
|
| | |
| | state["agent_outputs"]["executor"] = { |
| | "result": f"Patent Wake-Up workflow completed successfully", |
| | "patent_title": patent_analysis.title, |
| | "opportunities_found": len(market_analysis.opportunities), |
| | "matches_found": len(matches), |
| | "brief_path": brief.pdf_path, |
| | "agents_used": ["DocumentAnalysisAgent", "MarketAnalysisAgent", |
| | "MatchmakingAgent", "OutreachAgent"], |
| | } |
| |
|
| | logger.success("✅ Patent Wake-Up pipeline completed successfully!") |
| |
|
| | except Exception as e: |
| | logger.error(f"Patent Wake-Up pipeline failed: {e}") |
| | state["agent_outputs"]["executor"] = { |
| | "result": f"Pipeline failed: {str(e)}", |
| | "error": str(e), |
| | "agents_used": [], |
| | } |
| | state["final_output"] = f"Error: {str(e)}" |
| |
|
| | return state |
| |
|
| | async def _critic_node(self, state: AgentState) -> AgentState: |
| | logger.info(f"CRITIC node validating output") |
| | state["status"] = TaskStatus.VALIDATING |
| | state["current_agent"] = "CriticAgent" |
| |
|
| | if self.critic_agent: |
| | from ..agents.base_agent import Task |
| | task = Task( |
| | id=state["task_id"], |
| | description=state["task_description"], |
| | metadata={ |
| | "output_to_validate": state["final_output"], |
| | "output_type": self._get_output_type(state["scenario"]) |
| | } |
| | ) |
| | result_task = await self.critic_agent.process_task(task) |
| |
|
| | if result_task.status == "completed": |
| | validation = result_task.result |
| | state["validation_score"] = validation.overall_score |
| | state["validation_feedback"] = self.critic_agent.get_feedback_for_iteration(validation) |
| | state["validation_issues"] = validation.issues |
| | state["validation_suggestions"] = validation.suggestions |
| |
|
| | feedback_msg = AIMessage( |
| | content=f"Validation score: {validation.overall_score:.2f}\n{state['validation_feedback']}" |
| | ) |
| | state["messages"].append(feedback_msg) |
| | else: |
| | llm = self.llm_client.get_llm(complexity="analysis") |
| | validation_prompt = HumanMessage( |
| | content=f"Validate the following output:\n\n{state['final_output']}\n\nProvide a quality score (0.0-1.0) and feedback." |
| | ) |
| |
|
| | response = await llm.ainvoke([validation_prompt]) |
| | state["messages"].append(response) |
| |
|
| | state["validation_score"] = 0.90 |
| | state["validation_feedback"] = response.content |
| | state["validation_issues"] = [] |
| | state["validation_suggestions"] = [] |
| |
|
| | logger.info(f"Validation completed: score={state['validation_score']:.2f}") |
| | return state |
| |
|
| | async def _refine_node(self, state: AgentState) -> AgentState: |
| | logger.info(f"REFINE node preparing for iteration {state['iteration_count'] + 1}") |
| | state["status"] = TaskStatus.REFINING |
| | state["current_agent"] = "Refiner" |
| | state["iteration_count"] += 1 |
| |
|
| | refine_msg = HumanMessage( |
| | content=f"Iteration {state['iteration_count']}: Address the following issues:\n{state['validation_feedback']}" |
| | ) |
| | state["messages"].append(refine_msg) |
| |
|
| | state["intermediate_results"].append({ |
| | "iteration": state["iteration_count"] - 1, |
| | "output": state["final_output"], |
| | "score": state["validation_score"], |
| | "feedback": state["validation_feedback"], |
| | }) |
| |
|
| | logger.info(f"Refinement prepared for iteration {state['iteration_count']}") |
| | return state |
| |
|
| | async def _finish_node(self, state: AgentState) -> AgentState: |
| | logger.info(f"FINISH node completing workflow") |
| | state["status"] = TaskStatus.COMPLETED |
| | state["current_agent"] = None |
| | state["success"] = True |
| | state["end_time"] = datetime.now() |
| | state["execution_time_seconds"] = (state["end_time"] - state["start_time"]).total_seconds() |
| |
|
| | |
| | if self.memory_agent and state.get("validation_score", 0) >= 0.75: |
| | try: |
| | logger.info("Storing episode in memory...") |
| | await self.memory_agent.store_episode( |
| | task_id=state["task_id"], |
| | task_description=state["task_description"], |
| | scenario=state["scenario"], |
| | workflow_steps=state.get("subtasks", []), |
| | outcome={ |
| | "final_output": state["final_output"], |
| | "validation_score": state.get("validation_score", 0), |
| | "success": state["success"], |
| | "tools_used": state.get("agent_outputs", {}).get("executor", {}).get("tools_called", []), |
| | }, |
| | quality_score=state.get("validation_score", 0), |
| | execution_time=state["execution_time_seconds"], |
| | iterations_used=state.get("iteration_count", 0), |
| | ) |
| | logger.info(f"Episode stored: {state['task_id']}") |
| | except Exception as e: |
| | logger.warning(f"Failed to store episode: {e}") |
| |
|
| | completion_msg = AIMessage( |
| | content=f"Workflow completed successfully in {state['execution_time_seconds']:.2f}s" |
| | ) |
| | state["messages"].append(completion_msg) |
| |
|
| | logger.info(f"Workflow completed: {state['task_id']}") |
| | return state |
| |
|
| | def _should_refine(self, state: AgentState) -> Literal["refine", "finish"]: |
| | score = state.get("validation_score", 0.0) |
| | iterations = state.get("iteration_count", 0) |
| |
|
| | if score >= self.quality_threshold: |
| | logger.info(f"Quality threshold met ({score:.2f} >= {self.quality_threshold}), finishing") |
| | return "finish" |
| |
|
| | if iterations >= state.get("max_iterations", self.max_iterations): |
| | logger.warning(f"Max iterations reached ({iterations}), finishing anyway") |
| | return "finish" |
| |
|
| | logger.info(f"Refining (score={score:.2f}, iteration={iterations})") |
| | return "refine" |
| |
|
| | def _get_scenario_agents(self, scenario: ScenarioType) -> list: |
| | scenario_map = { |
| | ScenarioType.PATENT_WAKEUP: ["DocumentAnalysisAgent", "MarketAnalysisAgent", "MatchmakingAgent", "OutreachAgent"], |
| | ScenarioType.AGREEMENT_SAFETY: ["LegalAnalysisAgent", "ComplianceAgent", "RiskAssessmentAgent", "RecommendationAgent"], |
| | ScenarioType.PARTNER_MATCHING: ["ProfilingAgent", "SemanticMatchingAgent", "NetworkAnalysisAgent", "ConnectionFacilitatorAgent"], |
| | ScenarioType.GENERAL: ["ExecutorAgent"] |
| | } |
| | return scenario_map.get(scenario, ["ExecutorAgent"]) |
| |
|
| | def _get_output_type(self, scenario: ScenarioType) -> str: |
| | type_map = { |
| | ScenarioType.PATENT_WAKEUP: "patent_analysis", |
| | ScenarioType.AGREEMENT_SAFETY: "legal_review", |
| | ScenarioType.PARTNER_MATCHING: "stakeholder_matching", |
| | ScenarioType.GENERAL: "general" |
| | } |
| | return type_map.get(scenario, "general") |
| |
|
| | async def run( |
| | self, |
| | task_description: str, |
| | scenario: ScenarioType = ScenarioType.GENERAL, |
| | task_id: Optional[str] = None, |
| | input_data: Optional[Dict[str, Any]] = None, |
| | config: Optional[Dict[str, Any]] = None, |
| | ) -> WorkflowOutput: |
| | if task_id is None: |
| | task_id = f"task_{hash(task_description) % 100000}" |
| |
|
| | initial_state = create_initial_state( |
| | task_id=task_id, |
| | task_description=task_description, |
| | scenario=scenario, |
| | max_iterations=self.max_iterations, |
| | input_data=input_data, |
| | ) |
| |
|
| | logger.info(f"Starting workflow for task: {task_id}") |
| |
|
| | try: |
| | final_state = await self.app.ainvoke( |
| | initial_state, |
| | config=config or {"configurable": {"thread_id": task_id}} |
| | ) |
| |
|
| | output = state_to_output(final_state) |
| | logger.info(f"Workflow completed successfully: {task_id}") |
| | return output |
| |
|
| | except Exception as e: |
| | logger.error(f"Workflow failed: {e}") |
| | initial_state["status"] = TaskStatus.FAILED |
| | initial_state["success"] = False |
| | initial_state["error"] = str(e) |
| | initial_state["end_time"] = datetime.now() |
| | return state_to_output(initial_state) |
| |
|
| | async def stream( |
| | self, |
| | task_description: str, |
| | scenario: ScenarioType = ScenarioType.GENERAL, |
| | task_id: Optional[str] = None, |
| | config: Optional[Dict[str, Any]] = None, |
| | ): |
| | if task_id is None: |
| | task_id = f"task_{hash(task_description) % 100000}" |
| |
|
| | initial_state = create_initial_state( |
| | task_id=task_id, |
| | task_description=task_description, |
| | scenario=scenario, |
| | max_iterations=self.max_iterations, |
| | ) |
| |
|
| | async for event in self.app.astream( |
| | initial_state, |
| | config=config or {"configurable": {"thread_id": task_id}} |
| | ): |
| | yield event |
| |
|
| |
|
| | def create_workflow( |
| | llm_client: LangChainOllamaClient, |
| | planner_agent: Optional[Any] = None, |
| | critic_agent: Optional[Any] = None, |
| | memory_agent: Optional[Any] = None, |
| | vision_ocr_agent: Optional[Any] = None, |
| | quality_threshold: float = 0.85, |
| | max_iterations: int = 3, |
| | ) -> SparknetWorkflow: |
| | return SparknetWorkflow( |
| | llm_client=llm_client, |
| | planner_agent=planner_agent, |
| | critic_agent=critic_agent, |
| | memory_agent=memory_agent, |
| | vision_ocr_agent=vision_ocr_agent, |
| | quality_threshold=quality_threshold, |
| | max_iterations=max_iterations, |
| | ) |
| |
|