NeonCharlie-24 commited on
Commit
d16b388
·
unverified ·
1 Parent(s): 2fabb8c

Feat/advisor tools (#38)

Browse files

* added /scraper directory from Clary branch for inspection.

* Added Gemini function-calling MVP with course search tool.

* Added Rate My Professor tool and multi-tool dispatcher.

* Added tool registry with auto-discovery and tool filtering.

* fixed test warning in RMP tool test.

* added structured tool-call return result and get_tool_response method to orchestrator.

* wired tool calling into orchestrator chat routes.

* cleaned up standalone test routes and did minor refactoring.

* fixed context manager dropping non-allowlist persona and tool responses.

* pinned backend requirements to <pkg>~=M.m from latest container build.

* removed license notices and cu_info_scraper.

* moved the tool-calling loop into LLMClient with provider specific overrides implemented in each respective client class.

* refactored tools config to support per tool settings.

* added rmp school_id lookup script for CLI usage.

* added dark mode awareness to advisor color fallback which is used by the orchestrator when a tool call is made.

* fixed vLLM tool calling to handle multiple tool calls.

* fixed Gemini tool calling to handle multiple tool calls.

* deleted reference /scrapers directory.

* revert Gemini base url to be defined in init.

* refactored Gemini tool calling to use the /openai endpoint format.

multi_llm_chatbot_backend/app/api/routes/chat.py CHANGED
@@ -115,6 +115,27 @@ async def chat_stream(
115
  ).to_ndjson()
116
  return
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  # Get personas most relevant to the current session
119
  top_personas = await chat_orchestrator.get_top_personas(
120
  session_id=sid,
@@ -363,7 +384,26 @@ async def chat_sequential_enhanced(
363
  "trigger": "vague_input"
364
  }
365
  }
366
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
  # RESTORED: Get intelligently ordered personas based on context
368
  top_personas = await chat_orchestrator.get_top_personas(
369
  session_id=session_id,
 
115
  ).to_ndjson()
116
  return
117
 
118
+ # If an enabled tool can handle this query, return its response
119
+ # directly and skip persona generation.
120
+ tool_result = await chat_orchestrator.get_tool_response(message.user_input)
121
+ if tool_result.used_tool:
122
+ session.append_message("orchestrator", tool_result.text)
123
+ yield ChatStreamLine(
124
+ type="advisor",
125
+ data={
126
+ "persona_id": "orchestrator",
127
+ "persona_name": "Orchestrator",
128
+ "content": tool_result.text,
129
+ "used_documents": False,
130
+ "document_chunks_used": 0,
131
+ },
132
+ ).to_ndjson()
133
+ yield ChatStreamLine(
134
+ type="progress",
135
+ data={"phase": "complete"},
136
+ ).to_ndjson()
137
+ return
138
+
139
  # Get personas most relevant to the current session
140
  top_personas = await chat_orchestrator.get_top_personas(
141
  session_id=sid,
 
384
  "trigger": "vague_input"
385
  }
386
  }
387
+
388
+ # If an enabled tool can handle this query, return its response
389
+ # directly and skip persona generation.
390
+ tool_result = await chat_orchestrator.get_tool_response(message.user_input)
391
+ if tool_result.used_tool:
392
+ session.append_message("orchestrator", tool_result.text)
393
+ return {
394
+ "responses": [{
395
+ "persona_id": "orchestrator",
396
+ "persona_name": "Orchestrator",
397
+ "content": tool_result.text,
398
+ "used_documents": False,
399
+ "document_chunks_used": 0,
400
+ }],
401
+ "session_debug": {
402
+ "session_id": session_id,
403
+ "tool_used": True,
404
+ }
405
+ }
406
+
407
  # RESTORED: Get intelligently ordered personas based on context
408
  top_personas = await chat_orchestrator.get_top_personas(
409
  session_id=session_id,
multi_llm_chatbot_backend/app/api/routes/provider.py CHANGED
@@ -69,6 +69,8 @@ async def switch_provider(provider_data: ProviderSwitch):
69
  new_llm = create_llm_client(current_provider)
70
  llm = new_llm
71
 
 
 
72
  new_personas = get_default_personas(new_llm)
73
  chat_orchestrator.personas.clear()
74
  for persona in new_personas:
 
69
  new_llm = create_llm_client(current_provider)
70
  llm = new_llm
71
 
72
+ chat_orchestrator.llm_client = new_llm
73
+
74
  new_personas = get_default_personas(new_llm)
75
  chat_orchestrator.personas.clear()
76
  for persona in new_personas:
multi_llm_chatbot_backend/app/config.py CHANGED
@@ -11,7 +11,7 @@ import os
11
  import logging
12
  import colorsys
13
  from pathlib import Path
14
- from typing import List, Optional
15
  from colorhash import ColorHash
16
 
17
  import yaml
@@ -234,6 +234,23 @@ class RAGConfig(BaseModel):
234
  chroma_collection: str = "phd_advisor_documents"
235
 
236
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  class AppSettings(BaseModel):
238
  """Top-level container that mirrors the YAML structure."""
239
  app: AppConfig = AppConfig()
@@ -246,6 +263,7 @@ class AppSettings(BaseModel):
246
  mongodb: MongoDBConfig = MongoDBConfig()
247
  llm: LLMConfig = LLMConfig()
248
  rag: RAGConfig = RAGConfig()
 
249
 
250
  # ------------------------------------------------------------------
251
  # Convenience helpers
 
11
  import logging
12
  import colorsys
13
  from pathlib import Path
14
+ from typing import Any, Dict, List, Optional
15
  from colorhash import ColorHash
16
 
17
  import yaml
 
234
  chroma_collection: str = "phd_advisor_documents"
235
 
236
 
237
+ class ToolsConfig(BaseModel):
238
+ model_config = {"extra": "allow"}
239
+
240
+ def get_enabled_names(self) -> List[str]:
241
+ """Return tool names whose config has ``enabled: true``."""
242
+ return [
243
+ name
244
+ for name, cfg in self.__pydantic_extra__.items()
245
+ if isinstance(cfg, dict) and cfg.get("enabled", True)
246
+ ]
247
+
248
+ def get_tool_config(self, name: str) -> Dict[str, Any]:
249
+ """Return the raw config dict for a single tool, or ``{}``."""
250
+ cfg = self.__pydantic_extra__.get(name, {})
251
+ return cfg if isinstance(cfg, dict) else {}
252
+
253
+
254
  class AppSettings(BaseModel):
255
  """Top-level container that mirrors the YAML structure."""
256
  app: AppConfig = AppConfig()
 
263
  mongodb: MongoDBConfig = MongoDBConfig()
264
  llm: LLMConfig = LLMConfig()
265
  rag: RAGConfig = RAGConfig()
266
+ tools: ToolsConfig = ToolsConfig()
267
 
268
  # ------------------------------------------------------------------
269
  # Convenience helpers
multi_llm_chatbot_backend/app/core/bootstrap.py CHANGED
@@ -30,7 +30,7 @@ def create_llm_client(provider=None):
30
  )
31
 
32
  llm = create_llm_client()
33
- chat_orchestrator = ImprovedChatOrchestrator()
34
 
35
  DEFAULT_PERSONAS = get_default_personas(llm)
36
  for persona in DEFAULT_PERSONAS:
 
30
  )
31
 
32
  llm = create_llm_client()
33
+ chat_orchestrator = ImprovedChatOrchestrator(llm_client=llm)
34
 
35
  DEFAULT_PERSONAS = get_default_personas(llm)
36
  for persona in DEFAULT_PERSONAS:
multi_llm_chatbot_backend/app/core/context_manager.py CHANGED
@@ -209,17 +209,16 @@ class ContextManager:
209
  "role": "user",
210
  "parts": [{"text": content}]
211
  })
212
- elif role in ['assistant', 'methodologist', 'theorist', 'pragmatist']:
213
- formatted.append({
214
- "role": "model",
215
- "parts": [{"text": content}]
216
- })
217
  elif role == 'document':
218
- # Add document as user context
219
  formatted.append({
220
  "role": "user",
221
  "parts": [{"text": f"[Context Document] {content}"}]
222
  })
 
 
 
 
 
223
 
224
  return formatted
225
 
 
209
  "role": "user",
210
  "parts": [{"text": content}]
211
  })
 
 
 
 
 
212
  elif role == 'document':
 
213
  formatted.append({
214
  "role": "user",
215
  "parts": [{"text": f"[Context Document] {content}"}]
216
  })
217
+ else:
218
+ formatted.append({
219
+ "role": "model",
220
+ "parts": [{"text": content}]
221
+ })
222
 
223
  return formatted
224
 
multi_llm_chatbot_backend/app/core/improved_orchestrator.py CHANGED
@@ -4,6 +4,8 @@ from app.core.session_manager import ConversationContext, get_session_manager
4
  from app.core.context_manager import get_context_manager
5
  from app.core.rag_manager import get_rag_manager
6
  from app.config import get_settings
 
 
7
 
8
  import json
9
  import logging
@@ -16,8 +18,9 @@ class ImprovedChatOrchestrator:
16
  Enhanced orchestrator with document awareness and improved context handling
17
  """
18
 
19
- def __init__(self):
20
  self.personas: Dict[str, Persona] = {}
 
21
  self.session_manager = get_session_manager()
22
  self.context_manager = get_context_manager()
23
 
@@ -33,7 +36,49 @@ class ImprovedChatOrchestrator:
33
  def list_personas(self) -> List[str]:
34
  """List all available persona IDs"""
35
  return list(self.personas.keys())
36
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  async def process_message(self,
38
  user_input: str,
39
  session_id: Optional[str] = None,
 
4
  from app.core.context_manager import get_context_manager
5
  from app.core.rag_manager import get_rag_manager
6
  from app.config import get_settings
7
+ from app.llm.llm_client import LLMClient, ToolCallResult
8
+ from app.tools import get_tool_definitions, get_tool_executor
9
 
10
  import json
11
  import logging
 
18
  Enhanced orchestrator with document awareness and improved context handling
19
  """
20
 
21
+ def __init__(self, llm_client: LLMClient = None):
22
  self.personas: Dict[str, Persona] = {}
23
+ self.llm_client = llm_client
24
  self.session_manager = get_session_manager()
25
  self.context_manager = get_context_manager()
26
 
 
36
  def list_personas(self) -> List[str]:
37
  """List all available persona IDs"""
38
  return list(self.personas.keys())
