Spaces:
Build error
Build error
| import streamlit as st | |
| import os | |
| from langgraph.graph import MessagesState, StateGraph, START, END | |
| from typing_extensions import TypedDict | |
| from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage, AIMessage | |
| from typing import Annotated | |
| from langgraph.graph.message import add_messages | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from langchain_groq import ChatGroq | |
| # Define the state | |
| class MessagesState(TypedDict): | |
| messages: Annotated[list[AnyMessage], add_messages] | |
| # Create graph function | |
| def create_chat_graph(system_prompt, model_name): | |
| # Initialize LLM | |
| llm = ChatGroq(model=model_name) | |
| # Create system message | |
| system_message = SystemMessage(content=system_prompt) | |
| # Define the assistant function | |
| def assistant(state: MessagesState): | |
| # Get all messages including the system message | |
| messages = [system_message] + state["messages"] | |
| # Generate response | |
| response = llm.invoke(messages) | |
| # Return the response | |
| return {"messages": [response]} | |
| # Initialize the graph builder | |
| builder = StateGraph(MessagesState) | |
| # Add the assistant node | |
| builder.add_node("assistant", assistant) | |
| # Define edges | |
| builder.add_edge(START, "assistant") | |
| builder.add_edge("assistant", END) | |
| # Create memory saver for persistence | |
| memory = MemorySaver() | |
| # Compile the graph with memory | |
| graph = builder.compile(checkpointer=memory) | |
| return graph | |
| # Set up Streamlit page | |
| st.set_page_config(page_title="Conversational AI Assistant", page_icon="💬") | |
| st.title("AI Chatbot with Memory") | |
| # Sidebar configuration | |
| st.sidebar.header("Configuration") | |
| # API Key input (using st.secrets in production) | |
| if "GROQ_API_KEY" not in os.environ: | |
| api_key = st.sidebar.text_input("Enter your Groq API Key:", type="password") | |
| if api_key: | |
| os.environ["GROQ_API_KEY"] = api_key | |
| else: | |
| st.sidebar.warning("Please enter your Groq API key to continue.") | |
| # Model selection | |
| model_options = ["llama3-70b-8192", "mixtral-8x7b-32768", "gemma-7b-it"] | |
| selected_model = st.sidebar.selectbox("Select Model:", model_options) | |
| # System prompt | |
| default_prompt = "You are a helpful and friendly assistant. Maintain a conversational tone and remember previous interactions with the user." | |
| system_prompt = st.sidebar.text_area("System Prompt:", value=default_prompt, height=150) | |
| # Session ID for this conversation | |
| if "session_id" not in st.session_state: | |
| import uuid | |
| st.session_state.session_id = str(uuid.uuid4()) | |
| # Initialize or get chat history | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| # Initialize the graph on first run or when config changes | |
| if "chat_graph" not in st.session_state or st.sidebar.button("Reset Conversation"): | |
| if "GROQ_API_KEY" in os.environ: | |
| with st.spinner("Initializing chatbot..."): | |
| st.session_state.chat_graph = create_chat_graph(system_prompt, selected_model) | |
| st.session_state.messages = [] # Clear messages on reset | |
| st.success("Chatbot initialized!") | |
| else: | |
| st.sidebar.error("API key required to initialize chatbot.") | |
| # Display chat history | |
| for message in st.session_state.messages: | |
| if isinstance(message, dict): # Handle dict format | |
| role = message.get("role", "") | |
| content = message.get("content", "") | |
| else: # Handle direct string format | |
| role = "user" if message.startswith("User: ") else "assistant" | |
| content = message.replace("User: ", "").replace("Assistant: ", "") | |
| with st.chat_message(role): | |
| st.write(content) | |
| # Input for new message | |
| if "chat_graph" in st.session_state and "GROQ_API_KEY" in os.environ: | |
| user_input = st.chat_input("Type your message here...") | |
| if user_input: | |
| # Display user message | |
| with st.chat_message("user"): | |
| st.write(user_input) | |
| # Add to history | |
| st.session_state.messages.append({"role": "user", "content": user_input}) | |
| # Get response from the chatbot | |
| with st.spinner("Thinking..."): | |
| # Call the graph with the user's message | |
| config = {"configurable": {"thread_id": st.session_state.session_id}} | |
| user_message = [HumanMessage(content=user_input)] | |
| result = st.session_state.chat_graph.invoke({"messages": user_message}, config) | |
| # Extract response | |
| response = result["messages"][-1].content | |
| # Display assistant response | |
| with st.chat_message("assistant"): | |
| st.write(response) | |
| # Add to history | |
| st.session_state.messages.append({"role": "assistant", "content": response}) | |
| # Add some additional info in the sidebar | |
| st.sidebar.markdown("---") | |
| st.sidebar.subheader("About") | |
| st.sidebar.info( | |
| """ | |
| This chatbot uses LangGraph for maintaining conversation context and | |
| ChatGroq's language models for generating responses. Each conversation | |
| has a unique session ID to maintain history. | |
| """ | |
| ) | |
| # Download chat history | |
| if st.sidebar.button("Download Chat History"): | |
| import json | |
| from datetime import datetime | |
| # Convert chat history to downloadable format | |
| chat_export = "\n".join([f"{m['role']}: {m['content']}" for m in st.session_state.messages]) | |
| # Create download button | |
| st.sidebar.download_button( | |
| label="Download as Text", | |
| data=chat_export, | |
| file_name=f"chat_history_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt", | |
| mime="text/plain" | |
| ) |