| | """ |
| | Base Tool for SPARKNET |
| | Defines the interface for all tools that agents can use |
| | """ |
| |
|
| | from abc import ABC, abstractmethod |
| | from typing import Any, Dict, Optional |
| | from pydantic import BaseModel, Field |
| | from loguru import logger |
| | import json |
| |
|
| |
|
| | class ToolParameter(BaseModel): |
| | """Definition of a tool parameter.""" |
| | name: str = Field(..., description="Parameter name") |
| | type: str = Field(..., description="Parameter type (str, int, float, bool, list, dict)") |
| | description: str = Field(..., description="Parameter description") |
| | required: bool = Field(default=True, description="Whether parameter is required") |
| | default: Optional[Any] = Field(default=None, description="Default value if not required") |
| |
|
| |
|
| | class ToolResult(BaseModel): |
| | """Result from tool execution.""" |
| | success: bool = Field(..., description="Whether execution was successful") |
| | output: Any = Field(..., description="Tool output") |
| | error: Optional[str] = Field(default=None, description="Error message if failed") |
| | metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata") |
| |
|
| |
|
| | class BaseTool(ABC): |
| | """Base class for all tools.""" |
| |
|
| | def __init__(self, name: str, description: str): |
| | """ |
| | Initialize tool. |
| | |
| | Args: |
| | name: Tool name |
| | description: Tool description |
| | """ |
| | self.name = name |
| | self.description = description |
| | self.parameters: list[ToolParameter] = [] |
| |
|
| | @abstractmethod |
| | async def execute(self, **kwargs) -> ToolResult: |
| | """ |
| | Execute the tool with given parameters. |
| | |
| | Args: |
| | **kwargs: Tool parameters |
| | |
| | Returns: |
| | ToolResult with execution results |
| | """ |
| | pass |
| |
|
| | def add_parameter( |
| | self, |
| | name: str, |
| | param_type: str, |
| | description: str, |
| | required: bool = True, |
| | default: Optional[Any] = None, |
| | ): |
| | """ |
| | Add a parameter definition to the tool. |
| | |
| | Args: |
| | name: Parameter name |
| | param_type: Parameter type |
| | description: Parameter description |
| | required: Whether parameter is required |
| | default: Default value |
| | """ |
| | param = ToolParameter( |
| | name=name, |
| | type=param_type, |
| | description=description, |
| | required=required, |
| | default=default, |
| | ) |
| | self.parameters.append(param) |
| |
|
| | def validate_parameters(self, **kwargs) -> tuple[bool, Optional[str]]: |
| | """ |
| | Validate provided parameters against tool definition. |
| | |
| | Args: |
| | **kwargs: Provided parameters |
| | |
| | Returns: |
| | Tuple of (is_valid, error_message) |
| | """ |
| | |
| | for param in self.parameters: |
| | if param.required and param.name not in kwargs: |
| | return False, f"Missing required parameter: {param.name}" |
| |
|
| | |
| | for param in self.parameters: |
| | if param.name in kwargs: |
| | value = kwargs[param.name] |
| | expected_type = param.type |
| |
|
| | |
| | type_map = { |
| | "str": str, |
| | "int": int, |
| | "float": float, |
| | "bool": bool, |
| | "list": list, |
| | "dict": dict, |
| | } |
| |
|
| | if expected_type in type_map: |
| | if not isinstance(value, type_map[expected_type]): |
| | return False, f"Parameter {param.name} must be of type {expected_type}" |
| |
|
| | return True, None |
| |
|
| | def get_schema(self) -> Dict[str, Any]: |
| | """ |
| | Get tool schema for LLM function calling. |
| | |
| | Returns: |
| | Tool schema dictionary |
| | """ |
| | return { |
| | "name": self.name, |
| | "description": self.description, |
| | "parameters": { |
| | "type": "object", |
| | "properties": { |
| | param.name: { |
| | "type": param.type, |
| | "description": param.description, |
| | } |
| | for param in self.parameters |
| | }, |
| | "required": [param.name for param in self.parameters if param.required], |
| | }, |
| | } |
| |
|
| | async def safe_execute(self, **kwargs) -> ToolResult: |
| | """ |
| | Execute tool with parameter validation and error handling. |
| | |
| | Args: |
| | **kwargs: Tool parameters |
| | |
| | Returns: |
| | ToolResult with execution results |
| | """ |
| | |
| | is_valid, error_msg = self.validate_parameters(**kwargs) |
| | if not is_valid: |
| | logger.error(f"Tool {self.name} parameter validation failed: {error_msg}") |
| | return ToolResult(success=False, output=None, error=error_msg) |
| |
|
| | |
| | for param in self.parameters: |
| | if not param.required and param.name not in kwargs: |
| | kwargs[param.name] = param.default |
| |
|
| | |
| | try: |
| | logger.info(f"Executing tool: {self.name}") |
| | result = await self.execute(**kwargs) |
| | logger.info(f"Tool {self.name} executed successfully") |
| | return result |
| | except Exception as e: |
| | logger.error(f"Tool {self.name} execution failed: {e}") |
| | return ToolResult( |
| | success=False, |
| | output=None, |
| | error=str(e), |
| | ) |
| |
|
| | def __repr__(self) -> str: |
| | return f"<Tool: {self.name}>" |
| |
|
| |
|
| | class ToolRegistry: |
| | """Registry for managing available tools.""" |
| |
|
| | def __init__(self): |
| | """Initialize tool registry.""" |
| | self.tools: Dict[str, BaseTool] = {} |
| | logger.info("Tool registry initialized") |
| |
|
| | def register(self, tool: BaseTool): |
| | """ |
| | Register a tool. |
| | |
| | Args: |
| | tool: Tool instance to register |
| | """ |
| | self.tools[tool.name] = tool |
| | logger.info(f"Registered tool: {tool.name}") |
| |
|
| | def unregister(self, tool_name: str): |
| | """ |
| | Unregister a tool. |
| | |
| | Args: |
| | tool_name: Name of tool to unregister |
| | """ |
| | if tool_name in self.tools: |
| | del self.tools[tool_name] |
| | logger.info(f"Unregistered tool: {tool_name}") |
| |
|
| | def get_tool(self, tool_name: str) -> Optional[BaseTool]: |
| | """ |
| | Get a tool by name. |
| | |
| | Args: |
| | tool_name: Name of tool |
| | |
| | Returns: |
| | Tool instance or None |
| | """ |
| | return self.tools.get(tool_name) |
| |
|
| | def list_tools(self) -> list[str]: |
| | """ |
| | List all registered tools. |
| | |
| | Returns: |
| | List of tool names |
| | """ |
| | return list(self.tools.keys()) |
| |
|
| | def get_schemas(self) -> list[Dict[str, Any]]: |
| | """ |
| | Get schemas for all tools. |
| | |
| | Returns: |
| | List of tool schemas |
| | """ |
| | return [tool.get_schema() for tool in self.tools.values()] |
| |
|
| |
|
| | |
| | _tool_registry: Optional[ToolRegistry] = None |
| |
|
| |
|
| | def get_tool_registry() -> ToolRegistry: |
| | """Get or create the global tool registry.""" |
| | global _tool_registry |
| | if _tool_registry is None: |
| | _tool_registry = ToolRegistry() |
| | return _tool_registry |
| |
|