decodingdatascience commited on
Commit
3d3b91c
·
verified ·
1 Parent(s): 36e73d7

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +555 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,555 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================
2
+ # DDS SQL Agent with Modern LangChain Memory + Gradio UI
3
+ # Hugging Face Spaces version
4
+ # ============================================================
5
+
6
+ import os
7
+ import re
8
+ import sqlite3
9
+ from pathlib import Path
10
+ from uuid import uuid4
11
+
12
+ import gradio as gr
13
+ from langchain.agents import create_agent
14
+ from langchain.tools import tool
15
+ from langgraph.checkpoint.memory import InMemorySaver
16
+
17
+
18
+ # ------------------------------------------------------------
19
+ # 1. Environment configuration
20
+ # ------------------------------------------------------------
21
+ # Add this in Hugging Face Space Settings -> Variables and Secrets:
22
+ # Secret name: OPENAI_API_KEY
23
+ #
24
+ # Optional Space variables:
25
+ # MODEL_NAME = openai:gpt-5.4
26
+ # DATABASE_PATH = data/Chinook_Sqlite.sqlite
27
+
28
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
29
+ MODEL_NAME = os.getenv("MODEL_NAME", "openai:gpt-5.4")
30
+ DATABASE_PATH = Path(os.getenv("DATABASE_PATH", "data/Chinook_Sqlite.sqlite"))
31
+
32
+
33
+ # ------------------------------------------------------------
34
+ # 2. Database helpers
35
+ # ------------------------------------------------------------
36
+
37
+ def resolve_database_path() -> Path:
38
+ """
39
+ Resolve the SQLite database path.
40
+
41
+ Default:
42
+ - data/Chinook_Sqlite.sqlite
43
+
44
+ You can override it in Hugging Face Spaces with:
45
+ DATABASE_PATH=/path/to/your/database.sqlite
46
+ """
47
+
48
+ if DATABASE_PATH.exists():
49
+ return DATABASE_PATH
50
+
51
+ common_paths = [
52
+ Path("Chinook_Sqlite.sqlite"),
53
+ Path("chinook.db"),
54
+ Path("Chinook.db"),
55
+ Path("data/chinook.db"),
56
+ Path("data/Chinook.db"),
57
+ ]
58
+
59
+ for path in common_paths:
60
+ if path.exists():
61
+ return path
62
+
63
+ raise FileNotFoundError(
64
+ "SQLite database file was not found. "
65
+ "Upload your database file or set DATABASE_PATH in Hugging Face Variables."
66
+ )
67
+
68
+
69
+ DB_PATH = resolve_database_path()
70
+
71
+
72
+ def get_database_schema(db_path: Path) -> str:
73
+ """
74
+ Extract table and column information from the SQLite database.
75
+ This schema is injected into the system prompt so the agent knows the DB structure.
76
+ """
77
+
78
+ conn = sqlite3.connect(db_path)
79
+ cursor = conn.cursor()
80
+
81
+ cursor.execute(
82
+ """
83
+ SELECT name
84
+ FROM sqlite_master
85
+ WHERE type = 'table'
86
+ AND name NOT LIKE 'sqlite_%'
87
+ ORDER BY name;
88
+ """
89
+ )
90
+
91
+ tables = [row[0] for row in cursor.fetchall()]
92
+ schema_lines = []
93
+
94
+ for table in tables:
95
+ schema_lines.append(f"\nTable: {table}")
96
+
97
+ cursor.execute(f"PRAGMA table_info({table});")
98
+ columns = cursor.fetchall()
99
+
100
+ for column in columns:
101
+ # PRAGMA table_info columns:
102
+ # cid, name, type, notnull, dflt_value, pk
103
+ _, name, col_type, notnull, _, pk = column
104
+
105
+ flags = []
106
+ if pk:
107
+ flags.append("PRIMARY KEY")
108
+ if notnull:
109
+ flags.append("NOT NULL")
110
+
111
+ flag_text = f" ({', '.join(flags)})" if flags else ""
112
+ schema_lines.append(f"- {name}: {col_type}{flag_text}")
113
+
114
+ conn.close()
115
+
116
+ return "\n".join(schema_lines)
117
+
118
+
119
+ DATABASE_SCHEMA = get_database_schema(DB_PATH)
120
+
121
+
122
+ def strip_sql_code_fences(query: str) -> str:
123
+ """
124
+ Removes markdown code fences if the model returns SQL inside ```sql ... ```.
125
+ """
126
+
127
+ query = query.strip()
128
+
129
+ if query.startswith("```"):
130
+ query = re.sub(r"^```(?:sql)?", "", query, flags=re.IGNORECASE).strip()
131
+ query = re.sub(r"```$", "", query).strip()
132
+
133
+ return query
134
+
135
+
136
+ def is_read_only_sql(query: str) -> bool:
137
+ """
138
+ Basic read-only protection.
139
+ Allows SELECT, WITH, PRAGMA, and EXPLAIN.
140
+ Blocks INSERT, UPDATE, DELETE, DROP, ALTER, CREATE, etc.
141
+ """
142
+
143
+ cleaned = strip_sql_code_fences(query)
144
+ cleaned = re.sub(r"/\*.*?\*/", "", cleaned, flags=re.DOTALL)
145
+ cleaned = re.sub(r"--.*?$", "", cleaned, flags=re.MULTILINE)
146
+ cleaned = cleaned.strip().lower()
147
+
148
+ allowed_starts = ("select", "with", "pragma", "explain")
149
+
150
+ if not cleaned.startswith(allowed_starts):
151
+ return False
152
+
153
+ blocked_keywords = [
154
+ "insert ",
155
+ "update ",
156
+ "delete ",
157
+ "drop ",
158
+ "alter ",
159
+ "create ",
160
+ "replace ",
161
+ "truncate ",
162
+ "attach ",
163
+ "detach ",
164
+ "vacuum",
165
+ "reindex",
166
+ ]
167
+
168
+ return not any(keyword in cleaned for keyword in blocked_keywords)
169
+
170
+
171
+ def rows_to_markdown(columns, rows, max_rows: int = 50) -> str:
172
+ """
173
+ Convert SQL rows to a Markdown table for readable chatbot output.
174
+ """
175
+
176
+ if not rows:
177
+ return "Query executed successfully, but returned no rows."
178
+
179
+ rows = rows[:max_rows]
180
+
181
+ def clean_cell(value):
182
+ if value is None:
183
+ return ""
184
+ text = str(value)
185
+ text = text.replace("\n", " ").replace("|", "\\|")
186
+ return text
187
+
188
+ header = "| " + " | ".join(columns) + " |"
189
+ separator = "| " + " | ".join(["---"] * len(columns)) + " |"
190
+
191
+ body_lines = []
192
+ for row in rows:
193
+ body_lines.append("| " + " | ".join(clean_cell(value) for value in row) + " |")
194
+
195
+ return "\n".join([header, separator] + body_lines)
196
+
197
+
198
+ # ------------------------------------------------------------
199
+ # 3. SQL tool
200
+ # ------------------------------------------------------------
201
+
202
+ @tool
203
+ def execute_sql(query: str) -> str:
204
+ """
205
+ Execute a read-only SQLite SQL query against the Chinook database.
206
+
207
+ Use this tool when the user asks analytical questions that require database access.
208
+ Only SELECT, WITH, PRAGMA, and EXPLAIN queries are allowed.
209
+ """
210
+
211
+ query = strip_sql_code_fences(query)
212
+
213
+ if not is_read_only_sql(query):
214
+ return (
215
+ "Blocked for safety. Only read-only SQL is allowed. "
216
+ "Please use SELECT, WITH, PRAGMA, or EXPLAIN queries."
217
+ )
218
+
219
+ try:
220
+ conn = sqlite3.connect(DB_PATH)
221
+ cursor = conn.cursor()
222
+ cursor.execute(query)
223
+
224
+ rows = cursor.fetchall()
225
+ columns = [description[0] for description in cursor.description] if cursor.description else []
226
+
227
+ conn.close()
228
+
229
+ if not columns:
230
+ return "Query executed successfully."
231
+
232
+ result_table = rows_to_markdown(columns, rows)
233
+
234
+ if len(rows) > 50:
235
+ result_table += f"\n\nShowing first 50 rows out of {len(rows)} rows."
236
+
237
+ return result_table
238
+
239
+ except Exception as e:
240
+ return f"SQL execution error: {str(e)}"
241
+
242
+
243
+ # ------------------------------------------------------------
244
+ # 4. System prompt
245
+ # ------------------------------------------------------------
246
+
247
+ SYSTEM_PROMPT = f"""
248
+ You are a helpful SQL data analyst for the Chinook SQLite database.
249
+
250
+ Your job:
251
+ - Understand the user's business/data question.
252
+ - Write correct SQLite queries.
253
+ - Use the execute_sql tool to query the database.
254
+ - Explain the result clearly and concisely.
255
+ - For follow-up questions, use the conversation memory.
256
+
257
+ Important rules:
258
+ - Use only read-only SQL.
259
+ - Never modify the database.
260
+ - Prefer clear SQL with explicit table joins.
261
+ - When useful, explain the SQL logic briefly.
262
+ - If the user asks a vague question, make a reasonable interpretation and proceed.
263
+ - If the database does not contain enough information, say that clearly.
264
+
265
+ Available database schema:
266
+ {DATABASE_SCHEMA}
267
+ """
268
+
269
+
270
+ # ------------------------------------------------------------
271
+ # 5. Create LangChain agent with short-term memory
272
+ # ------------------------------------------------------------
273
+ # InMemorySaver gives thread-level memory during the live Space session.
274
+ # For production-grade persistent memory, replace this with a database-backed checkpointer.
275
+
276
+ checkpointer = InMemorySaver()
277
+
278
+ sql_agent_with_memory = create_agent(
279
+ model=MODEL_NAME,
280
+ tools=[execute_sql],
281
+ system_prompt=SYSTEM_PROMPT,
282
+ checkpointer=checkpointer,
283
+ )
284
+
285
+
286
+ # ------------------------------------------------------------
287
+ # 6. Gradio helpers
288
+ # ------------------------------------------------------------
289
+
290
+ def content_to_text(content):
291
+ """
292
+ Convert LangChain message content into displayable text.
293
+ """
294
+
295
+ if isinstance(content, str):
296
+ return content
297
+
298
+ if isinstance(content, list):
299
+ text_parts = []
300
+
301
+ for item in content:
302
+ if isinstance(item, dict):
303
+ if "text" in item:
304
+ text_parts.append(item["text"])
305
+ elif "content" in item:
306
+ text_parts.append(str(item["content"]))
307
+ else:
308
+ text_parts.append(str(item))
309
+ else:
310
+ text_parts.append(str(item))
311
+
312
+ return "\n".join(text_parts)
313
+
314
+ return str(content)
315
+
316
+
317
+ def create_thread_id():
318
+ """
319
+ Same thread_id = same LangGraph memory.
320
+ New thread_id = fresh conversation.
321
+ """
322
+
323
+ return f"dds-sql-agent-{uuid4()}"
324
+
325
+
326
+ def normalize_history_to_messages(history):
327
+ """
328
+ Gradio expects messages format:
329
+ [
330
+ {"role": "user", "content": "..."},
331
+ {"role": "assistant", "content": "..."}
332
+ ]
333
+ """
334
+
335
+ if history is None:
336
+ return []
337
+
338
+ normalized = []
339
+
340
+ for item in history:
341
+ if isinstance(item, dict) and "role" in item and "content" in item:
342
+ role = item.get("role")
343
+ if role in ["user", "assistant"]:
344
+ normalized.append(
345
+ {
346
+ "role": role,
347
+ "content": content_to_text(item.get("content", "")),
348
+ }
349
+ )
350
+
351
+ return normalized
352
+
353
+
354
+ # ------------------------------------------------------------
355
+ # 7. Gradio chat function
356
+ # ------------------------------------------------------------
357
+
358
+ def chat_with_sql_agent(message, history, thread_id):
359
+ """
360
+ Handles one user message from Gradio.
361
+
362
+ This returns messages format without passing type="messages"
363
+ to gr.Chatbot, because some Gradio 6 runtimes expect messages
364
+ but do not accept the type argument.
365
+ """
366
+
367
+ history = normalize_history_to_messages(history)
368
+
369
+ if not OPENAI_API_KEY:
370
+ assistant_message = (
371
+ "OPENAI_API_KEY is missing. In Hugging Face Spaces, go to "
372
+ "Settings → Variables and Secrets → New Secret, then add:\n\n"
373
+ "`OPENAI_API_KEY = your_openai_api_key`"
374
+ )
375
+
376
+ return history + [
377
+ {"role": "user", "content": message or ""},
378
+ {"role": "assistant", "content": assistant_message},
379
+ ], "", thread_id or create_thread_id()
380
+
381
+ if not thread_id:
382
+ thread_id = create_thread_id()
383
+
384
+ if not message or not message.strip():
385
+ return history, "", thread_id
386
+
387
+ user_message = message.strip()
388
+
389
+ try:
390
+ result = sql_agent_with_memory.invoke(
391
+ {
392
+ "messages": [
393
+ {
394
+ "role": "user",
395
+ "content": user_message,
396
+ }
397
+ ]
398
+ },
399
+ config={
400
+ "configurable": {
401
+ "thread_id": thread_id
402
+ }
403
+ },
404
+ )
405
+
406
+ assistant_message = content_to_text(result["messages"][-1].content)
407
+
408
+ except Exception as e:
409
+ assistant_message = f"""
410
+ Something went wrong while running the SQL agent.
411
+
412
+ Error:
413
+
414
+ ```text
415
+ {str(e)}
416
+ ```
417
+
418
+ Check:
419
+ 1. OPENAI_API_KEY is set in Hugging Face Secrets.
420
+ 2. MODEL_NAME is available in your OpenAI account.
421
+ 3. The SQLite database file exists at: `{DB_PATH}`
422
+ """
423
+
424
+ updated_history = history + [
425
+ {
426
+ "role": "user",
427
+ "content": user_message,
428
+ },
429
+ {
430
+ "role": "assistant",
431
+ "content": assistant_message,
432
+ },
433
+ ]
434
+
435
+ return updated_history, "", thread_id
436
+
437
+
438
+ def reset_chat():
439
+ """
440
+ Clears UI history and starts a fresh memory thread.
441
+ """
442
+
443
+ return [], create_thread_id()
444
+
445
+
446
+ def example_question(question):
447
+ """
448
+ Puts an example question into the textbox.
449
+ """
450
+
451
+ return question
452
+
453
+
454
+ # ------------------------------------------------------------
455
+ # 8. Build Gradio app
456
+ # ------------------------------------------------------------
457
+
458
+ custom_css = """
459
+ #main-container {
460
+ max-width: 1100px;
461
+ margin: 0 auto;
462
+ }
463
+
464
+ .dds-note {
465
+ font-size: 0.95rem;
466
+ opacity: 0.85;
467
+ }
468
+ """
469
+
470
+ with gr.Blocks(title="DDS SQL Agent", css=custom_css) as demo:
471
+
472
+ thread_id_state = gr.State(value=create_thread_id())
473
+
474
+ with gr.Column(elem_id="main-container"):
475
+ gr.Markdown(
476
+ f"""
477
+ # DDS SQL Agent with Memory
478
+
479
+ Ask questions about the Chinook SQLite database.
480
+ The agent can generate SQL, execute read-only queries, and remember follow-up questions in the same session.
481
+
482
+ **Model:** `{MODEL_NAME}`
483
+ **Database:** `{DB_PATH}`
484
+ """
485
+ )
486
+
487
+ if not OPENAI_API_KEY:
488
+ gr.Markdown(
489
+ """
490
+ > **Setup needed:** `OPENAI_API_KEY` is not set.
491
+ > Add it in Hugging Face Spaces under **Settings → Variables and Secrets → New Secret**.
492
+ """
493
+ )
494
+
495
+ chatbot = gr.Chatbot(
496
+ value=[],
497
+ height=560,
498
+ label="SQL Agent Chat",
499
+ placeholder="Ask a question about the database...",
500
+ )
501
+
502
+ with gr.Row():
503
+ user_input = gr.Textbox(
504
+ placeholder="Example: Which customer spent the most money?",
505
+ label="Your question",
506
+ scale=8,
507
+ )
508
+
509
+ submit_btn = gr.Button(
510
+ "Ask",
511
+ scale=1,
512
+ variant="primary",
513
+ )
514
+
515
+ with gr.Row():
516
+ clear_btn = gr.Button("New Chat / Reset Memory")
517
+
518
+ gr.Markdown("### Example questions")
519
+
520
+ with gr.Row():
521
+ ex1 = gr.Button("Which customer spent the most money?")
522
+ ex2 = gr.Button("Show total sales by country.")
523
+ ex3 = gr.Button("Which genre has the most tracks?")
524
+ ex4 = gr.Button("What are the top-selling tracks?")
525
+
526
+ ex1.click(example_question, inputs=[gr.State("Which customer spent the most money?")], outputs=[user_input])
527
+ ex2.click(example_question, inputs=[gr.State("Show total sales by country.")], outputs=[user_input])
528
+ ex3.click(example_question, inputs=[gr.State("Which genre has the most tracks?")], outputs=[user_input])
529
+ ex4.click(example_question, inputs=[gr.State("What are the top-selling tracks?")], outputs=[user_input])
530
+
531
+ submit_btn.click(
532
+ fn=chat_with_sql_agent,
533
+ inputs=[user_input, chatbot, thread_id_state],
534
+ outputs=[chatbot, user_input, thread_id_state],
535
+ )
536
+
537
+ user_input.submit(
538
+ fn=chat_with_sql_agent,
539
+ inputs=[user_input, chatbot, thread_id_state],
540
+ outputs=[chatbot, user_input, thread_id_state],
541
+ )
542
+
543
+ clear_btn.click(
544
+ fn=reset_chat,
545
+ inputs=[],
546
+ outputs=[chatbot, thread_id_state],
547
+ )
548
+
549
+
550
+ # ------------------------------------------------------------
551
+ # 9. Launch for Hugging Face Spaces
552
+ # ------------------------------------------------------------
553
+
554
+ if __name__ == "__main__":
555
+ demo.queue().launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio>=6.0.0
2
+ langchain>=1.0.0
3
+ langchain-openai>=1.0.0
4
+ langgraph>=1.0.0
5
+ openai>=2.0.0
6
+ typing-extensions>=4.10.0