| import inspect |
| import asyncio |
| from copy import deepcopy |
| from pydantic import Field, create_model |
| from typing import Optional, List |
| from ..core.logging import logger |
| from ..core.module import BaseModule |
| from ..core.message import Message, MessageType |
| from ..core.module_utils import generate_id |
| from ..models.base_model import BaseLLM |
| from ..agents.agent import Agent |
| from ..agents.agent_manager import AgentManager, AgentState |
| from ..storages.base import StorageHandler |
| from .environment import Environment, TrajectoryState |
| from .workflow_manager import WorkFlowManager, NextAction |
| from .workflow_graph import WorkFlowNode, WorkFlowGraph |
| from .action_graph import ActionGraph |
| from ..hitl import HITLManager, HITLBaseAgent |
| from ..utils.utils import generate_dynamic_class_name |
| from ..actions import ActionInput, ActionOutput |
|
|
| class WorkFlow(BaseModule): |
|
|
| graph: WorkFlowGraph |
| llm: Optional[BaseLLM] = None |
| agent_manager: AgentManager = Field(default=None, description="Responsible for managing agents") |
| workflow_manager: WorkFlowManager = Field(default=None, description="Responsible for task and action scheduling for workflow execution") |
| environment: Environment = Field(default_factory=Environment) |
| storage_handler: StorageHandler = None |
| workflow_id: str = Field(default_factory=generate_id) |
| version: int = 0 |
| max_execution_steps: int = Field(default=5, description="The maximum number of steps to complete a subtask (node) in the workflow") |
| hitl_manager: HITLManager = Field(default=None, description="Responsible for HITL work management") |
|
|
| def init_module(self): |
| if self.workflow_manager is None: |
| if self.llm is None: |
| raise ValueError("Must provide `llm` when `workflow_manager` is None") |
| self.workflow_manager = WorkFlowManager(llm=self.llm) |
| if self.agent_manager is None: |
| logger.warning("agent_manager is NoneType when initializing a WorkFlow instance") |
|
|
| def execute(self, inputs: dict = {}, **kwargs) -> str: |
| """ |
| Synchronous wrapper for async_execute. Creates a new event loop and runs the async method. |
| |
| Args: |
| inputs: Dictionary of inputs for workflow execution |
| **kwargs (Any): Additional keyword arguments |
| |
| Returns: |
| str: The output of the workflow execution |
| """ |
| loop = asyncio.new_event_loop() |
| asyncio.set_event_loop(loop) |
| try: |
| return loop.run_until_complete(self.async_execute(inputs, **kwargs)) |
| finally: |
| loop.close() |
|
|
| async def async_execute(self, inputs: dict = {}, **kwargs) -> str: |
| """ |
| Asynchronously execute the workflow. |
| |
| Args: |
| inputs: Dictionary of inputs for workflow execution |
| **kwargs (Any): Additional keyword arguments |
| |
| Returns: |
| str: The output of the workflow execution |
| """ |
| goal = self.graph.goal |
| |
| inputs = self._prepare_inputs(inputs) |
|
|
| |
| if hasattr(self, "hitl_manager") and (self.hitl_manager is not None): |
| self._prepare_hitl() |
|
|
| |
| self._validate_workflow_structure(inputs=inputs, **kwargs) |
| inp_message = Message(content=inputs, msg_type=MessageType.INPUT, wf_goal=goal) |
| self.environment.update(message=inp_message, state=TrajectoryState.COMPLETED) |
|
|
| failed = False |
| error_message = None |
| while not self.graph.is_complete and not failed: |
| try: |
| task: WorkFlowNode = await self.get_next_task() |
| if task is None: |
| break |
| logger.info(f"Executing subtask: {task.name}") |
| await self.execute_task(task=task) |
| except Exception as e: |
| failed = True |
| error_message = Message( |
| content=f"An Error occurs when executing the workflow: {e}", |
| msg_type=MessageType.ERROR, |
| wf_goal=goal |
| ) |
| self.environment.update(message=error_message, state=TrajectoryState.FAILED, error=str(e)) |
| |
| if failed: |
| logger.error(error_message.content) |
| return "Workflow Execution Failed" |
| |
| logger.info("Extracting WorkFlow Output ...") |
| output: str = await self.workflow_manager.extract_output(graph=self.graph, env=self.environment) |
| return output |
| |
| def _prepare_inputs(self, inputs: dict) -> dict: |
| """ |
| Prepare the inputs for the workflow execution. Mainly determine whether the goal should be added to the inputs. |
| """ |
| initial_node_names = self.graph.find_initial_nodes() |
| initial_node_required_inputs = set() |
| for initial_node_name in initial_node_names: |
| initial_node = self.graph.get_node(initial_node_name) |
| if initial_node.inputs: |
| initial_node_required_inputs.update([inp.name for inp in initial_node.inputs if inp.required]) |
| if "goal" in initial_node_required_inputs and "goal" not in inputs: |
| inputs.update({"goal": self.graph.goal}) |
| |
| return inputs |
| |
| async def get_next_task(self) -> WorkFlowNode: |
| task_execution_history = " -> ".join(self.environment.task_execution_history) |
| if not task_execution_history: |
| task_execution_history = "None" |
| logger.info(f"Task Execution Trajectory: {task_execution_history}. Scheduling next subtask ...") |
| task: WorkFlowNode = await self.workflow_manager.schedule_next_task(graph=self.graph, env=self.environment) |
| logger.info(f"The next subtask to be executed is: {task.name}") |
| return task |
| |
| async def execute_task(self, task: WorkFlowNode): |
| """ |
| Asynchronously execute a workflow task. |
| |
| Args: |
| task: The workflow node to execute |
| """ |
| last_executed_task = self.environment.get_last_executed_task() |
| self.graph.step(source_node=last_executed_task, target_node=task) |
| next_action: NextAction = await self.workflow_manager.schedule_next_action( |
| goal=self.graph.goal, |
| task=task, |
| agent_manager=self.agent_manager, |
| env=self.environment |
| ) |
| if next_action.action_graph is not None: |
| await self._async_execute_task_by_action_graph(task=task, next_action=next_action) |
| else: |
| await self._async_execute_task_by_agents(task=task, next_action=next_action) |
| self.graph.completed(node=task) |
|
|
| async def _async_execute_task_by_action_graph(self, task: WorkFlowNode, next_action: NextAction): |
| """ |
| Asynchronously execute a task using an action graph. |
| |
| Args: |
| task: The workflow node to execute |
| next_action: The next action to perform with its action graph |
| """ |
| action_graph: ActionGraph = next_action.action_graph |
| async_execute_source = inspect.getsource(action_graph.async_execute) |
| if "NotImplementedError" in async_execute_source: |
| execute_function = action_graph.execute |
| async_execute = False |
| else: |
| execute_function = action_graph.async_execute |
| async_execute = True |
| |
| execute_signature = inspect.signature(execute_function) |
| execute_params = {} |
| action_input_data = self.environment.get_all_execution_data() |
| for param_name, param_obj in execute_signature.parameters.items(): |
| if param_name in ["self", "args", "kwargs"]: |
| continue |
| |
| if param_name in action_input_data: |
| execute_params[param_name] = action_input_data[param_name] |
| elif param_obj.default is not param_obj.empty: |
| execute_params[param_name] = param_obj.default |
| else: |
| execute_params[param_name] = None |
| |
| |
| |
| if async_execute: |
| action_graph_output: dict = await action_graph.async_execute(**execute_params) |
| else: |
| action_graph_output: dict = action_graph.execute(**execute_params) |
|
|
| message = Message( |
| content=action_graph_output, action=action_graph.name, msg_type=MessageType.RESPONSE, |
| wf_goal=self.graph.goal, wf_task=task.name, wf_task_desc=task.description |
| ) |
| self.environment.update(message=message, state=TrajectoryState.COMPLETED) |
| |
| async def _async_execute_task_by_agents(self, task: WorkFlowNode, next_action: NextAction): |
| """ |
| Asynchronously execute a task using agents. |
| |
| Args: |
| task: The workflow node to execute |
| next_action: The next action to perform using agents |
| """ |
| num_execution = 0 |
| while next_action: |
| if num_execution >= self.max_execution_steps: |
| raise ValueError( |
| f"Maximum number of steps ({self.max_execution_steps}) reached when executing {task.name}. " |
| "Please check the workflow structure (e.g., inputs and outputs of the nodes and the agents) " |
| "or increase the `max_execution_steps` parameter." |
| ) |
| agent: Agent = self.agent_manager.get_agent(agent_name=next_action.agent) |
| if not self.agent_manager.wait_for_agent_available(agent_name=agent.name, timeout=300): |
| raise TimeoutError(f"Timeout waiting for agent {agent.name} to become available") |
| self.agent_manager.set_agent_state(agent_name=next_action.agent, new_state=AgentState.RUNNING) |
| try: |
| |
| |
| |
| |
| |
| |
| |
| |
| message = await self._async_execute_action(task=task, agent=agent, next_action=next_action) |
| self.environment.update(message=message, state=TrajectoryState.COMPLETED) |
| finally: |
| self.agent_manager.set_agent_state(agent_name=next_action.agent, new_state=AgentState.AVAILABLE) |
| if self.is_task_completed(task=task): |
| break |
| next_action: NextAction = await self.workflow_manager.schedule_next_action( |
| goal=self.graph.goal, |
| task=task, |
| agent_manager=self.agent_manager, |
| env=self.environment |
| ) |
| num_execution += 1 |
|
|
| async def _async_execute_action(self, task: WorkFlowNode, agent: Agent, next_action: NextAction) -> Message: |
| """ |
| Asynchronously execute an action using an agent. |
| """ |
| action_name = next_action.action |
| all_execution_data = self.environment.get_all_execution_data() |
|
|
| |
| if hasattr(self, "hitl_manager") and (self.hitl_manager is not None): |
| hitl_manager = self.hitl_manager |
| else: |
| hitl_manager = None |
|
|
| action_inputs_format = agent.get_action(action_name).inputs_format |
| action_input_data = {} |
| if action_inputs_format: |
| for input_name in action_inputs_format.get_attrs(): |
| if input_name in all_execution_data: |
| action_input_data[input_name] = all_execution_data[input_name] |
| action_required_input_names = action_inputs_format.get_required_input_names() |
| if not all(inp in action_input_data for inp in action_required_input_names): |
| |
| predecessors = self.graph.get_node_predecessors(node=task) |
| predecessors_messages = self.environment.get_task_messages( |
| tasks=predecessors + [task.name], include_inputs=True |
| ) |
| predecessors_messages = [ |
| message for message in predecessors_messages |
| if message.msg_type in [MessageType.INPUT, MessageType.RESPONSE] |
| ] |
| message, extracted_data = await agent.async_execute( |
| action_name=action_name, |
| msgs=predecessors_messages, |
| return_msg_type=MessageType.RESPONSE, |
| return_action_input_data=True, |
| wf_goal=self.graph.goal, |
| wf_task=task.name, |
| wf_task_desc=task.description, |
| hitl_manager=hitl_manager |
| ) |
| self.environment.update_execution_data_from_context_extraction(extracted_data) |
| return message |
| |
| message = await agent.async_execute( |
| action_name=action_name, |
| action_input_data=action_input_data, |
| return_msg_type=MessageType.RESPONSE, |
| wf_goal=self.graph.goal, |
| wf_task=task.name, |
| wf_task_desc=task.description, |
| hitl_manager=hitl_manager |
| ) |
| return message |
| |
| def is_task_completed(self, task: WorkFlowNode) -> bool: |
| task_outputs = [output.name for output in task.outputs] |
| current_execution_data = self.environment.get_all_execution_data() |
| return all(output in current_execution_data for output in task_outputs) |
| |
| def _validate_workflow_structure(self, inputs: dict, **kwargs): |
|
|
| |
| input_names = set(inputs.keys()) |
| for node in self.graph.nodes: |
| node_input_names = deepcopy(input_names) |
| is_initial_node = True |
| for name in self.graph.get_node_predecessors(node): |
| is_initial_node = False |
| predecessor = self.graph.get_node(name) |
| node_input_names.update(predecessor.get_output_names()) |
| node_required_input_names = set(node.get_input_names(required=True)) |
| if not all(input_name in node_input_names for input_name in node_required_input_names): |
| missing_required_inputs = node_required_input_names - node_input_names |
| if is_initial_node: |
| raise ValueError( |
| f"The initial node '{node.name}' is missing required inputs: {list(missing_required_inputs)}. " |
| "You should provide these inputs by specifying the `inputs={'input_name': 'input_value'}` parameter in the `execute` method, " |
| "or return the valid inputs in the `collate_func` when using `Evaluator`." |
| ) |
| else: |
| raise ValueError( |
| f"The node '{node.name}' is missing required inputs: {list(missing_required_inputs)}. " |
| f"You may need to check the `inputs` and `outputs` of the nodes to ensure that all the required inputs of node '{node.name}' are provided " |
| f"by either its predecessors or the `inputs` parameter in the `execute` method." |
| ) |
| |
| for node in self.graph.nodes: |
| for agent in node.agents: |
| if hasattr(agent, "forbidden_in_workflow") and (agent.forbidden_in_workflow): |
| raise ValueError(f"The Agent of class {agent.__class__} is forbidden to be used in the workflow.") |
|
|
| def _prepare_single_hitl_agent(self, agent: Agent, node: WorkFlowNode): |
| """ |
| add complementary information and settings which need dynamically setting up to a single hitl agent |
| For example, the `inputs_format` attribute, this needs a dynamical setting up. |
| Up to Now, we only consider a HITL agent must be the only agent in its WorkFlowNode instance, this condition may be changed in the future |
| Args: |
| agent (Agent): a single HITL Agent instance |
| node (WorkFlowNode): a single WorkFlowNode instane which contains exactly the agent of previous param. |
| """ |
| predecessors: List[str] = self.graph.get_node_predecessors(node) |
| hitl_action = None |
| for action in agent.actions: |
| if (action.inputs_format) and (action.outputs_format): |
| continue |
| elif hasattr(action, "interaction_type"): |
| hitl_action = action |
| break |
| if not hitl_action: |
| raise ValueError(f"Can not find a HITL action in agent {agent}") |
|
|
| hitl_inputs_data_fields = {} |
|
|
| |
| for predecessor in predecessors: |
| predecessor_node = self.graph.get_node(predecessor) |
| for param in predecessor_node.outputs: |
| if param.required: |
| hitl_inputs_data_fields[param.name] = (str, Field(description=param.description)) |
| else: |
| hitl_inputs_data_fields[param.name] = (Optional[str], Field(description=param.description)) |
| inputs_format = create_model( |
| agent._get_unique_class_name( |
| generate_dynamic_class_name(hitl_action.class_name+" action_input") |
| ), |
| **(hitl_inputs_data_fields or {}), |
| __base__= ActionInput |
| ) |
| |
| successors: List[str] = self.graph.get_node_children(node) |
| hitl_outputs_data_fields = {} |
| if successors == []: |
| |
| raise ValueError("WorkFlowNode with a HITL Agent can not be set as the ending node.") |
| for successor in successors: |
| successor_node = self.graph.get_node(successor) |
| for param in successor_node.inputs: |
| if param.required: |
| hitl_outputs_data_fields[param.name] = (str, Field(description=param.description)) |
| else: |
| hitl_outputs_data_fields[param.name] = (Optional[str], Field(description=param.description)) |
| outputs_format = create_model( |
| agent._get_unique_class_name( |
| generate_dynamic_class_name(hitl_action.class_name+" action_output") |
| ), |
| **(hitl_outputs_data_fields or {}), |
| __base__=ActionOutput |
| ) |
| hitl_action.inputs_format = inputs_format |
| hitl_action.outputs_format = outputs_format |
|
|
| |
| if self.hitl_manager.hitl_input_output_mapping is None: |
| raise ValueError("hitl_input_output_mapping attribute missing in HITLManager instance.") |
| return |
|
|
| def _prepare_hitl(self): |
| """ |
| Prepare hitl settings before executing the WorkFlow |
| """ |
| if self.hitl_manager is None: |
| return |
| hitl_agents: List[Agent] = [] |
| node_with_hitl_agents = [] |
| for node in self.graph.nodes: |
| agents = node.agents |
| found_hitl_agent = False |
| for agent in agents: |
| |
| if isinstance(agent, dict): |
| agent = self.agent_manager.get_agent(self.agent_manager.get_agent_name(agent)) |
| elif isinstance(agent, str): |
| agent = self.agent_manager.get_agent(agent) |
| elif isinstance(agent, Agent): |
| pass |
| |
| if isinstance(agent, HITLBaseAgent): |
| found_hitl_agent = True |
| if agent not in hitl_agents: |
| hitl_agents.append(agent) |
| if found_hitl_agent: |
| node_with_hitl_agents.append(node) |
| found_hitl_agent = False |
|
|
| |
| if len(hitl_agents) != len(node_with_hitl_agents): |
| raise ValueError("Incorrect WorkFlowNode definition: A HITL Agent must be the only agent in its WorkFlowNode instance") |
|
|
| |
| for agent, node in zip(hitl_agents, node_with_hitl_agents): |
| self._prepare_single_hitl_agent(agent, node) |
|
|
| return |