39
+
40
+ async def get_tool_response(self, user_message: str) -> ToolCallResult:
41
+ """Check whether a tool can handle *user_message*.
42
+
43
+ If tools are disabled in config, no LLM client is available, or the
44
+ model decides no tool is needed, returns
45
+ ``ToolCallResult(used_tool=False)``. Otherwise executes the tool and
46
+ returns the grounded response with ``used_tool=True``.
47
+ """
48
+ if self.llm_client is None:
49
+ return ToolCallResult(text="", used_tool=False)
50
+
51
+ settings = get_settings()
52
+ tools_enabled = settings.tools.get_enabled_names()
53
+
54
+ if not tools_enabled:
55
+ return ToolCallResult(text="", used_tool=False)
56
+
57
+ tool_definitions = get_tool_definitions(enabled=tools_enabled)
58
+ tool_executor = get_tool_executor(enabled=tools_enabled)
59
+
60
+ if not tool_definitions:
61
+ return ToolCallResult(text="", used_tool=False)
62
+
63
+ system_prompt = (
64
+ "You are a helpful assistant with access to external tools. "
65
+ "Use the available tools when the user's question can be answered "
66
+ "by one of them. If no tool is relevant, respond with a brief "
67
+ "text answer. "
68
+ "If a tool response includes 'truncated': true, let the user know "
69
+ "how many total results were found and suggest they narrow their "
70
+ "search for more specific results. "
71
+ "Format your responses using markdown. Use bullet points "
72
+ "to present structured data like course listings or professor ratings."
73
+ )
74
+
75
+ return await self.llm_client.generate_with_tools(
76
+ system_prompt=system_prompt,
77
+ user_message=user_message,
78
+ tool_definitions=tool_definitions,
79
+ tool_executor=tool_executor,
80
+ )
81
+
82
  async def process_message(self,
83
  user_input: str,
84
  session_id: Optional[str] = None,
multi_llm_chatbot_backend/app/llm/improved_gemini_client.py CHANGED
@@ -1,10 +1,14 @@
1
  import httpx
2
- import os
3
- from typing import List
4
- from app.llm.llm_client import LLMClient
 
 
 
 
 
5
  from app.core.context_manager import get_context_manager
6
  from app.config import get_settings
7
- import logging
8
 
9
  logger = logging.getLogger(__name__)
10
 
@@ -20,8 +24,16 @@ class ImprovedGeminiClient(LLMClient):
20
  if not self.api_key:
21
  raise ValueError("Gemini API key not set. Provide it in config.yaml (llm.gemini.api_key).")
22
 
 
23
  self.base_url = "https://generativelanguage.googleapis.com/v1beta/models"
24
  self.context_manager = get_context_manager()
 
 
 
 
 
 
 
25
 
26
  async def generate(self, system_prompt: str, context: List[dict], temperature: float, max_tokens: int, response_mime_type: str = None) -> str:
27
  """
@@ -120,3 +132,111 @@ class ImprovedGeminiClient(LLMClient):
120
  except Exception as e:
121
  logger.error(f"Unexpected error in Gemini client: {str(e)}")
122
  return "I encountered an unexpected error. Please try again."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import httpx
2
+ import json
3
+ import logging
4
+ from typing import Any, Callable, Dict, List, Optional
5
+
6
+ from openai import AsyncOpenAI, APIConnectionError, APIStatusError
7
+
8
+ from app.llm.llm_client import LLMClient, ToolCallInfo, ToolCallResult
9
+
10
  from app.core.context_manager import get_context_manager
11
  from app.config import get_settings
 
12
 
13
  logger = logging.getLogger(__name__)
14
 
 
24
  if not self.api_key:
25
  raise ValueError("Gemini API key not set. Provide it in config.yaml (llm.gemini.api_key).")
26
 
27
+ # Native Gemini REST API
28
  self.base_url = "https://generativelanguage.googleapis.com/v1beta/models"
29
  self.context_manager = get_context_manager()
30
+
31
+ # OpenAI-compatible endpoint (for tool calling)
32
+ self.openai_client = AsyncOpenAI(
33
+ base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
34
+ api_key=self.api_key,
35
+ timeout=90.0,
36
+ )
37
 
38
  async def generate(self, system_prompt: str, context: List[dict], temperature: float, max_tokens: int, response_mime_type: str = None) -> str:
39
  """
 
132
  except Exception as e:
133
  logger.error(f"Unexpected error in Gemini client: {str(e)}")
134
  return "I encountered an unexpected error. Please try again."
135
+
136
+ # ------------------------------------------------------------------
137
+ # Tool-calling support (via Gemini OpenAI-compatible endpoint)
138
+ # ------------------------------------------------------------------
139
+
140
+ _MAX_TOOL_ROUNDS = 5
141
+
142
+ async def generate_with_tools(
143
+ self,
144
+ system_prompt: str,
145
+ user_message: str,
146
+ tool_definitions: Optional[List[Dict[str, Any]]] = None,
147
+ tool_executor: Optional[Callable] = None,
148
+ temperature: float = 0.7,
149
+ max_tokens: int = 2048,
150
+ ) -> ToolCallResult:
151
+ """OpenAI-compatible tool-calling loop via Gemini's /openai/ endpoint.
152
+
153
+ Tool definitions are expected in OpenAI format (as returned by the
154
+ tool registry). Loops through the standard tool-call protocol
155
+ until the model produces a plain text response:
156
+
157
+ request → detect tool_calls → execute all → feed results
158
+ back → repeat (up to ``_MAX_TOOL_ROUNDS`` rounds).
159
+
160
+ All tool calls in a single response are executed before the next
161
+ round, so multi-tool queries (e.g. "compare professor A vs B")
162
+ work correctly.
163
+ """
164
+ messages: List[Dict[str, Any]] = [
165
+ {"role": "system", "content": system_prompt},
166
+ {"role": "user", "content": user_message},
167
+ ]
168
+
169
+ openai_tools = tool_definitions or []
170
+ all_tool_calls: List[ToolCallInfo] = []
171
+
172
+ try:
173
+ for _round in range(self._MAX_TOOL_ROUNDS):
174
+ response = await self.openai_client.chat.completions.create(
175
+ model=self.model_name,
176
+ messages=messages,
177
+ tools=openai_tools or None,
178
+ temperature=temperature,
179
+ max_tokens=max_tokens,
180
+ )
181
+
182
+ choice = response.choices[0].message
183
+
184
+ if not choice.tool_calls:
185
+ return ToolCallResult(
186
+ text=choice.content or "",
187
+ used_tool=bool(all_tool_calls),
188
+ tool_name=all_tool_calls[0].name if all_tool_calls else None,
189
+ tool_args=all_tool_calls[0].args if all_tool_calls else {},
190
+ tool_calls_made=all_tool_calls,
191
+ )
192
+
193
+ messages.append(choice.model_dump(exclude_none=True))
194
+
195
+ for tc in choice.tool_calls:
196
+ fn_name = tc.function.name
197
+ fn_args = json.loads(tc.function.arguments)
198
+ logger.info("Gemini requested tool call: %s(%s)", fn_name, fn_args)
199
+ all_tool_calls.append(ToolCallInfo(name=fn_name, args=fn_args))
200
+
201
+ try:
202
+ tool_result = await tool_executor(name=fn_name, **fn_args)
203
+ except Exception as exc:
204
+ logger.error("Tool %s failed: %s", fn_name, exc)
205
+ tool_result = {"error": str(exc)}
206
+
207
+ messages.append({
208
+ "role": "tool",
209
+ "tool_call_id": tc.id,
210
+ "content": json.dumps(tool_result),
211
+ })
212
+
213
+ logger.warning(
214
+ "Tool-calling loop exhausted after %d rounds", self._MAX_TOOL_ROUNDS,
215
+ )
216
+ last_content = response.choices[0].message.content or ""
217
+ return ToolCallResult(
218
+ text=last_content or "I was unable to finish looking that up. Please try again.",
219
+ used_tool=bool(all_tool_calls),
220
+ tool_name=all_tool_calls[0].name if all_tool_calls else None,
221
+ tool_args=all_tool_calls[0].args if all_tool_calls else {},
222
+ tool_calls_made=all_tool_calls,
223
+ )
224
+
225
+ except APIConnectionError:
226
+ logger.error("Unable to connect to Gemini OpenAI-compat endpoint")
227
+ return ToolCallResult(
228
+ text="I'm unable to connect to the AI service. Please try again.",
229
+ used_tool=False,
230
+ )
231
+ except APIStatusError as e:
232
+ logger.error("Gemini tool-call API error: %s - %s", e.status_code, e.message)
233
+ return ToolCallResult(
234
+ text="The AI service encountered an error. Please try again.",
235
+ used_tool=False,
236
+ )
237
+ except Exception as e:
238
+ logger.error("Unexpected error in Gemini tool-calling: %s", e)
239
+ return ToolCallResult(
240
+ text="I encountered an unexpected error. Please try again.",
241
+ used_tool=False,
242
+ )
multi_llm_chatbot_backend/app/llm/improved_vllm_client.py CHANGED
@@ -1,8 +1,11 @@
1
- from typing import List
 
 
 
2
  from openai import AsyncOpenAI, APIConnectionError, APIStatusError
3
- from app.llm.llm_client import LLMClient
 
4
  from app.core.context_manager import get_context_manager
5
- import logging
6
 
7
  logger = logging.getLogger(__name__)
8
 
@@ -15,7 +18,7 @@ class ImprovedVllmClient(LLMClient):
15
  self.client = AsyncOpenAI(
16
  base_url=f"{api_url}/v1",
17
  api_key=api_key,
18
- timeout=30.0,
19
  )
20
  self.context_manager = get_context_manager()
21
 
@@ -70,3 +73,118 @@ class ImprovedVllmClient(LLMClient):
70
  logger.error(f"Unexpected error in vLLM client: {str(e)}")
71
  return "I encountered an unexpected error. Please try again."
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ from typing import Any, Callable, Dict, List, Optional
4
+
5
  from openai import AsyncOpenAI, APIConnectionError, APIStatusError
6
+
7
+ from app.llm.llm_client import LLMClient, ToolCallInfo, ToolCallResult
8
  from app.core.context_manager import get_context_manager
 
9
 
10
  logger = logging.getLogger(__name__)
11
 
 
18
  self.client = AsyncOpenAI(
19
  base_url=f"{api_url}/v1",
20
  api_key=api_key,
21
+ timeout=90.0,
22
  )
23
  self.context_manager = get_context_manager()
24
 
 
73
  logger.error(f"Unexpected error in vLLM client: {str(e)}")
74
  return "I encountered an unexpected error. Please try again."
75
 
76
+ # ------------------------------------------------------------------
77
+ # Tool-calling support (OpenAI-compatible format)
78
+ # ------------------------------------------------------------------
79
+
80
+ _MAX_TOOL_ROUNDS = 5
81
+
82
+ async def generate_with_tools(
83
+ self,
84
+ system_prompt: str,
85
+ user_message: str,
86
+ tool_definitions: Optional[List[Dict[str, Any]]] = None,
87
+ tool_executor: Optional[Callable] = None,
88
+ temperature: float = 0.7,
89
+ max_tokens: int = 2048,
90
+ ) -> ToolCallResult:
91
+ """OpenAI-compatible tool-calling loop for vLLM.
92
+
93
+ Tool definitions are expected in OpenAI format (as returned by the
94
+ tool registry). Loops through the standard tool-call protocol
95
+ until the model produces a plain text response:
96
+
97
+ request → detect tool_calls → execute all → feed results
98
+ back → repeat (up to ``_MAX_TOOL_ROUNDS`` rounds).
99
+
100
+ All tool calls in a single response are executed before the next
101
+ round, so multi-tool queries (e.g. "compare professor A vs B")
102
+ work correctly.
103
+ """
104
+ if not self.model_name:
105
+ await self.refresh_model()
106
+
107
+ messages: List[Dict[str, Any]] = [
108
+ {"role": "system", "content": system_prompt},
109
+ {"role": "user", "content": user_message},
110
+ ]
111
+
112
+ openai_tools = tool_definitions or []
113
+
114
+ all_tool_calls: List[ToolCallInfo] = []
115
+
116
+ try:
117
+ for _round in range(self._MAX_TOOL_ROUNDS):
118
+ response = await self.client.chat.completions.create(
119
+ model=self.model_name,
120
+ messages=messages,
121
+ tools=openai_tools or None,
122
+ temperature=temperature,
123
+ max_tokens=max_tokens,
124
+ )
125
+
126
+ choice = response.choices[0].message
127
+
128
+ if not choice.tool_calls:
129
+ return ToolCallResult(
130
+ text=choice.content or "",
131
+ used_tool=bool(all_tool_calls),
132
+ tool_name=all_tool_calls[0].name if all_tool_calls else None,
133
+ tool_args=all_tool_calls[0].args if all_tool_calls else {},
134
+ tool_calls_made=all_tool_calls,
135
+ )
136
+
137
+ messages.append(choice.model_dump())
138
+
139
+ for tc in choice.tool_calls:
140
+ fn_name = tc.function.name
141
+ fn_args = json.loads(tc.function.arguments)
142
+ logger.info("vLLM requested tool call: %s(%s)", fn_name, fn_args)
143
+ all_tool_calls.append(ToolCallInfo(name=fn_name, args=fn_args))
144
+
145
+ try:
146
+ tool_result = await tool_executor(name=fn_name, **fn_args)
147
+ except Exception as exc:
148
+ logger.error("Tool %s failed: %s", fn_name, exc)
149
+ tool_result = {"error": str(exc)}
150
+
151
+ messages.append({
152
+ "role": "tool",
153
+ "tool_call_id": tc.id,
154
+ "content": json.dumps(tool_result),
155
+ })
156
+
157
+ logger.warning(
158
+ "Tool-calling loop exhausted after %d rounds", self._MAX_TOOL_ROUNDS,
159
+ )
160
+ last_content = response.choices[0].message.content or ""
161
+ return ToolCallResult(
162
+ text=last_content or "I was unable to finish looking that up. Please try again.",
163
+ used_tool=bool(all_tool_calls),
164
+ tool_name=all_tool_calls[0].name if all_tool_calls else None,
165
+ tool_args=all_tool_calls[0].args if all_tool_calls else {},
166
+ tool_calls_made=all_tool_calls,
167
+ )
168
+
169
+ except APIConnectionError:
170
+ logger.error("Unable to connect to vLLM at %s", self.api_url)
171
+ return ToolCallResult(
172
+ text="I'm unable to connect to the AI service. Please ensure the vLLM endpoint is available.",
173
+ used_tool=False,
174
+ )
175
+ except APIStatusError as e:
176
+ logger.error("vLLM tool-call API error: %s - %s", e.status_code, e.message)
177
+ if e.status_code == 404:
178
+ self.model_name = None
179
+ return ToolCallResult(
180
+ text="The AI service encountered an error. Please try again.",
181
+ used_tool=False,
182
+ )
183
+ except Exception as e:
184
+ logger.error("Unexpected error in vLLM tool-calling: %s", e)
185
+ return ToolCallResult(
186
+ text="I encountered an unexpected error. Please try again.",
187
+ used_tool=False,
188
+ )
189
+
190
+
multi_llm_chatbot_backend/app/llm/llm_client.py CHANGED
@@ -1,27 +1,72 @@
1
  from abc import ABC, abstractmethod
2
- from typing import List
 
3
  import re
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  class LLMClient(ABC):
6
  """Abstract base class for all LLM clients"""
7
-
8
  @abstractmethod
9
  async def generate(self, system_prompt: str, context: List[dict], temperature: float, max_tokens: int, response_mime_type: str = None) -> str:
10
  """
11
  Generate a response using the LLM.
12
-
13
  Args:
14
  system_prompt (str): The system prompt defining the persona/role
15
  context (List[dict]): List of conversation messages with 'role' and 'content' keys
16
  temperature (float): Sampling temperature for generation
17
  max_tokens (int): Maximum number of tokens to generate
18
  response_mime_type (str, optional): MIME type for the response format. Defaults to None.
19
-
20
  Returns:
21
  str: The generated response text
22
  """
23
  pass
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def _clean_response(self, response: str) -> str:
26
  """Clean up response text, preserving Markdown formatting."""
27
  response = response.replace("\r\n", "\n").replace("\r", "\n")
 
1
  from abc import ABC, abstractmethod
2
+ from dataclasses import dataclass, field
3
+ from typing import Any, Callable, Dict, List, Optional
4
  import re
5
 
6
+
7
+ @dataclass
8
+ class ToolCallInfo:
9
+ """Record of a single tool invocation."""
10
+
11
+ name: str
12
+ args: dict = field(default_factory=dict)
13
+
14
+
15
+ @dataclass
16
+ class ToolCallResult:
17
+ """Structured return value from ``generate_with_tools``."""
18
+
19
+ text: str
20
+ used_tool: bool
21
+ tool_name: Optional[str] = None
22
+ tool_args: dict = field(default_factory=dict)
23
+ tool_calls_made: List["ToolCallInfo"] = field(default_factory=list)
24
+
25
+
26
  class LLMClient(ABC):
27
  """Abstract base class for all LLM clients"""
28
+
29
  @abstractmethod
30
  async def generate(self, system_prompt: str, context: List[dict], temperature: float, max_tokens: int, response_mime_type: str = None) -> str:
31
  """
32
  Generate a response using the LLM.
33
+
34
  Args:
35
  system_prompt (str): The system prompt defining the persona/role
36
  context (List[dict]): List of conversation messages with 'role' and 'content' keys
37
  temperature (float): Sampling temperature for generation
38
  max_tokens (int): Maximum number of tokens to generate
39
  response_mime_type (str, optional): MIME type for the response format. Defaults to None.
40
+
41
  Returns:
42
  str: The generated response text
43
  """
44
  pass
45
 
46
+ async def generate_with_tools(
47
+ self,
48
+ system_prompt: str,
49
+ user_message: str,
50
+ tool_definitions: Optional[List[Dict[str, Any]]] = None,
51
+ tool_executor: Optional[Callable] = None,
52
+ temperature: float = 0.7,
53
+ max_tokens: int = 2048,
54
+ ) -> ToolCallResult:
55
+ """Generate a response, optionally invoking tools.
56
+
57
+ Subclasses that support native tool calling should override this
58
+ method. The default implementation ignores tools and falls back
59
+ to a plain ``generate()`` call so that providers without tool
60
+ support degrade gracefully.
61
+ """
62
+ text = await self.generate(
63
+ system_prompt=system_prompt,
64
+ context=[{"role": "user", "content": user_message}],
65
+ temperature=temperature,
66
+ max_tokens=max_tokens,
67
+ )
68
+ return ToolCallResult(text=text, used_tool=False)
69
+
70
  def _clean_response(self, response: str) -> str:
71
  """Clean up response text, preserving Markdown formatting."""
72
  response = response.replace("\r\n", "\n").replace("\r", "\n")
multi_llm_chatbot_backend/app/main.py CHANGED
@@ -58,6 +58,7 @@ app.include_router(auth_router, prefix="/auth", tags=["authentication"])
58
  app.include_router(chat_sessions_router, prefix="/api", tags=["chat-sessions"])
59
  app.include_router(phd_canvas_router, prefix="/api", tags=["phd-canvas"])
60
 
 
61
  # ---------------------------------------------------------------------------
62
  # Public configuration endpoint — serves the frontend-safe subset
63
  # ---------------------------------------------------------------------------
 
58
  app.include_router(chat_sessions_router, prefix="/api", tags=["chat-sessions"])
59
  app.include_router(phd_canvas_router, prefix="/api", tags=["phd-canvas"])
60
 
61
+
62
  # ---------------------------------------------------------------------------
63
  # Public configuration endpoint — serves the frontend-safe subset
64
  # ---------------------------------------------------------------------------
multi_llm_chatbot_backend/app/tests/unit/test_course_search_tool.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import unittest
3
+
4
+ from app.tools.search_courses import TOOL_DEFINITION, execute
5
+
6
+
7
+ class TestSearchCoursesContract(unittest.TestCase):
8
+ """The search_courses tool module must export a valid OpenAI
9
+ tool definition and an async executor."""
10
+
11
+ def test_tool_definition_has_required_fields(self):
12
+ self.assertEqual(TOOL_DEFINITION["type"], "function")
13
+ self.assertIn("function", TOOL_DEFINITION)
14
+ fn = TOOL_DEFINITION["function"]
15
+ self.assertIn("name", fn)
16
+ self.assertIn("description", fn)
17
+ self.assertIn("parameters", fn)
18
+
19
+ def test_tool_definition_name(self):
20
+ self.assertEqual(TOOL_DEFINITION["function"]["name"], "search_courses")
21
+
22
+ def test_tool_definition_has_nonempty_description(self):
23
+ self.assertIsInstance(TOOL_DEFINITION["function"]["description"], str)
24
+ self.assertGreater(len(TOOL_DEFINITION["function"]["description"]), 0)
25
+
26
+ def test_tool_definition_parameters_is_valid_schema(self):
27
+ params = TOOL_DEFINITION["function"]["parameters"]
28
+ self.assertEqual(params["type"], "object")
29
+ self.assertIn("properties", params)
30
+ self.assertIn("subject", params["properties"])
31
+
32
+ def test_execute_is_async_callable(self):
33
+ self.assertTrue(asyncio.iscoroutinefunction(execute))
multi_llm_chatbot_backend/app/tests/unit/test_gemini_client.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import unittest
4
+ from unittest.mock import AsyncMock, MagicMock, patch
5
+
6
+ from openai import APIConnectionError, APIStatusError
7
+
8
+ from app.llm.llm_client import ToolCallResult
9
+ from app.llm.improved_gemini_client import ImprovedGeminiClient
10
+
11
+
12
+ FAKE_TOOL = {
13
+ "type": "function",
14
+ "function": {
15
+ "name": "search_courses",
16
+ "description": "Search courses",
17
+ "parameters": {
18
+ "type": "object",
19
+ "properties": {
20
+ "subject": {"type": "string", "description": "Subject code"},
21
+ },
22
+ },
23
+ },
24
+ }
25
+
26
+
27
+ def _make_text_completion_mock(content="Response"):
28
+ """Build a ChatCompletion mock with no tool calls."""
29
+ mock_message = MagicMock()
30
+ mock_message.content = content
31
+ mock_message.tool_calls = None
32
+ mock_choice = MagicMock()
33
+ mock_choice.message = mock_message
34
+ return MagicMock(choices=[mock_choice])
35
+
36
+
37
+ def _make_tool_call_mock(fn_name, fn_args_dict, tool_call_id="call_123"):
38
+ """Build a ChatCompletion mock where the model requests a tool call."""
39
+ fn_args_json = json.dumps(fn_args_dict)
40
+
41
+ tool_call = MagicMock()
42
+ tool_call.id = tool_call_id
43
+ tool_call.function.name = fn_name
44
+ tool_call.function.arguments = fn_args_json
45
+
46
+ mock_message = MagicMock()
47
+ mock_message.content = None
48
+ mock_message.tool_calls = [tool_call]
49
+ mock_message.model_dump.return_value = {
50
+ "role": "assistant",
51
+ "content": None,
52
+ "tool_calls": [{
53
+ "id": tool_call_id,
54
+ "type": "function",
55
+ "function": {"name": fn_name, "arguments": fn_args_json},
56
+ }],
57
+ }
58
+
59
+ mock_choice = MagicMock()
60
+ mock_choice.message = mock_message
61
+ return MagicMock(choices=[mock_choice])
62
+
63
+
64
+ def _make_multi_tool_call_mock(calls):
65
+ """Build a ChatCompletion mock with multiple parallel tool calls.
66
+
67
+ *calls* is a list of (fn_name, fn_args_dict, tool_call_id) tuples.
68
+ """
69
+ tool_calls = []
70
+ dump_calls = []
71
+ for fn_name, fn_args_dict, tool_call_id in calls:
72
+ fn_args_json = json.dumps(fn_args_dict)
73
+ tc = MagicMock()
74
+ tc.id = tool_call_id
75
+ tc.function.name = fn_name
76
+ tc.function.arguments = fn_args_json
77
+ tool_calls.append(tc)
78
+ dump_calls.append({
79
+ "id": tool_call_id,
80
+ "type": "function",
81
+ "function": {"name": fn_name, "arguments": fn_args_json},
82
+ })
83
+
84
+ mock_message = MagicMock()
85
+ mock_message.content = None
86
+ mock_message.tool_calls = tool_calls
87
+ mock_message.model_dump.return_value = {
88
+ "role": "assistant",
89
+ "content": None,
90
+ "tool_calls": dump_calls,
91
+ }
92
+
93
+ mock_choice = MagicMock()
94
+ mock_choice.message = mock_message
95
+ return MagicMock(choices=[mock_choice])
96
+
97
+
98
+ def _make_gemini_client(MockSettings, MockCtxMgr):
99
+ """Instantiate an ImprovedGeminiClient with mocked dependencies."""
100
+ mock_settings = MagicMock()
101
+ mock_settings.llm.gemini.api_key = "fake-key"
102
+ mock_settings.llm.gemini.model = "gemini-2.0-flash"
103
+ MockSettings.return_value = mock_settings
104
+ MockCtxMgr.return_value = MagicMock()
105
+ return ImprovedGeminiClient()
106
+
107
+
108
+ @patch("app.llm.improved_gemini_client.get_context_manager")
109
+ @patch("app.llm.improved_gemini_client.get_settings")
110
+ class TestGeminiGenerateWithTools(unittest.TestCase):
111
+ """Unit tests for ImprovedGeminiClient.generate_with_tools()
112
+ using the OpenAI-compatible endpoint."""
113
+
114
+ # ------------------------------------------------------------------
115
+ # Happy path — no tool call
116
+ # ------------------------------------------------------------------
117
+
118
+ def test_direct_text_response_returns_text(self, MockSettings, MockCtxMgr):
119
+ """When the model responds with text (no tool call), return it."""
120
+ gemini = _make_gemini_client(MockSettings, MockCtxMgr)
121
+ gemini.openai_client.chat.completions.create = AsyncMock(
122
+ return_value=_make_text_completion_mock("Hello, world!"),
123
+ )
124
+ mock_executor = AsyncMock()
125
+
126
+ result = asyncio.run(gemini.generate_with_tools(
127
+ system_prompt="You are helpful.",
128
+ user_message="Hi there",
129
+ tool_definitions=[FAKE_TOOL],
130
+ tool_executor=mock_executor,
131
+ ))
132
+
133
+ self.assertIsInstance(result, ToolCallResult)
134
+ self.assertEqual(result.text, "Hello, world!")
135
+ self.assertFalse(result.used_tool)
136
+ mock_executor.assert_not_called()
137
+
138
+ # ------------------------------------------------------------------
139
+ # Happy path — tool call
140
+ # ------------------------------------------------------------------
141
+
142
+ def test_function_call_triggers_executor_and_returns_final_text(self, MockSettings, MockCtxMgr):
143
+ """When the model requests a tool call, execute it and return
144
+ the text from the follow-up completion."""
145
+ gemini = _make_gemini_client(MockSettings, MockCtxMgr)
146
+ gemini.openai_client.chat.completions.create = AsyncMock(side_effect=[
147
+ _make_tool_call_mock("search_courses", {"subject": "CSCI"}),
148
+ _make_text_completion_mock("CSCI 1300 is available MWF 10-10:50."),
149
+ ])
150
+ mock_executor = AsyncMock(
151
+ return_value={"courses": [{"title": "Intro to CS"}]}
152
+ )
153
+
154
+ result = asyncio.run(gemini.generate_with_tools(
155
+ system_prompt="You are helpful.",
156
+ user_message="What CSCI classes are there?",
157
+ tool_definitions=[FAKE_TOOL],
158
+ tool_executor=mock_executor,
159
+ ))
160
+
161
+ mock_executor.assert_called_once_with(
162
+ name="search_courses", subject="CSCI",
163
+ )
164
+ self.assertIsInstance(result, ToolCallResult)
165
+ self.assertEqual(result.text, "CSCI 1300 is available MWF 10-10:50.")
166
+ self.assertTrue(result.used_tool)
167
+ self.assertEqual(result.tool_name, "search_courses")
168
+ self.assertEqual(result.tool_args, {"subject": "CSCI"})
169
+ self.assertEqual(len(result.tool_calls_made), 1)
170
+ self.assertEqual(result.tool_calls_made[0].name, "search_courses")
171
+ self.assertEqual(gemini.openai_client.chat.completions.create.call_count, 2)
172
+
173
+ # ------------------------------------------------------------------
174
+ # Payload format
175
+ # ------------------------------------------------------------------
176
+
177
+ def test_tool_definitions_passed_through_in_openai_format(self, MockSettings, MockCtxMgr):
178
+ """Tool definitions (already in OpenAI format) are passed through
179
+ directly to the completions API."""
180
+ gemini = _make_gemini_client(MockSettings, MockCtxMgr)
181
+ gemini.openai_client.chat.completions.create = AsyncMock(
182
+ return_value=_make_text_completion_mock("Ok"),
183
+ )
184
+
185
+ asyncio.run(gemini.generate_with_tools(
186
+ system_prompt="You are helpful.",
187
+ user_message="Hello",
188
+ tool_definitions=[FAKE_TOOL],
189
+ tool_executor=AsyncMock(),
190
+ ))
191
+
192
+ call_kwargs = gemini.openai_client.chat.completions.create.call_args[1]
193
+ tools = call_kwargs["tools"]
194
+ self.assertEqual(len(tools), 1)
195
+ self.assertEqual(tools[0]["type"], "function")
196
+ self.assertEqual(tools[0]["function"]["name"], "search_courses")
197
+ self.assertIn("parameters", tools[0]["function"])
198
+
199
+ def test_tool_result_appended_to_followup(self, MockSettings, MockCtxMgr):
200
+ """After executing a tool, the follow-up call must include
201
+ the assistant message, a ``role: tool`` message, and ``tools=``."""
202
+ tool_output = {"courses": [{"title": "Algorithms"}]}
203
+ gemini = _make_gemini_client(MockSettings, MockCtxMgr)
204
+ gemini.openai_client.chat.completions.create = AsyncMock(side_effect=[
205
+ _make_tool_call_mock("search_courses", {"subject": "CSCI"}),
206
+ _make_text_completion_mock("Here are the results."),
207
+ ])
208
+ mock_executor = AsyncMock(return_value=tool_output)
209
+
210
+ asyncio.run(gemini.generate_with_tools(
211
+ system_prompt="You are helpful.",
212
+ user_message="Find CSCI courses",
213
+ tool_definitions=[FAKE_TOOL],
214
+ tool_executor=mock_executor,
215
+ ))
216
+
217
+ second_call_kwargs = gemini.openai_client.chat.completions.create.call_args_list[1][1]
218
+ messages = second_call_kwargs["messages"]
219
+
220
+ assistant_msg = messages[-2]
221
+ self.assertEqual(assistant_msg["role"], "assistant")
222
+
223
+ tool_msg = messages[-1]
224
+ self.assertEqual(tool_msg["role"], "tool")
225
+ self.assertEqual(tool_msg["tool_call_id"], "call_123")
226
+ self.assertEqual(json.loads(tool_msg["content"]), tool_output)
227
+
228
+ self.assertIn("tools", second_call_kwargs,
229
+ "Follow-up call must include tools= so the model can "
230
+ "request additional tool calls if needed")
231
+
232
+ # ------------------------------------------------------------------
233
+ # Error handling
234
+ # ------------------------------------------------------------------
235
+
236
+ def test_tool_executor_failure_serialises_error_and_continues(self, MockSettings, MockCtxMgr):
237
+ """If the tool executor raises, the error is serialised as the
238
+ tool result and the loop continues to the follow-up completion."""
239
+ gemini = _make_gemini_client(MockSettings, MockCtxMgr)
240
+ gemini.openai_client.chat.completions.create = AsyncMock(side_effect=[
241
+ _make_tool_call_mock("search_courses", {"subject": "CSCI"}),
242
+ _make_text_completion_mock("Sorry, I couldn't look that up."),
243
+ ])
244
+ mock_executor = AsyncMock(side_effect=RuntimeError("network down"))
245
+
246
+ result = asyncio.run(gemini.generate_with_tools(
247
+ system_prompt="You are helpful.",
248
+ user_message="Find CSCI courses",
249
+ tool_definitions=[FAKE_TOOL],
250
+ tool_executor=mock_executor,
251
+ ))
252
+
253
+ self.assertTrue(result.used_tool)
254
+ self.assertEqual(result.tool_name, "search_courses")
255
+ self.assertEqual(len(result.tool_calls_made), 1)
256
+
257
+ second_call_msgs = gemini.openai_client.chat.completions.create.call_args_list[1][1]["messages"]
258
+ tool_msg = [m for m in second_call_msgs if m.get("role") == "tool"][0]
259
+ self.assertIn("network down", json.loads(tool_msg["content"])["error"])
260
+
261
+ def test_connection_error_returns_not_used(self, MockSettings, MockCtxMgr):
262
+ """APIConnectionError during tool calling returns used_tool=False."""
263
+ gemini = _make_gemini_client(MockSettings, MockCtxMgr)
264
+ gemini.openai_client.chat.completions.create = AsyncMock(
265
+ side_effect=APIConnectionError(request=MagicMock()),
266
+ )
267
+
268
+ result = asyncio.run(gemini.generate_with_tools(
269
+ system_prompt="Test",
270
+ user_message="Hi",
271
+ tool_definitions=[FAKE_TOOL],
272
+ tool_executor=AsyncMock(),
273
+ ))
274
+
275
+ self.assertIsInstance(result, ToolCallResult)
276
+ self.assertFalse(result.used_tool)
277
+ self.assertIn("unable to connect", result.text.lower())
278
+
279
+ def test_status_error_returns_not_used(self, MockSettings, MockCtxMgr):
280
+ """APIStatusError during tool calling returns used_tool=False."""
281
+ gemini = _make_gemini_client(MockSettings, MockCtxMgr)
282
+ mock_response = MagicMock()
283
+ mock_response.status_code = 500
284
+ gemini.openai_client.chat.completions.create = AsyncMock(
285
+ side_effect=APIStatusError(
286
+ message="Server error", response=mock_response, body=None,
287
+ )
288
+ )
289
+
290
+ result = asyncio.run(gemini.generate_with_tools(
291
+ system_prompt="Test",
292
+ user_message="Hi",
293
+ tool_definitions=[FAKE_TOOL],
294
+ tool_executor=AsyncMock(),
295
+ ))
296
+
297
+ self.assertIsInstance(result, ToolCallResult)
298
+ self.assertFalse(result.used_tool)
299
+ self.assertIn("error", result.text.lower())
300
+
301
+ # ------------------------------------------------------------------
302
+ # Multi-tool call in a single response
303
+ # ------------------------------------------------------------------
304
+
305
+ def test_parallel_tool_calls_all_executed(self, MockSettings, MockCtxMgr):
306
+ """When the model requests multiple tool calls in one response,
307
+ all of them are executed and their results fed back."""
308
+ gemini = _make_gemini_client(MockSettings, MockCtxMgr)
309
+ gemini.openai_client.chat.completions.create = AsyncMock(side_effect=[
310
+ _make_multi_tool_call_mock([
311
+ ("rate_my_professor", {"professor_name": "Dubson"}, "call_a"),
312
+ ("rate_my_professor", {"professor_name": "West"}, "call_b"),
313
+ ]),
314
+ _make_text_completion_mock("Dubson has a 4.5 rating. West has a 3.8 rating."),
315
+ ])
316
+ mock_executor = AsyncMock(side_effect=[
317
+ {"professors": [{"name": "Dubson", "rating": 4.5}]},
318
+ {"professors": [{"name": "West", "rating": 3.8}]},
319
+ ])
320
+
321
+ result = asyncio.run(gemini.generate_with_tools(
322
+ system_prompt="You are helpful.",
323
+ user_message="Is professor Dubson or West rated better?",
324
+ tool_definitions=[FAKE_TOOL],
325
+ tool_executor=mock_executor,
326
+ ))
327
+
328
+ self.assertEqual(mock_executor.call_count, 2)
329
+ self.assertTrue(result.used_tool)
330
+ self.assertIn("Dubson", result.text)
331
+ self.assertIn("West", result.text)
332
+ self.assertEqual(len(result.tool_calls_made), 2)
333
+ self.assertEqual(result.tool_calls_made[0].name, "rate_my_professor")
334
+ self.assertEqual(result.tool_calls_made[1].args, {"professor_name": "West"})
335
+ self.assertEqual(gemini.openai_client.chat.completions.create.call_count, 2)
336
+
337
+ def test_parallel_tool_results_all_in_followup_messages(self, MockSettings, MockCtxMgr):
338
+ """All tool results must appear as separate role:tool messages
339
+ in the follow-up request."""
340
+ gemini = _make_gemini_client(MockSettings, MockCtxMgr)
341
+ gemini.openai_client.chat.completions.create = AsyncMock(side_effect=[
342
+ _make_multi_tool_call_mock([
343
+ ("rate_my_professor", {"professor_name": "Dubson"}, "call_a"),
344
+ ("rate_my_professor", {"professor_name": "West"}, "call_b"),
345
+ ]),
346
+ _make_text_completion_mock("Comparison complete."),
347
+ ])
348
+ mock_executor = AsyncMock(side_effect=[
349
+ {"professors": [{"name": "Dubson"}]},
350
+ {"professors": [{"name": "West"}]},
351
+ ])
352
+
353
+ asyncio.run(gemini.generate_with_tools(
354
+ system_prompt="You are helpful.",
355
+ user_message="Compare",
356
+ tool_definitions=[FAKE_TOOL],
357
+ tool_executor=mock_executor,
358
+ ))
359
+
360
+ second_call_msgs = gemini.openai_client.chat.completions.create.call_args_list[1][1]["messages"]
361
+ tool_msgs = [m for m in second_call_msgs if m.get("role") == "tool"]
362
+ self.assertEqual(len(tool_msgs), 2)
363
+ self.assertEqual(tool_msgs[0]["tool_call_id"], "call_a")
364
+ self.assertEqual(tool_msgs[1]["tool_call_id"], "call_b")
365
+
366
+ # ------------------------------------------------------------------
367
+ # Multi-round tool calling
368
+ # ------------------------------------------------------------------
369
+
370
+ def test_sequential_tool_rounds(self, MockSettings, MockCtxMgr):
371
+ """The loop handles a second round of tool calls after the first
372
+ results are fed back."""
373
+ gemini = _make_gemini_client(MockSettings, MockCtxMgr)
374
+ gemini.openai_client.chat.completions.create = AsyncMock(side_effect=[
375
+ _make_tool_call_mock("rate_my_professor", {"professor_name": "Dubson"}, "call_1"),
376
+ _make_tool_call_mock("rate_my_professor", {"professor_name": "West"}, "call_2"),
377
+ _make_text_completion_mock("Dubson is rated higher than West."),
378
+ ])
379
+ mock_executor = AsyncMock(side_effect=[
380
+ {"professors": [{"name": "Dubson", "rating": 4.5}]},
381
+ {"professors": [{"name": "West", "rating": 3.8}]},
382
+ ])
383
+
384
+ result = asyncio.run(gemini.generate_with_tools(
385
+ system_prompt="You are helpful.",
386
+ user_message="Compare Dubson and West",
387
+ tool_definitions=[FAKE_TOOL],
388
+ tool_executor=mock_executor,
389
+ ))
390
+
391
+ self.assertEqual(mock_executor.call_count, 2)
392
+ self.assertTrue(result.used_tool)
393
+ self.assertEqual(len(result.tool_calls_made), 2)
394
+ self.assertEqual(result.tool_name, "rate_my_professor")
395
+ self.assertEqual(gemini.openai_client.chat.completions.create.call_count, 3)
396
+
397
+ # ------------------------------------------------------------------
398
+ # Partial failure in multi-tool context
399
+ # ------------------------------------------------------------------
400
+
401
+ def test_partial_tool_failure_continues(self, MockSettings, MockCtxMgr):
402
+ """If one tool call in a batch fails, the error is serialised
403
+ and the loop continues to the follow-up."""
404
+ gemini = _make_gemini_client(MockSettings, MockCtxMgr)
405
+ gemini.openai_client.chat.completions.create = AsyncMock(side_effect=[
406
+ _make_multi_tool_call_mock([
407
+ ("rate_my_professor", {"professor_name": "Dubson"}, "call_a"),
408
+ ("rate_my_professor", {"professor_name": "West"}, "call_b"),
409
+ ]),
410
+ _make_text_completion_mock("Only Dubson data available."),
411
+ ])
412
+ mock_executor = AsyncMock(side_effect=[
413
+ {"professors": [{"name": "Dubson", "rating": 4.5}]},
414
+ RuntimeError("network down"),
415
+ ])
416
+
417
+ result = asyncio.run(gemini.generate_with_tools(
418
+ system_prompt="You are helpful.",
419
+ user_message="Compare",
420
+ tool_definitions=[FAKE_TOOL],
421
+ tool_executor=mock_executor,
422
+ ))
423
+
424
+ self.assertTrue(result.used_tool)
425
+ self.assertEqual(len(result.tool_calls_made), 2)
426
+ self.assertEqual(gemini.openai_client.chat.completions.create.call_count, 2)
427
+
428
+ second_call_msgs = gemini.openai_client.chat.completions.create.call_args_list[1][1]["messages"]
429
+ tool_msgs = [m for m in second_call_msgs if m.get("role") == "tool"]
430
+ self.assertEqual(len(tool_msgs), 2)
431
+ error_content = json.loads(tool_msgs[1]["content"])
432
+ self.assertIn("error", error_content)
multi_llm_chatbot_backend/app/tests/unit/test_rmp_tool.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import unittest
3
+ from unittest.mock import AsyncMock, MagicMock, patch
4
+
5
+ from app.tools.rate_my_professor import TOOL_DEFINITION, execute
6
+
7
+
8
+ def _graphql_success_response(nodes):
9
+ """Build a mock RMP GraphQL response containing the given teacher nodes."""
10
+ edges = [{"cursor": f"c{i}", "node": n} for i, n in enumerate(nodes)]
11
+ return {
12
+ "data": {
13
+ "search": {
14
+ "teachers": {
15
+ "didFallback": False,
16
+ "edges": edges,
17
+ "pageInfo": {"hasNextPage": False, "endCursor": ""},
18
+ }
19
+ }
20
+ }
21
+ }
22
+
23
+
24
+ SAMPLE_NODE = {
25
+ "id": "VGVhY2hlci0xMjM0",
26
+ "legacyId": 1234,
27
+ "firstName": "Jane",
28
+ "lastName": "Smith",
29
+ "department": "Computer Science",
30
+ "school": {"id": "U2Nob29sLTEwODc=", "name": "University of Colorado Boulder"},
31
+ "avgRating": 4.2,
32
+ "avgDifficulty": 3.1,
33
+ "wouldTakeAgainPercent": 85.0,
34
+ "numRatings": 42,
35
+ }
36
+
37
+
38
+ class TestRMPToolContract(unittest.TestCase):
39
+ """The rate_my_professor tool module must export a valid OpenAI
40
+ tool definition and an async executor."""
41
+
42
+ def test_tool_definition_has_required_fields(self):
43
+ self.assertEqual(TOOL_DEFINITION["type"], "function")
44
+ self.assertIn("function", TOOL_DEFINITION)
45
+ fn = TOOL_DEFINITION["function"]
46
+ self.assertIn("name", fn)
47
+ self.assertIn("description", fn)
48
+ self.assertIn("parameters", fn)
49
+
50
+ def test_tool_definition_name(self):
51
+ self.assertEqual(TOOL_DEFINITION["function"]["name"], "rate_my_professor")
52
+
53
+ def test_tool_definition_has_nonempty_description(self):
54
+ self.assertIsInstance(TOOL_DEFINITION["function"]["description"], str)
55
+ self.assertGreater(len(TOOL_DEFINITION["function"]["description"]), 0)
56
+
57
+ def test_tool_definition_parameters_schema(self):
58
+ params = TOOL_DEFINITION["function"]["parameters"]
59
+ self.assertEqual(params["type"], "object")
60
+ self.assertIn("properties", params)
61
+ self.assertIn("professor_name", params["properties"])
62
+
63
+ def test_execute_is_async_callable(self):
64
+ self.assertTrue(asyncio.iscoroutinefunction(execute))
65
+
66
+
67
+ def _fake_tool_config(name):
68
+ """Return a fake tool config dict with school_id set."""
69
+ if name == "rate_my_professor":
70
+ return {"enabled": True, "school_id": "U2Nob29sLTEwODc="}
71
+ return {}
72
+
73
+
74
+ @patch("app.tools.rate_my_professor.get_settings")
75
+ class TestRMPToolExecutor(unittest.TestCase):
76
+ """Unit tests for rate_my_professor.execute() with mocked HTTP."""
77
+
78
+ def _mock_client(self, get_response, post_response):
79
+ """Build a mock httpx.AsyncClient with canned GET and POST responses."""
80
+ get_resp = MagicMock()
81
+ get_resp.text = '<script>"Authorization":"Basic dGVzdDp0ZXN0"</script>'
82
+ get_resp.raise_for_status = MagicMock()
83
+ if get_response is not None:
84
+ get_resp.text = get_response
85
+
86
+ post_resp = MagicMock()
87
+ post_resp.status_code = 200
88
+ post_resp.json.return_value = post_response
89
+ post_resp.raise_for_status = MagicMock()
90
+
91
+ client_instance = AsyncMock()
92
+ client_instance.get = AsyncMock(return_value=get_resp)
93
+ client_instance.post = AsyncMock(return_value=post_resp)
94
+
95
+ ctx = MagicMock()
96
+ ctx.__aenter__ = AsyncMock(return_value=client_instance)
97
+ ctx.__aexit__ = AsyncMock(return_value=False)
98
+ return ctx, client_instance
99
+
100
+ def test_execute_returns_professor_data(self, mock_get_settings):
101
+ """Successful GraphQL response returns structured professor data."""
102
+ mock_get_settings.return_value.tools.get_tool_config = _fake_tool_config
103
+ ctx, client = self._mock_client(
104
+ get_response=None,
105
+ post_response=_graphql_success_response([SAMPLE_NODE]),
106
+ )
107
+
108
+ with patch("httpx.AsyncClient", return_value=ctx):
109
+ result = asyncio.run(execute(professor_name="Smith"))
110
+
111
+ self.assertIn("professors", result)
112
+ self.assertEqual(len(result["professors"]), 1)
113
+
114
+ prof = result["professors"][0]
115
+ self.assertEqual(prof["name"], "Jane Smith")
116
+ self.assertEqual(prof["department"], "Computer Science")
117
+ self.assertAlmostEqual(prof["rating"], 4.2)
118
+ self.assertAlmostEqual(prof["difficulty"], 3.1)
119
+ self.assertEqual(prof["num_ratings"], 42)
120
+
121
+ def test_execute_returns_empty_on_no_results(self, mock_get_settings):
122
+ """When the GraphQL API returns no matching professors, return
123
+ an empty list — not an error."""
124
+ mock_get_settings.return_value.tools.get_tool_config = _fake_tool_config
125
+ ctx, _ = self._mock_client(
126
+ get_response=None,
127
+ post_response=_graphql_success_response([]),
128
+ )
129
+
130
+ with patch("httpx.AsyncClient", return_value=ctx):
131
+ result = asyncio.run(execute(professor_name="Nonexistent"))
132
+
133
+ self.assertIn("professors", result)
134
+ self.assertEqual(len(result["professors"]), 0)
135
+
136
+ def test_execute_returns_error_on_api_failure(self, mock_get_settings):
137
+ """When the HTTP request fails, return an error payload instead
138
+ of raising an exception."""
139
+ mock_get_settings.return_value.tools.get_tool_config = _fake_tool_config
140
+ ctx = MagicMock()
141
+ client_instance = AsyncMock()
142
+ client_instance.get = AsyncMock(side_effect=Exception("connection refused"))
143
+ client_instance.post = AsyncMock(side_effect=Exception("connection refused"))
144
+ ctx.__aenter__ = AsyncMock(return_value=client_instance)
145
+ ctx.__aexit__ = AsyncMock(return_value=False)
146
+
147
+ with patch("httpx.AsyncClient", return_value=ctx):
148
+ result = asyncio.run(execute(professor_name="Smith"))
149
+
150
+ self.assertIn("professors", result)
151
+ self.assertEqual(len(result["professors"]), 0)
152
+ self.assertIn("error", result)
153
+
154
+ def test_execute_accepts_name_kwarg(self, mock_get_settings):
155
+ """The dispatcher passes name= as a kwarg; execute must accept
156
+ and ignore it without error."""
157
+ mock_get_settings.return_value.tools.get_tool_config = _fake_tool_config
158
+ ctx, _ = self._mock_client(
159
+ get_response=None,
160
+ post_response=_graphql_success_response([SAMPLE_NODE]),
161
+ )
162
+
163
+ with patch("httpx.AsyncClient", return_value=ctx):
164
+ result = asyncio.run(
165
+ execute(name="rate_my_professor", professor_name="Smith")
166
+ )
167
+
168
+ self.assertIn("professors", result)
169
+ self.assertEqual(len(result["professors"]), 1)
multi_llm_chatbot_backend/app/tests/unit/test_tool_registry.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import unittest
3
+ from unittest.mock import AsyncMock
4
+
5
+ from app.tools import (
6
+ get_tool_definitions,
7
+ get_tool_executor,
8
+ list_registered_tools,
9
+ _REGISTRY,
10
+ )
11
+
12
+
13
+ KNOWN_TOOLS = {"search_courses", "rate_my_professor"}
14
+
15
+
16
+ class TestToolDiscovery(unittest.TestCase):
17
+ """Auto-discovery should find every tool module that exports
18
+ TOOL_DEFINITION + execute."""
19
+
20
+ def test_known_tools_are_discovered(self):
21
+ registered = set(list_registered_tools())
22
+ for name in KNOWN_TOOLS:
23
+ self.assertIn(name, registered, f"Tool '{name}' was not discovered")
24
+
25
+ def test_registry_entries_have_definition_and_executor(self):
26
+ for name, entry in _REGISTRY.items():
27
+ self.assertIn("definition", entry, f"'{name}' missing definition")
28
+ self.assertIn("executor", entry, f"'{name}' missing executor")
29
+
30
+ def test_definitions_have_required_fields(self):
31
+ for name, entry in _REGISTRY.items():
32
+ defn = entry["definition"]
33
+ self.assertEqual(defn["type"], "function")
34
+ self.assertIn("function", defn)
35
+ fn = defn["function"]
36
+ self.assertIn("name", fn)
37
+ self.assertIn("description", fn)
38
+ self.assertIn("parameters", fn)
39
+ self.assertEqual(fn["name"], name)
40
+
41
+ def test_executors_are_async_callables(self):
42
+ for name, entry in _REGISTRY.items():
43
+ self.assertTrue(
44
+ asyncio.iscoroutinefunction(entry["executor"]),
45
+ f"Executor for '{name}' is not an async function",
46
+ )
47
+
48
+
49
+ class TestGetToolDefinitions(unittest.TestCase):
50
+ """get_tool_definitions() returns OpenAI-format tool dicts,
51
+ optionally filtered."""
52
+
53
+ def test_returns_all_when_no_filter(self):
54
+ defs = get_tool_definitions()
55
+ names = {d["function"]["name"] for d in defs}
56
+ self.assertTrue(KNOWN_TOOLS.issubset(names))
57
+
58
+ def test_filter_to_single_tool(self):
59
+ defs = get_tool_definitions(enabled=["search_courses"])
60
+ self.assertEqual(len(defs), 1)
61
+ self.assertEqual(defs[0]["function"]["name"], "search_courses")
62
+
63
+ def test_filter_to_multiple_tools(self):
64
+ defs = get_tool_definitions(enabled=["search_courses", "rate_my_professor"])
65
+ names = {d["function"]["name"] for d in defs}
66
+ self.assertEqual(names, KNOWN_TOOLS)
67
+
68
+ def test_filter_with_unknown_name_returns_empty(self):
69
+ defs = get_tool_definitions(enabled=["nonexistent_tool"])
70
+ self.assertEqual(defs, [])
71
+
72
+ def test_filter_with_empty_list_returns_empty(self):
73
+ defs = get_tool_definitions(enabled=[])
74
+ self.assertEqual(defs, [])
75
+
76
+ def test_filter_ignores_unknown_names_keeps_valid(self):
77
+ defs = get_tool_definitions(enabled=["search_courses", "bogus"])
78
+ self.assertEqual(len(defs), 1)
79
+ self.assertEqual(defs[0]["function"]["name"], "search_courses")
80
+
81
+
82
+ class TestGetToolExecutor(unittest.TestCase):
83
+ """get_tool_executor() returns a dispatcher that routes to the
84
+ correct tool executor."""
85
+
86
+ def test_dispatch_known_tool(self):
87
+ mock_exec = AsyncMock(return_value={"courses": []})
88
+ original = _REGISTRY["search_courses"]["executor"]
89
+ _REGISTRY["search_courses"]["executor"] = mock_exec
90
+ try:
91
+ dispatch = get_tool_executor()
92
+ result = asyncio.run(dispatch(name="search_courses", subject="CSCI"))
93
+ mock_exec.assert_called_once_with(name="search_courses", subject="CSCI")
94
+ self.assertEqual(result, {"courses": []})
95
+ finally:
96
+ _REGISTRY["search_courses"]["executor"] = original
97
+
98
+ def test_dispatch_unknown_tool_returns_error(self):
99
+ dispatch = get_tool_executor()
100
+ result = asyncio.run(dispatch(name="nonexistent"))
101
+ self.assertIn("error", result)
102
+
103
+ def test_filtered_executor_allows_enabled_tool(self):
104
+ mock_exec = AsyncMock(return_value={"courses": []})
105
+ original = _REGISTRY["search_courses"]["executor"]
106
+ _REGISTRY["search_courses"]["executor"] = mock_exec
107
+ try:
108
+ dispatch = get_tool_executor(enabled=["search_courses"])
109
+ result = asyncio.run(dispatch(name="search_courses", subject="CSCI"))
110
+ self.assertNotIn("error", result)
111
+ finally:
112
+ _REGISTRY["search_courses"]["executor"] = original
113
+
114
+ def test_filtered_executor_blocks_disabled_tool(self):
115
+ dispatch = get_tool_executor(enabled=["search_courses"])
116
+ result = asyncio.run(dispatch(name="rate_my_professor", professor_name="Smith"))
117
+ self.assertIn("error", result)
118
+ self.assertIn("not enabled", result["error"])
119
+
120
+ def test_filtered_executor_with_empty_list_blocks_all(self):
121
+ dispatch = get_tool_executor(enabled=[])
122
+ result = asyncio.run(dispatch(name="search_courses", subject="CSCI"))
123
+ self.assertIn("error", result)
multi_llm_chatbot_backend/app/tests/unit/test_vllm_client.py CHANGED
@@ -1,15 +1,31 @@
1
  import asyncio
 
2
  import unittest
3
  from unittest.mock import AsyncMock, MagicMock, patch
4
 
5
  from openai import APIConnectionError, APIStatusError
6
 
 
7
  from app.llm.improved_vllm_client import ImprovedVllmClient
8
 
9
 
10
  FAKE_URL = "https://fake.example.com/vllm0"
11
  FAKE_KEY = "test-key"
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  def _make_completion_mock(content="Response"):
15
  """Build a mock that looks like an OpenAI ChatCompletion."""
@@ -20,6 +36,77 @@ def _make_completion_mock(content="Response"):
20
  return MagicMock(choices=[mock_choice])
21
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  @patch("app.llm.improved_vllm_client.get_context_manager")
24
  @patch("app.llm.improved_vllm_client.AsyncOpenAI")
25
  class TestImprovedVllmClient(unittest.TestCase):
@@ -214,3 +301,328 @@ class TestImprovedVllmClient(unittest.TestCase):
214
  self.assertNotIn("\r", cleaned)
215
  self.assertNotIn("\n\n\n", cleaned)
216
  self.assertEqual(cleaned, "Line one.\n\nLine two.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import asyncio
2
+ import json
3
  import unittest
4
  from unittest.mock import AsyncMock, MagicMock, patch
5
 
6
  from openai import APIConnectionError, APIStatusError
7
 
8
+ from app.llm.llm_client import ToolCallResult
9
  from app.llm.improved_vllm_client import ImprovedVllmClient
10
 
11
 
12
  FAKE_URL = "https://fake.example.com/vllm0"
13
  FAKE_KEY = "test-key"
14
 
15
+ FAKE_TOOL = {
16
+ "type": "function",
17
+ "function": {
18
+ "name": "search_courses",
19
+ "description": "Search courses",
20
+ "parameters": {
21
+ "type": "object",
22
+ "properties": {
23
+ "subject": {"type": "string", "description": "Subject code"},
24
+ },
25
+ },
26
+ },
27
+ }
28
+
29
 
30
  def _make_completion_mock(content="Response"):
31
  """Build a mock that looks like an OpenAI ChatCompletion."""
 
36
  return MagicMock(choices=[mock_choice])
37
 
38
 
39
+ def _make_text_completion_mock(content="Response"):
40
+ """Build a ChatCompletion mock with no tool calls."""
41
+ mock_message = MagicMock()
42
+ mock_message.content = content
43
+ mock_message.tool_calls = None
44
+ mock_choice = MagicMock()
45
+ mock_choice.message = mock_message
46
+ return MagicMock(choices=[mock_choice])
47
+
48
+
49
+ def _make_tool_call_mock(fn_name, fn_args_dict, tool_call_id="call_123"):
50
+ """Build a ChatCompletion mock where the model requests a tool call."""
51
+ fn_args_json = json.dumps(fn_args_dict)
52
+
53
+ tool_call = MagicMock()
54
+ tool_call.id = tool_call_id
55
+ tool_call.function.name = fn_name
56
+ tool_call.function.arguments = fn_args_json
57
+
58
+ mock_message = MagicMock()
59
+ mock_message.content = None
60
+ mock_message.tool_calls = [tool_call]
61
+ mock_message.model_dump.return_value = {
62
+ "role": "assistant",
63
+ "content": None,
64
+ "tool_calls": [{
65
+ "id": tool_call_id,
66
+ "type": "function",
67
+ "function": {"name": fn_name, "arguments": fn_args_json},
68
+ }],
69
+ }
70
+
71
+ mock_choice = MagicMock()
72
+ mock_choice.message = mock_message
73
+ return MagicMock(choices=[mock_choice])
74
+
75
+
76
+ def _make_multi_tool_call_mock(calls):
77
+ """Build a ChatCompletion mock with multiple parallel tool calls.
78
+
79
+ *calls* is a list of (fn_name, fn_args_dict, tool_call_id) tuples.
80
+ """
81
+ tool_calls = []
82
+ dump_calls = []
83
+ for fn_name, fn_args_dict, tool_call_id in calls:
84
+ fn_args_json = json.dumps(fn_args_dict)
85
+ tc = MagicMock()
86
+ tc.id = tool_call_id
87
+ tc.function.name = fn_name
88
+ tc.function.arguments = fn_args_json
89
+ tool_calls.append(tc)
90
+ dump_calls.append({
91
+ "id": tool_call_id,
92
+ "type": "function",
93
+ "function": {"name": fn_name, "arguments": fn_args_json},
94
+ })
95
+
96
+ mock_message = MagicMock()
97
+ mock_message.content = None
98
+ mock_message.tool_calls = tool_calls
99
+ mock_message.model_dump.return_value = {
100
+ "role": "assistant",
101
+ "content": None,
102
+ "tool_calls": dump_calls,
103
+ }
104
+
105
+ mock_choice = MagicMock()
106
+ mock_choice.message = mock_message
107
+ return MagicMock(choices=[mock_choice])
108
+
109
+
110
  @patch("app.llm.improved_vllm_client.get_context_manager")
111
  @patch("app.llm.improved_vllm_client.AsyncOpenAI")
112
  class TestImprovedVllmClient(unittest.TestCase):
 
301
  self.assertNotIn("\r", cleaned)
302
  self.assertNotIn("\n\n\n", cleaned)
303
  self.assertEqual(cleaned, "Line one.\n\nLine two.")
304
+
305
+
306
+ @patch("app.llm.improved_vllm_client.get_context_manager")
307
+ @patch("app.llm.improved_vllm_client.AsyncOpenAI")
308
+ class TestVllmGenerateWithTools(unittest.TestCase):
309
+ """Unit tests for ImprovedVllmClient.generate_with_tools()."""
310
+
311
+ # ------------------------------------------------------------------
312
+ # Happy path — no tool call
313
+ # ------------------------------------------------------------------
314
+
315
+ def test_text_response_returns_not_used(self, MockAsyncOpenAI, mock_get_ctx):
316
+ """When the model responds with plain text, return used_tool=False."""
317
+ client = ImprovedVllmClient(
318
+ api_url=FAKE_URL, api_key=FAKE_KEY, model_name="test-model",
319
+ )
320
+ client.client.chat.completions.create = AsyncMock(
321
+ return_value=_make_text_completion_mock("Hello, world!"),
322
+ )
323
+
324
+ result = asyncio.run(client.generate_with_tools(
325
+ system_prompt="You are helpful.",
326
+ user_message="Hi there",
327
+ tool_definitions=[FAKE_TOOL],
328
+ tool_executor=AsyncMock(),
329
+ ))
330
+
331
+ self.assertIsInstance(result, ToolCallResult)
332
+ self.assertEqual(result.text, "Hello, world!")
333
+ self.assertFalse(result.used_tool)
334
+
335
+ # ------------------------------------------------------------------
336
+ # Happy path — tool call
337
+ # ------------------------------------------------------------------
338
+
339
+ def test_tool_call_executes_and_returns_final_text(self, MockAsyncOpenAI, mock_get_ctx):
340
+ """When the model requests a tool call, execute it and return
341
+ the text from the follow-up completion."""
342
+ client = ImprovedVllmClient(
343
+ api_url=FAKE_URL, api_key=FAKE_KEY, model_name="test-model",
344
+ )
345
+ client.client.chat.completions.create = AsyncMock(side_effect=[
346
+ _make_tool_call_mock("search_courses", {"subject": "CSCI"}),
347
+ _make_text_completion_mock("CSCI 1300 is available MWF 10-10:50."),
348
+ ])
349
+ mock_executor = AsyncMock(
350
+ return_value={"courses": [{"title": "Intro to CS"}]},
351
+ )
352
+
353
+ result = asyncio.run(client.generate_with_tools(
354
+ system_prompt="You are helpful.",
355
+ user_message="What CSCI classes are there?",
356
+ tool_definitions=[FAKE_TOOL],
357
+ tool_executor=mock_executor,
358
+ ))
359
+
360
+ mock_executor.assert_called_once_with(
361
+ name="search_courses", subject="CSCI",
362
+ )
363
+ self.assertIsInstance(result, ToolCallResult)
364
+ self.assertEqual(result.text, "CSCI 1300 is available MWF 10-10:50.")
365
+ self.assertTrue(result.used_tool)
366
+ self.assertEqual(result.tool_name, "search_courses")
367
+ self.assertEqual(result.tool_args, {"subject": "CSCI"})
368
+ self.assertEqual(len(result.tool_calls_made), 1)
369
+ self.assertEqual(result.tool_calls_made[0].name, "search_courses")
370
+ self.assertEqual(client.client.chat.completions.create.call_count, 2)
371
+
372
+ # ------------------------------------------------------------------
373
+ # Payload format
374
+ # ------------------------------------------------------------------
375
+
376
+ def test_tool_definitions_passed_through_in_openai_format(self, MockAsyncOpenAI, mock_get_ctx):
377
+ """Tool definitions (already in OpenAI format) are passed through
378
+ directly to the completions API."""
379
+ client = ImprovedVllmClient(
380
+ api_url=FAKE_URL, api_key=FAKE_KEY, model_name="test-model",
381
+ )
382
+ client.client.chat.completions.create = AsyncMock(
383
+ return_value=_make_text_completion_mock("Ok"),
384
+ )
385
+
386
+ asyncio.run(client.generate_with_tools(
387
+ system_prompt="You are helpful.",
388
+ user_message="Hello",
389
+ tool_definitions=[FAKE_TOOL],
390
+ tool_executor=AsyncMock(),
391
+ ))
392
+
393
+ call_kwargs = client.client.chat.completions.create.call_args[1]
394
+ tools = call_kwargs["tools"]
395
+ self.assertEqual(len(tools), 1)
396
+ self.assertEqual(tools[0]["type"], "function")
397
+ self.assertEqual(tools[0]["function"]["name"], "search_courses")
398
+ self.assertIn("parameters", tools[0]["function"])
399
+
400
+ def test_tool_result_appended_to_followup(self, MockAsyncOpenAI, mock_get_ctx):
401
+ """After executing a tool, the follow-up call must include
402
+ the assistant message, a ``role: tool`` message, and ``tools=``."""
403
+ tool_output = {"courses": [{"title": "Algorithms"}]}
404
+ client = ImprovedVllmClient(
405
+ api_url=FAKE_URL, api_key=FAKE_KEY, model_name="test-model",
406
+ )
407
+ client.client.chat.completions.create = AsyncMock(side_effect=[
408
+ _make_tool_call_mock("search_courses", {"subject": "CSCI"}),
409
+ _make_text_completion_mock("Here are the results."),
410
+ ])
411
+ mock_executor = AsyncMock(return_value=tool_output)
412
+
413
+ asyncio.run(client.generate_with_tools(
414
+ system_prompt="You are helpful.",
415
+ user_message="Find CSCI courses",
416
+ tool_definitions=[FAKE_TOOL],
417
+ tool_executor=mock_executor,
418
+ ))
419
+
420
+ second_call_kwargs = client.client.chat.completions.create.call_args_list[1][1]
421
+ messages = second_call_kwargs["messages"]
422
+
423
+ assistant_msg = messages[-2]
424
+ self.assertEqual(assistant_msg["role"], "assistant")
425
+
426
+ tool_msg = messages[-1]
427
+ self.assertEqual(tool_msg["role"], "tool")
428
+ self.assertEqual(tool_msg["tool_call_id"], "call_123")
429
+ self.assertEqual(json.loads(tool_msg["content"]), tool_output)
430
+
431
+ self.assertIn("tools", second_call_kwargs,
432
+ "Follow-up call must include tools= so the model can "
433
+ "request additional tool calls if needed")
434
+
435
+ # ------------------------------------------------------------------
436
+ # Error handling
437
+ # ------------------------------------------------------------------
438
+
439
+ def test_tool_executor_failure_serialises_error_and_continues(self, MockAsyncOpenAI, mock_get_ctx):
440
+ """If the tool executor raises, the error is serialised as the
441
+ tool result and the loop continues to the follow-up completion."""
442
+ client = ImprovedVllmClient(
443
+ api_url=FAKE_URL, api_key=FAKE_KEY, model_name="test-model",
444
+ )
445
+ client.client.chat.completions.create = AsyncMock(side_effect=[
446
+ _make_tool_call_mock("search_courses", {"subject": "CSCI"}),
447
+ _make_text_completion_mock("Sorry, I couldn't look that up."),
448
+ ])
449
+ mock_executor = AsyncMock(side_effect=RuntimeError("network down"))
450
+
451
+ result = asyncio.run(client.generate_with_tools(
452
+ system_prompt="You are helpful.",
453
+ user_message="Find CSCI courses",
454
+ tool_definitions=[FAKE_TOOL],
455
+ tool_executor=mock_executor,
456
+ ))
457
+
458
+ self.assertTrue(result.used_tool)
459
+ self.assertEqual(result.tool_name, "search_courses")
460
+ self.assertEqual(len(result.tool_calls_made), 1)
461
+
462
+ second_call_msgs = client.client.chat.completions.create.call_args_list[1][1]["messages"]
463
+ tool_msg = [m for m in second_call_msgs if m.get("role") == "tool"][0]
464
+ self.assertIn("network down", json.loads(tool_msg["content"])["error"])
465
+
466
+ def test_connection_error_returns_not_used(self, MockAsyncOpenAI, mock_get_ctx):
467
+ """APIConnectionError during tool calling returns used_tool=False."""
468
+ client = ImprovedVllmClient(
469
+ api_url=FAKE_URL, api_key=FAKE_KEY, model_name="test-model",
470
+ )
471
+ client.client.chat.completions.create = AsyncMock(
472
+ side_effect=APIConnectionError(request=MagicMock()),
473
+ )
474
+
475
+ result = asyncio.run(client.generate_with_tools(
476
+ system_prompt="Test",
477
+ user_message="Hi",
478
+ tool_definitions=[FAKE_TOOL],
479
+ tool_executor=AsyncMock(),
480
+ ))
481
+
482
+ self.assertIsInstance(result, ToolCallResult)
483
+ self.assertFalse(result.used_tool)
484
+ self.assertIn("unable to connect", result.text.lower())
485
+
486
+ # ------------------------------------------------------------------
487
+ # Multi-tool call in a single response
488
+ # ------------------------------------------------------------------
489
+
490
+ def test_parallel_tool_calls_all_executed(self, MockAsyncOpenAI, mock_get_ctx):
491
+ """When the model requests multiple tool calls in one response,
492
+ all of them are executed and their results fed back."""
493
+ client = ImprovedVllmClient(
494
+ api_url=FAKE_URL, api_key=FAKE_KEY, model_name="test-model",
495
+ )
496
+ client.client.chat.completions.create = AsyncMock(side_effect=[
497
+ _make_multi_tool_call_mock([
498
+ ("rate_my_professor", {"professor_name": "Dubson"}, "call_a"),
499
+ ("rate_my_professor", {"professor_name": "West"}, "call_b"),
500
+ ]),
501
+ _make_text_completion_mock("Dubson has a 4.5 rating. West has a 3.8 rating."),
502
+ ])
503
+ mock_executor = AsyncMock(
504
+ side_effect=[
505
+ {"professors": [{"name": "Dubson", "rating": 4.5}]},
506
+ {"professors": [{"name": "West", "rating": 3.8}]},
507
+ ],
508
+ )
509
+
510
+ result = asyncio.run(client.generate_with_tools(
511
+ system_prompt="You are helpful.",
512
+ user_message="Is professor Dubson or West rated better?",
513
+ tool_definitions=[FAKE_TOOL],
514
+ tool_executor=mock_executor,
515
+ ))
516
+
517
+ self.assertEqual(mock_executor.call_count, 2)
518
+ self.assertTrue(result.used_tool)
519
+ self.assertIn("Dubson", result.text)
520
+ self.assertIn("West", result.text)
521
+ self.assertEqual(len(result.tool_calls_made), 2)
522
+ self.assertEqual(result.tool_calls_made[0].name, "rate_my_professor")
523
+ self.assertEqual(result.tool_calls_made[1].args, {"professor_name": "West"})
524
+ self.assertEqual(client.client.chat.completions.create.call_count, 2)
525
+
526
+ def test_parallel_tool_results_all_in_followup_messages(self, MockAsyncOpenAI, mock_get_ctx):
527
+ """All tool results must appear as separate role:tool messages
528
+ in the follow-up request."""
529
+ client = ImprovedVllmClient(
530
+ api_url=FAKE_URL, api_key=FAKE_KEY, model_name="test-model",
531
+ )
532
+ client.client.chat.completions.create = AsyncMock(side_effect=[
533
+ _make_multi_tool_call_mock([
534
+ ("rate_my_professor", {"professor_name": "Dubson"}, "call_a"),
535
+ ("rate_my_professor", {"professor_name": "West"}, "call_b"),
536
+ ]),
537
+ _make_text_completion_mock("Comparison complete."),
538
+ ])
539
+ mock_executor = AsyncMock(side_effect=[
540
+ {"professors": [{"name": "Dubson"}]},
541
+ {"professors": [{"name": "West"}]},
542
+ ])
543
+
544
+ asyncio.run(client.generate_with_tools(
545
+ system_prompt="You are helpful.",
546
+ user_message="Compare",
547
+ tool_definitions=[FAKE_TOOL],
548
+ tool_executor=mock_executor,
549
+ ))
550
+
551
+ second_call_msgs = client.client.chat.completions.create.call_args_list[1][1]["messages"]
552
+ tool_msgs = [m for m in second_call_msgs if m.get("role") == "tool"]
553
+ self.assertEqual(len(tool_msgs), 2)
554
+ self.assertEqual(tool_msgs[0]["tool_call_id"], "call_a")
555
+ self.assertEqual(tool_msgs[1]["tool_call_id"], "call_b")
556
+
557
+ # ------------------------------------------------------------------
558
+ # Multi-round tool calling
559
+ # ------------------------------------------------------------------
560
+
561
+ def test_sequential_tool_rounds(self, MockAsyncOpenAI, mock_get_ctx):
562
+ """The loop handles a second round of tool calls after the first
563
+ results are fed back."""
564
+ client = ImprovedVllmClient(
565
+ api_url=FAKE_URL, api_key=FAKE_KEY, model_name="test-model",
566
+ )
567
+ client.client.chat.completions.create = AsyncMock(side_effect=[
568
+ _make_tool_call_mock("rate_my_professor", {"professor_name": "Dubson"}, "call_1"),
569
+ _make_tool_call_mock("rate_my_professor", {"professor_name": "West"}, "call_2"),
570
+ _make_text_completion_mock("Dubson is rated higher than West."),
571
+ ])
572
+ mock_executor = AsyncMock(side_effect=[
573
+ {"professors": [{"name": "Dubson", "rating": 4.5}]},
574
+ {"professors": [{"name": "West", "rating": 3.8}]},
575
+ ])
576
+
577
+ result = asyncio.run(client.generate_with_tools(
578
+ system_prompt="You are helpful.",
579
+ user_message="Compare Dubson and West",
580
+ tool_definitions=[FAKE_TOOL],
581
+ tool_executor=mock_executor,
582
+ ))
583
+
584
+ self.assertEqual(mock_executor.call_count, 2)
585
+ self.assertTrue(result.used_tool)
586
+ self.assertEqual(len(result.tool_calls_made), 2)
587
+ self.assertEqual(result.tool_name, "rate_my_professor")
588
+ self.assertEqual(client.client.chat.completions.create.call_count, 3)
589
+
590
+ # ------------------------------------------------------------------
591
+ # Tool executor failure in multi-tool context
592
+ # ------------------------------------------------------------------
593
+
594
+ def test_partial_tool_failure_continues(self, MockAsyncOpenAI, mock_get_ctx):
595
+ """If one tool call in a batch fails, the error is serialised
596
+ and the loop continues to the follow-up."""
597
+ client = ImprovedVllmClient(
598
+ api_url=FAKE_URL, api_key=FAKE_KEY, model_name="test-model",
599
+ )
600
+ client.client.chat.completions.create = AsyncMock(side_effect=[
601
+ _make_multi_tool_call_mock([
602
+ ("rate_my_professor", {"professor_name": "Dubson"}, "call_a"),
603
+ ("rate_my_professor", {"professor_name": "West"}, "call_b"),
604
+ ]),
605
+ _make_text_completion_mock("Only Dubson data available."),
606
+ ])
607
+ mock_executor = AsyncMock(side_effect=[
608
+ {"professors": [{"name": "Dubson", "rating": 4.5}]},
609
+ RuntimeError("network down"),
610
+ ])
611
+
612
+ result = asyncio.run(client.generate_with_tools(
613
+ system_prompt="You are helpful.",
614
+ user_message="Compare",
615
+ tool_definitions=[FAKE_TOOL],
616
+ tool_executor=mock_executor,
617
+ ))
618
+
619
+ self.assertTrue(result.used_tool)
620
+ self.assertEqual(len(result.tool_calls_made), 2)
621
+ self.assertEqual(client.client.chat.completions.create.call_count, 2)
622
+
623
+ second_call_msgs = client.client.chat.completions.create.call_args_list[1][1]["messages"]
624
+ tool_msgs = [m for m in second_call_msgs if m.get("role") == "tool"]
625
+ self.assertEqual(len(tool_msgs), 2)
626
+ error_content = json.loads(tool_msgs[1]["content"])
627
+ self.assertIn("error", error_content)
628
+
multi_llm_chatbot_backend/app/tools/__init__.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tool registry — auto-discovers tool modules in app.tools and provides
3
+ a central API for retrieving definitions and dispatching calls.
4
+
5
+ Every tool module in this package must export:
6
+ TOOL_DEFINITION : Dict[str, Any] — OpenAI tool format
7
+ {"type": "function", "function": {"name": ..., ...}}
8
+ execute : async (**kwargs) — returns Dict[str, Any]
9
+
10
+ Modules that don't export both are silently skipped.
11
+
12
+ Filtering semantics for the ``enabled`` parameter:
13
+ None — no filter; all registered tools are available (default)
14
+ [] — explicit empty list; no tools are available
15
+ [ids] — only the named tools are available
16
+ """
17
+
18
+ import importlib
19
+ import inspect
20
+ import logging
21
+ import pkgutil
22
+ from typing import Any, Callable, Dict, List, Optional
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ # Shared User-Agent for HTTP clients in tool modules (FOSE, RMP, etc.).
27
+ BROWSER_UA = (
28
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
29
+ "AppleWebKit/537.36 (KHTML, like Gecko) "
30
+ "Chrome/131.0.0.0 Safari/537.36"
31
+ )
32
+
33
+ _REGISTRY: Dict[str, Dict[str, Any]] = {}
34
+
35
+
36
+ def _discover_tools() -> None:
37
+ """Scan sibling modules in app.tools and register any that export
38
+ TOOL_DEFINITION and execute."""
39
+ import app.tools as tools_pkg
40
+
41
+ for _finder, module_name, _is_pkg in pkgutil.iter_modules(tools_pkg.__path__):
42
+ qualified = f"app.tools.{module_name}"
43
+ try:
44
+ mod = importlib.import_module(qualified)
45
+ except Exception:
46
+ logger.warning("Failed to import tool module: %s", qualified, exc_info=True)
47
+ continue
48
+
49
+ defn = getattr(mod, "TOOL_DEFINITION", None)
50
+ executor = getattr(mod, "execute", None)
51
+
52
+ if defn is None or executor is None:
53
+ continue
54
+
55
+ if (not isinstance(defn, dict)
56
+ or defn.get("type") != "function"
57
+ or not isinstance(defn.get("function"), dict)
58
+ or "name" not in defn["function"]):
59
+ logger.warning("Skipping %s: TOOL_DEFINITION not in OpenAI tool format", qualified)
60
+ continue
61
+
62
+ if not callable(executor) or not inspect.iscoroutinefunction(executor):
63
+ logger.warning("Skipping %s: execute is not an async callable", qualified)
64
+ continue
65
+
66
+ tool_name = defn["function"]["name"]
67
+ if tool_name in _REGISTRY:
68
+ logger.warning(
69
+ "Duplicate tool name '%s' from %s — skipping", tool_name, qualified,
70
+ )
71
+ continue
72
+
73
+ _REGISTRY[tool_name] = {"definition": defn, "executor": executor}
74
+ logger.info("Registered tool: %s (from %s)", tool_name, qualified)
75
+
76
+
77
+ def get_tool_definitions(enabled: Optional[List[str]] = None) -> List[Dict[str, Any]]:
78
+ """Return OpenAI-format tool dicts for registered tools.
79
+
80
+ If *enabled* is provided, only return tools whose names are in that list.
81
+ If None, return all registered tools.
82
+ """
83
+ if enabled is None:
84
+ return [entry["definition"] for entry in _REGISTRY.values()]
85
+
86
+ return [
87
+ _REGISTRY[name]["definition"]
88
+ for name in enabled
89
+ if name in _REGISTRY
90
+ ]
91
+
92
+
93
+ def get_tool_executor(enabled: Optional[List[str]] = None) -> Callable:
94
+ """Return a dispatcher compatible with generate_with_tools(tool_executor=...).
95
+
96
+ The returned async callable accepts (name, **kwargs) and routes to the
97
+ correct tool executor. If *enabled* is provided, only those tools are
98
+ dispatchable.
99
+ """
100
+ if enabled is not None:
101
+ allowed = {name for name in enabled if name in _REGISTRY}
102
+ else:
103
+ allowed = None
104
+
105
+ async def dispatch(name: str, **kwargs: Any) -> Dict[str, Any]:
106
+ if allowed is not None and name not in allowed:
107
+ logger.warning("Tool '%s' is not enabled", name)
108
+ return {"error": f"Tool not enabled: {name}"}
109
+
110
+ entry = _REGISTRY.get(name)
111
+ if entry is None:
112
+ logger.warning("Unknown tool requested: %s", name)
113
+ return {"error": f"Unknown tool: {name}"}
114
+
115
+ return await entry["executor"](name=name, **kwargs)
116
+
117
+ return dispatch
118
+
119
+
120
+ def list_registered_tools() -> List[str]:
121
+ """Return the names of all discovered tools."""
122
+ return list(_REGISTRY.keys())
123
+
124
+
125
+ _discover_tools()
multi_llm_chatbot_backend/app/tools/rate_my_professor.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ rate_my_professor tool — live query against RateMyProfessors' GraphQL API.
3
+
4
+ Exposes TOOL_DEFINITION (OpenAI tool format) and an execute() coroutine
5
+ that the tool-calling loop dispatches to.
6
+
7
+ Requires ``school_id`` in the tool config (see phd_config.yaml).
8
+ Use ``scripts/rmp_school_lookup.py`` to find the ID for a given school.
9
+ """
10
+
11
+ import logging
12
+ import re
13
+ from typing import Any, Dict, List
14
+ import httpx
15
+ from app.tools import BROWSER_UA
16
+ from app.config import get_settings
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ RMP_GRAPHQL_URL = "https://www.ratemyprofessors.com/graphql"
21
+ RMP_LANDING_URL = "https://www.ratemyprofessors.com/"
22
+ RMP_SEARCH_URL = "https://www.ratemyprofessors.com/search/professors/1087"
23
+
24
+
25
+ TEACHER_SEARCH_QUERY = """
26
+ query TeacherSearchPaginationQuery(
27
+ $count: Int!
28
+ $cursor: String
29
+ $query: TeacherSearchQuery!
30
+ ) {
31
+ search: newSearch {
32
+ teachers(query: $query, first: $count, after: $cursor) {
33
+ didFallback
34
+ edges {
35
+ cursor
36
+ node {
37
+ id
38
+ legacyId
39
+ firstName
40
+ lastName
41
+ department
42
+ school { id name }
43
+ avgRating
44
+ avgDifficulty
45
+ wouldTakeAgainPercent
46
+ numRatings
47
+ }
48
+ }
49
+ pageInfo {
50
+ hasNextPage
51
+ endCursor
52
+ }
53
+ }
54
+ }
55
+ }
56
+ """
57
+
58
+ TOOL_DEFINITION: Dict[str, Any] = {
59
+ "type": "function",
60
+ "function": {
61
+ "name": "rate_my_professor",
62
+ "description": (
63
+ "Look up RateMyProfessors ratings for a CU Boulder professor. "
64
+ "Returns rating, difficulty, percentage of students who would "
65
+ "take the professor again, and number of ratings."
66
+ ),
67
+ "parameters": {
68
+ "type": "object",
69
+ "properties": {
70
+ "professor_name": {
71
+ "type": "string",
72
+ "description": (
73
+ "Full or partial name of the professor to search for, "
74
+ "e.g. 'Hoenigman', 'Jane Smith'."
75
+ ),
76
+ },
77
+ },
78
+ "required": ["professor_name"],
79
+ },
80
+ },
81
+ }
82
+
83
+
84
+ def _node_to_professor(node: Dict[str, Any]) -> Dict[str, Any]:
85
+ """Convert a GraphQL teacher node to a lightweight result dict."""
86
+ return {
87
+ "name": f"{node.get('firstName', '')} {node.get('lastName', '')}".strip(),
88
+ "department": node.get("department", ""),
89
+ "rating": node.get("avgRating", 0),
90
+ "difficulty": node.get("avgDifficulty", 0),
91
+ "would_take_again_pct": node.get("wouldTakeAgainPercent", -1),
92
+ "num_ratings": node.get("numRatings", 0),
93
+ "rmp_id": node.get("id", ""),
94
+ }
95
+
96
+
97
+ async def _extract_auth_token(client: httpx.AsyncClient) -> str:
98
+ """Fetch the RMP landing page and extract the auth token from the JS bundle.
99
+
100
+ Falls back to the well-known Basic test:test token.
101
+ """
102
+ try:
103
+ resp = await client.get(
104
+ RMP_LANDING_URL, headers={"User-Agent": BROWSER_UA},
105
+ )
106
+ m = re.search(
107
+ r'"Authorization"\s*:\s*"(Basic\s+[A-Za-z0-9+/=]+)"', resp.text,
108
+ )
109
+ if m:
110
+ logger.info("Extracted RMP auth token from page JS")
111
+ return m.group(1)
112
+ except Exception as exc:
113
+ logger.debug("RMP auth token extraction failed: %s", exc)
114
+
115
+ return "Basic dGVzdDp0ZXN0"
116
+
117
+
118
+ async def execute(
119
+ *,
120
+ name: str = "",
121
+ professor_name: str,
122
+ ) -> Dict[str, Any]:
123
+ """Query RateMyProfessors for a CU Boulder professor by name.
124
+
125
+ The 'name' kwarg is passed by the dispatch loop and ignored here.
126
+ Returns {"professors": [...], "query": {...}}.
127
+ """
128
+ tool_cfg = get_settings().tools.get_tool_config("rate_my_professor")
129
+ school_id = tool_cfg.get("school_id")
130
+ if not school_id:
131
+ logger.error("No school_id configured for rate_my_professor")
132
+ return {
133
+ "professors": [],
134
+ "error": "No school_id configured for rate_my_professor",
135
+ "query": {"professor_name": professor_name},
136
+ }
137
+
138
+ professors: List[Dict[str, Any]] = []
139
+
140
+ try:
141
+ async with httpx.AsyncClient(timeout=30, follow_redirects=True) as client:
142
+ auth_token = await _extract_auth_token(client)
143
+
144
+ headers = {
145
+ "User-Agent": BROWSER_UA,
146
+ "Authorization": auth_token,
147
+ "Content-Type": "application/json",
148
+ "Referer": f"{RMP_SEARCH_URL}?q={professor_name}",
149
+ "Origin": "https://www.ratemyprofessors.com",
150
+ }
151
+
152
+ variables = {
153
+ "count": 20,
154
+ "cursor": "",
155
+ "query": {
156
+ "text": professor_name,
157
+ "schoolID": school_id,
158
+ "fallback": True,
159
+ "departmentID": None,
160
+ },
161
+ }
162
+
163
+ resp = await client.post(
164
+ RMP_GRAPHQL_URL,
165
+ json={"query": TEACHER_SEARCH_QUERY, "variables": variables},
166
+ headers=headers,
167
+ )
168
+
169
+ if resp.status_code == 403:
170
+ logger.warning("RMP GraphQL returned 403 — auth may be invalid")
171
+ return {
172
+ "professors": [],
173
+ "error": "RateMyProfessors authentication failed",
174
+ "query": {"professor_name": professor_name},
175
+ }
176
+
177
+ resp.raise_for_status()
178
+ data = resp.json()
179
+
180
+ teachers = (
181
+ data.get("data", {})
182
+ .get("search", {})
183
+ .get("teachers", {})
184
+ )
185
+
186
+ for edge in teachers.get("edges", []):
187
+ node = edge.get("node", {})
188
+ if node:
189
+ professors.append(_node_to_professor(node))
190
+
191
+ except Exception as exc:
192
+ logger.error("RMP API error for %s: %s", professor_name, exc)
193
+ return {
194
+ "professors": [],
195
+ "error": str(exc),
196
+ "query": {"professor_name": professor_name},
197
+ }
198
+
199
+ return {
200
+ "professors": professors,
201
+ "query": {"professor_name": professor_name},
202
+ }
multi_llm_chatbot_backend/app/tools/search_courses.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ search_courses tool — live query against CU Boulder's FOSE class-search API.
3
+
4
+ Exposes TOOL_DEFINITION (OpenAI tool format) and an execute() coroutine
5
+ that the tool-calling loop dispatches to.
6
+ """
7
+
8
+ import logging
9
+ import re
10
+ from typing import Any, Dict, List, Optional
11
+ import httpx
12
+ from app.tools import BROWSER_UA
13
+ from app.config import get_settings
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ FOSE_SEARCH_URL = "https://classes.colorado.edu/api/?page=fose&route=search"
18
+ CLASSES_BASE_URL = "https://classes.colorado.edu"
19
+
20
+ TOOL_DEFINITION: Dict[str, Any] = {
21
+ "type": "function",
22
+ "function": {
23
+ "name": "search_courses",
24
+ "description": (
25
+ "Search the CU Boulder course catalog for classes in a given "
26
+ "subject, optionally filtered by course number and semester. "
27
+ "Returns a list of matching sections with title, instructor, "
28
+ "schedule, and location."
29
+ ),
30
+ "parameters": {
31
+ "type": "object",
32
+ "properties": {
33
+ "subject": {
34
+ "type": "string",
35
+ "description": (
36
+ "Department / subject code, e.g. 'CSCI', 'MATH', 'PHYS'."
37
+ ),
38
+ },
39
+ "course_number": {
40
+ "type": "string",
41
+ "description": (
42
+ "Catalog number to filter on, e.g. '1300'. "
43
+ "Omit to return all courses in the subject."
44
+ ),
45
+ },
46
+ "semester": {
47
+ "type": "string",
48
+ "description": (
49
+ "Semester name, e.g. 'Spring 2026', 'Fall 2025'. "
50
+ "Defaults to 'Spring 2026' if not provided."
51
+ ),
52
+ },
53
+ },
54
+ "required": ["subject"],
55
+ },
56
+ },
57
+ }
58
+
59
+ def _term_to_srcdb(term: str) -> str:
60
+ """Convert 'Spring 2026' to '2261', 'Fall 2025' to '2257', etc.
61
+
62
+ CU Boulder's FOSE API uses a 4-digit code: literal '2', the last
63
+ two digits of the year, and a season digit (1=Spring, 4=Summer, 7=Fall).
64
+ """
65
+ term_lower = term.lower()
66
+ ym = re.search(r"20(\d{2})", term)
67
+ yy = ym.group(1) if ym else "26"
68
+ if "spring" in term_lower:
69
+ return f"2{yy}1"
70
+ if "summer" in term_lower:
71
+ return f"2{yy}4"
72
+ if "fall" in term_lower:
73
+ return f"2{yy}7"
74
+ return f"2{yy}1"
75
+
76
+
77
+ def _parse_schedule(meets: str) -> Dict[str, str]:
78
+ """Parse 'MWF 10:00am-10:50am' into structured fields."""
79
+ if not meets:
80
+ return {"days": "", "start_time": "", "end_time": "", "raw": ""}
81
+
82
+ day_match = re.match(r"([A-Za-z]+)", meets)
83
+ days = day_match.group(1) if day_match else ""
84
+
85
+ time_match = re.search(
86
+ r"(\d{1,2}:\d{2}\s*[ap]m)\s*-\s*(\d{1,2}:\d{2}\s*[ap]m)", meets, re.I
87
+ )
88
+ start = time_match.group(1).strip() if time_match else ""
89
+ end = time_match.group(2).strip() if time_match else ""
90
+
91
+ return {"days": days, "start_time": start, "end_time": end, "raw": meets}
92
+
93
+
94
+ def _row_to_course(item: Dict[str, Any], term: str) -> Optional[Dict[str, Any]]:
95
+ """Convert a FOSE result row to a lightweight course dict.
96
+ Returns None for rows that should be skipped (recitations, cancelled sections, etc.).
97
+ """
98
+ schd = item.get("schd", "")
99
+ if schd and schd not in ("LEC", "SEM", ""):
100
+ return None
101
+ if item.get("isCancelled"):
102
+ return None
103
+
104
+ code = item.get("code", "").strip()
105
+ if not code:
106
+ code = (
107
+ f"{item.get('subject', '')} "
108
+ f"{item.get('catalog_nbr', item.get('catalogNbr', ''))}"
109
+ ).strip()
110
+
111
+ return {
112
+ "course_code": code,
113
+ "title": item.get("title", ""),
114
+ "section": item.get("no", "") or item.get("section", ""),
115
+ "instructor": item.get("instr", "") or item.get("instructor", "Staff"),
116
+ "schedule": _parse_schedule(item.get("meets", "") or ""),
117
+ "location": item.get("bldg", item.get("location", "")),
118
+ "semester": term,
119
+ }
120
+
121
+
122
+
123
+ async def execute(
124
+ *,
125
+ name: str = "",
126
+ subject: str,
127
+ course_number: str = "",
128
+ semester: str = "Spring 2026",
129
+ ) -> Dict[str, Any]:
130
+ """Query the CU Boulder FOSE API and return matching courses.
131
+
132
+ The 'name' kwarg is passed by the dispatch loop and ignored here.
133
+ Returns {"courses": [...], "query": {...}}.
134
+ """
135
+ srcdb = _term_to_srcdb(semester)
136
+ subject = subject.upper().strip()
137
+
138
+ payload = {
139
+ "other": {"srcdb": srcdb},
140
+ "criteria": [{"field": "subject", "value": subject}],
141
+ }
142
+ headers = {
143
+ "User-Agent": BROWSER_UA,
144
+ "Content-Type": "application/json",
145
+ "Referer": CLASSES_BASE_URL,
146
+ "Origin": CLASSES_BASE_URL,
147
+ }
148
+
149
+ courses: List[Dict[str, Any]] = []
150
+
151
+ try:
152
+ async with httpx.AsyncClient(timeout=30, follow_redirects=True) as client:
153
+ resp = await client.post(
154
+ FOSE_SEARCH_URL, json=payload, headers=headers,
155
+ )
156
+ if resp.status_code != 200:
157
+ logger.warning("FOSE API returned %s for %s", resp.status_code, subject)
158
+ return {"courses": [], "query": {"subject": subject, "semester": semester}}
159
+
160
+ body = resp.json()
161
+ results = body.get("results", body.get("data", []))
162
+
163
+ for item in results:
164
+ row = _row_to_course(item, semester)
165
+ if row:
166
+ courses.append(row)
167
+
168
+ except Exception as exc:
169
+ logger.error("FOSE API error for %s: %s", subject, exc)
170
+ return {"courses": [], "error": str(exc), "query": {"subject": subject, "semester": semester}}
171
+
172
+ if course_number:
173
+ cn = course_number.strip()
174
+ courses = [c for c in courses if cn in c["course_code"]]
175
+
176
+ max_results = get_settings().tools.get_tool_config("search_courses").get("max_results", 20)
177
+
178
+ total = len(courses)
179
+ truncated = total > max_results
180
+ courses = courses[:max_results]
181
+
182
+ return {
183
+ "courses": courses,
184
+ "total_results": total,
185
+ "truncated": truncated,
186
+ "query": {
187
+ "subject": subject,
188
+ "course_number": course_number or None,
189
+ "semester": semester,
190
+ },
191
+ }
multi_llm_chatbot_backend/requirements.txt CHANGED
@@ -1,46 +1,43 @@
1
  # Core FastAPI framework
2
- fastapi
3
- uvicorn[standard]
4
- python-multipart
5
 
6
  # HTTP client for LLM APIs
7
- httpx
8
  openai~=2.30
9
 
10
  # Document processing
11
- PyPDF2
12
- docx2txt
13
- python-docx
14
 
15
  # Environment configuration
16
- python-dotenv
17
  pyyaml~=6.0
18
 
19
  # Persona color generation
20
  colorhash~=2.3
21
 
22
  # Vector database and embeddings
23
- chromadb
24
- sentence-transformers
25
 
26
  # Natural language processing
27
- nltk
28
- tiktoken
29
 
30
  # PDF generation and export
31
- reportlab
32
 
33
  # Database (MongoDB)
34
- pymongo
35
- motor
36
 
37
  # Authentication and security
38
- passlib[bcrypt]
39
- # `bcrypt` pinned for compat.
40
- bcrypt~=4.0
41
- python-jose[cryptography]
42
 
43
  # Data validation (required for EmailStr in Pydantic)
44
- email-validator
45
- python-docx
46
- python-multipart
 
1
  # Core FastAPI framework
2
+ fastapi~=0.135
3
+ uvicorn[standard]~=0.44
4
+ python-multipart~=0.0
5
 
6
  # HTTP client for LLM APIs
7
+ httpx~=0.28
8
  openai~=2.30
9
 
10
  # Document processing
11
+ PyPDF2~=3.0
12
+ docx2txt~=0.9
13
+ python-docx~=1.2
14
 
15
  # Environment configuration
16
+ python-dotenv~=1.2
17
  pyyaml~=6.0
18
 
19
  # Persona color generation
20
  colorhash~=2.3
21
 
22
  # Vector database and embeddings
23
+ chromadb~=1.5
24
+ sentence-transformers~=5.3
25
 
26
  # Natural language processing
27
+ nltk~=3.9
28
+ tiktoken~=0.12
29
 
30
  # PDF generation and export
31
+ reportlab~=4.4
32
 
33
  # Database (MongoDB)
34
+ pymongo~=4.16
35
+ motor~=3.7
36
 
37
  # Authentication and security
38
+ passlib[bcrypt]~=1.7
39
+ bcrypt~=4.3
40
+ python-jose[cryptography]~=3.5
 
41
 
42
  # Data validation (required for EmailStr in Pydantic)
43
+ email-validator~=2.3
 
 
phd-advisor-frontend/src/contexts/AppConfigContext.js CHANGED
@@ -40,7 +40,11 @@ const buildAdvisors = (personaItems) => {
40
  */
41
  const buildGetAdvisorColors = (advisors) => (advisorId, isDark = false) => {
42
  const advisor = advisors[advisorId];
43
- if (!advisor) return { color: '#6B7280', bgColor: '#F3F4F6' };
 
 
 
 
44
  return {
45
  color: isDark ? advisor.darkColor : advisor.color,
46
  bgColor: isDark ? advisor.darkBgColor : advisor.bgColor,
 
40
  */
41
  const buildGetAdvisorColors = (advisors) => (advisorId, isDark = false) => {
42
  const advisor = advisors[advisorId];
43
+ if (!advisor) {
44
+ return isDark
45
+ ? { color: '#9CA3AF', bgColor: '#374151', textColor: '#F9FAFB' }
46
+ : { color: '#6B7280', bgColor: '#F3F4F6', textColor: '#111827' };
47
+ }
48
  return {
49
  color: isDark ? advisor.darkColor : advisor.color,
50
  bgColor: isDark ? advisor.darkBgColor : advisor.bgColor,
phd_config.yaml CHANGED
@@ -150,3 +150,13 @@ llm:
150
  rag:
151
  embedding_model: "all-MiniLM-L6-v2"
152
  chroma_collection: "phd_advisor_documents"
 
 
 
 
 
 
 
 
 
 
 
150
  rag:
151
  embedding_model: "all-MiniLM-L6-v2"
152
  chroma_collection: "phd_advisor_documents"
153
+
154
+ # TODO: For development/testing only. PhD Advisor will likely not use these tools.
155
+ tools:
156
+ search_courses:
157
+ enabled: true
158
+ max_results: 20 # Limit the number of results returned by the search_courses tool
159
+ rate_my_professor:
160
+ enabled: true
161
+ # Run `python3 scripts/rmp_school_lookup.py "<school name>"` to find this value
162
+ school_id: "U2Nob29sLTEwODc=" # CU Boulder school ID
scripts/rmp_school_lookup.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Look up a RateMyProfessors school ID by name.
3
+
4
+ Usage:
5
+ python3 scripts/rmp_school_lookup.py "University of Colorado"
6
+
7
+ Prints matching schools with their GraphQL IDs (the value to put in
8
+ your config.yaml under tools.rate_my_professor.school_id).
9
+ """
10
+
11
+ import asyncio
12
+ import re
13
+ import sys
14
+
15
+ import httpx
16
+
17
+ RMP_GRAPHQL_URL = "https://www.ratemyprofessors.com/graphql"
18
+ RMP_LANDING_URL = "https://www.ratemyprofessors.com/"
19
+ BROWSER_UA = (
20
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
21
+ "AppleWebKit/537.36 (KHTML, like Gecko) "
22
+ "Chrome/131.0.0.0 Safari/537.36"
23
+ )
24
+
25
+ SCHOOL_SEARCH_QUERY = """
26
+ query SchoolSearchQuery($query: SchoolSearchQuery!) {
27
+ newSearch {
28
+ schools(query: $query) {
29
+ edges {
30
+ node {
31
+ id
32
+ name
33
+ city
34
+ state
35
+ }
36
+ }
37
+ }
38
+ }
39
+ }
40
+ """
41
+
42
+
43
+ async def _extract_auth_token(client: httpx.AsyncClient) -> str:
44
+ try:
45
+ resp = await client.get(RMP_LANDING_URL, headers={"User-Agent": BROWSER_UA})
46
+ m = re.search(r'"Authorization"\s*:\s*"(Basic\s+[A-Za-z0-9+/=]+)"', resp.text)
47
+ if m:
48
+ return m.group(1)
49
+ except Exception:
50
+ pass
51
+ return "Basic dGVzdDp0ZXN0"
52
+
53
+
54
+ async def search_schools(school_name: str) -> list:
55
+ async with httpx.AsyncClient(timeout=30, follow_redirects=True) as client:
56
+ auth_token = await _extract_auth_token(client)
57
+ resp = await client.post(
58
+ RMP_GRAPHQL_URL,
59
+ json={
60
+ "query": SCHOOL_SEARCH_QUERY,
61
+ "variables": {"query": {"text": school_name}},
62
+ },
63
+ headers={
64
+ "User-Agent": BROWSER_UA,
65
+ "Authorization": auth_token,
66
+ "Content-Type": "application/json",
67
+ },
68
+ )
69
+ resp.raise_for_status()
70
+ data = resp.json()
71
+ edges = (
72
+ data.get("data", {})
73
+ .get("newSearch", {})
74
+ .get("schools", {})
75
+ .get("edges", [])
76
+ )
77
+ return [
78
+ {
79
+ "school_id": edge["node"]["id"],
80
+ "name": edge["node"]["name"],
81
+ "city": edge["node"].get("city", ""),
82
+ "state": edge["node"].get("state", ""),
83
+ }
84
+ for edge in edges
85
+ if edge.get("node")
86
+ ]
87
+
88
+
89
+ def main():
90
+ if len(sys.argv) < 2:
91
+ print("Usage: python3 scripts/rmp_school_lookup.py <school name>")
92
+ print('Example: python3 scripts/rmp_school_lookup.py "University of Colorado"')
93
+ sys.exit(1)
94
+
95
+ query = " ".join(sys.argv[1:])
96
+ results = asyncio.run(search_schools(query))
97
+
98
+ if not results:
99
+ print(f"No schools found matching '{query}'")
100
+ sys.exit(0)
101
+
102
+ print(f"Found {len(results)} school(s) matching '{query}':\n")
103
+ for school in results:
104
+ location = ", ".join(filter(None, [school["city"], school["state"]]))
105
+ print(f" {school['name']}")
106
+ print(f" Location: {location}")
107
+ print(f" school_id: {school['school_id']}")
108
+ print()
109
+
110
+
111
+ if __name__ == "__main__":
112
+ main()