Spaces:
Sleeping
Feat/advisor tools (#38)
Browse files* added /scraper directory from Clary branch for inspection.
* Added Gemini function-calling MVP with course search tool.
* Added Rate My Professor tool and multi-tool dispatcher.
* Added tool registry with auto-discovery and tool filtering.
* fixed test warning in RMP tool test.
* added structured tool-call return result and get_tool_response method to orchestrator.
* wired tool calling into orchestrator chat routes.
* cleaned up standalone test routes and did minor refactoring.
* fixed context manager dropping non-allowlist persona and tool responses.
* pinned backend requirements to <pkg>~=M.m from latest container build.
* removed license notices and cu_info_scraper.
* moved the tool-calling loop into LLMClient with provider specific overrides implemented in each respective client class.
* refactored tools config to support per tool settings.
* added rmp school_id lookup script for CLI usage.
* added dark mode awareness to advisor color fallback which is used by the orchestrator when a tool call is made.
* fixed vLLM tool calling to handle multiple tool calls.
* fixed Gemini tool calling to handle multiple tool calls.
* deleted reference /scrapers directory.
* revert Gemini base url to be defined in init.
* refactored Gemini tool calling to use the /openai endpoint format.
- multi_llm_chatbot_backend/app/api/routes/chat.py +41 -1
- multi_llm_chatbot_backend/app/api/routes/provider.py +2 -0
- multi_llm_chatbot_backend/app/config.py +19 -1
- multi_llm_chatbot_backend/app/core/bootstrap.py +1 -1
- multi_llm_chatbot_backend/app/core/context_manager.py +5 -6
- multi_llm_chatbot_backend/app/core/improved_orchestrator.py +47 -2
- multi_llm_chatbot_backend/app/llm/improved_gemini_client.py +124 -4
- multi_llm_chatbot_backend/app/llm/improved_vllm_client.py +122 -4
- multi_llm_chatbot_backend/app/llm/llm_client.py +49 -4
- multi_llm_chatbot_backend/app/main.py +1 -0
- multi_llm_chatbot_backend/app/tests/unit/test_course_search_tool.py +33 -0
- multi_llm_chatbot_backend/app/tests/unit/test_gemini_client.py +432 -0
- multi_llm_chatbot_backend/app/tests/unit/test_rmp_tool.py +169 -0
- multi_llm_chatbot_backend/app/tests/unit/test_tool_registry.py +123 -0
- multi_llm_chatbot_backend/app/tests/unit/test_vllm_client.py +412 -0
- multi_llm_chatbot_backend/app/tools/__init__.py +125 -0
- multi_llm_chatbot_backend/app/tools/rate_my_professor.py +202 -0
- multi_llm_chatbot_backend/app/tools/search_courses.py +191 -0
- multi_llm_chatbot_backend/requirements.txt +19 -22
- phd-advisor-frontend/src/contexts/AppConfigContext.js +5 -1
- phd_config.yaml +10 -0
- scripts/rmp_school_lookup.py +112 -0
|
@@ -115,6 +115,27 @@ async def chat_stream(
|
|
| 115 |
).to_ndjson()
|
| 116 |
return
|
| 117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
# Get personas most relevant to the current session
|
| 119 |
top_personas = await chat_orchestrator.get_top_personas(
|
| 120 |
session_id=sid,
|
|
@@ -363,7 +384,26 @@ async def chat_sequential_enhanced(
|
|
| 363 |
"trigger": "vague_input"
|
| 364 |
}
|
| 365 |
}
|
| 366 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
# RESTORED: Get intelligently ordered personas based on context
|
| 368 |
top_personas = await chat_orchestrator.get_top_personas(
|
| 369 |
session_id=session_id,
|
|
|
|
| 115 |
).to_ndjson()
|
| 116 |
return
|
| 117 |
|
| 118 |
+
# If an enabled tool can handle this query, return its response
|
| 119 |
+
# directly and skip persona generation.
|
| 120 |
+
tool_result = await chat_orchestrator.get_tool_response(message.user_input)
|
| 121 |
+
if tool_result.used_tool:
|
| 122 |
+
session.append_message("orchestrator", tool_result.text)
|
| 123 |
+
yield ChatStreamLine(
|
| 124 |
+
type="advisor",
|
| 125 |
+
data={
|
| 126 |
+
"persona_id": "orchestrator",
|
| 127 |
+
"persona_name": "Orchestrator",
|
| 128 |
+
"content": tool_result.text,
|
| 129 |
+
"used_documents": False,
|
| 130 |
+
"document_chunks_used": 0,
|
| 131 |
+
},
|
| 132 |
+
).to_ndjson()
|
| 133 |
+
yield ChatStreamLine(
|
| 134 |
+
type="progress",
|
| 135 |
+
data={"phase": "complete"},
|
| 136 |
+
).to_ndjson()
|
| 137 |
+
return
|
| 138 |
+
|
| 139 |
# Get personas most relevant to the current session
|
| 140 |
top_personas = await chat_orchestrator.get_top_personas(
|
| 141 |
session_id=sid,
|
|
|
|
| 384 |
"trigger": "vague_input"
|
| 385 |
}
|
| 386 |
}
|
| 387 |
+
|
| 388 |
+
# If an enabled tool can handle this query, return its response
|
| 389 |
+
# directly and skip persona generation.
|
| 390 |
+
tool_result = await chat_orchestrator.get_tool_response(message.user_input)
|
| 391 |
+
if tool_result.used_tool:
|
| 392 |
+
session.append_message("orchestrator", tool_result.text)
|
| 393 |
+
return {
|
| 394 |
+
"responses": [{
|
| 395 |
+
"persona_id": "orchestrator",
|
| 396 |
+
"persona_name": "Orchestrator",
|
| 397 |
+
"content": tool_result.text,
|
| 398 |
+
"used_documents": False,
|
| 399 |
+
"document_chunks_used": 0,
|
| 400 |
+
}],
|
| 401 |
+
"session_debug": {
|
| 402 |
+
"session_id": session_id,
|
| 403 |
+
"tool_used": True,
|
| 404 |
+
}
|
| 405 |
+
}
|
| 406 |
+
|
| 407 |
# RESTORED: Get intelligently ordered personas based on context
|
| 408 |
top_personas = await chat_orchestrator.get_top_personas(
|
| 409 |
session_id=session_id,
|
|
@@ -69,6 +69,8 @@ async def switch_provider(provider_data: ProviderSwitch):
|
|
| 69 |
new_llm = create_llm_client(current_provider)
|
| 70 |
llm = new_llm
|
| 71 |
|
|
|
|
|
|
|
| 72 |
new_personas = get_default_personas(new_llm)
|
| 73 |
chat_orchestrator.personas.clear()
|
| 74 |
for persona in new_personas:
|
|
|
|
| 69 |
new_llm = create_llm_client(current_provider)
|
| 70 |
llm = new_llm
|
| 71 |
|
| 72 |
+
chat_orchestrator.llm_client = new_llm
|
| 73 |
+
|
| 74 |
new_personas = get_default_personas(new_llm)
|
| 75 |
chat_orchestrator.personas.clear()
|
| 76 |
for persona in new_personas:
|
|
@@ -11,7 +11,7 @@ import os
|
|
| 11 |
import logging
|
| 12 |
import colorsys
|
| 13 |
from pathlib import Path
|
| 14 |
-
from typing import List, Optional
|
| 15 |
from colorhash import ColorHash
|
| 16 |
|
| 17 |
import yaml
|
|
@@ -234,6 +234,23 @@ class RAGConfig(BaseModel):
|
|
| 234 |
chroma_collection: str = "phd_advisor_documents"
|
| 235 |
|
| 236 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
class AppSettings(BaseModel):
|
| 238 |
"""Top-level container that mirrors the YAML structure."""
|
| 239 |
app: AppConfig = AppConfig()
|
|
@@ -246,6 +263,7 @@ class AppSettings(BaseModel):
|
|
| 246 |
mongodb: MongoDBConfig = MongoDBConfig()
|
| 247 |
llm: LLMConfig = LLMConfig()
|
| 248 |
rag: RAGConfig = RAGConfig()
|
|
|
|
| 249 |
|
| 250 |
# ------------------------------------------------------------------
|
| 251 |
# Convenience helpers
|
|
|
|
| 11 |
import logging
|
| 12 |
import colorsys
|
| 13 |
from pathlib import Path
|
| 14 |
+
from typing import Any, Dict, List, Optional
|
| 15 |
from colorhash import ColorHash
|
| 16 |
|
| 17 |
import yaml
|
|
|
|
| 234 |
chroma_collection: str = "phd_advisor_documents"
|
| 235 |
|
| 236 |
|
| 237 |
+
class ToolsConfig(BaseModel):
|
| 238 |
+
model_config = {"extra": "allow"}
|
| 239 |
+
|
| 240 |
+
def get_enabled_names(self) -> List[str]:
|
| 241 |
+
"""Return tool names whose config has ``enabled: true``."""
|
| 242 |
+
return [
|
| 243 |
+
name
|
| 244 |
+
for name, cfg in self.__pydantic_extra__.items()
|
| 245 |
+
if isinstance(cfg, dict) and cfg.get("enabled", True)
|
| 246 |
+
]
|
| 247 |
+
|
| 248 |
+
def get_tool_config(self, name: str) -> Dict[str, Any]:
|
| 249 |
+
"""Return the raw config dict for a single tool, or ``{}``."""
|
| 250 |
+
cfg = self.__pydantic_extra__.get(name, {})
|
| 251 |
+
return cfg if isinstance(cfg, dict) else {}
|
| 252 |
+
|
| 253 |
+
|
| 254 |
class AppSettings(BaseModel):
|
| 255 |
"""Top-level container that mirrors the YAML structure."""
|
| 256 |
app: AppConfig = AppConfig()
|
|
|
|
| 263 |
mongodb: MongoDBConfig = MongoDBConfig()
|
| 264 |
llm: LLMConfig = LLMConfig()
|
| 265 |
rag: RAGConfig = RAGConfig()
|
| 266 |
+
tools: ToolsConfig = ToolsConfig()
|
| 267 |
|
| 268 |
# ------------------------------------------------------------------
|
| 269 |
# Convenience helpers
|
|
@@ -30,7 +30,7 @@ def create_llm_client(provider=None):
|
|
| 30 |
)
|
| 31 |
|
| 32 |
llm = create_llm_client()
|
| 33 |
-
chat_orchestrator = ImprovedChatOrchestrator()
|
| 34 |
|
| 35 |
DEFAULT_PERSONAS = get_default_personas(llm)
|
| 36 |
for persona in DEFAULT_PERSONAS:
|
|
|
|
| 30 |
)
|
| 31 |
|
| 32 |
llm = create_llm_client()
|
| 33 |
+
chat_orchestrator = ImprovedChatOrchestrator(llm_client=llm)
|
| 34 |
|
| 35 |
DEFAULT_PERSONAS = get_default_personas(llm)
|
| 36 |
for persona in DEFAULT_PERSONAS:
|
|
@@ -209,17 +209,16 @@ class ContextManager:
|
|
| 209 |
"role": "user",
|
| 210 |
"parts": [{"text": content}]
|
| 211 |
})
|
| 212 |
-
elif role in ['assistant', 'methodologist', 'theorist', 'pragmatist']:
|
| 213 |
-
formatted.append({
|
| 214 |
-
"role": "model",
|
| 215 |
-
"parts": [{"text": content}]
|
| 216 |
-
})
|
| 217 |
elif role == 'document':
|
| 218 |
-
# Add document as user context
|
| 219 |
formatted.append({
|
| 220 |
"role": "user",
|
| 221 |
"parts": [{"text": f"[Context Document] {content}"}]
|
| 222 |
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
|
| 224 |
return formatted
|
| 225 |
|
|
|
|
| 209 |
"role": "user",
|
| 210 |
"parts": [{"text": content}]
|
| 211 |
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
elif role == 'document':
|
|
|
|
| 213 |
formatted.append({
|
| 214 |
"role": "user",
|
| 215 |
"parts": [{"text": f"[Context Document] {content}"}]
|
| 216 |
})
|
| 217 |
+
else:
|
| 218 |
+
formatted.append({
|
| 219 |
+
"role": "model",
|
| 220 |
+
"parts": [{"text": content}]
|
| 221 |
+
})
|
| 222 |
|
| 223 |
return formatted
|
| 224 |
|
|
@@ -4,6 +4,8 @@ from app.core.session_manager import ConversationContext, get_session_manager
|
|
| 4 |
from app.core.context_manager import get_context_manager
|
| 5 |
from app.core.rag_manager import get_rag_manager
|
| 6 |
from app.config import get_settings
|
|
|
|
|
|
|
| 7 |
|
| 8 |
import json
|
| 9 |
import logging
|
|
@@ -16,8 +18,9 @@ class ImprovedChatOrchestrator:
|
|
| 16 |
Enhanced orchestrator with document awareness and improved context handling
|
| 17 |
"""
|
| 18 |
|
| 19 |
-
def __init__(self):
|
| 20 |
self.personas: Dict[str, Persona] = {}
|
|
|
|
| 21 |
self.session_manager = get_session_manager()
|
| 22 |
self.context_manager = get_context_manager()
|
| 23 |
|
|
@@ -33,7 +36,49 @@ class ImprovedChatOrchestrator:
|
|
| 33 |
def list_personas(self) -> List[str]:
|
| 34 |
"""List all available persona IDs"""
|
| 35 |
return list(self.personas.keys())
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
async def process_message(self,
|
| 38 |
user_input: str,
|
| 39 |
session_id: Optional[str] = None,
|
|
|
|
| 4 |
from app.core.context_manager import get_context_manager
|
| 5 |
from app.core.rag_manager import get_rag_manager
|
| 6 |
from app.config import get_settings
|
| 7 |
+
from app.llm.llm_client import LLMClient, ToolCallResult
|
| 8 |
+
from app.tools import get_tool_definitions, get_tool_executor
|
| 9 |
|
| 10 |
import json
|
| 11 |
import logging
|
|
|
|
| 18 |
Enhanced orchestrator with document awareness and improved context handling
|
| 19 |
"""
|
| 20 |
|
| 21 |
+
def __init__(self, llm_client: LLMClient = None):
|
| 22 |
self.personas: Dict[str, Persona] = {}
|
| 23 |
+
self.llm_client = llm_client
|
| 24 |
self.session_manager = get_session_manager()
|
| 25 |
self.context_manager = get_context_manager()
|
| 26 |
|
|
|
|
| 36 |
def list_personas(self) -> List[str]:
|
| 37 |
"""List all available persona IDs"""
|
| 38 |
return list(self.personas.keys())
|
| 39 |
+
|
| 40 |
+
async def get_tool_response(self, user_message: str) -> ToolCallResult:
|
| 41 |
+
"""Check whether a tool can handle *user_message*.
|
| 42 |
+
|
| 43 |
+
If tools are disabled in config, no LLM client is available, or the
|
| 44 |
+
model decides no tool is needed, returns
|
| 45 |
+
``ToolCallResult(used_tool=False)``. Otherwise executes the tool and
|
| 46 |
+
returns the grounded response with ``used_tool=True``.
|
| 47 |
+
"""
|
| 48 |
+
if self.llm_client is None:
|
| 49 |
+
return ToolCallResult(text="", used_tool=False)
|
| 50 |
+
|
| 51 |
+
settings = get_settings()
|
| 52 |
+
tools_enabled = settings.tools.get_enabled_names()
|
| 53 |
+
|
| 54 |
+
if not tools_enabled:
|
| 55 |
+
return ToolCallResult(text="", used_tool=False)
|
| 56 |
+
|
| 57 |
+
tool_definitions = get_tool_definitions(enabled=tools_enabled)
|
| 58 |
+
tool_executor = get_tool_executor(enabled=tools_enabled)
|
| 59 |
+
|
| 60 |
+
if not tool_definitions:
|
| 61 |
+
return ToolCallResult(text="", used_tool=False)
|
| 62 |
+
|
| 63 |
+
system_prompt = (
|
| 64 |
+
"You are a helpful assistant with access to external tools. "
|
| 65 |
+
"Use the available tools when the user's question can be answered "
|
| 66 |
+
"by one of them. If no tool is relevant, respond with a brief "
|
| 67 |
+
"text answer. "
|
| 68 |
+
"If a tool response includes 'truncated': true, let the user know "
|
| 69 |
+
"how many total results were found and suggest they narrow their "
|
| 70 |
+
"search for more specific results. "
|
| 71 |
+
"Format your responses using markdown. Use bullet points "
|
| 72 |
+
"to present structured data like course listings or professor ratings."
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
return await self.llm_client.generate_with_tools(
|
| 76 |
+
system_prompt=system_prompt,
|
| 77 |
+
user_message=user_message,
|
| 78 |
+
tool_definitions=tool_definitions,
|
| 79 |
+
tool_executor=tool_executor,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
async def process_message(self,
|
| 83 |
user_input: str,
|
| 84 |
session_id: Optional[str] = None,
|
|
@@ -1,10 +1,14 @@
|
|
| 1 |
import httpx
|
| 2 |
-
import
|
| 3 |
-
|
| 4 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
from app.core.context_manager import get_context_manager
|
| 6 |
from app.config import get_settings
|
| 7 |
-
import logging
|
| 8 |
|
| 9 |
logger = logging.getLogger(__name__)
|
| 10 |
|
|
@@ -20,8 +24,16 @@ class ImprovedGeminiClient(LLMClient):
|
|
| 20 |
if not self.api_key:
|
| 21 |
raise ValueError("Gemini API key not set. Provide it in config.yaml (llm.gemini.api_key).")
|
| 22 |
|
|
|
|
| 23 |
self.base_url = "https://generativelanguage.googleapis.com/v1beta/models"
|
| 24 |
self.context_manager = get_context_manager()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
async def generate(self, system_prompt: str, context: List[dict], temperature: float, max_tokens: int, response_mime_type: str = None) -> str:
|
| 27 |
"""
|
|
@@ -120,3 +132,111 @@ class ImprovedGeminiClient(LLMClient):
|
|
| 120 |
except Exception as e:
|
| 121 |
logger.error(f"Unexpected error in Gemini client: {str(e)}")
|
| 122 |
return "I encountered an unexpected error. Please try again."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import httpx
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
from typing import Any, Callable, Dict, List, Optional
|
| 5 |
+
|
| 6 |
+
from openai import AsyncOpenAI, APIConnectionError, APIStatusError
|
| 7 |
+
|
| 8 |
+
from app.llm.llm_client import LLMClient, ToolCallInfo, ToolCallResult
|
| 9 |
+
|
| 10 |
from app.core.context_manager import get_context_manager
|
| 11 |
from app.config import get_settings
|
|
|
|
| 12 |
|
| 13 |
logger = logging.getLogger(__name__)
|
| 14 |
|
|
|
|
| 24 |
if not self.api_key:
|
| 25 |
raise ValueError("Gemini API key not set. Provide it in config.yaml (llm.gemini.api_key).")
|
| 26 |
|
| 27 |
+
# Native Gemini REST API
|
| 28 |
self.base_url = "https://generativelanguage.googleapis.com/v1beta/models"
|
| 29 |
self.context_manager = get_context_manager()
|
| 30 |
+
|
| 31 |
+
# OpenAI-compatible endpoint (for tool calling)
|
| 32 |
+
self.openai_client = AsyncOpenAI(
|
| 33 |
+
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
|
| 34 |
+
api_key=self.api_key,
|
| 35 |
+
timeout=90.0,
|
| 36 |
+
)
|
| 37 |
|
| 38 |
async def generate(self, system_prompt: str, context: List[dict], temperature: float, max_tokens: int, response_mime_type: str = None) -> str:
|
| 39 |
"""
|
|
|
|
| 132 |
except Exception as e:
|
| 133 |
logger.error(f"Unexpected error in Gemini client: {str(e)}")
|
| 134 |
return "I encountered an unexpected error. Please try again."
|
| 135 |
+
|
| 136 |
+
# ------------------------------------------------------------------
|
| 137 |
+
# Tool-calling support (via Gemini OpenAI-compatible endpoint)
|
| 138 |
+
# ------------------------------------------------------------------
|
| 139 |
+
|
| 140 |
+
_MAX_TOOL_ROUNDS = 5
|
| 141 |
+
|
| 142 |
+
async def generate_with_tools(
|
| 143 |
+
self,
|
| 144 |
+
system_prompt: str,
|
| 145 |
+
user_message: str,
|
| 146 |
+
tool_definitions: Optional[List[Dict[str, Any]]] = None,
|
| 147 |
+
tool_executor: Optional[Callable] = None,
|
| 148 |
+
temperature: float = 0.7,
|
| 149 |
+
max_tokens: int = 2048,
|
| 150 |
+
) -> ToolCallResult:
|
| 151 |
+
"""OpenAI-compatible tool-calling loop via Gemini's /openai/ endpoint.
|
| 152 |
+
|
| 153 |
+
Tool definitions are expected in OpenAI format (as returned by the
|
| 154 |
+
tool registry). Loops through the standard tool-call protocol
|
| 155 |
+
until the model produces a plain text response:
|
| 156 |
+
|
| 157 |
+
request → detect tool_calls → execute all → feed results
|
| 158 |
+
back → repeat (up to ``_MAX_TOOL_ROUNDS`` rounds).
|
| 159 |
+
|
| 160 |
+
All tool calls in a single response are executed before the next
|
| 161 |
+
round, so multi-tool queries (e.g. "compare professor A vs B")
|
| 162 |
+
work correctly.
|
| 163 |
+
"""
|
| 164 |
+
messages: List[Dict[str, Any]] = [
|
| 165 |
+
{"role": "system", "content": system_prompt},
|
| 166 |
+
{"role": "user", "content": user_message},
|
| 167 |
+
]
|
| 168 |
+
|
| 169 |
+
openai_tools = tool_definitions or []
|
| 170 |
+
all_tool_calls: List[ToolCallInfo] = []
|
| 171 |
+
|
| 172 |
+
try:
|
| 173 |
+
for _round in range(self._MAX_TOOL_ROUNDS):
|
| 174 |
+
response = await self.openai_client.chat.completions.create(
|
| 175 |
+
model=self.model_name,
|
| 176 |
+
messages=messages,
|
| 177 |
+
tools=openai_tools or None,
|
| 178 |
+
temperature=temperature,
|
| 179 |
+
max_tokens=max_tokens,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
choice = response.choices[0].message
|
| 183 |
+
|
| 184 |
+
if not choice.tool_calls:
|
| 185 |
+
return ToolCallResult(
|
| 186 |
+
text=choice.content or "",
|
| 187 |
+
used_tool=bool(all_tool_calls),
|
| 188 |
+
tool_name=all_tool_calls[0].name if all_tool_calls else None,
|
| 189 |
+
tool_args=all_tool_calls[0].args if all_tool_calls else {},
|
| 190 |
+
tool_calls_made=all_tool_calls,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
messages.append(choice.model_dump(exclude_none=True))
|
| 194 |
+
|
| 195 |
+
for tc in choice.tool_calls:
|
| 196 |
+
fn_name = tc.function.name
|
| 197 |
+
fn_args = json.loads(tc.function.arguments)
|
| 198 |
+
logger.info("Gemini requested tool call: %s(%s)", fn_name, fn_args)
|
| 199 |
+
all_tool_calls.append(ToolCallInfo(name=fn_name, args=fn_args))
|
| 200 |
+
|
| 201 |
+
try:
|
| 202 |
+
tool_result = await tool_executor(name=fn_name, **fn_args)
|
| 203 |
+
except Exception as exc:
|
| 204 |
+
logger.error("Tool %s failed: %s", fn_name, exc)
|
| 205 |
+
tool_result = {"error": str(exc)}
|
| 206 |
+
|
| 207 |
+
messages.append({
|
| 208 |
+
"role": "tool",
|
| 209 |
+
"tool_call_id": tc.id,
|
| 210 |
+
"content": json.dumps(tool_result),
|
| 211 |
+
})
|
| 212 |
+
|
| 213 |
+
logger.warning(
|
| 214 |
+
"Tool-calling loop exhausted after %d rounds", self._MAX_TOOL_ROUNDS,
|
| 215 |
+
)
|
| 216 |
+
last_content = response.choices[0].message.content or ""
|
| 217 |
+
return ToolCallResult(
|
| 218 |
+
text=last_content or "I was unable to finish looking that up. Please try again.",
|
| 219 |
+
used_tool=bool(all_tool_calls),
|
| 220 |
+
tool_name=all_tool_calls[0].name if all_tool_calls else None,
|
| 221 |
+
tool_args=all_tool_calls[0].args if all_tool_calls else {},
|
| 222 |
+
tool_calls_made=all_tool_calls,
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
except APIConnectionError:
|
| 226 |
+
logger.error("Unable to connect to Gemini OpenAI-compat endpoint")
|
| 227 |
+
return ToolCallResult(
|
| 228 |
+
text="I'm unable to connect to the AI service. Please try again.",
|
| 229 |
+
used_tool=False,
|
| 230 |
+
)
|
| 231 |
+
except APIStatusError as e:
|
| 232 |
+
logger.error("Gemini tool-call API error: %s - %s", e.status_code, e.message)
|
| 233 |
+
return ToolCallResult(
|
| 234 |
+
text="The AI service encountered an error. Please try again.",
|
| 235 |
+
used_tool=False,
|
| 236 |
+
)
|
| 237 |
+
except Exception as e:
|
| 238 |
+
logger.error("Unexpected error in Gemini tool-calling: %s", e)
|
| 239 |
+
return ToolCallResult(
|
| 240 |
+
text="I encountered an unexpected error. Please try again.",
|
| 241 |
+
used_tool=False,
|
| 242 |
+
)
|
|
@@ -1,8 +1,11 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
| 2 |
from openai import AsyncOpenAI, APIConnectionError, APIStatusError
|
| 3 |
-
|
|
|
|
| 4 |
from app.core.context_manager import get_context_manager
|
| 5 |
-
import logging
|
| 6 |
|
| 7 |
logger = logging.getLogger(__name__)
|
| 8 |
|
|
@@ -15,7 +18,7 @@ class ImprovedVllmClient(LLMClient):
|
|
| 15 |
self.client = AsyncOpenAI(
|
| 16 |
base_url=f"{api_url}/v1",
|
| 17 |
api_key=api_key,
|
| 18 |
-
timeout=
|
| 19 |
)
|
| 20 |
self.context_manager = get_context_manager()
|
| 21 |
|
|
@@ -70,3 +73,118 @@ class ImprovedVllmClient(LLMClient):
|
|
| 70 |
logger.error(f"Unexpected error in vLLM client: {str(e)}")
|
| 71 |
return "I encountered an unexpected error. Please try again."
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
from typing import Any, Callable, Dict, List, Optional
|
| 4 |
+
|
| 5 |
from openai import AsyncOpenAI, APIConnectionError, APIStatusError
|
| 6 |
+
|
| 7 |
+
from app.llm.llm_client import LLMClient, ToolCallInfo, ToolCallResult
|
| 8 |
from app.core.context_manager import get_context_manager
|
|
|
|
| 9 |
|
| 10 |
logger = logging.getLogger(__name__)
|
| 11 |
|
|
|
|
| 18 |
self.client = AsyncOpenAI(
|
| 19 |
base_url=f"{api_url}/v1",
|
| 20 |
api_key=api_key,
|
| 21 |
+
timeout=90.0,
|
| 22 |
)
|
| 23 |
self.context_manager = get_context_manager()
|
| 24 |
|
|
|
|
| 73 |
logger.error(f"Unexpected error in vLLM client: {str(e)}")
|
| 74 |
return "I encountered an unexpected error. Please try again."
|
| 75 |
|
| 76 |
+
# ------------------------------------------------------------------
|
| 77 |
+
# Tool-calling support (OpenAI-compatible format)
|
| 78 |
+
# ------------------------------------------------------------------
|
| 79 |
+
|
| 80 |
+
_MAX_TOOL_ROUNDS = 5
|
| 81 |
+
|
| 82 |
+
async def generate_with_tools(
|
| 83 |
+
self,
|
| 84 |
+
system_prompt: str,
|
| 85 |
+
user_message: str,
|
| 86 |
+
tool_definitions: Optional[List[Dict[str, Any]]] = None,
|
| 87 |
+
tool_executor: Optional[Callable] = None,
|
| 88 |
+
temperature: float = 0.7,
|
| 89 |
+
max_tokens: int = 2048,
|
| 90 |
+
) -> ToolCallResult:
|
| 91 |
+
"""OpenAI-compatible tool-calling loop for vLLM.
|
| 92 |
+
|
| 93 |
+
Tool definitions are expected in OpenAI format (as returned by the
|
| 94 |
+
tool registry). Loops through the standard tool-call protocol
|
| 95 |
+
until the model produces a plain text response:
|
| 96 |
+
|
| 97 |
+
request → detect tool_calls → execute all → feed results
|
| 98 |
+
back → repeat (up to ``_MAX_TOOL_ROUNDS`` rounds).
|
| 99 |
+
|
| 100 |
+
All tool calls in a single response are executed before the next
|
| 101 |
+
round, so multi-tool queries (e.g. "compare professor A vs B")
|
| 102 |
+
work correctly.
|
| 103 |
+
"""
|
| 104 |
+
if not self.model_name:
|
| 105 |
+
await self.refresh_model()
|
| 106 |
+
|
| 107 |
+
messages: List[Dict[str, Any]] = [
|
| 108 |
+
{"role": "system", "content": system_prompt},
|
| 109 |
+
{"role": "user", "content": user_message},
|
| 110 |
+
]
|
| 111 |
+
|
| 112 |
+
openai_tools = tool_definitions or []
|
| 113 |
+
|
| 114 |
+
all_tool_calls: List[ToolCallInfo] = []
|
| 115 |
+
|
| 116 |
+
try:
|
| 117 |
+
for _round in range(self._MAX_TOOL_ROUNDS):
|
| 118 |
+
response = await self.client.chat.completions.create(
|
| 119 |
+
model=self.model_name,
|
| 120 |
+
messages=messages,
|
| 121 |
+
tools=openai_tools or None,
|
| 122 |
+
temperature=temperature,
|
| 123 |
+
max_tokens=max_tokens,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
choice = response.choices[0].message
|
| 127 |
+
|
| 128 |
+
if not choice.tool_calls:
|
| 129 |
+
return ToolCallResult(
|
| 130 |
+
text=choice.content or "",
|
| 131 |
+
used_tool=bool(all_tool_calls),
|
| 132 |
+
tool_name=all_tool_calls[0].name if all_tool_calls else None,
|
| 133 |
+
tool_args=all_tool_calls[0].args if all_tool_calls else {},
|
| 134 |
+
tool_calls_made=all_tool_calls,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
messages.append(choice.model_dump())
|
| 138 |
+
|
| 139 |
+
for tc in choice.tool_calls:
|
| 140 |
+
fn_name = tc.function.name
|
| 141 |
+
fn_args = json.loads(tc.function.arguments)
|
| 142 |
+
logger.info("vLLM requested tool call: %s(%s)", fn_name, fn_args)
|
| 143 |
+
all_tool_calls.append(ToolCallInfo(name=fn_name, args=fn_args))
|
| 144 |
+
|
| 145 |
+
try:
|
| 146 |
+
tool_result = await tool_executor(name=fn_name, **fn_args)
|
| 147 |
+
except Exception as exc:
|
| 148 |
+
logger.error("Tool %s failed: %s", fn_name, exc)
|
| 149 |
+
tool_result = {"error": str(exc)}
|
| 150 |
+
|
| 151 |
+
messages.append({
|
| 152 |
+
"role": "tool",
|
| 153 |
+
"tool_call_id": tc.id,
|
| 154 |
+
"content": json.dumps(tool_result),
|
| 155 |
+
})
|
| 156 |
+
|
| 157 |
+
logger.warning(
|
| 158 |
+
"Tool-calling loop exhausted after %d rounds", self._MAX_TOOL_ROUNDS,
|
| 159 |
+
)
|
| 160 |
+
last_content = response.choices[0].message.content or ""
|
| 161 |
+
return ToolCallResult(
|
| 162 |
+
text=last_content or "I was unable to finish looking that up. Please try again.",
|
| 163 |
+
used_tool=bool(all_tool_calls),
|
| 164 |
+
tool_name=all_tool_calls[0].name if all_tool_calls else None,
|
| 165 |
+
tool_args=all_tool_calls[0].args if all_tool_calls else {},
|
| 166 |
+
tool_calls_made=all_tool_calls,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
except APIConnectionError:
|
| 170 |
+
logger.error("Unable to connect to vLLM at %s", self.api_url)
|
| 171 |
+
return ToolCallResult(
|
| 172 |
+
text="I'm unable to connect to the AI service. Please ensure the vLLM endpoint is available.",
|
| 173 |
+
used_tool=False,
|
| 174 |
+
)
|
| 175 |
+
except APIStatusError as e:
|
| 176 |
+
logger.error("vLLM tool-call API error: %s - %s", e.status_code, e.message)
|
| 177 |
+
if e.status_code == 404:
|
| 178 |
+
self.model_name = None
|
| 179 |
+
return ToolCallResult(
|
| 180 |
+
text="The AI service encountered an error. Please try again.",
|
| 181 |
+
used_tool=False,
|
| 182 |
+
)
|
| 183 |
+
except Exception as e:
|
| 184 |
+
logger.error("Unexpected error in vLLM tool-calling: %s", e)
|
| 185 |
+
return ToolCallResult(
|
| 186 |
+
text="I encountered an unexpected error. Please try again.",
|
| 187 |
+
used_tool=False,
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
|
|
@@ -1,27 +1,72 @@
|
|
| 1 |
from abc import ABC, abstractmethod
|
| 2 |
-
from
|
|
|
|
| 3 |
import re
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
class LLMClient(ABC):
|
| 6 |
"""Abstract base class for all LLM clients"""
|
| 7 |
-
|
| 8 |
@abstractmethod
|
| 9 |
async def generate(self, system_prompt: str, context: List[dict], temperature: float, max_tokens: int, response_mime_type: str = None) -> str:
|
| 10 |
"""
|
| 11 |
Generate a response using the LLM.
|
| 12 |
-
|
| 13 |
Args:
|
| 14 |
system_prompt (str): The system prompt defining the persona/role
|
| 15 |
context (List[dict]): List of conversation messages with 'role' and 'content' keys
|
| 16 |
temperature (float): Sampling temperature for generation
|
| 17 |
max_tokens (int): Maximum number of tokens to generate
|
| 18 |
response_mime_type (str, optional): MIME type for the response format. Defaults to None.
|
| 19 |
-
|
| 20 |
Returns:
|
| 21 |
str: The generated response text
|
| 22 |
"""
|
| 23 |
pass
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
def _clean_response(self, response: str) -> str:
|
| 26 |
"""Clean up response text, preserving Markdown formatting."""
|
| 27 |
response = response.replace("\r\n", "\n").replace("\r", "\n")
|
|
|
|
| 1 |
from abc import ABC, abstractmethod
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from typing import Any, Callable, Dict, List, Optional
|
| 4 |
import re
|
| 5 |
|
| 6 |
+
|
| 7 |
+
@dataclass
|
| 8 |
+
class ToolCallInfo:
|
| 9 |
+
"""Record of a single tool invocation."""
|
| 10 |
+
|
| 11 |
+
name: str
|
| 12 |
+
args: dict = field(default_factory=dict)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class ToolCallResult:
|
| 17 |
+
"""Structured return value from ``generate_with_tools``."""
|
| 18 |
+
|
| 19 |
+
text: str
|
| 20 |
+
used_tool: bool
|
| 21 |
+
tool_name: Optional[str] = None
|
| 22 |
+
tool_args: dict = field(default_factory=dict)
|
| 23 |
+
tool_calls_made: List["ToolCallInfo"] = field(default_factory=list)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
class LLMClient(ABC):
|
| 27 |
"""Abstract base class for all LLM clients"""
|
| 28 |
+
|
| 29 |
@abstractmethod
|
| 30 |
async def generate(self, system_prompt: str, context: List[dict], temperature: float, max_tokens: int, response_mime_type: str = None) -> str:
|
| 31 |
"""
|
| 32 |
Generate a response using the LLM.
|
| 33 |
+
|
| 34 |
Args:
|
| 35 |
system_prompt (str): The system prompt defining the persona/role
|
| 36 |
context (List[dict]): List of conversation messages with 'role' and 'content' keys
|
| 37 |
temperature (float): Sampling temperature for generation
|
| 38 |
max_tokens (int): Maximum number of tokens to generate
|
| 39 |
response_mime_type (str, optional): MIME type for the response format. Defaults to None.
|
| 40 |
+
|
| 41 |
Returns:
|
| 42 |
str: The generated response text
|
| 43 |
"""
|
| 44 |
pass
|
| 45 |
|
| 46 |
+
async def generate_with_tools(
|
| 47 |
+
self,
|
| 48 |
+
system_prompt: str,
|
| 49 |
+
user_message: str,
|
| 50 |
+
tool_definitions: Optional[List[Dict[str, Any]]] = None,
|
| 51 |
+
tool_executor: Optional[Callable] = None,
|
| 52 |
+
temperature: float = 0.7,
|
| 53 |
+
max_tokens: int = 2048,
|
| 54 |
+
) -> ToolCallResult:
|
| 55 |
+
"""Generate a response, optionally invoking tools.
|
| 56 |
+
|
| 57 |
+
Subclasses that support native tool calling should override this
|
| 58 |
+
method. The default implementation ignores tools and falls back
|
| 59 |
+
to a plain ``generate()`` call so that providers without tool
|
| 60 |
+
support degrade gracefully.
|
| 61 |
+
"""
|
| 62 |
+
text = await self.generate(
|
| 63 |
+
system_prompt=system_prompt,
|
| 64 |
+
context=[{"role": "user", "content": user_message}],
|
| 65 |
+
temperature=temperature,
|
| 66 |
+
max_tokens=max_tokens,
|
| 67 |
+
)
|
| 68 |
+
return ToolCallResult(text=text, used_tool=False)
|
| 69 |
+
|
| 70 |
def _clean_response(self, response: str) -> str:
|
| 71 |
"""Clean up response text, preserving Markdown formatting."""
|
| 72 |
response = response.replace("\r\n", "\n").replace("\r", "\n")
|
|
@@ -58,6 +58,7 @@ app.include_router(auth_router, prefix="/auth", tags=["authentication"])
|
|
| 58 |
app.include_router(chat_sessions_router, prefix="/api", tags=["chat-sessions"])
|
| 59 |
app.include_router(phd_canvas_router, prefix="/api", tags=["phd-canvas"])
|
| 60 |
|
|
|
|
| 61 |
# ---------------------------------------------------------------------------
|
| 62 |
# Public configuration endpoint — serves the frontend-safe subset
|
| 63 |
# ---------------------------------------------------------------------------
|
|
|
|
| 58 |
app.include_router(chat_sessions_router, prefix="/api", tags=["chat-sessions"])
|
| 59 |
app.include_router(phd_canvas_router, prefix="/api", tags=["phd-canvas"])
|
| 60 |
|
| 61 |
+
|
| 62 |
# ---------------------------------------------------------------------------
|
| 63 |
# Public configuration endpoint — serves the frontend-safe subset
|
| 64 |
# ---------------------------------------------------------------------------
|
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import unittest
|
| 3 |
+
|
| 4 |
+
from app.tools.search_courses import TOOL_DEFINITION, execute
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class TestSearchCoursesContract(unittest.TestCase):
|
| 8 |
+
"""The search_courses tool module must export a valid OpenAI
|
| 9 |
+
tool definition and an async executor."""
|
| 10 |
+
|
| 11 |
+
def test_tool_definition_has_required_fields(self):
|
| 12 |
+
self.assertEqual(TOOL_DEFINITION["type"], "function")
|
| 13 |
+
self.assertIn("function", TOOL_DEFINITION)
|
| 14 |
+
fn = TOOL_DEFINITION["function"]
|
| 15 |
+
self.assertIn("name", fn)
|
| 16 |
+
self.assertIn("description", fn)
|
| 17 |
+
self.assertIn("parameters", fn)
|
| 18 |
+
|
| 19 |
+
def test_tool_definition_name(self):
|
| 20 |
+
self.assertEqual(TOOL_DEFINITION["function"]["name"], "search_courses")
|
| 21 |
+
|
| 22 |
+
def test_tool_definition_has_nonempty_description(self):
|
| 23 |
+
self.assertIsInstance(TOOL_DEFINITION["function"]["description"], str)
|
| 24 |
+
self.assertGreater(len(TOOL_DEFINITION["function"]["description"]), 0)
|
| 25 |
+
|
| 26 |
+
def test_tool_definition_parameters_is_valid_schema(self):
|
| 27 |
+
params = TOOL_DEFINITION["function"]["parameters"]
|
| 28 |
+
self.assertEqual(params["type"], "object")
|
| 29 |
+
self.assertIn("properties", params)
|
| 30 |
+
self.assertIn("subject", params["properties"])
|
| 31 |
+
|
| 32 |
+
def test_execute_is_async_callable(self):
|
| 33 |
+
self.assertTrue(asyncio.iscoroutinefunction(execute))
|
|
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import json
|
| 3 |
+
import unittest
|
| 4 |
+
from unittest.mock import AsyncMock, MagicMock, patch
|
| 5 |
+
|
| 6 |
+
from openai import APIConnectionError, APIStatusError
|
| 7 |
+
|
| 8 |
+
from app.llm.llm_client import ToolCallResult
|
| 9 |
+
from app.llm.improved_gemini_client import ImprovedGeminiClient
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
FAKE_TOOL = {
|
| 13 |
+
"type": "function",
|
| 14 |
+
"function": {
|
| 15 |
+
"name": "search_courses",
|
| 16 |
+
"description": "Search courses",
|
| 17 |
+
"parameters": {
|
| 18 |
+
"type": "object",
|
| 19 |
+
"properties": {
|
| 20 |
+
"subject": {"type": "string", "description": "Subject code"},
|
| 21 |
+
},
|
| 22 |
+
},
|
| 23 |
+
},
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _make_text_completion_mock(content="Response"):
|
| 28 |
+
"""Build a ChatCompletion mock with no tool calls."""
|
| 29 |
+
mock_message = MagicMock()
|
| 30 |
+
mock_message.content = content
|
| 31 |
+
mock_message.tool_calls = None
|
| 32 |
+
mock_choice = MagicMock()
|
| 33 |
+
mock_choice.message = mock_message
|
| 34 |
+
return MagicMock(choices=[mock_choice])
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _make_tool_call_mock(fn_name, fn_args_dict, tool_call_id="call_123"):
|
| 38 |
+
"""Build a ChatCompletion mock where the model requests a tool call."""
|
| 39 |
+
fn_args_json = json.dumps(fn_args_dict)
|
| 40 |
+
|
| 41 |
+
tool_call = MagicMock()
|
| 42 |
+
tool_call.id = tool_call_id
|
| 43 |
+
tool_call.function.name = fn_name
|
| 44 |
+
tool_call.function.arguments = fn_args_json
|
| 45 |
+
|
| 46 |
+
mock_message = MagicMock()
|
| 47 |
+
mock_message.content = None
|
| 48 |
+
mock_message.tool_calls = [tool_call]
|
| 49 |
+
mock_message.model_dump.return_value = {
|
| 50 |
+
"role": "assistant",
|
| 51 |
+
"content": None,
|
| 52 |
+
"tool_calls": [{
|
| 53 |
+
"id": tool_call_id,
|
| 54 |
+
"type": "function",
|
| 55 |
+
"function": {"name": fn_name, "arguments": fn_args_json},
|
| 56 |
+
}],
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
mock_choice = MagicMock()
|
| 60 |
+
mock_choice.message = mock_message
|
| 61 |
+
return MagicMock(choices=[mock_choice])
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _make_multi_tool_call_mock(calls):
|
| 65 |
+
"""Build a ChatCompletion mock with multiple parallel tool calls.
|
| 66 |
+
|
| 67 |
+
*calls* is a list of (fn_name, fn_args_dict, tool_call_id) tuples.
|
| 68 |
+
"""
|
| 69 |
+
tool_calls = []
|
| 70 |
+
dump_calls = []
|
| 71 |
+
for fn_name, fn_args_dict, tool_call_id in calls:
|
| 72 |
+
fn_args_json = json.dumps(fn_args_dict)
|
| 73 |
+
tc = MagicMock()
|
| 74 |
+
tc.id = tool_call_id
|
| 75 |
+
tc.function.name = fn_name
|
| 76 |
+
tc.function.arguments = fn_args_json
|
| 77 |
+
tool_calls.append(tc)
|
| 78 |
+
dump_calls.append({
|
| 79 |
+
"id": tool_call_id,
|
| 80 |
+
"type": "function",
|
| 81 |
+
"function": {"name": fn_name, "arguments": fn_args_json},
|
| 82 |
+
})
|
| 83 |
+
|
| 84 |
+
mock_message = MagicMock()
|
| 85 |
+
mock_message.content = None
|
| 86 |
+
mock_message.tool_calls = tool_calls
|
| 87 |
+
mock_message.model_dump.return_value = {
|
| 88 |
+
"role": "assistant",
|
| 89 |
+
"content": None,
|
| 90 |
+
"tool_calls": dump_calls,
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
mock_choice = MagicMock()
|
| 94 |
+
mock_choice.message = mock_message
|
| 95 |
+
return MagicMock(choices=[mock_choice])
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _make_gemini_client(MockSettings, MockCtxMgr):
|
| 99 |
+
"""Instantiate an ImprovedGeminiClient with mocked dependencies."""
|
| 100 |
+
mock_settings = MagicMock()
|
| 101 |
+
mock_settings.llm.gemini.api_key = "fake-key"
|
| 102 |
+
mock_settings.llm.gemini.model = "gemini-2.0-flash"
|
| 103 |
+
MockSettings.return_value = mock_settings
|
| 104 |
+
MockCtxMgr.return_value = MagicMock()
|
| 105 |
+
return ImprovedGeminiClient()
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
@patch("app.llm.improved_gemini_client.get_context_manager")
|
| 109 |
+
@patch("app.llm.improved_gemini_client.get_settings")
|
| 110 |
+
class TestGeminiGenerateWithTools(unittest.TestCase):
|
| 111 |
+
"""Unit tests for ImprovedGeminiClient.generate_with_tools()
|
| 112 |
+
using the OpenAI-compatible endpoint."""
|
| 113 |
+
|
| 114 |
+
# ------------------------------------------------------------------
|
| 115 |
+
# Happy path — no tool call
|
| 116 |
+
# ------------------------------------------------------------------
|
| 117 |
+
|
| 118 |
+
def test_direct_text_response_returns_text(self, MockSettings, MockCtxMgr):
|
| 119 |
+
"""When the model responds with text (no tool call), return it."""
|
| 120 |
+
gemini = _make_gemini_client(MockSettings, MockCtxMgr)
|
| 121 |
+
gemini.openai_client.chat.completions.create = AsyncMock(
|
| 122 |
+
return_value=_make_text_completion_mock("Hello, world!"),
|
| 123 |
+
)
|
| 124 |
+
mock_executor = AsyncMock()
|
| 125 |
+
|
| 126 |
+
result = asyncio.run(gemini.generate_with_tools(
|
| 127 |
+
system_prompt="You are helpful.",
|
| 128 |
+
user_message="Hi there",
|
| 129 |
+
tool_definitions=[FAKE_TOOL],
|
| 130 |
+
tool_executor=mock_executor,
|
| 131 |
+
))
|
| 132 |
+
|
| 133 |
+
self.assertIsInstance(result, ToolCallResult)
|
| 134 |
+
self.assertEqual(result.text, "Hello, world!")
|
| 135 |
+
self.assertFalse(result.used_tool)
|
| 136 |
+
mock_executor.assert_not_called()
|
| 137 |
+
|
| 138 |
+
# ------------------------------------------------------------------
|
| 139 |
+
# Happy path — tool call
|
| 140 |
+
# ------------------------------------------------------------------
|
| 141 |
+
|
| 142 |
+
def test_function_call_triggers_executor_and_returns_final_text(self, MockSettings, MockCtxMgr):
|
| 143 |
+
"""When the model requests a tool call, execute it and return
|
| 144 |
+
the text from the follow-up completion."""
|
| 145 |
+
gemini = _make_gemini_client(MockSettings, MockCtxMgr)
|
| 146 |
+
gemini.openai_client.chat.completions.create = AsyncMock(side_effect=[
|
| 147 |
+
_make_tool_call_mock("search_courses", {"subject": "CSCI"}),
|
| 148 |
+
_make_text_completion_mock("CSCI 1300 is available MWF 10-10:50."),
|
| 149 |
+
])
|
| 150 |
+
mock_executor = AsyncMock(
|
| 151 |
+
return_value={"courses": [{"title": "Intro to CS"}]}
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
result = asyncio.run(gemini.generate_with_tools(
|
| 155 |
+
system_prompt="You are helpful.",
|
| 156 |
+
user_message="What CSCI classes are there?",
|
| 157 |
+
tool_definitions=[FAKE_TOOL],
|
| 158 |
+
tool_executor=mock_executor,
|
| 159 |
+
))
|
| 160 |
+
|
| 161 |
+
mock_executor.assert_called_once_with(
|
| 162 |
+
name="search_courses", subject="CSCI",
|
| 163 |
+
)
|
| 164 |
+
self.assertIsInstance(result, ToolCallResult)
|
| 165 |
+
self.assertEqual(result.text, "CSCI 1300 is available MWF 10-10:50.")
|
| 166 |
+
self.assertTrue(result.used_tool)
|
| 167 |
+
self.assertEqual(result.tool_name, "search_courses")
|
| 168 |
+
self.assertEqual(result.tool_args, {"subject": "CSCI"})
|
| 169 |
+
self.assertEqual(len(result.tool_calls_made), 1)
|
| 170 |
+
self.assertEqual(result.tool_calls_made[0].name, "search_courses")
|
| 171 |
+
self.assertEqual(gemini.openai_client.chat.completions.create.call_count, 2)
|
| 172 |
+
|
| 173 |
+
# ------------------------------------------------------------------
|
| 174 |
+
# Payload format
|
| 175 |
+
# ------------------------------------------------------------------
|
| 176 |
+
|
| 177 |
+
def test_tool_definitions_passed_through_in_openai_format(self, MockSettings, MockCtxMgr):
|
| 178 |
+
"""Tool definitions (already in OpenAI format) are passed through
|
| 179 |
+
directly to the completions API."""
|
| 180 |
+
gemini = _make_gemini_client(MockSettings, MockCtxMgr)
|
| 181 |
+
gemini.openai_client.chat.completions.create = AsyncMock(
|
| 182 |
+
return_value=_make_text_completion_mock("Ok"),
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
asyncio.run(gemini.generate_with_tools(
|
| 186 |
+
system_prompt="You are helpful.",
|
| 187 |
+
user_message="Hello",
|
| 188 |
+
tool_definitions=[FAKE_TOOL],
|
| 189 |
+
tool_executor=AsyncMock(),
|
| 190 |
+
))
|
| 191 |
+
|
| 192 |
+
call_kwargs = gemini.openai_client.chat.completions.create.call_args[1]
|
| 193 |
+
tools = call_kwargs["tools"]
|
| 194 |
+
self.assertEqual(len(tools), 1)
|
| 195 |
+
self.assertEqual(tools[0]["type"], "function")
|
| 196 |
+
self.assertEqual(tools[0]["function"]["name"], "search_courses")
|
| 197 |
+
self.assertIn("parameters", tools[0]["function"])
|
| 198 |
+
|
| 199 |
+
def test_tool_result_appended_to_followup(self, MockSettings, MockCtxMgr):
|
| 200 |
+
"""After executing a tool, the follow-up call must include
|
| 201 |
+
the assistant message, a ``role: tool`` message, and ``tools=``."""
|
| 202 |
+
tool_output = {"courses": [{"title": "Algorithms"}]}
|
| 203 |
+
gemini = _make_gemini_client(MockSettings, MockCtxMgr)
|
| 204 |
+
gemini.openai_client.chat.completions.create = AsyncMock(side_effect=[
|
| 205 |
+
_make_tool_call_mock("search_courses", {"subject": "CSCI"}),
|
| 206 |
+
_make_text_completion_mock("Here are the results."),
|
| 207 |
+
])
|
| 208 |
+
mock_executor = AsyncMock(return_value=tool_output)
|
| 209 |
+
|
| 210 |
+
asyncio.run(gemini.generate_with_tools(
|
| 211 |
+
system_prompt="You are helpful.",
|
| 212 |
+
user_message="Find CSCI courses",
|
| 213 |
+
tool_definitions=[FAKE_TOOL],
|
| 214 |
+
tool_executor=mock_executor,
|
| 215 |
+
))
|
| 216 |
+
|
| 217 |
+
second_call_kwargs = gemini.openai_client.chat.completions.create.call_args_list[1][1]
|
| 218 |
+
messages = second_call_kwargs["messages"]
|
| 219 |
+
|
| 220 |
+
assistant_msg = messages[-2]
|
| 221 |
+
self.assertEqual(assistant_msg["role"], "assistant")
|
| 222 |
+
|
| 223 |
+
tool_msg = messages[-1]
|
| 224 |
+
self.assertEqual(tool_msg["role"], "tool")
|
| 225 |
+
self.assertEqual(tool_msg["tool_call_id"], "call_123")
|
| 226 |
+
self.assertEqual(json.loads(tool_msg["content"]), tool_output)
|
| 227 |
+
|
| 228 |
+
self.assertIn("tools", second_call_kwargs,
|
| 229 |
+
"Follow-up call must include tools= so the model can "
|
| 230 |
+
"request additional tool calls if needed")
|
| 231 |
+
|
| 232 |
+
# ------------------------------------------------------------------
|
| 233 |
+
# Error handling
|
| 234 |
+
# ------------------------------------------------------------------
|
| 235 |
+
|
| 236 |
+
def test_tool_executor_failure_serialises_error_and_continues(self, MockSettings, MockCtxMgr):
|
| 237 |
+
"""If the tool executor raises, the error is serialised as the
|
| 238 |
+
tool result and the loop continues to the follow-up completion."""
|
| 239 |
+
gemini = _make_gemini_client(MockSettings, MockCtxMgr)
|
| 240 |
+
gemini.openai_client.chat.completions.create = AsyncMock(side_effect=[
|
| 241 |
+
_make_tool_call_mock("search_courses", {"subject": "CSCI"}),
|
| 242 |
+
_make_text_completion_mock("Sorry, I couldn't look that up."),
|
| 243 |
+
])
|
| 244 |
+
mock_executor = AsyncMock(side_effect=RuntimeError("network down"))
|
| 245 |
+
|
| 246 |
+
result = asyncio.run(gemini.generate_with_tools(
|
| 247 |
+
system_prompt="You are helpful.",
|
| 248 |
+
user_message="Find CSCI courses",
|
| 249 |
+
tool_definitions=[FAKE_TOOL],
|
| 250 |
+
tool_executor=mock_executor,
|
| 251 |
+
))
|
| 252 |
+
|
| 253 |
+
self.assertTrue(result.used_tool)
|
| 254 |
+
self.assertEqual(result.tool_name, "search_courses")
|
| 255 |
+
self.assertEqual(len(result.tool_calls_made), 1)
|
| 256 |
+
|
| 257 |
+
second_call_msgs = gemini.openai_client.chat.completions.create.call_args_list[1][1]["messages"]
|
| 258 |
+
tool_msg = [m for m in second_call_msgs if m.get("role") == "tool"][0]
|
| 259 |
+
self.assertIn("network down", json.loads(tool_msg["content"])["error"])
|
| 260 |
+
|
| 261 |
+
def test_connection_error_returns_not_used(self, MockSettings, MockCtxMgr):
|
| 262 |
+
"""APIConnectionError during tool calling returns used_tool=False."""
|
| 263 |
+
gemini = _make_gemini_client(MockSettings, MockCtxMgr)
|
| 264 |
+
gemini.openai_client.chat.completions.create = AsyncMock(
|
| 265 |
+
side_effect=APIConnectionError(request=MagicMock()),
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
result = asyncio.run(gemini.generate_with_tools(
|
| 269 |
+
system_prompt="Test",
|
| 270 |
+
user_message="Hi",
|
| 271 |
+
tool_definitions=[FAKE_TOOL],
|
| 272 |
+
tool_executor=AsyncMock(),
|
| 273 |
+
))
|
| 274 |
+
|
| 275 |
+
self.assertIsInstance(result, ToolCallResult)
|
| 276 |
+
self.assertFalse(result.used_tool)
|
| 277 |
+
self.assertIn("unable to connect", result.text.lower())
|
| 278 |
+
|
| 279 |
+
def test_status_error_returns_not_used(self, MockSettings, MockCtxMgr):
|
| 280 |
+
"""APIStatusError during tool calling returns used_tool=False."""
|
| 281 |
+
gemini = _make_gemini_client(MockSettings, MockCtxMgr)
|
| 282 |
+
mock_response = MagicMock()
|
| 283 |
+
mock_response.status_code = 500
|
| 284 |
+
gemini.openai_client.chat.completions.create = AsyncMock(
|
| 285 |
+
side_effect=APIStatusError(
|
| 286 |
+
message="Server error", response=mock_response, body=None,
|
| 287 |
+
)
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
result = asyncio.run(gemini.generate_with_tools(
|
| 291 |
+
system_prompt="Test",
|
| 292 |
+
user_message="Hi",
|
| 293 |
+
tool_definitions=[FAKE_TOOL],
|
| 294 |
+
tool_executor=AsyncMock(),
|
| 295 |
+
))
|
| 296 |
+
|
| 297 |
+
self.assertIsInstance(result, ToolCallResult)
|
| 298 |
+
self.assertFalse(result.used_tool)
|
| 299 |
+
self.assertIn("error", result.text.lower())
|
| 300 |
+
|
| 301 |
+
# ------------------------------------------------------------------
|
| 302 |
+
# Multi-tool call in a single response
|
| 303 |
+
# ------------------------------------------------------------------
|
| 304 |
+
|
| 305 |
+
def test_parallel_tool_calls_all_executed(self, MockSettings, MockCtxMgr):
|
| 306 |
+
"""When the model requests multiple tool calls in one response,
|
| 307 |
+
all of them are executed and their results fed back."""
|
| 308 |
+
gemini = _make_gemini_client(MockSettings, MockCtxMgr)
|
| 309 |
+
gemini.openai_client.chat.completions.create = AsyncMock(side_effect=[
|
| 310 |
+
_make_multi_tool_call_mock([
|
| 311 |
+
("rate_my_professor", {"professor_name": "Dubson"}, "call_a"),
|
| 312 |
+
("rate_my_professor", {"professor_name": "West"}, "call_b"),
|
| 313 |
+
]),
|
| 314 |
+
_make_text_completion_mock("Dubson has a 4.5 rating. West has a 3.8 rating."),
|
| 315 |
+
])
|
| 316 |
+
mock_executor = AsyncMock(side_effect=[
|
| 317 |
+
{"professors": [{"name": "Dubson", "rating": 4.5}]},
|
| 318 |
+
{"professors": [{"name": "West", "rating": 3.8}]},
|
| 319 |
+
])
|
| 320 |
+
|
| 321 |
+
result = asyncio.run(gemini.generate_with_tools(
|
| 322 |
+
system_prompt="You are helpful.",
|
| 323 |
+
user_message="Is professor Dubson or West rated better?",
|
| 324 |
+
tool_definitions=[FAKE_TOOL],
|
| 325 |
+
tool_executor=mock_executor,
|
| 326 |
+
))
|
| 327 |
+
|
| 328 |
+
self.assertEqual(mock_executor.call_count, 2)
|
| 329 |
+
self.assertTrue(result.used_tool)
|
| 330 |
+
self.assertIn("Dubson", result.text)
|
| 331 |
+
self.assertIn("West", result.text)
|
| 332 |
+
self.assertEqual(len(result.tool_calls_made), 2)
|
| 333 |
+
self.assertEqual(result.tool_calls_made[0].name, "rate_my_professor")
|
| 334 |
+
self.assertEqual(result.tool_calls_made[1].args, {"professor_name": "West"})
|
| 335 |
+
self.assertEqual(gemini.openai_client.chat.completions.create.call_count, 2)
|
| 336 |
+
|
| 337 |
+
def test_parallel_tool_results_all_in_followup_messages(self, MockSettings, MockCtxMgr):
|
| 338 |
+
"""All tool results must appear as separate role:tool messages
|
| 339 |
+
in the follow-up request."""
|
| 340 |
+
gemini = _make_gemini_client(MockSettings, MockCtxMgr)
|
| 341 |
+
gemini.openai_client.chat.completions.create = AsyncMock(side_effect=[
|
| 342 |
+
_make_multi_tool_call_mock([
|
| 343 |
+
("rate_my_professor", {"professor_name": "Dubson"}, "call_a"),
|
| 344 |
+
("rate_my_professor", {"professor_name": "West"}, "call_b"),
|
| 345 |
+
]),
|
| 346 |
+
_make_text_completion_mock("Comparison complete."),
|
| 347 |
+
])
|
| 348 |
+
mock_executor = AsyncMock(side_effect=[
|
| 349 |
+
{"professors": [{"name": "Dubson"}]},
|
| 350 |
+
{"professors": [{"name": "West"}]},
|
| 351 |
+
])
|
| 352 |
+
|
| 353 |
+
asyncio.run(gemini.generate_with_tools(
|
| 354 |
+
system_prompt="You are helpful.",
|
| 355 |
+
user_message="Compare",
|
| 356 |
+
tool_definitions=[FAKE_TOOL],
|
| 357 |
+
tool_executor=mock_executor,
|
| 358 |
+
))
|
| 359 |
+
|
| 360 |
+
second_call_msgs = gemini.openai_client.chat.completions.create.call_args_list[1][1]["messages"]
|
| 361 |
+
tool_msgs = [m for m in second_call_msgs if m.get("role") == "tool"]
|
| 362 |
+
self.assertEqual(len(tool_msgs), 2)
|
| 363 |
+
self.assertEqual(tool_msgs[0]["tool_call_id"], "call_a")
|
| 364 |
+
self.assertEqual(tool_msgs[1]["tool_call_id"], "call_b")
|
| 365 |
+
|
| 366 |
+
# ------------------------------------------------------------------
|
| 367 |
+
# Multi-round tool calling
|
| 368 |
+
# ------------------------------------------------------------------
|
| 369 |
+
|
| 370 |
+
def test_sequential_tool_rounds(self, MockSettings, MockCtxMgr):
|
| 371 |
+
"""The loop handles a second round of tool calls after the first
|
| 372 |
+
results are fed back."""
|
| 373 |
+
gemini = _make_gemini_client(MockSettings, MockCtxMgr)
|
| 374 |
+
gemini.openai_client.chat.completions.create = AsyncMock(side_effect=[
|
| 375 |
+
_make_tool_call_mock("rate_my_professor", {"professor_name": "Dubson"}, "call_1"),
|
| 376 |
+
_make_tool_call_mock("rate_my_professor", {"professor_name": "West"}, "call_2"),
|
| 377 |
+
_make_text_completion_mock("Dubson is rated higher than West."),
|
| 378 |
+
])
|
| 379 |
+
mock_executor = AsyncMock(side_effect=[
|
| 380 |
+
{"professors": [{"name": "Dubson", "rating": 4.5}]},
|
| 381 |
+
{"professors": [{"name": "West", "rating": 3.8}]},
|
| 382 |
+
])
|
| 383 |
+
|
| 384 |
+
result = asyncio.run(gemini.generate_with_tools(
|
| 385 |
+
system_prompt="You are helpful.",
|
| 386 |
+
user_message="Compare Dubson and West",
|
| 387 |
+
tool_definitions=[FAKE_TOOL],
|
| 388 |
+
tool_executor=mock_executor,
|
| 389 |
+
))
|
| 390 |
+
|
| 391 |
+
self.assertEqual(mock_executor.call_count, 2)
|
| 392 |
+
self.assertTrue(result.used_tool)
|
| 393 |
+
self.assertEqual(len(result.tool_calls_made), 2)
|
| 394 |
+
self.assertEqual(result.tool_name, "rate_my_professor")
|
| 395 |
+
self.assertEqual(gemini.openai_client.chat.completions.create.call_count, 3)
|
| 396 |
+
|
| 397 |
+
# ------------------------------------------------------------------
|
| 398 |
+
# Partial failure in multi-tool context
|
| 399 |
+
# ------------------------------------------------------------------
|
| 400 |
+
|
| 401 |
+
def test_partial_tool_failure_continues(self, MockSettings, MockCtxMgr):
|
| 402 |
+
"""If one tool call in a batch fails, the error is serialised
|
| 403 |
+
and the loop continues to the follow-up."""
|
| 404 |
+
gemini = _make_gemini_client(MockSettings, MockCtxMgr)
|
| 405 |
+
gemini.openai_client.chat.completions.create = AsyncMock(side_effect=[
|
| 406 |
+
_make_multi_tool_call_mock([
|
| 407 |
+
("rate_my_professor", {"professor_name": "Dubson"}, "call_a"),
|
| 408 |
+
("rate_my_professor", {"professor_name": "West"}, "call_b"),
|
| 409 |
+
]),
|
| 410 |
+
_make_text_completion_mock("Only Dubson data available."),
|
| 411 |
+
])
|
| 412 |
+
mock_executor = AsyncMock(side_effect=[
|
| 413 |
+
{"professors": [{"name": "Dubson", "rating": 4.5}]},
|
| 414 |
+
RuntimeError("network down"),
|
| 415 |
+
])
|
| 416 |
+
|
| 417 |
+
result = asyncio.run(gemini.generate_with_tools(
|
| 418 |
+
system_prompt="You are helpful.",
|
| 419 |
+
user_message="Compare",
|
| 420 |
+
tool_definitions=[FAKE_TOOL],
|
| 421 |
+
tool_executor=mock_executor,
|
| 422 |
+
))
|
| 423 |
+
|
| 424 |
+
self.assertTrue(result.used_tool)
|
| 425 |
+
self.assertEqual(len(result.tool_calls_made), 2)
|
| 426 |
+
self.assertEqual(gemini.openai_client.chat.completions.create.call_count, 2)
|
| 427 |
+
|
| 428 |
+
second_call_msgs = gemini.openai_client.chat.completions.create.call_args_list[1][1]["messages"]
|
| 429 |
+
tool_msgs = [m for m in second_call_msgs if m.get("role") == "tool"]
|
| 430 |
+
self.assertEqual(len(tool_msgs), 2)
|
| 431 |
+
error_content = json.loads(tool_msgs[1]["content"])
|
| 432 |
+
self.assertIn("error", error_content)
|
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import unittest
|
| 3 |
+
from unittest.mock import AsyncMock, MagicMock, patch
|
| 4 |
+
|
| 5 |
+
from app.tools.rate_my_professor import TOOL_DEFINITION, execute
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def _graphql_success_response(nodes):
|
| 9 |
+
"""Build a mock RMP GraphQL response containing the given teacher nodes."""
|
| 10 |
+
edges = [{"cursor": f"c{i}", "node": n} for i, n in enumerate(nodes)]
|
| 11 |
+
return {
|
| 12 |
+
"data": {
|
| 13 |
+
"search": {
|
| 14 |
+
"teachers": {
|
| 15 |
+
"didFallback": False,
|
| 16 |
+
"edges": edges,
|
| 17 |
+
"pageInfo": {"hasNextPage": False, "endCursor": ""},
|
| 18 |
+
}
|
| 19 |
+
}
|
| 20 |
+
}
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
SAMPLE_NODE = {
|
| 25 |
+
"id": "VGVhY2hlci0xMjM0",
|
| 26 |
+
"legacyId": 1234,
|
| 27 |
+
"firstName": "Jane",
|
| 28 |
+
"lastName": "Smith",
|
| 29 |
+
"department": "Computer Science",
|
| 30 |
+
"school": {"id": "U2Nob29sLTEwODc=", "name": "University of Colorado Boulder"},
|
| 31 |
+
"avgRating": 4.2,
|
| 32 |
+
"avgDifficulty": 3.1,
|
| 33 |
+
"wouldTakeAgainPercent": 85.0,
|
| 34 |
+
"numRatings": 42,
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class TestRMPToolContract(unittest.TestCase):
|
| 39 |
+
"""The rate_my_professor tool module must export a valid OpenAI
|
| 40 |
+
tool definition and an async executor."""
|
| 41 |
+
|
| 42 |
+
def test_tool_definition_has_required_fields(self):
|
| 43 |
+
self.assertEqual(TOOL_DEFINITION["type"], "function")
|
| 44 |
+
self.assertIn("function", TOOL_DEFINITION)
|
| 45 |
+
fn = TOOL_DEFINITION["function"]
|
| 46 |
+
self.assertIn("name", fn)
|
| 47 |
+
self.assertIn("description", fn)
|
| 48 |
+
self.assertIn("parameters", fn)
|
| 49 |
+
|
| 50 |
+
def test_tool_definition_name(self):
|
| 51 |
+
self.assertEqual(TOOL_DEFINITION["function"]["name"], "rate_my_professor")
|
| 52 |
+
|
| 53 |
+
def test_tool_definition_has_nonempty_description(self):
|
| 54 |
+
self.assertIsInstance(TOOL_DEFINITION["function"]["description"], str)
|
| 55 |
+
self.assertGreater(len(TOOL_DEFINITION["function"]["description"]), 0)
|
| 56 |
+
|
| 57 |
+
def test_tool_definition_parameters_schema(self):
|
| 58 |
+
params = TOOL_DEFINITION["function"]["parameters"]
|
| 59 |
+
self.assertEqual(params["type"], "object")
|
| 60 |
+
self.assertIn("properties", params)
|
| 61 |
+
self.assertIn("professor_name", params["properties"])
|
| 62 |
+
|
| 63 |
+
def test_execute_is_async_callable(self):
|
| 64 |
+
self.assertTrue(asyncio.iscoroutinefunction(execute))
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _fake_tool_config(name):
|
| 68 |
+
"""Return a fake tool config dict with school_id set."""
|
| 69 |
+
if name == "rate_my_professor":
|
| 70 |
+
return {"enabled": True, "school_id": "U2Nob29sLTEwODc="}
|
| 71 |
+
return {}
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
@patch("app.tools.rate_my_professor.get_settings")
|
| 75 |
+
class TestRMPToolExecutor(unittest.TestCase):
|
| 76 |
+
"""Unit tests for rate_my_professor.execute() with mocked HTTP."""
|
| 77 |
+
|
| 78 |
+
def _mock_client(self, get_response, post_response):
|
| 79 |
+
"""Build a mock httpx.AsyncClient with canned GET and POST responses."""
|
| 80 |
+
get_resp = MagicMock()
|
| 81 |
+
get_resp.text = '<script>"Authorization":"Basic dGVzdDp0ZXN0"</script>'
|
| 82 |
+
get_resp.raise_for_status = MagicMock()
|
| 83 |
+
if get_response is not None:
|
| 84 |
+
get_resp.text = get_response
|
| 85 |
+
|
| 86 |
+
post_resp = MagicMock()
|
| 87 |
+
post_resp.status_code = 200
|
| 88 |
+
post_resp.json.return_value = post_response
|
| 89 |
+
post_resp.raise_for_status = MagicMock()
|
| 90 |
+
|
| 91 |
+
client_instance = AsyncMock()
|
| 92 |
+
client_instance.get = AsyncMock(return_value=get_resp)
|
| 93 |
+
client_instance.post = AsyncMock(return_value=post_resp)
|
| 94 |
+
|
| 95 |
+
ctx = MagicMock()
|
| 96 |
+
ctx.__aenter__ = AsyncMock(return_value=client_instance)
|
| 97 |
+
ctx.__aexit__ = AsyncMock(return_value=False)
|
| 98 |
+
return ctx, client_instance
|
| 99 |
+
|
| 100 |
+
def test_execute_returns_professor_data(self, mock_get_settings):
|
| 101 |
+
"""Successful GraphQL response returns structured professor data."""
|
| 102 |
+
mock_get_settings.return_value.tools.get_tool_config = _fake_tool_config
|
| 103 |
+
ctx, client = self._mock_client(
|
| 104 |
+
get_response=None,
|
| 105 |
+
post_response=_graphql_success_response([SAMPLE_NODE]),
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
with patch("httpx.AsyncClient", return_value=ctx):
|
| 109 |
+
result = asyncio.run(execute(professor_name="Smith"))
|
| 110 |
+
|
| 111 |
+
self.assertIn("professors", result)
|
| 112 |
+
self.assertEqual(len(result["professors"]), 1)
|
| 113 |
+
|
| 114 |
+
prof = result["professors"][0]
|
| 115 |
+
self.assertEqual(prof["name"], "Jane Smith")
|
| 116 |
+
self.assertEqual(prof["department"], "Computer Science")
|
| 117 |
+
self.assertAlmostEqual(prof["rating"], 4.2)
|
| 118 |
+
self.assertAlmostEqual(prof["difficulty"], 3.1)
|
| 119 |
+
self.assertEqual(prof["num_ratings"], 42)
|
| 120 |
+
|
| 121 |
+
def test_execute_returns_empty_on_no_results(self, mock_get_settings):
|
| 122 |
+
"""When the GraphQL API returns no matching professors, return
|
| 123 |
+
an empty list — not an error."""
|
| 124 |
+
mock_get_settings.return_value.tools.get_tool_config = _fake_tool_config
|
| 125 |
+
ctx, _ = self._mock_client(
|
| 126 |
+
get_response=None,
|
| 127 |
+
post_response=_graphql_success_response([]),
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
with patch("httpx.AsyncClient", return_value=ctx):
|
| 131 |
+
result = asyncio.run(execute(professor_name="Nonexistent"))
|
| 132 |
+
|
| 133 |
+
self.assertIn("professors", result)
|
| 134 |
+
self.assertEqual(len(result["professors"]), 0)
|
| 135 |
+
|
| 136 |
+
def test_execute_returns_error_on_api_failure(self, mock_get_settings):
|
| 137 |
+
"""When the HTTP request fails, return an error payload instead
|
| 138 |
+
of raising an exception."""
|
| 139 |
+
mock_get_settings.return_value.tools.get_tool_config = _fake_tool_config
|
| 140 |
+
ctx = MagicMock()
|
| 141 |
+
client_instance = AsyncMock()
|
| 142 |
+
client_instance.get = AsyncMock(side_effect=Exception("connection refused"))
|
| 143 |
+
client_instance.post = AsyncMock(side_effect=Exception("connection refused"))
|
| 144 |
+
ctx.__aenter__ = AsyncMock(return_value=client_instance)
|
| 145 |
+
ctx.__aexit__ = AsyncMock(return_value=False)
|
| 146 |
+
|
| 147 |
+
with patch("httpx.AsyncClient", return_value=ctx):
|
| 148 |
+
result = asyncio.run(execute(professor_name="Smith"))
|
| 149 |
+
|
| 150 |
+
self.assertIn("professors", result)
|
| 151 |
+
self.assertEqual(len(result["professors"]), 0)
|
| 152 |
+
self.assertIn("error", result)
|
| 153 |
+
|
| 154 |
+
def test_execute_accepts_name_kwarg(self, mock_get_settings):
|
| 155 |
+
"""The dispatcher passes name= as a kwarg; execute must accept
|
| 156 |
+
and ignore it without error."""
|
| 157 |
+
mock_get_settings.return_value.tools.get_tool_config = _fake_tool_config
|
| 158 |
+
ctx, _ = self._mock_client(
|
| 159 |
+
get_response=None,
|
| 160 |
+
post_response=_graphql_success_response([SAMPLE_NODE]),
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
with patch("httpx.AsyncClient", return_value=ctx):
|
| 164 |
+
result = asyncio.run(
|
| 165 |
+
execute(name="rate_my_professor", professor_name="Smith")
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
self.assertIn("professors", result)
|
| 169 |
+
self.assertEqual(len(result["professors"]), 1)
|
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import unittest
|
| 3 |
+
from unittest.mock import AsyncMock
|
| 4 |
+
|
| 5 |
+
from app.tools import (
|
| 6 |
+
get_tool_definitions,
|
| 7 |
+
get_tool_executor,
|
| 8 |
+
list_registered_tools,
|
| 9 |
+
_REGISTRY,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
KNOWN_TOOLS = {"search_courses", "rate_my_professor"}
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class TestToolDiscovery(unittest.TestCase):
|
| 17 |
+
"""Auto-discovery should find every tool module that exports
|
| 18 |
+
TOOL_DEFINITION + execute."""
|
| 19 |
+
|
| 20 |
+
def test_known_tools_are_discovered(self):
|
| 21 |
+
registered = set(list_registered_tools())
|
| 22 |
+
for name in KNOWN_TOOLS:
|
| 23 |
+
self.assertIn(name, registered, f"Tool '{name}' was not discovered")
|
| 24 |
+
|
| 25 |
+
def test_registry_entries_have_definition_and_executor(self):
|
| 26 |
+
for name, entry in _REGISTRY.items():
|
| 27 |
+
self.assertIn("definition", entry, f"'{name}' missing definition")
|
| 28 |
+
self.assertIn("executor", entry, f"'{name}' missing executor")
|
| 29 |
+
|
| 30 |
+
def test_definitions_have_required_fields(self):
|
| 31 |
+
for name, entry in _REGISTRY.items():
|
| 32 |
+
defn = entry["definition"]
|
| 33 |
+
self.assertEqual(defn["type"], "function")
|
| 34 |
+
self.assertIn("function", defn)
|
| 35 |
+
fn = defn["function"]
|
| 36 |
+
self.assertIn("name", fn)
|
| 37 |
+
self.assertIn("description", fn)
|
| 38 |
+
self.assertIn("parameters", fn)
|
| 39 |
+
self.assertEqual(fn["name"], name)
|
| 40 |
+
|
| 41 |
+
def test_executors_are_async_callables(self):
|
| 42 |
+
for name, entry in _REGISTRY.items():
|
| 43 |
+
self.assertTrue(
|
| 44 |
+
asyncio.iscoroutinefunction(entry["executor"]),
|
| 45 |
+
f"Executor for '{name}' is not an async function",
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class TestGetToolDefinitions(unittest.TestCase):
|
| 50 |
+
"""get_tool_definitions() returns OpenAI-format tool dicts,
|
| 51 |
+
optionally filtered."""
|
| 52 |
+
|
| 53 |
+
def test_returns_all_when_no_filter(self):
|
| 54 |
+
defs = get_tool_definitions()
|
| 55 |
+
names = {d["function"]["name"] for d in defs}
|
| 56 |
+
self.assertTrue(KNOWN_TOOLS.issubset(names))
|
| 57 |
+
|
| 58 |
+
def test_filter_to_single_tool(self):
|
| 59 |
+
defs = get_tool_definitions(enabled=["search_courses"])
|
| 60 |
+
self.assertEqual(len(defs), 1)
|
| 61 |
+
self.assertEqual(defs[0]["function"]["name"], "search_courses")
|
| 62 |
+
|
| 63 |
+
def test_filter_to_multiple_tools(self):
|
| 64 |
+
defs = get_tool_definitions(enabled=["search_courses", "rate_my_professor"])
|
| 65 |
+
names = {d["function"]["name"] for d in defs}
|
| 66 |
+
self.assertEqual(names, KNOWN_TOOLS)
|
| 67 |
+
|
| 68 |
+
def test_filter_with_unknown_name_returns_empty(self):
|
| 69 |
+
defs = get_tool_definitions(enabled=["nonexistent_tool"])
|
| 70 |
+
self.assertEqual(defs, [])
|
| 71 |
+
|
| 72 |
+
def test_filter_with_empty_list_returns_empty(self):
|
| 73 |
+
defs = get_tool_definitions(enabled=[])
|
| 74 |
+
self.assertEqual(defs, [])
|
| 75 |
+
|
| 76 |
+
def test_filter_ignores_unknown_names_keeps_valid(self):
|
| 77 |
+
defs = get_tool_definitions(enabled=["search_courses", "bogus"])
|
| 78 |
+
self.assertEqual(len(defs), 1)
|
| 79 |
+
self.assertEqual(defs[0]["function"]["name"], "search_courses")
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class TestGetToolExecutor(unittest.TestCase):
|
| 83 |
+
"""get_tool_executor() returns a dispatcher that routes to the
|
| 84 |
+
correct tool executor."""
|
| 85 |
+
|
| 86 |
+
def test_dispatch_known_tool(self):
|
| 87 |
+
mock_exec = AsyncMock(return_value={"courses": []})
|
| 88 |
+
original = _REGISTRY["search_courses"]["executor"]
|
| 89 |
+
_REGISTRY["search_courses"]["executor"] = mock_exec
|
| 90 |
+
try:
|
| 91 |
+
dispatch = get_tool_executor()
|
| 92 |
+
result = asyncio.run(dispatch(name="search_courses", subject="CSCI"))
|
| 93 |
+
mock_exec.assert_called_once_with(name="search_courses", subject="CSCI")
|
| 94 |
+
self.assertEqual(result, {"courses": []})
|
| 95 |
+
finally:
|
| 96 |
+
_REGISTRY["search_courses"]["executor"] = original
|
| 97 |
+
|
| 98 |
+
def test_dispatch_unknown_tool_returns_error(self):
|
| 99 |
+
dispatch = get_tool_executor()
|
| 100 |
+
result = asyncio.run(dispatch(name="nonexistent"))
|
| 101 |
+
self.assertIn("error", result)
|
| 102 |
+
|
| 103 |
+
def test_filtered_executor_allows_enabled_tool(self):
|
| 104 |
+
mock_exec = AsyncMock(return_value={"courses": []})
|
| 105 |
+
original = _REGISTRY["search_courses"]["executor"]
|
| 106 |
+
_REGISTRY["search_courses"]["executor"] = mock_exec
|
| 107 |
+
try:
|
| 108 |
+
dispatch = get_tool_executor(enabled=["search_courses"])
|
| 109 |
+
result = asyncio.run(dispatch(name="search_courses", subject="CSCI"))
|
| 110 |
+
self.assertNotIn("error", result)
|
| 111 |
+
finally:
|
| 112 |
+
_REGISTRY["search_courses"]["executor"] = original
|
| 113 |
+
|
| 114 |
+
def test_filtered_executor_blocks_disabled_tool(self):
|
| 115 |
+
dispatch = get_tool_executor(enabled=["search_courses"])
|
| 116 |
+
result = asyncio.run(dispatch(name="rate_my_professor", professor_name="Smith"))
|
| 117 |
+
self.assertIn("error", result)
|
| 118 |
+
self.assertIn("not enabled", result["error"])
|
| 119 |
+
|
| 120 |
+
def test_filtered_executor_with_empty_list_blocks_all(self):
|
| 121 |
+
dispatch = get_tool_executor(enabled=[])
|
| 122 |
+
result = asyncio.run(dispatch(name="search_courses", subject="CSCI"))
|
| 123 |
+
self.assertIn("error", result)
|
|
@@ -1,15 +1,31 @@
|
|
| 1 |
import asyncio
|
|
|
|
| 2 |
import unittest
|
| 3 |
from unittest.mock import AsyncMock, MagicMock, patch
|
| 4 |
|
| 5 |
from openai import APIConnectionError, APIStatusError
|
| 6 |
|
|
|
|
| 7 |
from app.llm.improved_vllm_client import ImprovedVllmClient
|
| 8 |
|
| 9 |
|
| 10 |
FAKE_URL = "https://fake.example.com/vllm0"
|
| 11 |
FAKE_KEY = "test-key"
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
def _make_completion_mock(content="Response"):
|
| 15 |
"""Build a mock that looks like an OpenAI ChatCompletion."""
|
|
@@ -20,6 +36,77 @@ def _make_completion_mock(content="Response"):
|
|
| 20 |
return MagicMock(choices=[mock_choice])
|
| 21 |
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
@patch("app.llm.improved_vllm_client.get_context_manager")
|
| 24 |
@patch("app.llm.improved_vllm_client.AsyncOpenAI")
|
| 25 |
class TestImprovedVllmClient(unittest.TestCase):
|
|
@@ -214,3 +301,328 @@ class TestImprovedVllmClient(unittest.TestCase):
|
|
| 214 |
self.assertNotIn("\r", cleaned)
|
| 215 |
self.assertNotIn("\n\n\n", cleaned)
|
| 216 |
self.assertEqual(cleaned, "Line one.\n\nLine two.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import asyncio
|
| 2 |
+
import json
|
| 3 |
import unittest
|
| 4 |
from unittest.mock import AsyncMock, MagicMock, patch
|
| 5 |
|
| 6 |
from openai import APIConnectionError, APIStatusError
|
| 7 |
|
| 8 |
+
from app.llm.llm_client import ToolCallResult
|
| 9 |
from app.llm.improved_vllm_client import ImprovedVllmClient
|
| 10 |
|
| 11 |
|
| 12 |
FAKE_URL = "https://fake.example.com/vllm0"
|
| 13 |
FAKE_KEY = "test-key"
|
| 14 |
|
| 15 |
+
FAKE_TOOL = {
|
| 16 |
+
"type": "function",
|
| 17 |
+
"function": {
|
| 18 |
+
"name": "search_courses",
|
| 19 |
+
"description": "Search courses",
|
| 20 |
+
"parameters": {
|
| 21 |
+
"type": "object",
|
| 22 |
+
"properties": {
|
| 23 |
+
"subject": {"type": "string", "description": "Subject code"},
|
| 24 |
+
},
|
| 25 |
+
},
|
| 26 |
+
},
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
|
| 30 |
def _make_completion_mock(content="Response"):
|
| 31 |
"""Build a mock that looks like an OpenAI ChatCompletion."""
|
|
|
|
| 36 |
return MagicMock(choices=[mock_choice])
|
| 37 |
|
| 38 |
|
| 39 |
+
def _make_text_completion_mock(content="Response"):
|
| 40 |
+
"""Build a ChatCompletion mock with no tool calls."""
|
| 41 |
+
mock_message = MagicMock()
|
| 42 |
+
mock_message.content = content
|
| 43 |
+
mock_message.tool_calls = None
|
| 44 |
+
mock_choice = MagicMock()
|
| 45 |
+
mock_choice.message = mock_message
|
| 46 |
+
return MagicMock(choices=[mock_choice])
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _make_tool_call_mock(fn_name, fn_args_dict, tool_call_id="call_123"):
|
| 50 |
+
"""Build a ChatCompletion mock where the model requests a tool call."""
|
| 51 |
+
fn_args_json = json.dumps(fn_args_dict)
|
| 52 |
+
|
| 53 |
+
tool_call = MagicMock()
|
| 54 |
+
tool_call.id = tool_call_id
|
| 55 |
+
tool_call.function.name = fn_name
|
| 56 |
+
tool_call.function.arguments = fn_args_json
|
| 57 |
+
|
| 58 |
+
mock_message = MagicMock()
|
| 59 |
+
mock_message.content = None
|
| 60 |
+
mock_message.tool_calls = [tool_call]
|
| 61 |
+
mock_message.model_dump.return_value = {
|
| 62 |
+
"role": "assistant",
|
| 63 |
+
"content": None,
|
| 64 |
+
"tool_calls": [{
|
| 65 |
+
"id": tool_call_id,
|
| 66 |
+
"type": "function",
|
| 67 |
+
"function": {"name": fn_name, "arguments": fn_args_json},
|
| 68 |
+
}],
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
mock_choice = MagicMock()
|
| 72 |
+
mock_choice.message = mock_message
|
| 73 |
+
return MagicMock(choices=[mock_choice])
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def _make_multi_tool_call_mock(calls):
|
| 77 |
+
"""Build a ChatCompletion mock with multiple parallel tool calls.
|
| 78 |
+
|
| 79 |
+
*calls* is a list of (fn_name, fn_args_dict, tool_call_id) tuples.
|
| 80 |
+
"""
|
| 81 |
+
tool_calls = []
|
| 82 |
+
dump_calls = []
|
| 83 |
+
for fn_name, fn_args_dict, tool_call_id in calls:
|
| 84 |
+
fn_args_json = json.dumps(fn_args_dict)
|
| 85 |
+
tc = MagicMock()
|
| 86 |
+
tc.id = tool_call_id
|
| 87 |
+
tc.function.name = fn_name
|
| 88 |
+
tc.function.arguments = fn_args_json
|
| 89 |
+
tool_calls.append(tc)
|
| 90 |
+
dump_calls.append({
|
| 91 |
+
"id": tool_call_id,
|
| 92 |
+
"type": "function",
|
| 93 |
+
"function": {"name": fn_name, "arguments": fn_args_json},
|
| 94 |
+
})
|
| 95 |
+
|
| 96 |
+
mock_message = MagicMock()
|
| 97 |
+
mock_message.content = None
|
| 98 |
+
mock_message.tool_calls = tool_calls
|
| 99 |
+
mock_message.model_dump.return_value = {
|
| 100 |
+
"role": "assistant",
|
| 101 |
+
"content": None,
|
| 102 |
+
"tool_calls": dump_calls,
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
mock_choice = MagicMock()
|
| 106 |
+
mock_choice.message = mock_message
|
| 107 |
+
return MagicMock(choices=[mock_choice])
|
| 108 |
+
|
| 109 |
+
|
| 110 |
@patch("app.llm.improved_vllm_client.get_context_manager")
|
| 111 |
@patch("app.llm.improved_vllm_client.AsyncOpenAI")
|
| 112 |
class TestImprovedVllmClient(unittest.TestCase):
|
|
|
|
| 301 |
self.assertNotIn("\r", cleaned)
|
| 302 |
self.assertNotIn("\n\n\n", cleaned)
|
| 303 |
self.assertEqual(cleaned, "Line one.\n\nLine two.")
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
@patch("app.llm.improved_vllm_client.get_context_manager")
|
| 307 |
+
@patch("app.llm.improved_vllm_client.AsyncOpenAI")
|
| 308 |
+
class TestVllmGenerateWithTools(unittest.TestCase):
|
| 309 |
+
"""Unit tests for ImprovedVllmClient.generate_with_tools()."""
|
| 310 |
+
|
| 311 |
+
# ------------------------------------------------------------------
|
| 312 |
+
# Happy path — no tool call
|
| 313 |
+
# ------------------------------------------------------------------
|
| 314 |
+
|
| 315 |
+
def test_text_response_returns_not_used(self, MockAsyncOpenAI, mock_get_ctx):
|
| 316 |
+
"""When the model responds with plain text, return used_tool=False."""
|
| 317 |
+
client = ImprovedVllmClient(
|
| 318 |
+
api_url=FAKE_URL, api_key=FAKE_KEY, model_name="test-model",
|
| 319 |
+
)
|
| 320 |
+
client.client.chat.completions.create = AsyncMock(
|
| 321 |
+
return_value=_make_text_completion_mock("Hello, world!"),
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
result = asyncio.run(client.generate_with_tools(
|
| 325 |
+
system_prompt="You are helpful.",
|
| 326 |
+
user_message="Hi there",
|
| 327 |
+
tool_definitions=[FAKE_TOOL],
|
| 328 |
+
tool_executor=AsyncMock(),
|
| 329 |
+
))
|
| 330 |
+
|
| 331 |
+
self.assertIsInstance(result, ToolCallResult)
|
| 332 |
+
self.assertEqual(result.text, "Hello, world!")
|
| 333 |
+
self.assertFalse(result.used_tool)
|
| 334 |
+
|
| 335 |
+
# ------------------------------------------------------------------
|
| 336 |
+
# Happy path — tool call
|
| 337 |
+
# ------------------------------------------------------------------
|
| 338 |
+
|
| 339 |
+
def test_tool_call_executes_and_returns_final_text(self, MockAsyncOpenAI, mock_get_ctx):
|
| 340 |
+
"""When the model requests a tool call, execute it and return
|
| 341 |
+
the text from the follow-up completion."""
|
| 342 |
+
client = ImprovedVllmClient(
|
| 343 |
+
api_url=FAKE_URL, api_key=FAKE_KEY, model_name="test-model",
|
| 344 |
+
)
|
| 345 |
+
client.client.chat.completions.create = AsyncMock(side_effect=[
|
| 346 |
+
_make_tool_call_mock("search_courses", {"subject": "CSCI"}),
|
| 347 |
+
_make_text_completion_mock("CSCI 1300 is available MWF 10-10:50."),
|
| 348 |
+
])
|
| 349 |
+
mock_executor = AsyncMock(
|
| 350 |
+
return_value={"courses": [{"title": "Intro to CS"}]},
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
result = asyncio.run(client.generate_with_tools(
|
| 354 |
+
system_prompt="You are helpful.",
|
| 355 |
+
user_message="What CSCI classes are there?",
|
| 356 |
+
tool_definitions=[FAKE_TOOL],
|
| 357 |
+
tool_executor=mock_executor,
|
| 358 |
+
))
|
| 359 |
+
|
| 360 |
+
mock_executor.assert_called_once_with(
|
| 361 |
+
name="search_courses", subject="CSCI",
|
| 362 |
+
)
|
| 363 |
+
self.assertIsInstance(result, ToolCallResult)
|
| 364 |
+
self.assertEqual(result.text, "CSCI 1300 is available MWF 10-10:50.")
|
| 365 |
+
self.assertTrue(result.used_tool)
|
| 366 |
+
self.assertEqual(result.tool_name, "search_courses")
|
| 367 |
+
self.assertEqual(result.tool_args, {"subject": "CSCI"})
|
| 368 |
+
self.assertEqual(len(result.tool_calls_made), 1)
|
| 369 |
+
self.assertEqual(result.tool_calls_made[0].name, "search_courses")
|
| 370 |
+
self.assertEqual(client.client.chat.completions.create.call_count, 2)
|
| 371 |
+
|
| 372 |
+
# ------------------------------------------------------------------
|
| 373 |
+
# Payload format
|
| 374 |
+
# ------------------------------------------------------------------
|
| 375 |
+
|
| 376 |
+
def test_tool_definitions_passed_through_in_openai_format(self, MockAsyncOpenAI, mock_get_ctx):
|
| 377 |
+
"""Tool definitions (already in OpenAI format) are passed through
|
| 378 |
+
directly to the completions API."""
|
| 379 |
+
client = ImprovedVllmClient(
|
| 380 |
+
api_url=FAKE_URL, api_key=FAKE_KEY, model_name="test-model",
|
| 381 |
+
)
|
| 382 |
+
client.client.chat.completions.create = AsyncMock(
|
| 383 |
+
return_value=_make_text_completion_mock("Ok"),
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
asyncio.run(client.generate_with_tools(
|
| 387 |
+
system_prompt="You are helpful.",
|
| 388 |
+
user_message="Hello",
|
| 389 |
+
tool_definitions=[FAKE_TOOL],
|
| 390 |
+
tool_executor=AsyncMock(),
|
| 391 |
+
))
|
| 392 |
+
|
| 393 |
+
call_kwargs = client.client.chat.completions.create.call_args[1]
|
| 394 |
+
tools = call_kwargs["tools"]
|
| 395 |
+
self.assertEqual(len(tools), 1)
|
| 396 |
+
self.assertEqual(tools[0]["type"], "function")
|
| 397 |
+
self.assertEqual(tools[0]["function"]["name"], "search_courses")
|
| 398 |
+
self.assertIn("parameters", tools[0]["function"])
|
| 399 |
+
|
| 400 |
+
def test_tool_result_appended_to_followup(self, MockAsyncOpenAI, mock_get_ctx):
|
| 401 |
+
"""After executing a tool, the follow-up call must include
|
| 402 |
+
the assistant message, a ``role: tool`` message, and ``tools=``."""
|
| 403 |
+
tool_output = {"courses": [{"title": "Algorithms"}]}
|
| 404 |
+
client = ImprovedVllmClient(
|
| 405 |
+
api_url=FAKE_URL, api_key=FAKE_KEY, model_name="test-model",
|
| 406 |
+
)
|
| 407 |
+
client.client.chat.completions.create = AsyncMock(side_effect=[
|
| 408 |
+
_make_tool_call_mock("search_courses", {"subject": "CSCI"}),
|
| 409 |
+
_make_text_completion_mock("Here are the results."),
|
| 410 |
+
])
|
| 411 |
+
mock_executor = AsyncMock(return_value=tool_output)
|
| 412 |
+
|
| 413 |
+
asyncio.run(client.generate_with_tools(
|
| 414 |
+
system_prompt="You are helpful.",
|
| 415 |
+
user_message="Find CSCI courses",
|
| 416 |
+
tool_definitions=[FAKE_TOOL],
|
| 417 |
+
tool_executor=mock_executor,
|
| 418 |
+
))
|
| 419 |
+
|
| 420 |
+
second_call_kwargs = client.client.chat.completions.create.call_args_list[1][1]
|
| 421 |
+
messages = second_call_kwargs["messages"]
|
| 422 |
+
|
| 423 |
+
assistant_msg = messages[-2]
|
| 424 |
+
self.assertEqual(assistant_msg["role"], "assistant")
|
| 425 |
+
|
| 426 |
+
tool_msg = messages[-1]
|
| 427 |
+
self.assertEqual(tool_msg["role"], "tool")
|
| 428 |
+
self.assertEqual(tool_msg["tool_call_id"], "call_123")
|
| 429 |
+
self.assertEqual(json.loads(tool_msg["content"]), tool_output)
|
| 430 |
+
|
| 431 |
+
self.assertIn("tools", second_call_kwargs,
|
| 432 |
+
"Follow-up call must include tools= so the model can "
|
| 433 |
+
"request additional tool calls if needed")
|
| 434 |
+
|
| 435 |
+
# ------------------------------------------------------------------
|
| 436 |
+
# Error handling
|
| 437 |
+
# ------------------------------------------------------------------
|
| 438 |
+
|
| 439 |
+
def test_tool_executor_failure_serialises_error_and_continues(self, MockAsyncOpenAI, mock_get_ctx):
|
| 440 |
+
"""If the tool executor raises, the error is serialised as the
|
| 441 |
+
tool result and the loop continues to the follow-up completion."""
|
| 442 |
+
client = ImprovedVllmClient(
|
| 443 |
+
api_url=FAKE_URL, api_key=FAKE_KEY, model_name="test-model",
|
| 444 |
+
)
|
| 445 |
+
client.client.chat.completions.create = AsyncMock(side_effect=[
|
| 446 |
+
_make_tool_call_mock("search_courses", {"subject": "CSCI"}),
|
| 447 |
+
_make_text_completion_mock("Sorry, I couldn't look that up."),
|
| 448 |
+
])
|
| 449 |
+
mock_executor = AsyncMock(side_effect=RuntimeError("network down"))
|
| 450 |
+
|
| 451 |
+
result = asyncio.run(client.generate_with_tools(
|
| 452 |
+
system_prompt="You are helpful.",
|
| 453 |
+
user_message="Find CSCI courses",
|
| 454 |
+
tool_definitions=[FAKE_TOOL],
|
| 455 |
+
tool_executor=mock_executor,
|
| 456 |
+
))
|
| 457 |
+
|
| 458 |
+
self.assertTrue(result.used_tool)
|
| 459 |
+
self.assertEqual(result.tool_name, "search_courses")
|
| 460 |
+
self.assertEqual(len(result.tool_calls_made), 1)
|
| 461 |
+
|
| 462 |
+
second_call_msgs = client.client.chat.completions.create.call_args_list[1][1]["messages"]
|
| 463 |
+
tool_msg = [m for m in second_call_msgs if m.get("role") == "tool"][0]
|
| 464 |
+
self.assertIn("network down", json.loads(tool_msg["content"])["error"])
|
| 465 |
+
|
| 466 |
+
def test_connection_error_returns_not_used(self, MockAsyncOpenAI, mock_get_ctx):
|
| 467 |
+
"""APIConnectionError during tool calling returns used_tool=False."""
|
| 468 |
+
client = ImprovedVllmClient(
|
| 469 |
+
api_url=FAKE_URL, api_key=FAKE_KEY, model_name="test-model",
|
| 470 |
+
)
|
| 471 |
+
client.client.chat.completions.create = AsyncMock(
|
| 472 |
+
side_effect=APIConnectionError(request=MagicMock()),
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
result = asyncio.run(client.generate_with_tools(
|
| 476 |
+
system_prompt="Test",
|
| 477 |
+
user_message="Hi",
|
| 478 |
+
tool_definitions=[FAKE_TOOL],
|
| 479 |
+
tool_executor=AsyncMock(),
|
| 480 |
+
))
|
| 481 |
+
|
| 482 |
+
self.assertIsInstance(result, ToolCallResult)
|
| 483 |
+
self.assertFalse(result.used_tool)
|
| 484 |
+
self.assertIn("unable to connect", result.text.lower())
|
| 485 |
+
|
| 486 |
+
# ------------------------------------------------------------------
|
| 487 |
+
# Multi-tool call in a single response
|
| 488 |
+
# ------------------------------------------------------------------
|
| 489 |
+
|
| 490 |
+
def test_parallel_tool_calls_all_executed(self, MockAsyncOpenAI, mock_get_ctx):
|
| 491 |
+
"""When the model requests multiple tool calls in one response,
|
| 492 |
+
all of them are executed and their results fed back."""
|
| 493 |
+
client = ImprovedVllmClient(
|
| 494 |
+
api_url=FAKE_URL, api_key=FAKE_KEY, model_name="test-model",
|
| 495 |
+
)
|
| 496 |
+
client.client.chat.completions.create = AsyncMock(side_effect=[
|
| 497 |
+
_make_multi_tool_call_mock([
|
| 498 |
+
("rate_my_professor", {"professor_name": "Dubson"}, "call_a"),
|
| 499 |
+
("rate_my_professor", {"professor_name": "West"}, "call_b"),
|
| 500 |
+
]),
|
| 501 |
+
_make_text_completion_mock("Dubson has a 4.5 rating. West has a 3.8 rating."),
|
| 502 |
+
])
|
| 503 |
+
mock_executor = AsyncMock(
|
| 504 |
+
side_effect=[
|
| 505 |
+
{"professors": [{"name": "Dubson", "rating": 4.5}]},
|
| 506 |
+
{"professors": [{"name": "West", "rating": 3.8}]},
|
| 507 |
+
],
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
result = asyncio.run(client.generate_with_tools(
|
| 511 |
+
system_prompt="You are helpful.",
|
| 512 |
+
user_message="Is professor Dubson or West rated better?",
|
| 513 |
+
tool_definitions=[FAKE_TOOL],
|
| 514 |
+
tool_executor=mock_executor,
|
| 515 |
+
))
|
| 516 |
+
|
| 517 |
+
self.assertEqual(mock_executor.call_count, 2)
|
| 518 |
+
self.assertTrue(result.used_tool)
|
| 519 |
+
self.assertIn("Dubson", result.text)
|
| 520 |
+
self.assertIn("West", result.text)
|
| 521 |
+
self.assertEqual(len(result.tool_calls_made), 2)
|
| 522 |
+
self.assertEqual(result.tool_calls_made[0].name, "rate_my_professor")
|
| 523 |
+
self.assertEqual(result.tool_calls_made[1].args, {"professor_name": "West"})
|
| 524 |
+
self.assertEqual(client.client.chat.completions.create.call_count, 2)
|
| 525 |
+
|
| 526 |
+
def test_parallel_tool_results_all_in_followup_messages(self, MockAsyncOpenAI, mock_get_ctx):
|
| 527 |
+
"""All tool results must appear as separate role:tool messages
|
| 528 |
+
in the follow-up request."""
|
| 529 |
+
client = ImprovedVllmClient(
|
| 530 |
+
api_url=FAKE_URL, api_key=FAKE_KEY, model_name="test-model",
|
| 531 |
+
)
|
| 532 |
+
client.client.chat.completions.create = AsyncMock(side_effect=[
|
| 533 |
+
_make_multi_tool_call_mock([
|
| 534 |
+
("rate_my_professor", {"professor_name": "Dubson"}, "call_a"),
|
| 535 |
+
("rate_my_professor", {"professor_name": "West"}, "call_b"),
|
| 536 |
+
]),
|
| 537 |
+
_make_text_completion_mock("Comparison complete."),
|
| 538 |
+
])
|
| 539 |
+
mock_executor = AsyncMock(side_effect=[
|
| 540 |
+
{"professors": [{"name": "Dubson"}]},
|
| 541 |
+
{"professors": [{"name": "West"}]},
|
| 542 |
+
])
|
| 543 |
+
|
| 544 |
+
asyncio.run(client.generate_with_tools(
|
| 545 |
+
system_prompt="You are helpful.",
|
| 546 |
+
user_message="Compare",
|
| 547 |
+
tool_definitions=[FAKE_TOOL],
|
| 548 |
+
tool_executor=mock_executor,
|
| 549 |
+
))
|
| 550 |
+
|
| 551 |
+
second_call_msgs = client.client.chat.completions.create.call_args_list[1][1]["messages"]
|
| 552 |
+
tool_msgs = [m for m in second_call_msgs if m.get("role") == "tool"]
|
| 553 |
+
self.assertEqual(len(tool_msgs), 2)
|
| 554 |
+
self.assertEqual(tool_msgs[0]["tool_call_id"], "call_a")
|
| 555 |
+
self.assertEqual(tool_msgs[1]["tool_call_id"], "call_b")
|
| 556 |
+
|
| 557 |
+
# ------------------------------------------------------------------
|
| 558 |
+
# Multi-round tool calling
|
| 559 |
+
# ------------------------------------------------------------------
|
| 560 |
+
|
| 561 |
+
def test_sequential_tool_rounds(self, MockAsyncOpenAI, mock_get_ctx):
|
| 562 |
+
"""The loop handles a second round of tool calls after the first
|
| 563 |
+
results are fed back."""
|
| 564 |
+
client = ImprovedVllmClient(
|
| 565 |
+
api_url=FAKE_URL, api_key=FAKE_KEY, model_name="test-model",
|
| 566 |
+
)
|
| 567 |
+
client.client.chat.completions.create = AsyncMock(side_effect=[
|
| 568 |
+
_make_tool_call_mock("rate_my_professor", {"professor_name": "Dubson"}, "call_1"),
|
| 569 |
+
_make_tool_call_mock("rate_my_professor", {"professor_name": "West"}, "call_2"),
|
| 570 |
+
_make_text_completion_mock("Dubson is rated higher than West."),
|
| 571 |
+
])
|
| 572 |
+
mock_executor = AsyncMock(side_effect=[
|
| 573 |
+
{"professors": [{"name": "Dubson", "rating": 4.5}]},
|
| 574 |
+
{"professors": [{"name": "West", "rating": 3.8}]},
|
| 575 |
+
])
|
| 576 |
+
|
| 577 |
+
result = asyncio.run(client.generate_with_tools(
|
| 578 |
+
system_prompt="You are helpful.",
|
| 579 |
+
user_message="Compare Dubson and West",
|
| 580 |
+
tool_definitions=[FAKE_TOOL],
|
| 581 |
+
tool_executor=mock_executor,
|
| 582 |
+
))
|
| 583 |
+
|
| 584 |
+
self.assertEqual(mock_executor.call_count, 2)
|
| 585 |
+
self.assertTrue(result.used_tool)
|
| 586 |
+
self.assertEqual(len(result.tool_calls_made), 2)
|
| 587 |
+
self.assertEqual(result.tool_name, "rate_my_professor")
|
| 588 |
+
self.assertEqual(client.client.chat.completions.create.call_count, 3)
|
| 589 |
+
|
| 590 |
+
# ------------------------------------------------------------------
|
| 591 |
+
# Tool executor failure in multi-tool context
|
| 592 |
+
# ------------------------------------------------------------------
|
| 593 |
+
|
| 594 |
+
def test_partial_tool_failure_continues(self, MockAsyncOpenAI, mock_get_ctx):
|
| 595 |
+
"""If one tool call in a batch fails, the error is serialised
|
| 596 |
+
and the loop continues to the follow-up."""
|
| 597 |
+
client = ImprovedVllmClient(
|
| 598 |
+
api_url=FAKE_URL, api_key=FAKE_KEY, model_name="test-model",
|
| 599 |
+
)
|
| 600 |
+
client.client.chat.completions.create = AsyncMock(side_effect=[
|
| 601 |
+
_make_multi_tool_call_mock([
|
| 602 |
+
("rate_my_professor", {"professor_name": "Dubson"}, "call_a"),
|
| 603 |
+
("rate_my_professor", {"professor_name": "West"}, "call_b"),
|
| 604 |
+
]),
|
| 605 |
+
_make_text_completion_mock("Only Dubson data available."),
|
| 606 |
+
])
|
| 607 |
+
mock_executor = AsyncMock(side_effect=[
|
| 608 |
+
{"professors": [{"name": "Dubson", "rating": 4.5}]},
|
| 609 |
+
RuntimeError("network down"),
|
| 610 |
+
])
|
| 611 |
+
|
| 612 |
+
result = asyncio.run(client.generate_with_tools(
|
| 613 |
+
system_prompt="You are helpful.",
|
| 614 |
+
user_message="Compare",
|
| 615 |
+
tool_definitions=[FAKE_TOOL],
|
| 616 |
+
tool_executor=mock_executor,
|
| 617 |
+
))
|
| 618 |
+
|
| 619 |
+
self.assertTrue(result.used_tool)
|
| 620 |
+
self.assertEqual(len(result.tool_calls_made), 2)
|
| 621 |
+
self.assertEqual(client.client.chat.completions.create.call_count, 2)
|
| 622 |
+
|
| 623 |
+
second_call_msgs = client.client.chat.completions.create.call_args_list[1][1]["messages"]
|
| 624 |
+
tool_msgs = [m for m in second_call_msgs if m.get("role") == "tool"]
|
| 625 |
+
self.assertEqual(len(tool_msgs), 2)
|
| 626 |
+
error_content = json.loads(tool_msgs[1]["content"])
|
| 627 |
+
self.assertIn("error", error_content)
|
| 628 |
+
|
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tool registry — auto-discovers tool modules in app.tools and provides
|
| 3 |
+
a central API for retrieving definitions and dispatching calls.
|
| 4 |
+
|
| 5 |
+
Every tool module in this package must export:
|
| 6 |
+
TOOL_DEFINITION : Dict[str, Any] — OpenAI tool format
|
| 7 |
+
{"type": "function", "function": {"name": ..., ...}}
|
| 8 |
+
execute : async (**kwargs) — returns Dict[str, Any]
|
| 9 |
+
|
| 10 |
+
Modules that don't export both are silently skipped.
|
| 11 |
+
|
| 12 |
+
Filtering semantics for the ``enabled`` parameter:
|
| 13 |
+
None — no filter; all registered tools are available (default)
|
| 14 |
+
[] — explicit empty list; no tools are available
|
| 15 |
+
[ids] — only the named tools are available
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import importlib
|
| 19 |
+
import inspect
|
| 20 |
+
import logging
|
| 21 |
+
import pkgutil
|
| 22 |
+
from typing import Any, Callable, Dict, List, Optional
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
# Shared User-Agent for HTTP clients in tool modules (FOSE, RMP, etc.).
|
| 27 |
+
BROWSER_UA = (
|
| 28 |
+
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
| 29 |
+
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
| 30 |
+
"Chrome/131.0.0.0 Safari/537.36"
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
_REGISTRY: Dict[str, Dict[str, Any]] = {}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _discover_tools() -> None:
|
| 37 |
+
"""Scan sibling modules in app.tools and register any that export
|
| 38 |
+
TOOL_DEFINITION and execute."""
|
| 39 |
+
import app.tools as tools_pkg
|
| 40 |
+
|
| 41 |
+
for _finder, module_name, _is_pkg in pkgutil.iter_modules(tools_pkg.__path__):
|
| 42 |
+
qualified = f"app.tools.{module_name}"
|
| 43 |
+
try:
|
| 44 |
+
mod = importlib.import_module(qualified)
|
| 45 |
+
except Exception:
|
| 46 |
+
logger.warning("Failed to import tool module: %s", qualified, exc_info=True)
|
| 47 |
+
continue
|
| 48 |
+
|
| 49 |
+
defn = getattr(mod, "TOOL_DEFINITION", None)
|
| 50 |
+
executor = getattr(mod, "execute", None)
|
| 51 |
+
|
| 52 |
+
if defn is None or executor is None:
|
| 53 |
+
continue
|
| 54 |
+
|
| 55 |
+
if (not isinstance(defn, dict)
|
| 56 |
+
or defn.get("type") != "function"
|
| 57 |
+
or not isinstance(defn.get("function"), dict)
|
| 58 |
+
or "name" not in defn["function"]):
|
| 59 |
+
logger.warning("Skipping %s: TOOL_DEFINITION not in OpenAI tool format", qualified)
|
| 60 |
+
continue
|
| 61 |
+
|
| 62 |
+
if not callable(executor) or not inspect.iscoroutinefunction(executor):
|
| 63 |
+
logger.warning("Skipping %s: execute is not an async callable", qualified)
|
| 64 |
+
continue
|
| 65 |
+
|
| 66 |
+
tool_name = defn["function"]["name"]
|
| 67 |
+
if tool_name in _REGISTRY:
|
| 68 |
+
logger.warning(
|
| 69 |
+
"Duplicate tool name '%s' from %s — skipping", tool_name, qualified,
|
| 70 |
+
)
|
| 71 |
+
continue
|
| 72 |
+
|
| 73 |
+
_REGISTRY[tool_name] = {"definition": defn, "executor": executor}
|
| 74 |
+
logger.info("Registered tool: %s (from %s)", tool_name, qualified)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def get_tool_definitions(enabled: Optional[List[str]] = None) -> List[Dict[str, Any]]:
|
| 78 |
+
"""Return OpenAI-format tool dicts for registered tools.
|
| 79 |
+
|
| 80 |
+
If *enabled* is provided, only return tools whose names are in that list.
|
| 81 |
+
If None, return all registered tools.
|
| 82 |
+
"""
|
| 83 |
+
if enabled is None:
|
| 84 |
+
return [entry["definition"] for entry in _REGISTRY.values()]
|
| 85 |
+
|
| 86 |
+
return [
|
| 87 |
+
_REGISTRY[name]["definition"]
|
| 88 |
+
for name in enabled
|
| 89 |
+
if name in _REGISTRY
|
| 90 |
+
]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def get_tool_executor(enabled: Optional[List[str]] = None) -> Callable:
|
| 94 |
+
"""Return a dispatcher compatible with generate_with_tools(tool_executor=...).
|
| 95 |
+
|
| 96 |
+
The returned async callable accepts (name, **kwargs) and routes to the
|
| 97 |
+
correct tool executor. If *enabled* is provided, only those tools are
|
| 98 |
+
dispatchable.
|
| 99 |
+
"""
|
| 100 |
+
if enabled is not None:
|
| 101 |
+
allowed = {name for name in enabled if name in _REGISTRY}
|
| 102 |
+
else:
|
| 103 |
+
allowed = None
|
| 104 |
+
|
| 105 |
+
async def dispatch(name: str, **kwargs: Any) -> Dict[str, Any]:
|
| 106 |
+
if allowed is not None and name not in allowed:
|
| 107 |
+
logger.warning("Tool '%s' is not enabled", name)
|
| 108 |
+
return {"error": f"Tool not enabled: {name}"}
|
| 109 |
+
|
| 110 |
+
entry = _REGISTRY.get(name)
|
| 111 |
+
if entry is None:
|
| 112 |
+
logger.warning("Unknown tool requested: %s", name)
|
| 113 |
+
return {"error": f"Unknown tool: {name}"}
|
| 114 |
+
|
| 115 |
+
return await entry["executor"](name=name, **kwargs)
|
| 116 |
+
|
| 117 |
+
return dispatch
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def list_registered_tools() -> List[str]:
|
| 121 |
+
"""Return the names of all discovered tools."""
|
| 122 |
+
return list(_REGISTRY.keys())
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
_discover_tools()
|
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
rate_my_professor tool — live query against RateMyProfessors' GraphQL API.
|
| 3 |
+
|
| 4 |
+
Exposes TOOL_DEFINITION (OpenAI tool format) and an execute() coroutine
|
| 5 |
+
that the tool-calling loop dispatches to.
|
| 6 |
+
|
| 7 |
+
Requires ``school_id`` in the tool config (see phd_config.yaml).
|
| 8 |
+
Use ``scripts/rmp_school_lookup.py`` to find the ID for a given school.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
import re
|
| 13 |
+
from typing import Any, Dict, List
|
| 14 |
+
import httpx
|
| 15 |
+
from app.tools import BROWSER_UA
|
| 16 |
+
from app.config import get_settings
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
RMP_GRAPHQL_URL = "https://www.ratemyprofessors.com/graphql"
|
| 21 |
+
RMP_LANDING_URL = "https://www.ratemyprofessors.com/"
|
| 22 |
+
RMP_SEARCH_URL = "https://www.ratemyprofessors.com/search/professors/1087"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
TEACHER_SEARCH_QUERY = """
|
| 26 |
+
query TeacherSearchPaginationQuery(
|
| 27 |
+
$count: Int!
|
| 28 |
+
$cursor: String
|
| 29 |
+
$query: TeacherSearchQuery!
|
| 30 |
+
) {
|
| 31 |
+
search: newSearch {
|
| 32 |
+
teachers(query: $query, first: $count, after: $cursor) {
|
| 33 |
+
didFallback
|
| 34 |
+
edges {
|
| 35 |
+
cursor
|
| 36 |
+
node {
|
| 37 |
+
id
|
| 38 |
+
legacyId
|
| 39 |
+
firstName
|
| 40 |
+
lastName
|
| 41 |
+
department
|
| 42 |
+
school { id name }
|
| 43 |
+
avgRating
|
| 44 |
+
avgDifficulty
|
| 45 |
+
wouldTakeAgainPercent
|
| 46 |
+
numRatings
|
| 47 |
+
}
|
| 48 |
+
}
|
| 49 |
+
pageInfo {
|
| 50 |
+
hasNextPage
|
| 51 |
+
endCursor
|
| 52 |
+
}
|
| 53 |
+
}
|
| 54 |
+
}
|
| 55 |
+
}
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
TOOL_DEFINITION: Dict[str, Any] = {
|
| 59 |
+
"type": "function",
|
| 60 |
+
"function": {
|
| 61 |
+
"name": "rate_my_professor",
|
| 62 |
+
"description": (
|
| 63 |
+
"Look up RateMyProfessors ratings for a CU Boulder professor. "
|
| 64 |
+
"Returns rating, difficulty, percentage of students who would "
|
| 65 |
+
"take the professor again, and number of ratings."
|
| 66 |
+
),
|
| 67 |
+
"parameters": {
|
| 68 |
+
"type": "object",
|
| 69 |
+
"properties": {
|
| 70 |
+
"professor_name": {
|
| 71 |
+
"type": "string",
|
| 72 |
+
"description": (
|
| 73 |
+
"Full or partial name of the professor to search for, "
|
| 74 |
+
"e.g. 'Hoenigman', 'Jane Smith'."
|
| 75 |
+
),
|
| 76 |
+
},
|
| 77 |
+
},
|
| 78 |
+
"required": ["professor_name"],
|
| 79 |
+
},
|
| 80 |
+
},
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def _node_to_professor(node: Dict[str, Any]) -> Dict[str, Any]:
|
| 85 |
+
"""Convert a GraphQL teacher node to a lightweight result dict."""
|
| 86 |
+
return {
|
| 87 |
+
"name": f"{node.get('firstName', '')} {node.get('lastName', '')}".strip(),
|
| 88 |
+
"department": node.get("department", ""),
|
| 89 |
+
"rating": node.get("avgRating", 0),
|
| 90 |
+
"difficulty": node.get("avgDifficulty", 0),
|
| 91 |
+
"would_take_again_pct": node.get("wouldTakeAgainPercent", -1),
|
| 92 |
+
"num_ratings": node.get("numRatings", 0),
|
| 93 |
+
"rmp_id": node.get("id", ""),
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
async def _extract_auth_token(client: httpx.AsyncClient) -> str:
|
| 98 |
+
"""Fetch the RMP landing page and extract the auth token from the JS bundle.
|
| 99 |
+
|
| 100 |
+
Falls back to the well-known Basic test:test token.
|
| 101 |
+
"""
|
| 102 |
+
try:
|
| 103 |
+
resp = await client.get(
|
| 104 |
+
RMP_LANDING_URL, headers={"User-Agent": BROWSER_UA},
|
| 105 |
+
)
|
| 106 |
+
m = re.search(
|
| 107 |
+
r'"Authorization"\s*:\s*"(Basic\s+[A-Za-z0-9+/=]+)"', resp.text,
|
| 108 |
+
)
|
| 109 |
+
if m:
|
| 110 |
+
logger.info("Extracted RMP auth token from page JS")
|
| 111 |
+
return m.group(1)
|
| 112 |
+
except Exception as exc:
|
| 113 |
+
logger.debug("RMP auth token extraction failed: %s", exc)
|
| 114 |
+
|
| 115 |
+
return "Basic dGVzdDp0ZXN0"
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
async def execute(
|
| 119 |
+
*,
|
| 120 |
+
name: str = "",
|
| 121 |
+
professor_name: str,
|
| 122 |
+
) -> Dict[str, Any]:
|
| 123 |
+
"""Query RateMyProfessors for a CU Boulder professor by name.
|
| 124 |
+
|
| 125 |
+
The 'name' kwarg is passed by the dispatch loop and ignored here.
|
| 126 |
+
Returns {"professors": [...], "query": {...}}.
|
| 127 |
+
"""
|
| 128 |
+
tool_cfg = get_settings().tools.get_tool_config("rate_my_professor")
|
| 129 |
+
school_id = tool_cfg.get("school_id")
|
| 130 |
+
if not school_id:
|
| 131 |
+
logger.error("No school_id configured for rate_my_professor")
|
| 132 |
+
return {
|
| 133 |
+
"professors": [],
|
| 134 |
+
"error": "No school_id configured for rate_my_professor",
|
| 135 |
+
"query": {"professor_name": professor_name},
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
professors: List[Dict[str, Any]] = []
|
| 139 |
+
|
| 140 |
+
try:
|
| 141 |
+
async with httpx.AsyncClient(timeout=30, follow_redirects=True) as client:
|
| 142 |
+
auth_token = await _extract_auth_token(client)
|
| 143 |
+
|
| 144 |
+
headers = {
|
| 145 |
+
"User-Agent": BROWSER_UA,
|
| 146 |
+
"Authorization": auth_token,
|
| 147 |
+
"Content-Type": "application/json",
|
| 148 |
+
"Referer": f"{RMP_SEARCH_URL}?q={professor_name}",
|
| 149 |
+
"Origin": "https://www.ratemyprofessors.com",
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
variables = {
|
| 153 |
+
"count": 20,
|
| 154 |
+
"cursor": "",
|
| 155 |
+
"query": {
|
| 156 |
+
"text": professor_name,
|
| 157 |
+
"schoolID": school_id,
|
| 158 |
+
"fallback": True,
|
| 159 |
+
"departmentID": None,
|
| 160 |
+
},
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
resp = await client.post(
|
| 164 |
+
RMP_GRAPHQL_URL,
|
| 165 |
+
json={"query": TEACHER_SEARCH_QUERY, "variables": variables},
|
| 166 |
+
headers=headers,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
if resp.status_code == 403:
|
| 170 |
+
logger.warning("RMP GraphQL returned 403 — auth may be invalid")
|
| 171 |
+
return {
|
| 172 |
+
"professors": [],
|
| 173 |
+
"error": "RateMyProfessors authentication failed",
|
| 174 |
+
"query": {"professor_name": professor_name},
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
resp.raise_for_status()
|
| 178 |
+
data = resp.json()
|
| 179 |
+
|
| 180 |
+
teachers = (
|
| 181 |
+
data.get("data", {})
|
| 182 |
+
.get("search", {})
|
| 183 |
+
.get("teachers", {})
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
for edge in teachers.get("edges", []):
|
| 187 |
+
node = edge.get("node", {})
|
| 188 |
+
if node:
|
| 189 |
+
professors.append(_node_to_professor(node))
|
| 190 |
+
|
| 191 |
+
except Exception as exc:
|
| 192 |
+
logger.error("RMP API error for %s: %s", professor_name, exc)
|
| 193 |
+
return {
|
| 194 |
+
"professors": [],
|
| 195 |
+
"error": str(exc),
|
| 196 |
+
"query": {"professor_name": professor_name},
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
return {
|
| 200 |
+
"professors": professors,
|
| 201 |
+
"query": {"professor_name": professor_name},
|
| 202 |
+
}
|
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
search_courses tool — live query against CU Boulder's FOSE class-search API.
|
| 3 |
+
|
| 4 |
+
Exposes TOOL_DEFINITION (OpenAI tool format) and an execute() coroutine
|
| 5 |
+
that the tool-calling loop dispatches to.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
import re
|
| 10 |
+
from typing import Any, Dict, List, Optional
|
| 11 |
+
import httpx
|
| 12 |
+
from app.tools import BROWSER_UA
|
| 13 |
+
from app.config import get_settings
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
FOSE_SEARCH_URL = "https://classes.colorado.edu/api/?page=fose&route=search"
|
| 18 |
+
CLASSES_BASE_URL = "https://classes.colorado.edu"
|
| 19 |
+
|
| 20 |
+
TOOL_DEFINITION: Dict[str, Any] = {
|
| 21 |
+
"type": "function",
|
| 22 |
+
"function": {
|
| 23 |
+
"name": "search_courses",
|
| 24 |
+
"description": (
|
| 25 |
+
"Search the CU Boulder course catalog for classes in a given "
|
| 26 |
+
"subject, optionally filtered by course number and semester. "
|
| 27 |
+
"Returns a list of matching sections with title, instructor, "
|
| 28 |
+
"schedule, and location."
|
| 29 |
+
),
|
| 30 |
+
"parameters": {
|
| 31 |
+
"type": "object",
|
| 32 |
+
"properties": {
|
| 33 |
+
"subject": {
|
| 34 |
+
"type": "string",
|
| 35 |
+
"description": (
|
| 36 |
+
"Department / subject code, e.g. 'CSCI', 'MATH', 'PHYS'."
|
| 37 |
+
),
|
| 38 |
+
},
|
| 39 |
+
"course_number": {
|
| 40 |
+
"type": "string",
|
| 41 |
+
"description": (
|
| 42 |
+
"Catalog number to filter on, e.g. '1300'. "
|
| 43 |
+
"Omit to return all courses in the subject."
|
| 44 |
+
),
|
| 45 |
+
},
|
| 46 |
+
"semester": {
|
| 47 |
+
"type": "string",
|
| 48 |
+
"description": (
|
| 49 |
+
"Semester name, e.g. 'Spring 2026', 'Fall 2025'. "
|
| 50 |
+
"Defaults to 'Spring 2026' if not provided."
|
| 51 |
+
),
|
| 52 |
+
},
|
| 53 |
+
},
|
| 54 |
+
"required": ["subject"],
|
| 55 |
+
},
|
| 56 |
+
},
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
def _term_to_srcdb(term: str) -> str:
|
| 60 |
+
"""Convert 'Spring 2026' to '2261', 'Fall 2025' to '2257', etc.
|
| 61 |
+
|
| 62 |
+
CU Boulder's FOSE API uses a 4-digit code: literal '2', the last
|
| 63 |
+
two digits of the year, and a season digit (1=Spring, 4=Summer, 7=Fall).
|
| 64 |
+
"""
|
| 65 |
+
term_lower = term.lower()
|
| 66 |
+
ym = re.search(r"20(\d{2})", term)
|
| 67 |
+
yy = ym.group(1) if ym else "26"
|
| 68 |
+
if "spring" in term_lower:
|
| 69 |
+
return f"2{yy}1"
|
| 70 |
+
if "summer" in term_lower:
|
| 71 |
+
return f"2{yy}4"
|
| 72 |
+
if "fall" in term_lower:
|
| 73 |
+
return f"2{yy}7"
|
| 74 |
+
return f"2{yy}1"
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def _parse_schedule(meets: str) -> Dict[str, str]:
|
| 78 |
+
"""Parse 'MWF 10:00am-10:50am' into structured fields."""
|
| 79 |
+
if not meets:
|
| 80 |
+
return {"days": "", "start_time": "", "end_time": "", "raw": ""}
|
| 81 |
+
|
| 82 |
+
day_match = re.match(r"([A-Za-z]+)", meets)
|
| 83 |
+
days = day_match.group(1) if day_match else ""
|
| 84 |
+
|
| 85 |
+
time_match = re.search(
|
| 86 |
+
r"(\d{1,2}:\d{2}\s*[ap]m)\s*-\s*(\d{1,2}:\d{2}\s*[ap]m)", meets, re.I
|
| 87 |
+
)
|
| 88 |
+
start = time_match.group(1).strip() if time_match else ""
|
| 89 |
+
end = time_match.group(2).strip() if time_match else ""
|
| 90 |
+
|
| 91 |
+
return {"days": days, "start_time": start, "end_time": end, "raw": meets}
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def _row_to_course(item: Dict[str, Any], term: str) -> Optional[Dict[str, Any]]:
|
| 95 |
+
"""Convert a FOSE result row to a lightweight course dict.
|
| 96 |
+
Returns None for rows that should be skipped (recitations, cancelled sections, etc.).
|
| 97 |
+
"""
|
| 98 |
+
schd = item.get("schd", "")
|
| 99 |
+
if schd and schd not in ("LEC", "SEM", ""):
|
| 100 |
+
return None
|
| 101 |
+
if item.get("isCancelled"):
|
| 102 |
+
return None
|
| 103 |
+
|
| 104 |
+
code = item.get("code", "").strip()
|
| 105 |
+
if not code:
|
| 106 |
+
code = (
|
| 107 |
+
f"{item.get('subject', '')} "
|
| 108 |
+
f"{item.get('catalog_nbr', item.get('catalogNbr', ''))}"
|
| 109 |
+
).strip()
|
| 110 |
+
|
| 111 |
+
return {
|
| 112 |
+
"course_code": code,
|
| 113 |
+
"title": item.get("title", ""),
|
| 114 |
+
"section": item.get("no", "") or item.get("section", ""),
|
| 115 |
+
"instructor": item.get("instr", "") or item.get("instructor", "Staff"),
|
| 116 |
+
"schedule": _parse_schedule(item.get("meets", "") or ""),
|
| 117 |
+
"location": item.get("bldg", item.get("location", "")),
|
| 118 |
+
"semester": term,
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
async def execute(
|
| 124 |
+
*,
|
| 125 |
+
name: str = "",
|
| 126 |
+
subject: str,
|
| 127 |
+
course_number: str = "",
|
| 128 |
+
semester: str = "Spring 2026",
|
| 129 |
+
) -> Dict[str, Any]:
|
| 130 |
+
"""Query the CU Boulder FOSE API and return matching courses.
|
| 131 |
+
|
| 132 |
+
The 'name' kwarg is passed by the dispatch loop and ignored here.
|
| 133 |
+
Returns {"courses": [...], "query": {...}}.
|
| 134 |
+
"""
|
| 135 |
+
srcdb = _term_to_srcdb(semester)
|
| 136 |
+
subject = subject.upper().strip()
|
| 137 |
+
|
| 138 |
+
payload = {
|
| 139 |
+
"other": {"srcdb": srcdb},
|
| 140 |
+
"criteria": [{"field": "subject", "value": subject}],
|
| 141 |
+
}
|
| 142 |
+
headers = {
|
| 143 |
+
"User-Agent": BROWSER_UA,
|
| 144 |
+
"Content-Type": "application/json",
|
| 145 |
+
"Referer": CLASSES_BASE_URL,
|
| 146 |
+
"Origin": CLASSES_BASE_URL,
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
courses: List[Dict[str, Any]] = []
|
| 150 |
+
|
| 151 |
+
try:
|
| 152 |
+
async with httpx.AsyncClient(timeout=30, follow_redirects=True) as client:
|
| 153 |
+
resp = await client.post(
|
| 154 |
+
FOSE_SEARCH_URL, json=payload, headers=headers,
|
| 155 |
+
)
|
| 156 |
+
if resp.status_code != 200:
|
| 157 |
+
logger.warning("FOSE API returned %s for %s", resp.status_code, subject)
|
| 158 |
+
return {"courses": [], "query": {"subject": subject, "semester": semester}}
|
| 159 |
+
|
| 160 |
+
body = resp.json()
|
| 161 |
+
results = body.get("results", body.get("data", []))
|
| 162 |
+
|
| 163 |
+
for item in results:
|
| 164 |
+
row = _row_to_course(item, semester)
|
| 165 |
+
if row:
|
| 166 |
+
courses.append(row)
|
| 167 |
+
|
| 168 |
+
except Exception as exc:
|
| 169 |
+
logger.error("FOSE API error for %s: %s", subject, exc)
|
| 170 |
+
return {"courses": [], "error": str(exc), "query": {"subject": subject, "semester": semester}}
|
| 171 |
+
|
| 172 |
+
if course_number:
|
| 173 |
+
cn = course_number.strip()
|
| 174 |
+
courses = [c for c in courses if cn in c["course_code"]]
|
| 175 |
+
|
| 176 |
+
max_results = get_settings().tools.get_tool_config("search_courses").get("max_results", 20)
|
| 177 |
+
|
| 178 |
+
total = len(courses)
|
| 179 |
+
truncated = total > max_results
|
| 180 |
+
courses = courses[:max_results]
|
| 181 |
+
|
| 182 |
+
return {
|
| 183 |
+
"courses": courses,
|
| 184 |
+
"total_results": total,
|
| 185 |
+
"truncated": truncated,
|
| 186 |
+
"query": {
|
| 187 |
+
"subject": subject,
|
| 188 |
+
"course_number": course_number or None,
|
| 189 |
+
"semester": semester,
|
| 190 |
+
},
|
| 191 |
+
}
|
|
@@ -1,46 +1,43 @@
|
|
| 1 |
# Core FastAPI framework
|
| 2 |
-
fastapi
|
| 3 |
-
uvicorn[standard]
|
| 4 |
-
python-multipart
|
| 5 |
|
| 6 |
# HTTP client for LLM APIs
|
| 7 |
-
httpx
|
| 8 |
openai~=2.30
|
| 9 |
|
| 10 |
# Document processing
|
| 11 |
-
PyPDF2
|
| 12 |
-
docx2txt
|
| 13 |
-
python-docx
|
| 14 |
|
| 15 |
# Environment configuration
|
| 16 |
-
python-dotenv
|
| 17 |
pyyaml~=6.0
|
| 18 |
|
| 19 |
# Persona color generation
|
| 20 |
colorhash~=2.3
|
| 21 |
|
| 22 |
# Vector database and embeddings
|
| 23 |
-
chromadb
|
| 24 |
-
sentence-transformers
|
| 25 |
|
| 26 |
# Natural language processing
|
| 27 |
-
nltk
|
| 28 |
-
tiktoken
|
| 29 |
|
| 30 |
# PDF generation and export
|
| 31 |
-
reportlab
|
| 32 |
|
| 33 |
# Database (MongoDB)
|
| 34 |
-
pymongo
|
| 35 |
-
motor
|
| 36 |
|
| 37 |
# Authentication and security
|
| 38 |
-
passlib[bcrypt]
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
python-jose[cryptography]
|
| 42 |
|
| 43 |
# Data validation (required for EmailStr in Pydantic)
|
| 44 |
-
email-validator
|
| 45 |
-
python-docx
|
| 46 |
-
python-multipart
|
|
|
|
| 1 |
# Core FastAPI framework
|
| 2 |
+
fastapi~=0.135
|
| 3 |
+
uvicorn[standard]~=0.44
|
| 4 |
+
python-multipart~=0.0
|
| 5 |
|
| 6 |
# HTTP client for LLM APIs
|
| 7 |
+
httpx~=0.28
|
| 8 |
openai~=2.30
|
| 9 |
|
| 10 |
# Document processing
|
| 11 |
+
PyPDF2~=3.0
|
| 12 |
+
docx2txt~=0.9
|
| 13 |
+
python-docx~=1.2
|
| 14 |
|
| 15 |
# Environment configuration
|
| 16 |
+
python-dotenv~=1.2
|
| 17 |
pyyaml~=6.0
|
| 18 |
|
| 19 |
# Persona color generation
|
| 20 |
colorhash~=2.3
|
| 21 |
|
| 22 |
# Vector database and embeddings
|
| 23 |
+
chromadb~=1.5
|
| 24 |
+
sentence-transformers~=5.3
|
| 25 |
|
| 26 |
# Natural language processing
|
| 27 |
+
nltk~=3.9
|
| 28 |
+
tiktoken~=0.12
|
| 29 |
|
| 30 |
# PDF generation and export
|
| 31 |
+
reportlab~=4.4
|
| 32 |
|
| 33 |
# Database (MongoDB)
|
| 34 |
+
pymongo~=4.16
|
| 35 |
+
motor~=3.7
|
| 36 |
|
| 37 |
# Authentication and security
|
| 38 |
+
passlib[bcrypt]~=1.7
|
| 39 |
+
bcrypt~=4.3
|
| 40 |
+
python-jose[cryptography]~=3.5
|
|
|
|
| 41 |
|
| 42 |
# Data validation (required for EmailStr in Pydantic)
|
| 43 |
+
email-validator~=2.3
|
|
|
|
|
|
|
@@ -40,7 +40,11 @@ const buildAdvisors = (personaItems) => {
|
|
| 40 |
*/
|
| 41 |
const buildGetAdvisorColors = (advisors) => (advisorId, isDark = false) => {
|
| 42 |
const advisor = advisors[advisorId];
|
| 43 |
-
if (!advisor)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
return {
|
| 45 |
color: isDark ? advisor.darkColor : advisor.color,
|
| 46 |
bgColor: isDark ? advisor.darkBgColor : advisor.bgColor,
|
|
|
|
| 40 |
*/
|
| 41 |
const buildGetAdvisorColors = (advisors) => (advisorId, isDark = false) => {
|
| 42 |
const advisor = advisors[advisorId];
|
| 43 |
+
if (!advisor) {
|
| 44 |
+
return isDark
|
| 45 |
+
? { color: '#9CA3AF', bgColor: '#374151', textColor: '#F9FAFB' }
|
| 46 |
+
: { color: '#6B7280', bgColor: '#F3F4F6', textColor: '#111827' };
|
| 47 |
+
}
|
| 48 |
return {
|
| 49 |
color: isDark ? advisor.darkColor : advisor.color,
|
| 50 |
bgColor: isDark ? advisor.darkBgColor : advisor.bgColor,
|
|
@@ -150,3 +150,13 @@ llm:
|
|
| 150 |
rag:
|
| 151 |
embedding_model: "all-MiniLM-L6-v2"
|
| 152 |
chroma_collection: "phd_advisor_documents"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
rag:
|
| 151 |
embedding_model: "all-MiniLM-L6-v2"
|
| 152 |
chroma_collection: "phd_advisor_documents"
|
| 153 |
+
|
| 154 |
+
# TODO: For development/testing only. PhD Advisor will likely not use these tools.
|
| 155 |
+
tools:
|
| 156 |
+
search_courses:
|
| 157 |
+
enabled: true
|
| 158 |
+
max_results: 20 # Limit the number of results returned by the search_courses tool
|
| 159 |
+
rate_my_professor:
|
| 160 |
+
enabled: true
|
| 161 |
+
# Run `python3 scripts/rmp_school_lookup.py "<school name>"` to find this value
|
| 162 |
+
school_id: "U2Nob29sLTEwODc=" # CU Boulder school ID
|
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Look up a RateMyProfessors school ID by name.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
python3 scripts/rmp_school_lookup.py "University of Colorado"
|
| 6 |
+
|
| 7 |
+
Prints matching schools with their GraphQL IDs (the value to put in
|
| 8 |
+
your config.yaml under tools.rate_my_professor.school_id).
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import asyncio
|
| 12 |
+
import re
|
| 13 |
+
import sys
|
| 14 |
+
|
| 15 |
+
import httpx
|
| 16 |
+
|
| 17 |
+
RMP_GRAPHQL_URL = "https://www.ratemyprofessors.com/graphql"
|
| 18 |
+
RMP_LANDING_URL = "https://www.ratemyprofessors.com/"
|
| 19 |
+
BROWSER_UA = (
|
| 20 |
+
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
| 21 |
+
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
| 22 |
+
"Chrome/131.0.0.0 Safari/537.36"
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
SCHOOL_SEARCH_QUERY = """
|
| 26 |
+
query SchoolSearchQuery($query: SchoolSearchQuery!) {
|
| 27 |
+
newSearch {
|
| 28 |
+
schools(query: $query) {
|
| 29 |
+
edges {
|
| 30 |
+
node {
|
| 31 |
+
id
|
| 32 |
+
name
|
| 33 |
+
city
|
| 34 |
+
state
|
| 35 |
+
}
|
| 36 |
+
}
|
| 37 |
+
}
|
| 38 |
+
}
|
| 39 |
+
}
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
async def _extract_auth_token(client: httpx.AsyncClient) -> str:
|
| 44 |
+
try:
|
| 45 |
+
resp = await client.get(RMP_LANDING_URL, headers={"User-Agent": BROWSER_UA})
|
| 46 |
+
m = re.search(r'"Authorization"\s*:\s*"(Basic\s+[A-Za-z0-9+/=]+)"', resp.text)
|
| 47 |
+
if m:
|
| 48 |
+
return m.group(1)
|
| 49 |
+
except Exception:
|
| 50 |
+
pass
|
| 51 |
+
return "Basic dGVzdDp0ZXN0"
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
async def search_schools(school_name: str) -> list:
|
| 55 |
+
async with httpx.AsyncClient(timeout=30, follow_redirects=True) as client:
|
| 56 |
+
auth_token = await _extract_auth_token(client)
|
| 57 |
+
resp = await client.post(
|
| 58 |
+
RMP_GRAPHQL_URL,
|
| 59 |
+
json={
|
| 60 |
+
"query": SCHOOL_SEARCH_QUERY,
|
| 61 |
+
"variables": {"query": {"text": school_name}},
|
| 62 |
+
},
|
| 63 |
+
headers={
|
| 64 |
+
"User-Agent": BROWSER_UA,
|
| 65 |
+
"Authorization": auth_token,
|
| 66 |
+
"Content-Type": "application/json",
|
| 67 |
+
},
|
| 68 |
+
)
|
| 69 |
+
resp.raise_for_status()
|
| 70 |
+
data = resp.json()
|
| 71 |
+
edges = (
|
| 72 |
+
data.get("data", {})
|
| 73 |
+
.get("newSearch", {})
|
| 74 |
+
.get("schools", {})
|
| 75 |
+
.get("edges", [])
|
| 76 |
+
)
|
| 77 |
+
return [
|
| 78 |
+
{
|
| 79 |
+
"school_id": edge["node"]["id"],
|
| 80 |
+
"name": edge["node"]["name"],
|
| 81 |
+
"city": edge["node"].get("city", ""),
|
| 82 |
+
"state": edge["node"].get("state", ""),
|
| 83 |
+
}
|
| 84 |
+
for edge in edges
|
| 85 |
+
if edge.get("node")
|
| 86 |
+
]
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def main():
|
| 90 |
+
if len(sys.argv) < 2:
|
| 91 |
+
print("Usage: python3 scripts/rmp_school_lookup.py <school name>")
|
| 92 |
+
print('Example: python3 scripts/rmp_school_lookup.py "University of Colorado"')
|
| 93 |
+
sys.exit(1)
|
| 94 |
+
|
| 95 |
+
query = " ".join(sys.argv[1:])
|
| 96 |
+
results = asyncio.run(search_schools(query))
|
| 97 |
+
|
| 98 |
+
if not results:
|
| 99 |
+
print(f"No schools found matching '{query}'")
|
| 100 |
+
sys.exit(0)
|
| 101 |
+
|
| 102 |
+
print(f"Found {len(results)} school(s) matching '{query}':\n")
|
| 103 |
+
for school in results:
|
| 104 |
+
location = ", ".join(filter(None, [school["city"], school["state"]]))
|
| 105 |
+
print(f" {school['name']}")
|
| 106 |
+
print(f" Location: {location}")
|
| 107 |
+
print(f" school_id: {school['school_id']}")
|
| 108 |
+
print()
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
if __name__ == "__main__":
|
| 112 |
+
main()
|