| import json |
| from uuid import uuid4 |
| from datetime import datetime |
| from typing import Union, List, Dict, Any, Optional, Tuple |
|
|
| from pydantic import Field |
|
|
| from .long_term_memory import LongTermMemory |
| from ..rag.schema import Query |
| from ..core.logging import logger |
| from ..core.module import BaseModule |
| from ..core.message import Message, MessageType |
| from ..models.base_model import BaseLLM |
| from ..prompts.memory.manager import MANAGER_PROMPT |
|
|
|
|
| class MemoryManager(BaseModule): |
| """ |
| The Memory Manager organizes and manages LongTermMemory data at a higher level. |
| It retrieves data, processes it with optional LLM-based action inference, and stores new or updated data. |
| It creates Message objects for agent use, combining user prompts with memory context. |
| |
| Attributes: |
| memory (LongTermMemory): The LongTermMemory instance for storing and retrieving messages. |
| llm (Optional[BaseLLM]): LLM for deciding memory operations. |
| use_llm_management (bool): Toggle LLM-based memory management. |
| """ |
| memory: LongTermMemory = Field(..., description="Long-term memory instance") |
| llm: Optional[BaseLLM] = Field(default=None, description="LLM for deciding memory operations") |
| use_llm_management: bool = Field(default=True, description="Toggle LLM-based memory management") |
|
|
| def init_module(self): |
| pass |
| |
| async def _prompt_llm_for_memory_operation(self, input_data: Dict[str, Any], relevant_data: List[Tuple[Message, str]] = None) -> Dict[str, Any]: |
| """Prompt the LLM to decide memory operation (add, update, delete) and return structured JSON.""" |
| if not self.llm or not self.use_llm_management: |
| return input_data |
|
|
| relevant_data_str = '\n'.join([json.dumps({"message": msg.to_dict(), "memory_id": mid}) for msg, mid in (relevant_data or [])]) |
| prompt = MANAGER_PROMPT.replace("<<INPUT_DATA>>", json.dumps(input_data, ensure_ascii=False)).replace("<<RELEVANT_DATA>>", relevant_data_str) |
|
|
| logger.info(f"Memory Manager LLM Prompt: \n\n{prompt}") |
| try: |
| response = self.llm.generate(prompt=prompt) |
| parsed = json.loads(response.content.replace("```json", "").replace("```", "").strip()) |
| if parsed["action"] not in ["add", "update", "delete"]: |
| raise ValueError(f"Invalid action: {parsed['action']}") |
| if parsed["action"] in ["update", "delete"] and not parsed.get("memory_id"): |
| raise ValueError(f"memory_id required for {parsed['action']}") |
| if parsed["action"] in ["add", "update"] and not parsed.get("message"): |
| raise ValueError(f"message required for {parsed['action']}") |
| return parsed |
| except Exception as e: |
| logger.error(f"LLM failed to generate valid memory operation: {str(e)}") |
| return input_data |
|
|
| async def handle_memory( |
| self, |
| action: str, |
| user_prompt: Optional[Union[str, Message, Query]] = None, |
| data: Optional[Union[Message, str, List[Union[Message, str]], Dict, List[Tuple[str, Union[Message, str]]]]] = None, |
| top_k: Optional[int] = None, |
| metadata_filters: Optional[Dict] = None |
| ) -> Union[List[str], List[Tuple[Message, str]], List[bool], Message, None]: |
| """ |
| Handle memory operations based on the specified action, with optional LLM inference. |
| |
| Args: |
| action (str): The memory operation ("add", "search", "get", "update", "delete", "clear", "save", "load", "create_message"). |
| user_prompt (Optional[Union[str, Message, Query]]): The user prompt or query to process with memory data. |
| data (Optional): Input data for the operation (e.g., messages, memory IDs, updates). |
| top_k (Optional[int]): Number of results to retrieve for search operations. |
| metadata_filters (Optional[Dict]): Filters for memory retrieval. |
| |
| Returns: |
| Union[List[str], List[Tuple[Message, str]], List[bool], Message, None]: Result of the operation. |
| """ |
| if action not in ["add", "search", "get", "update", "delete", "clear", "save", "load", "create_message"]: |
| logger.error(f"Invalid action: {action}") |
| raise ValueError(f"Invalid action: {action}") |
|
|
| if action == "add": |
| if not data: |
| logger.warning("No data provided for add operation") |
| return [] |
| if not isinstance(data, list): |
| data = [data] |
| messages = [ |
| Message( |
| content=msg if isinstance(msg, str) else msg.content, |
| msg_type=MessageType.REQUEST if isinstance(msg, str) else msg.msg_type, |
| timestamp=datetime.now().isoformat() if isinstance(msg, str) else msg.timestamp, |
| agent="user" if isinstance(msg, str) else msg.agent, |
| message_id=str(uuid4()) if isinstance(msg, str) or not msg.message_id else msg.message_id |
| ) for msg in data |
| ] |
| input_data = [ |
| { |
| "action": "add", |
| "memory_id": str(uuid4()), |
| "message": msg.to_dict() |
| } for msg in messages |
| ] |
| if self.use_llm_management and self.llm: |
| llm_decisions = await self._prompt_llm_for_memory_operation(input_data) |
| final_messages = [] |
| final_memory_ids = [] |
| for decision, msg in zip(llm_decisions, messages): |
| if decision.get("action") != "add": |
| logger.info(f"LLM rejected adding memory: {decision}") |
| continue |
| final_messages.append(msg) |
| final_memory_ids.append(decision.get("memory_id")) |
| return self.memory.add(final_messages) if final_messages else [] |
| return self.memory.add(messages) |
|
|
| elif action == "search": |
| if not user_prompt: |
| logger.warning("No user_prompt provided for search operation") |
| return [] |
| if isinstance(user_prompt, Message): |
| user_prompt = user_prompt.content |
| return await self.memory.search_async(user_prompt, top_k, metadata_filters) |
|
|
| elif action == "get": |
| if not data: |
| logger.warning("No memory IDs provided for get operation") |
| return [] |
| return await self.memory.get(data, return_chunk=False) |
|
|
| elif action == "update": |
| if not data: |
| logger.warning("No updates provided for update operation") |
| return [] |
| updates = [ |
| (mid, Message( |
| content=msg if isinstance(msg, str) else msg.content, |
| msg_type=MessageType.REQUEST if isinstance(msg, str) else msg.msg_type, |
| timestamp=datetime.now().isoformat(), |
| agent="user" if isinstance(msg, str) else msg.agent, |
| message_id=str(uuid4()) if isinstance(msg, str) or not msg.message_id else msg.message_id |
| )) for mid, msg in (data if isinstance(data, list) else [data]) |
| ] |
| input_data = [ |
| { |
| "action": "update", |
| "memory_id": mid, |
| "message": msg.to_dict() |
| } for mid, msg in updates |
| ] |
| if self.use_llm_management and self.llm: |
| existing_memories = await self.memory.get([mid for mid, _ in updates]) |
| llm_decisions = await self._prompt_llm_for_memory_operation(input_data, relevant_data=existing_memories) |
| final_updates = [] |
| for decision, (mid, msg) in zip(llm_decisions, updates): |
| if decision.get("action") != "update": |
| logger.info(f"LLM rejected updating memory {mid}: {decision}") |
| continue |
| final_updates.append((mid, msg)) |
| return self.memory.update(final_updates) if final_updates else [False] * len(updates) |
| return self.memory.update(updates) |
|
|
| elif action == "delete": |
| if not data: |
| logger.warning("No memory IDs provided for delete operation") |
| return [] |
| memory_ids = data if isinstance(data, list) else [data] |
| if self.use_llm_management and self.llm: |
| input_data = [{"action": "delete", "memory_id": mid} for mid in memory_ids] |
| existing_memories = await self.memory.get(memory_ids) |
| llm_decisions = await self._prompt_llm_for_memory_operation(input_data, relevant_data=existing_memories) |
| valid_memory_ids = [decision.get("memory_id") for decision in llm_decisions if decision.get("action") == "delete"] |
| return self.memory.delete(valid_memory_ids) if valid_memory_ids else [False] * len(memory_ids) |
| return self.memory.delete(memory_ids) |
|
|
| elif action == "clear": |
| self.memory.clear() |
| return None |
|
|
| elif action == "save": |
| self.memory.save(data) |
| return None |
|
|
| elif action == "load": |
| return self.memory.load(data) |
|
|
| elif action == "create_message": |
| if not user_prompt: |
| logger.warning("No user_prompt provided for create_message operation") |
| return None |
| if isinstance(user_prompt, Query): |
| user_prompt = user_prompt.query_str |
| elif isinstance(user_prompt, Message): |
| user_prompt = user_prompt.content |
| memories = await self.memory.search_async(user_prompt, top_k, metadata_filters) |
| context = "\n".join([msg.content for msg, _ in memories]) |
| memory_ids = [mid for _, mid in memories] |
| combined_content = f"User Prompt: {user_prompt}\nContext: {context}" if context else user_prompt |
| return Message( |
| content=combined_content, |
| msg_type=MessageType.REQUEST, |
| timestamp=datetime.now().isoformat(), |
| agent="user", |
| memory_ids=memory_ids |
| ) |
|
|
| async def create_conversation_message( |
| self, |
| user_prompt: Union[str, Message], |
| conversation_id: str, |
| top_k: Optional[int] = None, |
| metadata_filters: Optional[Dict] = None |
| ) -> Message: |
| """ |
| Create a Message combining user prompt with conversation history and relevant memories. |
| |
| Args: |
| user_prompt (Union[str, Message]): The user's input prompt or message. |
| conversation_id (str): ID of the conversation thread. |
| top_k (Optional[int]): Number of results to retrieve. |
| metadata_filters (Optional[Dict]): Filters for memory retrieval. |
| |
| Returns: |
| Message: A new Message object with user prompt, history, and memory context. |
| """ |
| if isinstance(user_prompt, Message): |
| user_prompt = user_prompt.content |
|
|
| |
| history_filter = {"corpus_id": conversation_id} |
| if metadata_filters: |
| history_filter.update(metadata_filters) |
| history_results = await self.memory.search_async( |
| query=user_prompt, n=top_k or 10, metadata_filters=history_filter |
| ) |
| history = "\n".join([f"{msg.content}" for msg, _ in history_results]) |
|
|
| |
| combined_content = ( |
| f"User Prompt: \n{user_prompt}\n" |
| f"Conversation History: \n\n{history or 'No history available'}\n" |
| ) |
| return Message( |
| content=combined_content, |
| msg_type=MessageType.REQUEST, |
| timestamp=datetime.now().isoformat(), |
| agent="user", |
| memory_ids=user_prompt.message_id if isinstance(user_prompt, Message) else str(uuid4()) |
| ) |