Spaces:
Running
Running
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from typing import Any, Callable, Optional, Sequence | |
| from agent_base.model_profiles import ModelProfile | |
| from agent_base.utils import safe_jsonable | |
| COMPACT_MEMORY_PREFIX = ( | |
| "Runtime memory summary from earlier turns.\n" | |
| "This is compressed context, not ground truth.\n" | |
| "The workspace files remain authoritative; re-read any file if exact details matter.\n\n" | |
| ) | |
| class CompactionOutcome: | |
| status: str | |
| compacted_messages: list[dict[str, Any]] | |
| summary_text: str = "" | |
| error: str = "" | |
| trigger_reason: str = "" | |
| prior_token_estimate: int = 0 | |
| new_token_estimate: int = 0 | |
| compacted_group_count: int = 0 | |
| kept_group_count: int = 0 | |
| existing_memory_text: str = "" | |
| summary_request: list[dict[str, Any]] | None = None | |
| summary_response: dict[str, Any] | None = None | |
| pre_messages: list[dict[str, Any]] | None = None | |
| post_messages: list[dict[str, Any]] | None = None | |
| def should_compact_messages( | |
| *, | |
| last_input_tokens: Optional[int], | |
| current_token_estimate: int, | |
| model_profile: ModelProfile, | |
| ) -> tuple[bool, str]: | |
| usage_hit = last_input_tokens is not None and int(last_input_tokens) >= model_profile.compact_trigger_tokens | |
| estimate_hit = current_token_estimate >= model_profile.compact_trigger_tokens | |
| if usage_hit and estimate_hit: | |
| return True, "usage+estimate" | |
| if usage_hit: | |
| return True, "usage" | |
| if estimate_hit: | |
| return True, "estimate" | |
| return False, "" | |
| def compact_messages( | |
| *, | |
| messages: Sequence[dict[str, Any]], | |
| original_prompt_text: str, | |
| model_name: str, | |
| model_profile: ModelProfile, | |
| llm_caller: Callable[..., dict[str, Any]], | |
| token_counter: Callable[[Sequence[dict[str, Any]]], int], | |
| runtime_deadline: Optional[float] = None, | |
| ) -> CompactionOutcome: | |
| safe_messages = [dict(message) for message in messages] | |
| if len(safe_messages) <= 2: | |
| return CompactionOutcome( | |
| status="error", | |
| compacted_messages=safe_messages, | |
| pre_messages=safe_messages, | |
| post_messages=safe_messages, | |
| error="context compaction requires at least one conversational turn beyond the initial prompt", | |
| ) | |
| prior_token_estimate = token_counter(safe_messages) | |
| existing_memory_text, eligible_messages = _split_existing_memory_messages(safe_messages[2:]) | |
| turn_groups = _turn_groups(eligible_messages) | |
| if not turn_groups: | |
| return CompactionOutcome( | |
| status="error", | |
| compacted_messages=safe_messages, | |
| prior_token_estimate=prior_token_estimate, | |
| existing_memory_text=existing_memory_text, | |
| pre_messages=safe_messages, | |
| post_messages=safe_messages, | |
| error="context compaction found no eligible conversational turns", | |
| ) | |
| compacted_groups, recent_groups = _split_turn_groups(turn_groups, model_profile) | |
| if not compacted_groups: | |
| return CompactionOutcome( | |
| status="error", | |
| compacted_messages=safe_messages, | |
| prior_token_estimate=prior_token_estimate, | |
| existing_memory_text=existing_memory_text, | |
| pre_messages=safe_messages, | |
| post_messages=safe_messages, | |
| error="context compaction did not find any older turns to summarize", | |
| ) | |
| history_text = _render_history_text(compacted_groups, model_profile) | |
| prior_memory_block = "" | |
| if existing_memory_text: | |
| prior_memory_block = ( | |
| "Previously compressed memory to preserve and refine:\n" | |
| f"{_truncate_summary_text(existing_memory_text, max_chars=max(1200, model_profile.context_window // 3))}\n\n" | |
| ) | |
| summary_request = [ | |
| { | |
| "role": "system", | |
| "content": ( | |
| "You compress older tool-using agent history into short working memory for continued execution. " | |
| "Return plain text only. Do not call tools. Do not invent facts." | |
| ), | |
| }, | |
| { | |
| "role": "user", | |
| "content": ( | |
| "Summarize the earlier conversation history for a tool-using agent.\n\n" | |
| f"Original task:\n{original_prompt_text}\n\n" | |
| "Write a concise working memory with these sections:\n" | |
| "- Goal\n" | |
| "- Constraints\n" | |
| "- Files and artifacts\n" | |
| "- Evidence and results\n" | |
| "- Open issues\n" | |
| "- Next useful actions\n\n" | |
| "Rules:\n" | |
| "- Prefer concrete file paths, numeric results, and grounded facts.\n" | |
| "- Mention uncertainty when details may need to be re-read from files.\n" | |
| "- Merge any prior compressed memory with the newer history below into one refreshed memory.\n" | |
| "- Deduplicate repeated sections and do not repeat earlier summaries verbatim.\n" | |
| "- The workspace remains authoritative.\n\n" | |
| f"{prior_memory_block}" | |
| f"Older history to compress:\n{history_text}" | |
| ), | |
| }, | |
| ] | |
| summary_reply = llm_caller( | |
| summary_request, | |
| runtime_deadline=runtime_deadline, | |
| max_output_tokens=model_profile.compact_summary_max_tokens, | |
| ) | |
| if not isinstance(summary_reply, dict) or summary_reply.get("status") != "ok": | |
| error = summary_reply.get("error", "context compaction summary call failed") if isinstance(summary_reply, dict) else str(summary_reply) | |
| return CompactionOutcome( | |
| status="error", | |
| compacted_messages=safe_messages, | |
| prior_token_estimate=prior_token_estimate, | |
| existing_memory_text=existing_memory_text, | |
| summary_request=summary_request, | |
| summary_response=safe_jsonable(summary_reply) if isinstance(summary_reply, dict) else {"status": "error", "error": error}, | |
| pre_messages=safe_messages, | |
| post_messages=safe_messages, | |
| error=error, | |
| compacted_group_count=len(compacted_groups), | |
| kept_group_count=len(recent_groups), | |
| ) | |
| if summary_reply.get("tool_calls"): | |
| return CompactionOutcome( | |
| status="error", | |
| compacted_messages=safe_messages, | |
| prior_token_estimate=prior_token_estimate, | |
| existing_memory_text=existing_memory_text, | |
| summary_request=summary_request, | |
| summary_response=safe_jsonable(summary_reply), | |
| pre_messages=safe_messages, | |
| post_messages=safe_messages, | |
| compacted_group_count=len(compacted_groups), | |
| kept_group_count=len(recent_groups), | |
| error="context compaction summary call returned tool calls", | |
| ) | |
| summary_text = str(summary_reply.get("content", "") or "").strip() | |
| if not summary_text: | |
| return CompactionOutcome( | |
| status="error", | |
| compacted_messages=safe_messages, | |
| prior_token_estimate=prior_token_estimate, | |
| existing_memory_text=existing_memory_text, | |
| summary_request=summary_request, | |
| summary_response=safe_jsonable(summary_reply), | |
| pre_messages=safe_messages, | |
| post_messages=safe_messages, | |
| compacted_group_count=len(compacted_groups), | |
| kept_group_count=len(recent_groups), | |
| error="context compaction summary call returned empty text", | |
| ) | |
| summary_message = {"role": "user", "content": COMPACT_MEMORY_PREFIX + summary_text} | |
| compacted_messages = safe_messages[:2] + [summary_message] | |
| for group in recent_groups: | |
| compacted_messages.extend(group) | |
| new_token_estimate = token_counter(compacted_messages) | |
| return CompactionOutcome( | |
| status="ok", | |
| compacted_messages=compacted_messages, | |
| summary_text=summary_text, | |
| prior_token_estimate=prior_token_estimate, | |
| new_token_estimate=new_token_estimate, | |
| compacted_group_count=len(compacted_groups), | |
| kept_group_count=len(recent_groups), | |
| existing_memory_text=existing_memory_text, | |
| summary_request=summary_request, | |
| summary_response=safe_jsonable(summary_reply), | |
| pre_messages=safe_messages, | |
| post_messages=compacted_messages, | |
| ) | |
| def _turn_groups(messages: Sequence[dict[str, Any]]) -> list[list[dict[str, Any]]]: | |
| groups: list[list[dict[str, Any]]] = [] | |
| current_group: list[dict[str, Any]] = [] | |
| for message in messages: | |
| role = str(message.get("role", "")) | |
| if role == "assistant" and current_group: | |
| groups.append(current_group) | |
| current_group = [message] | |
| continue | |
| current_group.append(message) | |
| if current_group: | |
| groups.append(current_group) | |
| return groups | |
| def _split_existing_memory_messages(messages: Sequence[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]: | |
| existing_summaries: list[str] = [] | |
| remaining_messages: list[dict[str, Any]] = [] | |
| preserving_summary_prefix = True | |
| for message in messages: | |
| content = message.get("content", "") | |
| if ( | |
| preserving_summary_prefix | |
| and str(message.get("role", "")) == "user" | |
| and isinstance(content, str) | |
| and content.startswith(COMPACT_MEMORY_PREFIX) | |
| ): | |
| existing_summaries.append(content[len(COMPACT_MEMORY_PREFIX) :].strip()) | |
| continue | |
| preserving_summary_prefix = False | |
| remaining_messages.append(dict(message)) | |
| merged_summary = "\n\n".join(summary for summary in existing_summaries if summary).strip() | |
| return merged_summary, remaining_messages | |
| def _split_turn_groups(turn_groups: Sequence[Sequence[dict[str, Any]]], model_profile: ModelProfile) -> tuple[list[list[dict[str, Any]]], list[list[dict[str, Any]]]]: | |
| recent_char_budget = max(400, model_profile.recent_history_budget_tokens * 4) | |
| recent_groups: list[list[dict[str, Any]]] = [] | |
| recent_chars = 0 | |
| for group in reversed(turn_groups): | |
| rendered = _render_group(group, max_chars_per_message=240) | |
| if recent_groups and recent_chars >= recent_char_budget: | |
| break | |
| recent_groups.insert(0, [dict(message) for message in group]) | |
| recent_chars += len(rendered) | |
| if len(recent_groups) >= 4: | |
| break | |
| if len(recent_groups) >= len(turn_groups): | |
| recent_groups = recent_groups[1:] | |
| compacted_count = max(0, len(turn_groups) - len(recent_groups)) | |
| compacted_groups = [[dict(message) for message in group] for group in turn_groups[:compacted_count]] | |
| return compacted_groups, recent_groups | |
| def _render_history_text(turn_groups: Sequence[Sequence[dict[str, Any]]], model_profile: ModelProfile) -> str: | |
| max_history_chars = max(600, min(64000, model_profile.context_window * 2)) | |
| max_chars_per_message = max(200, min(4000, max_history_chars // 10)) | |
| parts: list[str] = [] | |
| used = 0 | |
| for index, group in enumerate(turn_groups, start=1): | |
| rendered = f"[Turn group {index}]\n{_render_group(group, max_chars_per_message=max_chars_per_message)}" | |
| if parts and used + len(rendered) > max_history_chars: | |
| remaining = max_history_chars - used | |
| if remaining > 80: | |
| parts.append(rendered[: remaining - 40].rstrip() + "\n...[history truncated]") | |
| break | |
| parts.append(rendered) | |
| used += len(rendered) | |
| return "\n\n".join(parts).strip() | |
| def _render_group(group: Sequence[dict[str, Any]], *, max_chars_per_message: int) -> str: | |
| lines: list[str] = [] | |
| for message in group: | |
| role = str(message.get("role", "")) | |
| content = _message_excerpt(message, max_chars=max_chars_per_message) | |
| lines.append(f"{role}: {content}") | |
| return "\n".join(lines).strip() | |
| def _message_excerpt(message: dict[str, Any], *, max_chars: int) -> str: | |
| content = message.get("content", "") | |
| text: str | |
| if isinstance(content, str): | |
| text = content | |
| elif isinstance(content, list): | |
| parts: list[str] = [] | |
| for part in content: | |
| if isinstance(part, dict) and part.get("type") == "text": | |
| parts.append(str(part.get("text", ""))) | |
| elif isinstance(part, dict) and part.get("type") == "image_url": | |
| parts.append("[image_url]") | |
| else: | |
| parts.append(str(part)) | |
| text = " ".join(part for part in parts if part) | |
| else: | |
| text = str(content) | |
| tool_calls = message.get("tool_calls") | |
| if tool_calls: | |
| tool_names = [] | |
| for tool_call in tool_calls: | |
| function_block = tool_call.get("function", {}) if isinstance(tool_call, dict) else {} | |
| tool_names.append(str(function_block.get("name", ""))) | |
| if tool_names: | |
| text = (text + "\nTool calls: " + ", ".join(name for name in tool_names if name)).strip() | |
| compacted = " ".join(text.split()) | |
| if len(compacted) <= max_chars: | |
| return compacted | |
| return compacted[: max_chars - 16].rstrip() + "...[truncated]" | |
| def _truncate_summary_text(text: str, *, max_chars: int) -> str: | |
| compacted = " ".join(str(text).split()) | |
| if len(compacted) <= max_chars: | |
| return compacted | |
| return compacted[: max_chars - 16].rstrip() + "...[truncated]" | |