Spaces:
Running
Running
| # ============================================================ | |
| # DDS SQL Agent with Modern LangChain Memory + Gradio UI | |
| # Hugging Face Spaces version | |
| # ============================================================ | |
| import os | |
| import re | |
| import sqlite3 | |
| from pathlib import Path | |
| from uuid import uuid4 | |
| import gradio as gr | |
| from langchain.agents import create_agent | |
| from langchain.tools import tool | |
| from langgraph.checkpoint.memory import InMemorySaver | |
| # ------------------------------------------------------------ | |
| # 1. Environment configuration | |
| # ------------------------------------------------------------ | |
| # Add this in Hugging Face Space Settings -> Variables and Secrets: | |
| # Secret name: OPENAI_API_KEY | |
| # | |
| # Optional Space variables: | |
| # MODEL_NAME = openai:gpt-5.4 | |
| # DATABASE_PATH = data/Chinook_Sqlite.sqlite | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "openai:gpt-5.4") | |
| DATABASE_PATH = Path(os.getenv("DATABASE_PATH", "data/Chinook_Sqlite.sqlite")) | |
| # ------------------------------------------------------------ | |
| # 2. Database helpers | |
| # ------------------------------------------------------------ | |
| def resolve_database_path() -> Path: | |
| """ | |
| Resolve the SQLite database path. | |
| Default: | |
| - data/Chinook_Sqlite.sqlite | |
| You can override it in Hugging Face Spaces with: | |
| DATABASE_PATH=/path/to/your/database.sqlite | |
| """ | |
| if DATABASE_PATH.exists(): | |
| return DATABASE_PATH | |
| common_paths = [ | |
| Path("Chinook_Sqlite.sqlite"), | |
| Path("chinook.db"), | |
| Path("Chinook.db"), | |
| Path("data/chinook.db"), | |
| Path("data/Chinook.db"), | |
| ] | |
| for path in common_paths: | |
| if path.exists(): | |
| return path | |
| raise FileNotFoundError( | |
| "SQLite database file was not found. " | |
| "Upload your database file or set DATABASE_PATH in Hugging Face Variables." | |
| ) | |
| DB_PATH = resolve_database_path() | |
| def get_database_schema(db_path: Path) -> str: | |
| """ | |
| Extract table and column information from the SQLite database. | |
| This schema is injected into the system prompt so the agent knows the DB structure. | |
| """ | |
| conn = sqlite3.connect(db_path) | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| """ | |
| SELECT name | |
| FROM sqlite_master | |
| WHERE type = 'table' | |
| AND name NOT LIKE 'sqlite_%' | |
| ORDER BY name; | |
| """ | |
| ) | |
| tables = [row[0] for row in cursor.fetchall()] | |
| schema_lines = [] | |
| for table in tables: | |
| schema_lines.append(f"\nTable: {table}") | |
| cursor.execute(f"PRAGMA table_info({table});") | |
| columns = cursor.fetchall() | |
| for column in columns: | |
| # PRAGMA table_info columns: | |
| # cid, name, type, notnull, dflt_value, pk | |
| _, name, col_type, notnull, _, pk = column | |
| flags = [] | |
| if pk: | |
| flags.append("PRIMARY KEY") | |
| if notnull: | |
| flags.append("NOT NULL") | |
| flag_text = f" ({', '.join(flags)})" if flags else "" | |
| schema_lines.append(f"- {name}: {col_type}{flag_text}") | |
| conn.close() | |
| return "\n".join(schema_lines) | |
| DATABASE_SCHEMA = get_database_schema(DB_PATH) | |
| def strip_sql_code_fences(query: str) -> str: | |
| """ | |
| Removes markdown code fences if the model returns SQL inside ```sql ... ```. | |
| """ | |
| query = query.strip() | |
| if query.startswith("```"): | |
| query = re.sub(r"^```(?:sql)?", "", query, flags=re.IGNORECASE).strip() | |
| query = re.sub(r"```$", "", query).strip() | |
| return query | |
| def is_read_only_sql(query: str) -> bool: | |
| """ | |
| Basic read-only protection. | |
| Allows SELECT, WITH, PRAGMA, and EXPLAIN. | |
| Blocks INSERT, UPDATE, DELETE, DROP, ALTER, CREATE, etc. | |
| """ | |
| cleaned = strip_sql_code_fences(query) | |
| cleaned = re.sub(r"/\*.*?\*/", "", cleaned, flags=re.DOTALL) | |
| cleaned = re.sub(r"--.*?$", "", cleaned, flags=re.MULTILINE) | |
| cleaned = cleaned.strip().lower() | |
| allowed_starts = ("select", "with", "pragma", "explain") | |
| if not cleaned.startswith(allowed_starts): | |
| return False | |
| blocked_keywords = [ | |
| "insert ", | |
| "update ", | |
| "delete ", | |
| "drop ", | |
| "alter ", | |
| "create ", | |
| "replace ", | |
| "truncate ", | |
| "attach ", | |
| "detach ", | |
| "vacuum", | |
| "reindex", | |
| ] | |
| return not any(keyword in cleaned for keyword in blocked_keywords) | |
| def rows_to_markdown(columns, rows, max_rows: int = 50) -> str: | |
| """ | |
| Convert SQL rows to a Markdown table for readable chatbot output. | |
| """ | |
| if not rows: | |
| return "Query executed successfully, but returned no rows." | |
| rows = rows[:max_rows] | |
| def clean_cell(value): | |
| if value is None: | |
| return "" | |
| text = str(value) | |
| text = text.replace("\n", " ").replace("|", "\\|") | |
| return text | |
| header = "| " + " | ".join(columns) + " |" | |
| separator = "| " + " | ".join(["---"] * len(columns)) + " |" | |
| body_lines = [] | |
| for row in rows: | |
| body_lines.append("| " + " | ".join(clean_cell(value) for value in row) + " |") | |
| return "\n".join([header, separator] + body_lines) | |
| # ------------------------------------------------------------ | |
| # 3. SQL tool | |
| # ------------------------------------------------------------ | |
| def execute_sql(query: str) -> str: | |
| """ | |
| Execute a read-only SQLite SQL query against the Chinook database. | |
| Use this tool when the user asks analytical questions that require database access. | |
| Only SELECT, WITH, PRAGMA, and EXPLAIN queries are allowed. | |
| """ | |
| query = strip_sql_code_fences(query) | |
| if not is_read_only_sql(query): | |
| return ( | |
| "Blocked for safety. Only read-only SQL is allowed. " | |
| "Please use SELECT, WITH, PRAGMA, or EXPLAIN queries." | |
| ) | |
| try: | |
| conn = sqlite3.connect(DB_PATH) | |
| cursor = conn.cursor() | |
| cursor.execute(query) | |
| rows = cursor.fetchall() | |
| columns = [description[0] for description in cursor.description] if cursor.description else [] | |
| conn.close() | |
| if not columns: | |
| return "Query executed successfully." | |
| result_table = rows_to_markdown(columns, rows) | |
| if len(rows) > 50: | |
| result_table += f"\n\nShowing first 50 rows out of {len(rows)} rows." | |
| return result_table | |
| except Exception as e: | |
| return f"SQL execution error: {str(e)}" | |
| # ------------------------------------------------------------ | |
| # 4. System prompt | |
| # ------------------------------------------------------------ | |
| SYSTEM_PROMPT = f""" | |
| You are a helpful SQL data analyst for the Chinook SQLite database. | |
| Your job: | |
| - Understand the user's business/data question. | |
| - Write correct SQLite queries. | |
| - Use the execute_sql tool to query the database. | |
| - Explain the result clearly and concisely. | |
| - For follow-up questions, use the conversation memory. | |
| Important rules: | |
| - Use only read-only SQL. | |
| - Never modify the database. | |
| - Prefer clear SQL with explicit table joins. | |
| - When useful, explain the SQL logic briefly. | |
| - If the user asks a vague question, make a reasonable interpretation and proceed. | |
| - If the database does not contain enough information, say that clearly. | |
| Available database schema: | |
| {DATABASE_SCHEMA} | |
| """ | |
| # ------------------------------------------------------------ | |
| # 5. Create LangChain agent with short-term memory | |
| # ------------------------------------------------------------ | |
| # InMemorySaver gives thread-level memory during the live Space session. | |
| # For production-grade persistent memory, replace this with a database-backed checkpointer. | |
| checkpointer = InMemorySaver() | |
| sql_agent_with_memory = create_agent( | |
| model=MODEL_NAME, | |
| tools=[execute_sql], | |
| system_prompt=SYSTEM_PROMPT, | |
| checkpointer=checkpointer, | |
| ) | |
| # ------------------------------------------------------------ | |
| # 6. Gradio helpers | |
| # ------------------------------------------------------------ | |
| def content_to_text(content): | |
| """ | |
| Convert LangChain message content into displayable text. | |
| """ | |
| if isinstance(content, str): | |
| return content | |
| if isinstance(content, list): | |
| text_parts = [] | |
| for item in content: | |
| if isinstance(item, dict): | |
| if "text" in item: | |
| text_parts.append(item["text"]) | |
| elif "content" in item: | |
| text_parts.append(str(item["content"])) | |
| else: | |
| text_parts.append(str(item)) | |
| else: | |
| text_parts.append(str(item)) | |
| return "\n".join(text_parts) | |
| return str(content) | |
| def create_thread_id(): | |
| """ | |
| Same thread_id = same LangGraph memory. | |
| New thread_id = fresh conversation. | |
| """ | |
| return f"dds-sql-agent-{uuid4()}" | |
| def normalize_history_to_messages(history): | |
| """ | |
| Gradio expects messages format: | |
| [ | |
| {"role": "user", "content": "..."}, | |
| {"role": "assistant", "content": "..."} | |
| ] | |
| """ | |
| if history is None: | |
| return [] | |
| normalized = [] | |
| for item in history: | |
| if isinstance(item, dict) and "role" in item and "content" in item: | |
| role = item.get("role") | |
| if role in ["user", "assistant"]: | |
| normalized.append( | |
| { | |
| "role": role, | |
| "content": content_to_text(item.get("content", "")), | |
| } | |
| ) | |
| return normalized | |
| # ------------------------------------------------------------ | |
| # 7. Gradio chat function | |
| # ------------------------------------------------------------ | |
| def chat_with_sql_agent(message, history, thread_id): | |
| """ | |
| Handles one user message from Gradio. | |
| This returns messages format without passing type="messages" | |
| to gr.Chatbot, because some Gradio 6 runtimes expect messages | |
| but do not accept the type argument. | |
| """ | |
| history = normalize_history_to_messages(history) | |
| if not OPENAI_API_KEY: | |
| assistant_message = ( | |
| "OPENAI_API_KEY is missing. In Hugging Face Spaces, go to " | |
| "Settings → Variables and Secrets → New Secret, then add:\n\n" | |
| "`OPENAI_API_KEY = your_openai_api_key`" | |
| ) | |
| return history + [ | |
| {"role": "user", "content": message or ""}, | |
| {"role": "assistant", "content": assistant_message}, | |
| ], "", thread_id or create_thread_id() | |
| if not thread_id: | |
| thread_id = create_thread_id() | |
| if not message or not message.strip(): | |
| return history, "", thread_id | |
| user_message = message.strip() | |
| try: | |
| result = sql_agent_with_memory.invoke( | |
| { | |
| "messages": [ | |
| { | |
| "role": "user", | |
| "content": user_message, | |
| } | |
| ] | |
| }, | |
| config={ | |
| "configurable": { | |
| "thread_id": thread_id | |
| } | |
| }, | |
| ) | |
| assistant_message = content_to_text(result["messages"][-1].content) | |
| except Exception as e: | |
| assistant_message = f""" | |
| Something went wrong while running the SQL agent. | |
| Error: | |
| ```text | |
| {str(e)} | |
| ``` | |
| Check: | |
| 1. OPENAI_API_KEY is set in Hugging Face Secrets. | |
| 2. MODEL_NAME is available in your OpenAI account. | |
| 3. The SQLite database file exists at: `{DB_PATH}` | |
| """ | |
| updated_history = history + [ | |
| { | |
| "role": "user", | |
| "content": user_message, | |
| }, | |
| { | |
| "role": "assistant", | |
| "content": assistant_message, | |
| }, | |
| ] | |
| return updated_history, "", thread_id | |
| def reset_chat(): | |
| """ | |
| Clears UI history and starts a fresh memory thread. | |
| """ | |
| return [], create_thread_id() | |
| def example_question(question): | |
| """ | |
| Puts an example question into the textbox. | |
| """ | |
| return question | |
| # ------------------------------------------------------------ | |
| # 8. Build Gradio app | |
| # ------------------------------------------------------------ | |
| custom_css = """ | |
| #main-container { | |
| max-width: 1100px; | |
| margin: 0 auto; | |
| } | |
| .dds-note { | |
| font-size: 0.95rem; | |
| opacity: 0.85; | |
| } | |
| """ | |
| with gr.Blocks(title="DDS SQL Agent", css=custom_css) as demo: | |
| thread_id_state = gr.State(value=create_thread_id()) | |
| with gr.Column(elem_id="main-container"): | |
| gr.Markdown( | |
| f""" | |
| # DDS SQL Agent with Memory | |
| Ask questions about the Chinook SQLite database. | |
| The agent can generate SQL, execute read-only queries, and remember follow-up questions in the same session. | |
| **Model:** `{MODEL_NAME}` | |
| **Database:** `{DB_PATH}` | |
| """ | |
| ) | |
| if not OPENAI_API_KEY: | |
| gr.Markdown( | |
| """ | |
| > **Setup needed:** `OPENAI_API_KEY` is not set. | |
| > Add it in Hugging Face Spaces under **Settings → Variables and Secrets → New Secret**. | |
| """ | |
| ) | |
| chatbot = gr.Chatbot( | |
| value=[], | |
| height=560, | |
| label="SQL Agent Chat", | |
| placeholder="Ask a question about the database...", | |
| ) | |
| with gr.Row(): | |
| user_input = gr.Textbox( | |
| placeholder="Example: Which customer spent the most money?", | |
| label="Your question", | |
| scale=8, | |
| ) | |
| submit_btn = gr.Button( | |
| "Ask", | |
| scale=1, | |
| variant="primary", | |
| ) | |
| with gr.Row(): | |
| clear_btn = gr.Button("New Chat / Reset Memory") | |
| gr.Markdown("### Example questions") | |
| with gr.Row(): | |
| ex1 = gr.Button("Which customer spent the most money?") | |
| ex2 = gr.Button("Show total sales by country.") | |
| ex3 = gr.Button("Which genre has the most tracks?") | |
| ex4 = gr.Button("What are the top-selling tracks?") | |
| ex1.click(example_question, inputs=[gr.State("Which customer spent the most money?")], outputs=[user_input]) | |
| ex2.click(example_question, inputs=[gr.State("Show total sales by country.")], outputs=[user_input]) | |
| ex3.click(example_question, inputs=[gr.State("Which genre has the most tracks?")], outputs=[user_input]) | |
| ex4.click(example_question, inputs=[gr.State("What are the top-selling tracks?")], outputs=[user_input]) | |
| submit_btn.click( | |
| fn=chat_with_sql_agent, | |
| inputs=[user_input, chatbot, thread_id_state], | |
| outputs=[chatbot, user_input, thread_id_state], | |
| ) | |
| user_input.submit( | |
| fn=chat_with_sql_agent, | |
| inputs=[user_input, chatbot, thread_id_state], | |
| outputs=[chatbot, user_input, thread_id_state], | |
| ) | |
| clear_btn.click( | |
| fn=reset_chat, | |
| inputs=[], | |
| outputs=[chatbot, thread_id_state], | |
| ) | |
| # ------------------------------------------------------------ | |
| # 9. Launch for Hugging Face Spaces | |
| # ------------------------------------------------------------ | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |