Spaces:
Sleeping
Sleeping
Sync from GitHub via hub-sync
Browse files- README.md +8 -12
- requirements.txt +1 -2
- server_app.py +3 -7
- src/__init__.py +2 -2
- src/database.py +0 -143
- src/qdrant_keepalive.py +0 -83
- src/rag_system.py +198 -191
- src/vector_store.py +99 -105
README.md
CHANGED
|
@@ -23,8 +23,8 @@ FastAPI backend for Code Compass, a personal full-stack RAG project that indexes
|
|
| 23 |
|
| 24 |
- Clone a public GitHub repository into temporary storage
|
| 25 |
- Filter and chunk source files for retrieval
|
| 26 |
-
- Generate embeddings and store chunks in
|
| 27 |
-
- Maintain lightweight repository and session metadata in
|
| 28 |
- Run indexing as a background task
|
| 29 |
- Retrieve evidence with semantic search, lexical search, fusion, and reranking
|
| 30 |
- Generate answers from the selected context and return citations to the UI
|
|
@@ -45,21 +45,17 @@ Production is configured for lower-cost hosting:
|
|
| 45 |
- `EMBEDDING_PROVIDER=local`
|
| 46 |
- Groq-hosted Llama for answer generation
|
| 47 |
- Local sentence-transformer embeddings for retrieval
|
| 48 |
-
-
|
| 49 |
|
| 50 |
-
##
|
| 51 |
|
| 52 |
-
The backend
|
| 53 |
|
| 54 |
Configuration:
|
| 55 |
|
| 56 |
-
- `
|
| 57 |
-
- `
|
| 58 |
-
- `
|
| 59 |
-
- `QDRANT_KEEPALIVE_ENABLED=true`
|
| 60 |
-
- `QDRANT_KEEPALIVE_INTERVAL_SECONDS=43200`
|
| 61 |
-
|
| 62 |
-
The main repository also includes a GitHub Actions keepalive workflow for cases where the backend host is asleep.
|
| 63 |
|
| 64 |
## Metrics
|
| 65 |
|
|
|
|
| 23 |
|
| 24 |
- Clone a public GitHub repository into temporary storage
|
| 25 |
- Filter and chunk source files for retrieval
|
| 26 |
+
- Generate embeddings and store chunks in Chroma DB
|
| 27 |
+
- Maintain lightweight repository and session metadata in memory
|
| 28 |
- Run indexing as a background task
|
| 29 |
- Retrieve evidence with semantic search, lexical search, fusion, and reranking
|
| 30 |
- Generate answers from the selected context and return citations to the UI
|
|
|
|
| 45 |
- `EMBEDDING_PROVIDER=local`
|
| 46 |
- Groq-hosted Llama for answer generation
|
| 47 |
- Local sentence-transformer embeddings for retrieval
|
| 48 |
+
- Chroma DB for vector storage
|
| 49 |
|
| 50 |
+
## Chroma Storage
|
| 51 |
|
| 52 |
+
The backend uses Chroma DB for vector storage in both local development and production. By default it stores the collection under `./data/chroma`, and you can point it somewhere else with `CHROMA_PATH`.
|
| 53 |
|
| 54 |
Configuration:
|
| 55 |
|
| 56 |
+
- `CHROMA_PATH=./data/chroma`
|
| 57 |
+
- `CHROMA_COLLECTION=repo_qa_chunks`
|
| 58 |
+
- `CHROMA_UPSERT_BATCH_SIZE=64`
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
## Metrics
|
| 61 |
|
requirements.txt
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
fastapi==0.109.2
|
| 2 |
uvicorn[standard]==0.27.1
|
| 3 |
-
sqlalchemy==2.0.25
|
| 4 |
pydantic==2.6.1
|
| 5 |
python-dotenv==1.0.1
|
| 6 |
|
|
@@ -11,7 +10,7 @@ google-genai==1.12.1
|
|
| 11 |
httpx==0.28.1
|
| 12 |
numpy==1.26.4
|
| 13 |
rank-bm25==0.2.2
|
| 14 |
-
|
| 15 |
sentence-transformers==2.7.0
|
| 16 |
einops==0.8.1
|
| 17 |
tree-sitter==0.21.3
|
|
|
|
| 1 |
fastapi==0.109.2
|
| 2 |
uvicorn[standard]==0.27.1
|
|
|
|
| 3 |
pydantic==2.6.1
|
| 4 |
python-dotenv==1.0.1
|
| 5 |
|
|
|
|
| 10 |
httpx==0.28.1
|
| 11 |
numpy==1.26.4
|
| 12 |
rank-bm25==0.2.2
|
| 13 |
+
chromadb>=0.5.23
|
| 14 |
sentence-transformers==2.7.0
|
| 15 |
einops==0.8.1
|
| 16 |
tree-sitter==0.21.3
|
server_app.py
CHANGED
|
@@ -8,7 +8,6 @@ from pydantic import BaseModel, Field, HttpUrl
|
|
| 8 |
from dotenv import load_dotenv
|
| 9 |
|
| 10 |
from src.bedrock_claude import BedrockTransientError, is_bedrock_retryable_error
|
| 11 |
-
from src.qdrant_keepalive import QdrantKeepAliveScheduler
|
| 12 |
from src.rag_system import CodebaseRAGSystem
|
| 13 |
|
| 14 |
load_dotenv(Path(__file__).with_name(".env"))
|
|
@@ -35,7 +34,6 @@ app.add_middleware(
|
|
| 35 |
)
|
| 36 |
|
| 37 |
rag_system: Optional[CodebaseRAGSystem] = None
|
| 38 |
-
qdrant_keepalive: Optional[QdrantKeepAliveScheduler] = None
|
| 39 |
|
| 40 |
|
| 41 |
class RepoIndexRequest(BaseModel):
|
|
@@ -62,17 +60,15 @@ def require_session_id(x_session_id: Optional[str] = Header(None, alias="X-Sessi
|
|
| 62 |
|
| 63 |
@app.on_event("startup")
|
| 64 |
def startup():
|
| 65 |
-
global
|
| 66 |
Path("./data").mkdir(exist_ok=True)
|
| 67 |
rag_system = CodebaseRAGSystem()
|
| 68 |
-
qdrant_keepalive = QdrantKeepAliveScheduler(rag_system.vector_store)
|
| 69 |
-
qdrant_keepalive.start()
|
| 70 |
|
| 71 |
|
| 72 |
@app.on_event("shutdown")
|
| 73 |
def shutdown():
|
| 74 |
-
if
|
| 75 |
-
|
| 76 |
|
| 77 |
|
| 78 |
@app.get("/")
|
|
|
|
| 8 |
from dotenv import load_dotenv
|
| 9 |
|
| 10 |
from src.bedrock_claude import BedrockTransientError, is_bedrock_retryable_error
|
|
|
|
| 11 |
from src.rag_system import CodebaseRAGSystem
|
| 12 |
|
| 13 |
load_dotenv(Path(__file__).with_name(".env"))
|
|
|
|
| 34 |
)
|
| 35 |
|
| 36 |
rag_system: Optional[CodebaseRAGSystem] = None
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
class RepoIndexRequest(BaseModel):
|
|
|
|
| 60 |
|
| 61 |
@app.on_event("startup")
|
| 62 |
def startup():
|
| 63 |
+
global rag_system
|
| 64 |
Path("./data").mkdir(exist_ok=True)
|
| 65 |
rag_system = CodebaseRAGSystem()
|
|
|
|
|
|
|
| 66 |
|
| 67 |
|
| 68 |
@app.on_event("shutdown")
|
| 69 |
def shutdown():
|
| 70 |
+
if rag_system is not None:
|
| 71 |
+
rag_system.vector_store.save()
|
| 72 |
|
| 73 |
|
| 74 |
@app.get("/")
|
src/__init__.py
CHANGED
|
@@ -7,14 +7,14 @@ from .embeddings import EmbeddingGenerator
|
|
| 7 |
from .hybrid_search import HybridSearchEngine
|
| 8 |
from .rag_system import CodebaseRAGSystem
|
| 9 |
from .repo_fetcher import RepoFetcher
|
| 10 |
-
from .vector_store import
|
| 11 |
|
| 12 |
__version__ = "2.0.0"
|
| 13 |
__all__ = [
|
| 14 |
"CodeParser",
|
| 15 |
"CodebaseRAGSystem",
|
| 16 |
"EmbeddingGenerator",
|
| 17 |
-
"
|
| 18 |
"HybridSearchEngine",
|
| 19 |
"RepoFetcher",
|
| 20 |
]
|
|
|
|
| 7 |
from .hybrid_search import HybridSearchEngine
|
| 8 |
from .rag_system import CodebaseRAGSystem
|
| 9 |
from .repo_fetcher import RepoFetcher
|
| 10 |
+
from .vector_store import ChromaVectorStore
|
| 11 |
|
| 12 |
__version__ = "2.0.0"
|
| 13 |
__all__ = [
|
| 14 |
"CodeParser",
|
| 15 |
"CodebaseRAGSystem",
|
| 16 |
"EmbeddingGenerator",
|
| 17 |
+
"ChromaVectorStore",
|
| 18 |
"HybridSearchEngine",
|
| 19 |
"RepoFetcher",
|
| 20 |
]
|
src/database.py
DELETED
|
@@ -1,143 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
from datetime import datetime
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
|
| 5 |
-
from sqlalchemy import (
|
| 6 |
-
JSON,
|
| 7 |
-
Column,
|
| 8 |
-
DateTime,
|
| 9 |
-
Float,
|
| 10 |
-
ForeignKey,
|
| 11 |
-
Integer,
|
| 12 |
-
String,
|
| 13 |
-
Text,
|
| 14 |
-
create_engine,
|
| 15 |
-
inspect,
|
| 16 |
-
text,
|
| 17 |
-
)
|
| 18 |
-
from sqlalchemy.orm import declarative_base, relationship, sessionmaker
|
| 19 |
-
|
| 20 |
-
Base = declarative_base()
|
| 21 |
-
_ENGINE_CACHE = {}
|
| 22 |
-
_SESSION_FACTORY_CACHE = {}
|
| 23 |
-
SERVER_DIR = Path(__file__).resolve().parents[1]
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
class Repository(Base):
|
| 27 |
-
__tablename__ = "repositories"
|
| 28 |
-
|
| 29 |
-
id = Column(Integer, primary_key=True)
|
| 30 |
-
github_url = Column(String(1024), nullable=False, unique=True)
|
| 31 |
-
source_url = Column(String(1024))
|
| 32 |
-
session_key = Column(String(255), index=True)
|
| 33 |
-
session_expires_at = Column(DateTime)
|
| 34 |
-
owner = Column(String(255), nullable=False)
|
| 35 |
-
name = Column(String(255), nullable=False)
|
| 36 |
-
branch = Column(String(255), nullable=False, default="main")
|
| 37 |
-
local_path = Column(String(1024))
|
| 38 |
-
status = Column(String(64), nullable=False, default="queued")
|
| 39 |
-
error_message = Column(Text)
|
| 40 |
-
file_count = Column(Integer, nullable=False, default=0)
|
| 41 |
-
chunk_count = Column(Integer, nullable=False, default=0)
|
| 42 |
-
indexed_at = Column(DateTime)
|
| 43 |
-
created_at = Column(DateTime, default=datetime.utcnow)
|
| 44 |
-
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
| 45 |
-
|
| 46 |
-
chunks = relationship(
|
| 47 |
-
"CodeChunk", back_populates="repository", cascade="all, delete-orphan"
|
| 48 |
-
)
|
| 49 |
-
chat_turns = relationship(
|
| 50 |
-
"ChatTurn", back_populates="repository", cascade="all, delete-orphan"
|
| 51 |
-
)
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
class CodeChunk(Base):
|
| 55 |
-
__tablename__ = "code_chunks"
|
| 56 |
-
|
| 57 |
-
id = Column(Integer, primary_key=True)
|
| 58 |
-
repository_id = Column(Integer, ForeignKey("repositories.id"), nullable=False)
|
| 59 |
-
file_path = Column(String(1024), nullable=False)
|
| 60 |
-
language = Column(String(64), nullable=False)
|
| 61 |
-
symbol_name = Column(String(255))
|
| 62 |
-
symbol_type = Column(String(128), nullable=False, default="chunk")
|
| 63 |
-
line_start = Column(Integer, nullable=False)
|
| 64 |
-
line_end = Column(Integer, nullable=False)
|
| 65 |
-
signature = Column(Text)
|
| 66 |
-
content = Column(Text, nullable=False)
|
| 67 |
-
searchable_text = Column(Text, nullable=False)
|
| 68 |
-
metadata_json = Column(JSON, nullable=False, default=dict)
|
| 69 |
-
embedding_id = Column(Integer)
|
| 70 |
-
rerank_score = Column(Float)
|
| 71 |
-
created_at = Column(DateTime, default=datetime.utcnow)
|
| 72 |
-
|
| 73 |
-
repository = relationship("Repository", back_populates="chunks")
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
class ChatTurn(Base):
|
| 77 |
-
__tablename__ = "chat_turns"
|
| 78 |
-
|
| 79 |
-
id = Column(Integer, primary_key=True)
|
| 80 |
-
repository_id = Column(Integer, ForeignKey("repositories.id"), nullable=False)
|
| 81 |
-
role = Column(String(32), nullable=False)
|
| 82 |
-
content = Column(Text, nullable=False)
|
| 83 |
-
answer_json = Column(JSON)
|
| 84 |
-
created_at = Column(DateTime, default=datetime.utcnow)
|
| 85 |
-
|
| 86 |
-
repository = relationship("Repository", back_populates="chat_turns")
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
def init_db(database_url: str = None):
|
| 90 |
-
if database_url is None:
|
| 91 |
-
database_url = os.getenv("DATABASE_URL", "sqlite:///./codebase_rag.db")
|
| 92 |
-
|
| 93 |
-
database_url = resolve_database_url(database_url)
|
| 94 |
-
if database_url in _ENGINE_CACHE:
|
| 95 |
-
return _ENGINE_CACHE[database_url], _SESSION_FACTORY_CACHE[database_url]
|
| 96 |
-
|
| 97 |
-
connect_args = {"check_same_thread": False} if database_url.startswith("sqlite") else {}
|
| 98 |
-
engine = create_engine(database_url, echo=False, connect_args=connect_args)
|
| 99 |
-
Base.metadata.create_all(engine)
|
| 100 |
-
_ensure_runtime_columns(engine)
|
| 101 |
-
session_local = sessionmaker(bind=engine)
|
| 102 |
-
_ENGINE_CACHE[database_url] = engine
|
| 103 |
-
_SESSION_FACTORY_CACHE[database_url] = session_local
|
| 104 |
-
return engine, session_local
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
def resolve_database_url(database_url: str) -> str:
|
| 108 |
-
if not database_url.startswith("sqlite:///"):
|
| 109 |
-
return database_url
|
| 110 |
-
|
| 111 |
-
sqlite_path = database_url.removeprefix("sqlite:///")
|
| 112 |
-
if sqlite_path == ":memory:":
|
| 113 |
-
return database_url
|
| 114 |
-
|
| 115 |
-
path = Path(sqlite_path)
|
| 116 |
-
if not path.is_absolute():
|
| 117 |
-
path = SERVER_DIR / path
|
| 118 |
-
path.parent.mkdir(parents=True, exist_ok=True)
|
| 119 |
-
path.touch(exist_ok=True)
|
| 120 |
-
return f"sqlite:///{path.resolve()}"
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
def _ensure_runtime_columns(engine):
|
| 124 |
-
inspector = inspect(engine)
|
| 125 |
-
if "repositories" not in inspector.get_table_names():
|
| 126 |
-
return
|
| 127 |
-
|
| 128 |
-
existing = {column["name"] for column in inspector.get_columns("repositories")}
|
| 129 |
-
alterations = {
|
| 130 |
-
"source_url": "ALTER TABLE repositories ADD COLUMN source_url VARCHAR(1024)",
|
| 131 |
-
"session_key": "ALTER TABLE repositories ADD COLUMN session_key VARCHAR(255)",
|
| 132 |
-
"session_expires_at": "ALTER TABLE repositories ADD COLUMN session_expires_at DATETIME",
|
| 133 |
-
}
|
| 134 |
-
|
| 135 |
-
with engine.begin() as connection:
|
| 136 |
-
for column_name, statement in alterations.items():
|
| 137 |
-
if column_name not in existing:
|
| 138 |
-
connection.execute(text(statement))
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
def get_db_session(database_url: str = None):
|
| 142 |
-
_, session_local = init_db(database_url)
|
| 143 |
-
return session_local()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/qdrant_keepalive.py
DELETED
|
@@ -1,83 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import threading
|
| 3 |
-
from typing import Optional
|
| 4 |
-
|
| 5 |
-
from src.vector_store import QdrantVectorStore
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
class QdrantKeepAliveScheduler:
|
| 9 |
-
def __init__(self, vector_store: QdrantVectorStore):
|
| 10 |
-
self.vector_store = vector_store
|
| 11 |
-
self.interval_seconds = self._interval_seconds()
|
| 12 |
-
self.run_on_start = self._env_flag("QDRANT_KEEPALIVE_RUN_ON_START", True)
|
| 13 |
-
self.keepalive_enabled = self._env_flag("QDRANT_KEEPALIVE_ENABLED", True)
|
| 14 |
-
self.enabled = self.keepalive_enabled and self.vector_store.is_remote()
|
| 15 |
-
self._stop_event = threading.Event()
|
| 16 |
-
self._thread: Optional[threading.Thread] = None
|
| 17 |
-
|
| 18 |
-
def start(self):
|
| 19 |
-
if not self.enabled:
|
| 20 |
-
reason = (
|
| 21 |
-
"disabled by QDRANT_KEEPALIVE_ENABLED"
|
| 22 |
-
if not self.keepalive_enabled
|
| 23 |
-
else "set QDRANT_URL to enable remote Qdrant pings"
|
| 24 |
-
)
|
| 25 |
-
print(
|
| 26 |
-
f"[qdrant-keepalive] Disabled; {reason}",
|
| 27 |
-
flush=True,
|
| 28 |
-
)
|
| 29 |
-
return
|
| 30 |
-
if self._thread and self._thread.is_alive():
|
| 31 |
-
return
|
| 32 |
-
|
| 33 |
-
self._stop_event.clear()
|
| 34 |
-
self._thread = threading.Thread(
|
| 35 |
-
target=self._run,
|
| 36 |
-
name="qdrant-keepalive",
|
| 37 |
-
daemon=True,
|
| 38 |
-
)
|
| 39 |
-
self._thread.start()
|
| 40 |
-
print(
|
| 41 |
-
f"[qdrant-keepalive] Started interval_seconds={self.interval_seconds}",
|
| 42 |
-
flush=True,
|
| 43 |
-
)
|
| 44 |
-
|
| 45 |
-
def stop(self):
|
| 46 |
-
self._stop_event.set()
|
| 47 |
-
if self._thread and self._thread.is_alive():
|
| 48 |
-
self._thread.join(timeout=5)
|
| 49 |
-
self._thread = None
|
| 50 |
-
|
| 51 |
-
def _run(self):
|
| 52 |
-
if self.run_on_start:
|
| 53 |
-
self._ping()
|
| 54 |
-
|
| 55 |
-
while not self._stop_event.wait(self.interval_seconds):
|
| 56 |
-
self._ping()
|
| 57 |
-
|
| 58 |
-
def _ping(self):
|
| 59 |
-
try:
|
| 60 |
-
stats = self.vector_store.keep_alive()
|
| 61 |
-
print(
|
| 62 |
-
"[qdrant-keepalive] Ping succeeded "
|
| 63 |
-
f"collection={stats['collection_name']} "
|
| 64 |
-
f"points={stats['total_vectors']}",
|
| 65 |
-
flush=True,
|
| 66 |
-
)
|
| 67 |
-
except Exception as exc:
|
| 68 |
-
print(f"[qdrant-keepalive] Ping failed: {exc}", flush=True)
|
| 69 |
-
|
| 70 |
-
@staticmethod
|
| 71 |
-
def _env_flag(name: str, default: bool) -> bool:
|
| 72 |
-
value = os.getenv(name)
|
| 73 |
-
if value is None:
|
| 74 |
-
return default
|
| 75 |
-
return value.strip().lower() not in {"0", "false", "no", "off"}
|
| 76 |
-
|
| 77 |
-
@staticmethod
|
| 78 |
-
def _interval_seconds() -> int:
|
| 79 |
-
value = os.getenv("QDRANT_KEEPALIVE_INTERVAL_SECONDS", "43200")
|
| 80 |
-
try:
|
| 81 |
-
return max(60, int(value))
|
| 82 |
-
except ValueError:
|
| 83 |
-
return 43200
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/rag_system.py
CHANGED
|
@@ -1,44 +1,57 @@
|
|
| 1 |
import os
|
| 2 |
import re
|
|
|
|
| 3 |
from datetime import datetime, timedelta
|
|
|
|
| 4 |
from typing import Dict, List, Optional
|
| 5 |
|
| 6 |
from openai import OpenAI
|
| 7 |
|
| 8 |
from src.code_parser import CodeParser
|
| 9 |
from src.bedrock_claude import create_bedrock_runtime_client, generate_bedrock_claude_text
|
| 10 |
-
from src.database import Repository, get_db_session, init_db, resolve_database_url
|
| 11 |
from src.embeddings import EmbeddingGenerator
|
| 12 |
from src.hybrid_search import HybridSearchEngine
|
| 13 |
from src.repo_fetcher import RepoFetcher
|
| 14 |
-
from src.vector_store import
|
| 15 |
|
| 16 |
|
| 17 |
class SessionCancelledError(RuntimeError):
|
| 18 |
pass
|
| 19 |
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
class CodebaseRAGSystem:
|
| 22 |
def __init__(
|
| 23 |
self,
|
| 24 |
-
database_url: str = None,
|
| 25 |
repo_dir: str = None,
|
| 26 |
index_path: str = None,
|
| 27 |
):
|
| 28 |
-
self.database_url = database_url or os.getenv(
|
| 29 |
-
"DATABASE_URL", "sqlite:///./codebase_rag.db"
|
| 30 |
-
)
|
| 31 |
-
self.database_url = resolve_database_url(self.database_url)
|
| 32 |
-
init_db(self.database_url)
|
| 33 |
-
print(f"[database] Using database_url={self.database_url}", flush=True)
|
| 34 |
-
|
| 35 |
self.repo_fetcher = RepoFetcher(base_dir=repo_dir)
|
| 36 |
self.parser = CodeParser()
|
| 37 |
self.embedder = EmbeddingGenerator()
|
| 38 |
-
self.vector_store =
|
| 39 |
embedding_dim=self.embedder.get_embedding_dim(),
|
| 40 |
-
index_path=index_path or "./data/
|
| 41 |
-
persist=
|
| 42 |
)
|
| 43 |
self.hybrid_search = HybridSearchEngine(
|
| 44 |
reranker_model=os.getenv(
|
|
@@ -51,34 +64,35 @@ class CodebaseRAGSystem:
|
|
| 51 |
self.llm_model = ""
|
| 52 |
self._configure_llm()
|
| 53 |
self.session_ttl_minutes = int(os.getenv("SESSION_TTL_MINUTES", "120"))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
self.indexing_progress: Dict[int, dict] = {}
|
| 55 |
self.repo_chunks: Dict[int, List[dict]] = {}
|
| 56 |
self.cancelled_repo_ids = set()
|
| 57 |
self.rebuild_indexes()
|
| 58 |
|
| 59 |
def rebuild_indexes(self):
|
| 60 |
-
|
| 61 |
-
try:
|
| 62 |
self.vector_store.clear()
|
|
|
|
|
|
|
|
|
|
| 63 |
self.repo_chunks.clear()
|
| 64 |
self.indexing_progress.clear()
|
| 65 |
self.cancelled_repo_ids.clear()
|
| 66 |
-
repos = session.query(Repository).all()
|
| 67 |
-
self._delete_repositories(session, repos, track_cancellation=False)
|
| 68 |
-
self.cancelled_repo_ids.clear()
|
| 69 |
-
session.commit()
|
| 70 |
-
finally:
|
| 71 |
-
session.close()
|
| 72 |
|
| 73 |
def create_or_reset_repository(self, github_url: str, session_key: str) -> Repository:
|
| 74 |
info = self.repo_fetcher.parse_github_url(github_url)
|
| 75 |
registry_key = self._build_registry_key(session_key, github_url)
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
self.
|
| 79 |
-
repo =
|
| 80 |
if repo is None:
|
| 81 |
repo = Repository(
|
|
|
|
| 82 |
github_url=registry_key,
|
| 83 |
source_url=github_url,
|
| 84 |
session_key=session_key,
|
|
@@ -88,8 +102,9 @@ class CodebaseRAGSystem:
|
|
| 88 |
branch=info["branch"],
|
| 89 |
status="queued",
|
| 90 |
)
|
| 91 |
-
|
| 92 |
-
|
|
|
|
| 93 |
self.cancelled_repo_ids.discard(repo.id)
|
| 94 |
else:
|
| 95 |
repo.source_url = github_url
|
|
@@ -103,37 +118,39 @@ class CodebaseRAGSystem:
|
|
| 103 |
repo.file_count = 0
|
| 104 |
repo.chunk_count = 0
|
| 105 |
repo.indexed_at = None
|
|
|
|
| 106 |
self.cancelled_repo_ids.discard(repo.id)
|
| 107 |
self.hybrid_search.remove_repository(repo.id)
|
| 108 |
self.vector_store.remove_repository(repo.id)
|
| 109 |
self.repo_chunks.pop(repo.id, None)
|
| 110 |
|
| 111 |
-
session.commit()
|
| 112 |
-
session.refresh(repo)
|
| 113 |
return repo
|
| 114 |
-
finally:
|
| 115 |
-
session.close()
|
| 116 |
|
| 117 |
def index_repository(self, repo_id: int):
|
| 118 |
-
|
| 119 |
try:
|
| 120 |
-
self.
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
-
repo.status = "indexing"
|
| 128 |
-
repo.error_message = None
|
| 129 |
-
repo.session_expires_at = self._session_expiry()
|
| 130 |
-
session.commit()
|
| 131 |
self._set_progress(repo.id, phase="cloning", message="Cloning repository")
|
| 132 |
|
| 133 |
clone_info = self.repo_fetcher.clone_repository(repo.source_url or repo.github_url)
|
| 134 |
self._ensure_repo_not_cancelled(repo.id)
|
| 135 |
-
|
| 136 |
-
|
|
|
|
|
|
|
|
|
|
| 137 |
print(
|
| 138 |
f"[indexing] Repository cloned repo_id={repo.id} branch={repo.branch} "
|
| 139 |
f"path={clone_info['local_path']}",
|
|
@@ -250,82 +267,74 @@ class CodebaseRAGSystem:
|
|
| 250 |
}
|
| 251 |
created_rows.append(row)
|
| 252 |
|
| 253 |
-
repo.status = "indexed"
|
| 254 |
-
repo.file_count = file_count
|
| 255 |
-
repo.chunk_count = len(created_rows)
|
| 256 |
-
repo.indexed_at = datetime.utcnow()
|
| 257 |
-
repo.session_expires_at = self._session_expiry()
|
| 258 |
-
self._ensure_repo_still_exists(session, repo.id)
|
| 259 |
-
self._ensure_repo_not_cancelled(repo.id)
|
| 260 |
-
session.commit()
|
| 261 |
-
|
| 262 |
serialized = [self._serialize_chunk(chunk) for chunk in created_rows]
|
| 263 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
self.vector_store.save()
|
| 265 |
-
self.
|
| 266 |
-
|
|
|
|
| 267 |
self.repo_fetcher.cleanup_repository(clone_info["local_path"])
|
| 268 |
print(f"[indexing] Repository index complete repo_id={repo.id}", flush=True)
|
| 269 |
except Exception as exc:
|
| 270 |
print(f"[indexing] Repository index failed repo_id={repo_id} error={exc}", flush=True)
|
| 271 |
-
session.rollback()
|
| 272 |
self.vector_store.remove_repository(repo_id)
|
| 273 |
-
self.repo_chunks.pop(repo_id, None)
|
| 274 |
self.hybrid_search.remove_repository(repo_id)
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
|
|
|
|
|
|
| 283 |
try:
|
| 284 |
-
if
|
| 285 |
self.repo_fetcher.cleanup_repository(clone_info["local_path"])
|
| 286 |
except Exception:
|
| 287 |
pass
|
| 288 |
-
self.
|
|
|
|
| 289 |
if isinstance(exc, SessionCancelledError):
|
| 290 |
return
|
| 291 |
raise
|
| 292 |
-
finally:
|
| 293 |
-
session.close()
|
| 294 |
|
| 295 |
def list_repositories(self) -> List[dict]:
|
| 296 |
raise NotImplementedError
|
| 297 |
|
| 298 |
def list_repositories_for_session(self, session_key: str) -> List[dict]:
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
.
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
)
|
| 308 |
-
self._touch_session(session, session_key)
|
| 309 |
return [self._serialize_repo(repo) for repo in repos]
|
| 310 |
-
finally:
|
| 311 |
-
session.close()
|
| 312 |
|
| 313 |
def get_repository(self, repo_id: int) -> Optional[dict]:
|
| 314 |
raise NotImplementedError
|
| 315 |
|
| 316 |
def get_repository_for_session(self, repo_id: int, session_key: str) -> Optional[dict]:
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
self.
|
| 320 |
-
repo =
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
.first()
|
| 324 |
-
)
|
| 325 |
-
self._touch_session(session, session_key)
|
| 326 |
return self._serialize_repo(repo) if repo else None
|
| 327 |
-
finally:
|
| 328 |
-
session.close()
|
| 329 |
|
| 330 |
def answer_question(
|
| 331 |
self,
|
|
@@ -335,96 +344,92 @@ class CodebaseRAGSystem:
|
|
| 335 |
top_k: int = 8,
|
| 336 |
history=None,
|
| 337 |
) -> dict:
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
self.
|
| 341 |
-
repo =
|
| 342 |
-
|
| 343 |
-
.filter_by(id=repo_id, session_key=session_key)
|
| 344 |
-
.first()
|
| 345 |
-
)
|
| 346 |
if repo is None:
|
| 347 |
raise ValueError("Repository not found")
|
| 348 |
if repo.status != "indexed":
|
| 349 |
raise ValueError("Repository is not ready for questions yet")
|
| 350 |
if repo_id not in self.repo_chunks:
|
| 351 |
raise ValueError("Session cache expired. Re-index the repository and try again.")
|
| 352 |
-
self.
|
|
|
|
| 353 |
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
|
| 373 |
-
|
| 374 |
-
|
| 375 |
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
|
| 398 |
-
|
| 399 |
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
|
| 407 |
-
|
| 408 |
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
finally:
|
| 418 |
-
session.close()
|
| 419 |
|
| 420 |
def end_session(self, session_key: str):
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
|
| 429 |
|
| 430 |
def _generate_answer(
|
|
@@ -800,35 +805,33 @@ Do not leave the answer unfinished.
|
|
| 800 |
return payload
|
| 801 |
|
| 802 |
def _set_progress(self, repo_id: int, **progress):
|
| 803 |
-
self.
|
| 804 |
-
|
| 805 |
-
|
| 806 |
-
|
| 807 |
-
|
|
|
|
| 808 |
|
| 809 |
-
def _touch_session(self,
|
| 810 |
expiry = self._session_expiry()
|
| 811 |
-
|
| 812 |
-
|
| 813 |
-
|
| 814 |
-
|
| 815 |
|
| 816 |
-
def _cleanup_expired_sessions(self
|
| 817 |
now = datetime.utcnow()
|
| 818 |
-
expired =
|
| 819 |
-
|
| 820 |
-
.
|
| 821 |
-
.
|
| 822 |
-
|
| 823 |
-
)
|
| 824 |
if not expired:
|
| 825 |
return
|
| 826 |
-
self._delete_repositories(
|
| 827 |
-
session.commit()
|
| 828 |
|
| 829 |
def _delete_repositories(
|
| 830 |
self,
|
| 831 |
-
session,
|
| 832 |
repos: List[Repository],
|
| 833 |
track_cancellation: bool = True,
|
| 834 |
):
|
|
@@ -840,13 +843,18 @@ Do not leave the answer unfinished.
|
|
| 840 |
self.vector_store.remove_repository(repo_id)
|
| 841 |
self.repo_chunks.pop(repo_id, None)
|
| 842 |
self.indexing_progress.pop(repo_id, None)
|
| 843 |
-
|
| 844 |
-
|
|
|
|
| 845 |
|
| 846 |
def _ensure_repo_not_cancelled(self, repo_id: int):
|
| 847 |
if repo_id in self.cancelled_repo_ids:
|
| 848 |
raise SessionCancelledError("Session ended before indexing completed.")
|
| 849 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 850 |
def _build_retrieval_query(self, question: str, history: List[dict]) -> str:
|
| 851 |
normalized = " ".join(question.strip().split())
|
| 852 |
if self._is_repo_overview_question(normalized):
|
|
@@ -1660,9 +1668,8 @@ Do not leave the answer unfinished.
|
|
| 1660 |
lines.append(f"{role}: {content[:400]}")
|
| 1661 |
return "\n".join(lines) if lines else "None"
|
| 1662 |
|
| 1663 |
-
|
| 1664 |
-
|
| 1665 |
-
if session.query(Repository.id).filter_by(id=repo_id).first() is None:
|
| 1666 |
raise RuntimeError("Repository was removed before indexing completed.")
|
| 1667 |
|
| 1668 |
def _session_expiry(self) -> datetime:
|
|
|
|
| 1 |
import os
|
| 2 |
import re
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
from datetime import datetime, timedelta
|
| 5 |
+
from threading import RLock
|
| 6 |
from typing import Dict, List, Optional
|
| 7 |
|
| 8 |
from openai import OpenAI
|
| 9 |
|
| 10 |
from src.code_parser import CodeParser
|
| 11 |
from src.bedrock_claude import create_bedrock_runtime_client, generate_bedrock_claude_text
|
|
|
|
| 12 |
from src.embeddings import EmbeddingGenerator
|
| 13 |
from src.hybrid_search import HybridSearchEngine
|
| 14 |
from src.repo_fetcher import RepoFetcher
|
| 15 |
+
from src.vector_store import ChromaVectorStore
|
| 16 |
|
| 17 |
|
| 18 |
class SessionCancelledError(RuntimeError):
|
| 19 |
pass
|
| 20 |
|
| 21 |
|
| 22 |
+
@dataclass
|
| 23 |
+
class Repository:
|
| 24 |
+
id: int
|
| 25 |
+
github_url: str
|
| 26 |
+
source_url: str
|
| 27 |
+
session_key: str
|
| 28 |
+
session_expires_at: datetime
|
| 29 |
+
owner: str
|
| 30 |
+
name: str
|
| 31 |
+
branch: str = "main"
|
| 32 |
+
local_path: Optional[str] = None
|
| 33 |
+
status: str = "queued"
|
| 34 |
+
error_message: Optional[str] = None
|
| 35 |
+
file_count: int = 0
|
| 36 |
+
chunk_count: int = 0
|
| 37 |
+
indexed_at: Optional[datetime] = None
|
| 38 |
+
created_at: datetime = field(default_factory=datetime.utcnow)
|
| 39 |
+
updated_at: datetime = field(default_factory=datetime.utcnow)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
class CodebaseRAGSystem:
|
| 43 |
def __init__(
|
| 44 |
self,
|
|
|
|
| 45 |
repo_dir: str = None,
|
| 46 |
index_path: str = None,
|
| 47 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
self.repo_fetcher = RepoFetcher(base_dir=repo_dir)
|
| 49 |
self.parser = CodeParser()
|
| 50 |
self.embedder = EmbeddingGenerator()
|
| 51 |
+
self.vector_store = ChromaVectorStore(
|
| 52 |
embedding_dim=self.embedder.get_embedding_dim(),
|
| 53 |
+
index_path=index_path or "./data/chroma",
|
| 54 |
+
persist=True,
|
| 55 |
)
|
| 56 |
self.hybrid_search = HybridSearchEngine(
|
| 57 |
reranker_model=os.getenv(
|
|
|
|
| 64 |
self.llm_model = ""
|
| 65 |
self._configure_llm()
|
| 66 |
self.session_ttl_minutes = int(os.getenv("SESSION_TTL_MINUTES", "120"))
|
| 67 |
+
self.repo_lock = RLock()
|
| 68 |
+
self.repositories: Dict[int, Repository] = {}
|
| 69 |
+
self.repository_registry: Dict[str, int] = {}
|
| 70 |
+
self.next_repo_id = 1
|
| 71 |
self.indexing_progress: Dict[int, dict] = {}
|
| 72 |
self.repo_chunks: Dict[int, List[dict]] = {}
|
| 73 |
self.cancelled_repo_ids = set()
|
| 74 |
self.rebuild_indexes()
|
| 75 |
|
| 76 |
def rebuild_indexes(self):
|
| 77 |
+
with self.repo_lock:
|
|
|
|
| 78 |
self.vector_store.clear()
|
| 79 |
+
self.repositories.clear()
|
| 80 |
+
self.repository_registry.clear()
|
| 81 |
+
self.next_repo_id = 1
|
| 82 |
self.repo_chunks.clear()
|
| 83 |
self.indexing_progress.clear()
|
| 84 |
self.cancelled_repo_ids.clear()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
def create_or_reset_repository(self, github_url: str, session_key: str) -> Repository:
|
| 87 |
info = self.repo_fetcher.parse_github_url(github_url)
|
| 88 |
registry_key = self._build_registry_key(session_key, github_url)
|
| 89 |
+
with self.repo_lock:
|
| 90 |
+
self._cleanup_expired_sessions()
|
| 91 |
+
repo_id = self.repository_registry.get(registry_key)
|
| 92 |
+
repo = self.repositories.get(repo_id) if repo_id else None
|
| 93 |
if repo is None:
|
| 94 |
repo = Repository(
|
| 95 |
+
id=self.next_repo_id,
|
| 96 |
github_url=registry_key,
|
| 97 |
source_url=github_url,
|
| 98 |
session_key=session_key,
|
|
|
|
| 102 |
branch=info["branch"],
|
| 103 |
status="queued",
|
| 104 |
)
|
| 105 |
+
self.next_repo_id += 1
|
| 106 |
+
self.repositories[repo.id] = repo
|
| 107 |
+
self.repository_registry[registry_key] = repo.id
|
| 108 |
self.cancelled_repo_ids.discard(repo.id)
|
| 109 |
else:
|
| 110 |
repo.source_url = github_url
|
|
|
|
| 118 |
repo.file_count = 0
|
| 119 |
repo.chunk_count = 0
|
| 120 |
repo.indexed_at = None
|
| 121 |
+
self._mark_repo_updated(repo)
|
| 122 |
self.cancelled_repo_ids.discard(repo.id)
|
| 123 |
self.hybrid_search.remove_repository(repo.id)
|
| 124 |
self.vector_store.remove_repository(repo.id)
|
| 125 |
self.repo_chunks.pop(repo.id, None)
|
| 126 |
|
|
|
|
|
|
|
| 127 |
return repo
|
|
|
|
|
|
|
| 128 |
|
| 129 |
def index_repository(self, repo_id: int):
|
| 130 |
+
clone_info = None
|
| 131 |
try:
|
| 132 |
+
with self.repo_lock:
|
| 133 |
+
self._cleanup_expired_sessions()
|
| 134 |
+
repo = self.repositories.get(repo_id)
|
| 135 |
+
if repo is None:
|
| 136 |
+
raise ValueError("Repository not found")
|
| 137 |
+
self._ensure_repo_not_cancelled(repo.id)
|
| 138 |
+
print(f"[indexing] Starting repository index repo_id={repo.id}", flush=True)
|
| 139 |
+
|
| 140 |
+
repo.status = "indexing"
|
| 141 |
+
repo.error_message = None
|
| 142 |
+
repo.session_expires_at = self._session_expiry()
|
| 143 |
+
self._mark_repo_updated(repo)
|
| 144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
self._set_progress(repo.id, phase="cloning", message="Cloning repository")
|
| 146 |
|
| 147 |
clone_info = self.repo_fetcher.clone_repository(repo.source_url or repo.github_url)
|
| 148 |
self._ensure_repo_not_cancelled(repo.id)
|
| 149 |
+
with self.repo_lock:
|
| 150 |
+
self._ensure_repo_still_exists(repo.id)
|
| 151 |
+
repo.branch = clone_info["branch"]
|
| 152 |
+
repo.local_path = None
|
| 153 |
+
self._mark_repo_updated(repo)
|
| 154 |
print(
|
| 155 |
f"[indexing] Repository cloned repo_id={repo.id} branch={repo.branch} "
|
| 156 |
f"path={clone_info['local_path']}",
|
|
|
|
| 267 |
}
|
| 268 |
created_rows.append(row)
|
| 269 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
serialized = [self._serialize_chunk(chunk) for chunk in created_rows]
|
| 271 |
+
with self.repo_lock:
|
| 272 |
+
self._ensure_repo_still_exists(repo.id)
|
| 273 |
+
self._ensure_repo_not_cancelled(repo.id)
|
| 274 |
+
repo.status = "indexed"
|
| 275 |
+
repo.file_count = file_count
|
| 276 |
+
repo.chunk_count = len(created_rows)
|
| 277 |
+
repo.indexed_at = datetime.utcnow()
|
| 278 |
+
repo.session_expires_at = self._session_expiry()
|
| 279 |
+
self._mark_repo_updated(repo)
|
| 280 |
+
self.repo_chunks[repo.id] = serialized
|
| 281 |
self.vector_store.save()
|
| 282 |
+
with self.repo_lock:
|
| 283 |
+
self.indexing_progress.pop(repo.id, None)
|
| 284 |
+
self.cancelled_repo_ids.discard(repo.id)
|
| 285 |
self.repo_fetcher.cleanup_repository(clone_info["local_path"])
|
| 286 |
print(f"[indexing] Repository index complete repo_id={repo.id}", flush=True)
|
| 287 |
except Exception as exc:
|
| 288 |
print(f"[indexing] Repository index failed repo_id={repo_id} error={exc}", flush=True)
|
|
|
|
| 289 |
self.vector_store.remove_repository(repo_id)
|
|
|
|
| 290 |
self.hybrid_search.remove_repository(repo_id)
|
| 291 |
+
with self.repo_lock:
|
| 292 |
+
self.repo_chunks.pop(repo_id, None)
|
| 293 |
+
repo = self.repositories.get(repo_id)
|
| 294 |
+
if repo:
|
| 295 |
+
if repo_id in self.cancelled_repo_ids:
|
| 296 |
+
self._delete_repositories([repo], track_cancellation=False)
|
| 297 |
+
else:
|
| 298 |
+
repo.status = "failed"
|
| 299 |
+
repo.error_message = str(exc)
|
| 300 |
+
self._mark_repo_updated(repo)
|
| 301 |
try:
|
| 302 |
+
if clone_info:
|
| 303 |
self.repo_fetcher.cleanup_repository(clone_info["local_path"])
|
| 304 |
except Exception:
|
| 305 |
pass
|
| 306 |
+
with self.repo_lock:
|
| 307 |
+
self.indexing_progress.pop(repo_id, None)
|
| 308 |
if isinstance(exc, SessionCancelledError):
|
| 309 |
return
|
| 310 |
raise
|
|
|
|
|
|
|
| 311 |
|
| 312 |
def list_repositories(self) -> List[dict]:
|
| 313 |
raise NotImplementedError
|
| 314 |
|
| 315 |
def list_repositories_for_session(self, session_key: str) -> List[dict]:
|
| 316 |
+
with self.repo_lock:
|
| 317 |
+
self._cleanup_expired_sessions()
|
| 318 |
+
repos = [
|
| 319 |
+
repo
|
| 320 |
+
for repo in self.repositories.values()
|
| 321 |
+
if repo.session_key == session_key
|
| 322 |
+
]
|
| 323 |
+
repos.sort(key=lambda repo: repo.updated_at, reverse=True)
|
| 324 |
+
self._touch_session(session_key)
|
|
|
|
| 325 |
return [self._serialize_repo(repo) for repo in repos]
|
|
|
|
|
|
|
| 326 |
|
| 327 |
def get_repository(self, repo_id: int) -> Optional[dict]:
|
| 328 |
raise NotImplementedError
|
| 329 |
|
| 330 |
def get_repository_for_session(self, repo_id: int, session_key: str) -> Optional[dict]:
|
| 331 |
+
with self.repo_lock:
|
| 332 |
+
self._cleanup_expired_sessions()
|
| 333 |
+
repo = self.repositories.get(repo_id)
|
| 334 |
+
if repo and repo.session_key != session_key:
|
| 335 |
+
repo = None
|
| 336 |
+
self._touch_session(session_key)
|
|
|
|
|
|
|
|
|
|
| 337 |
return self._serialize_repo(repo) if repo else None
|
|
|
|
|
|
|
| 338 |
|
| 339 |
def answer_question(
|
| 340 |
self,
|
|
|
|
| 344 |
top_k: int = 8,
|
| 345 |
history=None,
|
| 346 |
) -> dict:
|
| 347 |
+
with self.repo_lock:
|
| 348 |
+
self._cleanup_expired_sessions()
|
| 349 |
+
repo = self.repositories.get(repo_id)
|
| 350 |
+
if repo and repo.session_key != session_key:
|
| 351 |
+
repo = None
|
|
|
|
|
|
|
|
|
|
| 352 |
if repo is None:
|
| 353 |
raise ValueError("Repository not found")
|
| 354 |
if repo.status != "indexed":
|
| 355 |
raise ValueError("Repository is not ready for questions yet")
|
| 356 |
if repo_id not in self.repo_chunks:
|
| 357 |
raise ValueError("Session cache expired. Re-index the repository and try again.")
|
| 358 |
+
repo_chunks = list(self.repo_chunks[repo_id])
|
| 359 |
+
self._touch_session(session_key)
|
| 360 |
|
| 361 |
+
normalized_history = self._normalize_history(history or [])
|
| 362 |
+
question_intent = self._question_intent(question)
|
| 363 |
+
deep_search_intents = {
|
| 364 |
+
"api",
|
| 365 |
+
"implementation",
|
| 366 |
+
"cross_file",
|
| 367 |
+
"error_handling",
|
| 368 |
+
"setup",
|
| 369 |
+
"tests",
|
| 370 |
+
}
|
| 371 |
+
deep_multiplier = int(os.getenv("RAG_DEEP_SEARCH_MULTIPLIER", "8"))
|
| 372 |
+
shallow_multiplier = int(os.getenv("RAG_SEARCH_MULTIPLIER", "4"))
|
| 373 |
+
search_depth = (
|
| 374 |
+
top_k * deep_multiplier
|
| 375 |
+
if question_intent in deep_search_intents
|
| 376 |
+
else top_k * shallow_multiplier
|
| 377 |
+
)
|
| 378 |
+
search_depth = max(top_k, min(search_depth, 120))
|
| 379 |
|
| 380 |
+
retrieval_query = self._build_retrieval_query(question, normalized_history)
|
| 381 |
+
query_embedding = self.embedder.embed_text(retrieval_query)
|
| 382 |
|
| 383 |
+
semantic_hits = []
|
| 384 |
+
for score, meta in self.vector_store.search(query_embedding, k=search_depth, repo_filter=repo_id):
|
| 385 |
+
serialized = dict(meta)
|
| 386 |
+
serialized["semantic_score"] = score
|
| 387 |
+
semantic_hits.append(serialized)
|
| 388 |
|
| 389 |
+
lexical_hits = self.hybrid_search.bm25_search(
|
| 390 |
+
repo_chunks,
|
| 391 |
+
retrieval_query,
|
| 392 |
+
top_k=search_depth,
|
| 393 |
+
)
|
| 394 |
+
semantic_hits = self.hybrid_search.normalize_semantic_results(semantic_hits)
|
| 395 |
+
fused = self.hybrid_search.reciprocal_rank_fusion(lexical_hits, semantic_hits, top_k=search_depth)
|
| 396 |
+
|
| 397 |
+
path_hits = self._path_intent_search(
|
| 398 |
+
repo_chunks,
|
| 399 |
+
question,
|
| 400 |
+
retrieval_query,
|
| 401 |
+
top_k=search_depth,
|
| 402 |
+
)
|
| 403 |
+
fused = self._merge_ranked_candidates(fused, path_hits, top_k=search_depth)
|
| 404 |
|
| 405 |
+
rerank_query = retrieval_query if question_intent in deep_search_intents else question
|
| 406 |
|
| 407 |
+
# FIX: rerank to a small candidate pool first (20), then let
|
| 408 |
+
# _prioritize_results and _select_answer_sources trim to final top_k.
|
| 409 |
+
# Previously rerank was called with search_depth (up to 120), meaning
|
| 410 |
+
# the LLM received far too many chunks and faithfulness dropped.
|
| 411 |
+
rerank_pool = min(search_depth, 20)
|
| 412 |
+
reranked = self.hybrid_search.rerank(rerank_query, fused, top_k=rerank_pool)
|
| 413 |
|
| 414 |
+
reranked = self._prioritize_results(question, retrieval_query, reranked, top_k=top_k)
|
| 415 |
|
| 416 |
+
# FIX: cap final sources at 5 instead of top_k (8).
|
| 417 |
+
# 5 sources × 1500 chars = ~7500 chars context, which the LLM handles well.
|
| 418 |
+
# 8 sources × 2500 chars = ~20000 chars, which causes lost-in-the-middle issues.
|
| 419 |
+
final_top_k = min(top_k, 5)
|
| 420 |
+
reranked = self._select_answer_sources(question, reranked, top_k=final_top_k)
|
| 421 |
|
| 422 |
+
answer = self._generate_answer(repo, question, reranked, normalized_history)
|
| 423 |
+
return answer
|
|
|
|
|
|
|
| 424 |
|
| 425 |
def end_session(self, session_key: str):
|
| 426 |
+
with self.repo_lock:
|
| 427 |
+
repos = [
|
| 428 |
+
repo
|
| 429 |
+
for repo in self.repositories.values()
|
| 430 |
+
if repo.session_key == session_key
|
| 431 |
+
]
|
| 432 |
+
self._delete_repositories(repos)
|
| 433 |
|
| 434 |
|
| 435 |
def _generate_answer(
|
|
|
|
| 805 |
return payload
|
| 806 |
|
| 807 |
def _set_progress(self, repo_id: int, **progress):
|
| 808 |
+
with self.repo_lock:
|
| 809 |
+
self.indexing_progress[repo_id] = {
|
| 810 |
+
**self.indexing_progress.get(repo_id, {}),
|
| 811 |
+
**progress,
|
| 812 |
+
"updated_at": datetime.utcnow().isoformat(),
|
| 813 |
+
}
|
| 814 |
|
| 815 |
+
def _touch_session(self, session_key: str):
|
| 816 |
expiry = self._session_expiry()
|
| 817 |
+
for repo in self.repositories.values():
|
| 818 |
+
if repo.session_key == session_key:
|
| 819 |
+
repo.session_expires_at = expiry
|
| 820 |
+
self._mark_repo_updated(repo)
|
| 821 |
|
| 822 |
+
def _cleanup_expired_sessions(self):
|
| 823 |
now = datetime.utcnow()
|
| 824 |
+
expired = [
|
| 825 |
+
repo
|
| 826 |
+
for repo in self.repositories.values()
|
| 827 |
+
if repo.session_expires_at is not None and repo.session_expires_at < now
|
| 828 |
+
]
|
|
|
|
| 829 |
if not expired:
|
| 830 |
return
|
| 831 |
+
self._delete_repositories(expired)
|
|
|
|
| 832 |
|
| 833 |
def _delete_repositories(
|
| 834 |
self,
|
|
|
|
| 835 |
repos: List[Repository],
|
| 836 |
track_cancellation: bool = True,
|
| 837 |
):
|
|
|
|
| 843 |
self.vector_store.remove_repository(repo_id)
|
| 844 |
self.repo_chunks.pop(repo_id, None)
|
| 845 |
self.indexing_progress.pop(repo_id, None)
|
| 846 |
+
repo = self.repositories.pop(repo_id, None)
|
| 847 |
+
if repo:
|
| 848 |
+
self.repository_registry.pop(repo.github_url, None)
|
| 849 |
|
| 850 |
def _ensure_repo_not_cancelled(self, repo_id: int):
|
| 851 |
if repo_id in self.cancelled_repo_ids:
|
| 852 |
raise SessionCancelledError("Session ended before indexing completed.")
|
| 853 |
|
| 854 |
+
@staticmethod
|
| 855 |
+
def _mark_repo_updated(repo: Repository):
|
| 856 |
+
repo.updated_at = datetime.utcnow()
|
| 857 |
+
|
| 858 |
def _build_retrieval_query(self, question: str, history: List[dict]) -> str:
|
| 859 |
normalized = " ".join(question.strip().split())
|
| 860 |
if self._is_repo_overview_question(normalized):
|
|
|
|
| 1668 |
lines.append(f"{role}: {content[:400]}")
|
| 1669 |
return "\n".join(lines) if lines else "None"
|
| 1670 |
|
| 1671 |
+
def _ensure_repo_still_exists(self, repo_id: int):
|
| 1672 |
+
if repo_id not in self.repositories:
|
|
|
|
| 1673 |
raise RuntimeError("Repository was removed before indexing completed.")
|
| 1674 |
|
| 1675 |
def _session_expiry(self) -> datetime:
|
src/vector_store.py
CHANGED
|
@@ -1,60 +1,44 @@
|
|
| 1 |
import os
|
|
|
|
| 2 |
from typing import List, Optional, Tuple
|
| 3 |
from uuid import uuid4
|
| 4 |
|
| 5 |
import numpy as np
|
| 6 |
-
from
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
-
class
|
| 10 |
-
def __init__(self, embedding_dim: int, index_path: str = None, persist: bool =
|
| 11 |
self.embedding_dim = embedding_dim
|
| 12 |
-
self.collection_name = os.getenv("
|
| 13 |
-
self.upsert_batch_size = max(1, int(os.getenv("
|
| 14 |
-
self.
|
| 15 |
-
self.
|
| 16 |
-
self.timeout = int(os.getenv("QDRANT_TIMEOUT_SECONDS", "120"))
|
| 17 |
self.client = self._create_client()
|
| 18 |
-
self._ensure_collection()
|
| 19 |
|
| 20 |
def _create_client(self):
|
| 21 |
-
if self.
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
| 27 |
)
|
| 28 |
-
return QdrantClient(":memory:")
|
| 29 |
|
| 30 |
-
|
| 31 |
-
def _clean_env(name: str) -> Optional[str]:
|
| 32 |
-
value = os.getenv(name)
|
| 33 |
-
if value is None:
|
| 34 |
-
return None
|
| 35 |
-
cleaned = value.strip()
|
| 36 |
-
return cleaned or None
|
| 37 |
|
| 38 |
def _ensure_collection(self):
|
| 39 |
-
|
| 40 |
-
self.
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
size=self.embedding_dim,
|
| 44 |
-
distance=models.Distance.COSINE,
|
| 45 |
-
),
|
| 46 |
-
)
|
| 47 |
-
self._ensure_payload_indexes()
|
| 48 |
-
|
| 49 |
-
def _ensure_payload_indexes(self):
|
| 50 |
-
self.client.create_payload_index(
|
| 51 |
-
collection_name=self.collection_name,
|
| 52 |
-
field_name="repository_id",
|
| 53 |
-
field_schema=models.PayloadSchemaType.INTEGER,
|
| 54 |
-
wait=True,
|
| 55 |
)
|
| 56 |
|
| 57 |
-
def add_embeddings(self, embeddings: np.ndarray, metadata: List[dict]) -> List[
|
| 58 |
if embeddings.size == 0:
|
| 59 |
return []
|
| 60 |
|
|
@@ -63,31 +47,33 @@ class QdrantVectorStore:
|
|
| 63 |
embeddings = embeddings.reshape(1, -1)
|
| 64 |
|
| 65 |
ids = [uuid4().hex for _ in metadata]
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
payload = dict(meta)
|
| 69 |
-
payload["id"] = idx
|
| 70 |
-
points.append(
|
| 71 |
-
models.PointStruct(
|
| 72 |
-
id=idx,
|
| 73 |
-
vector=embedding.tolist(),
|
| 74 |
-
payload=payload,
|
| 75 |
-
)
|
| 76 |
-
)
|
| 77 |
-
total_points = len(points)
|
| 78 |
for start in range(0, total_points, self.upsert_batch_size):
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
batch_number = (start // self.upsert_batch_size) + 1
|
| 81 |
total_batches = (total_points + self.upsert_batch_size - 1) // self.upsert_batch_size
|
| 82 |
print(
|
| 83 |
-
f"[
|
| 84 |
-
f"points={len(
|
| 85 |
flush=True,
|
| 86 |
)
|
| 87 |
-
self.
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
|
|
|
| 91 |
)
|
| 92 |
|
| 93 |
return ids
|
|
@@ -102,67 +88,75 @@ class QdrantVectorStore:
|
|
| 102 |
query_embedding = query_embedding.reshape(1, -1)
|
| 103 |
query_embedding = query_embedding.astype("float32")
|
| 104 |
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
match=models.MatchValue(value=repo_filter),
|
| 112 |
-
)
|
| 113 |
-
]
|
| 114 |
-
)
|
| 115 |
-
|
| 116 |
-
hits = self.client.search(
|
| 117 |
-
collection_name=self.collection_name,
|
| 118 |
-
query_vector=query_embedding[0].tolist(),
|
| 119 |
-
query_filter=query_filter,
|
| 120 |
-
limit=k,
|
| 121 |
)
|
| 122 |
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
def remove_repository(self, repo_id: int):
|
| 126 |
-
self.
|
| 127 |
-
collection_name=self.collection_name,
|
| 128 |
-
wait=True,
|
| 129 |
-
points_selector=models.FilterSelector(
|
| 130 |
-
filter=models.Filter(
|
| 131 |
-
must=[
|
| 132 |
-
models.FieldCondition(
|
| 133 |
-
key="repository_id",
|
| 134 |
-
match=models.MatchValue(value=repo_id),
|
| 135 |
-
)
|
| 136 |
-
]
|
| 137 |
-
)
|
| 138 |
-
),
|
| 139 |
-
)
|
| 140 |
|
| 141 |
def clear(self):
|
| 142 |
-
|
| 143 |
-
self.client.delete_collection(self.collection_name)
|
| 144 |
-
|
|
|
|
|
|
|
| 145 |
|
| 146 |
def save(self):
|
| 147 |
-
|
|
|
|
|
|
|
| 148 |
|
| 149 |
def load(self):
|
| 150 |
-
self._ensure_collection()
|
| 151 |
-
|
| 152 |
-
def is_remote(self) -> bool:
|
| 153 |
-
return self.qdrant_url is not None
|
| 154 |
|
| 155 |
def keep_alive(self) -> dict:
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
}
|
| 161 |
|
| 162 |
def get_stats(self) -> dict:
|
| 163 |
-
info = self.client.get_collection(self.collection_name)
|
| 164 |
return {
|
| 165 |
-
"total_vectors":
|
| 166 |
"embedding_dim": self.embedding_dim,
|
| 167 |
"collection_name": self.collection_name,
|
|
|
|
| 168 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
from typing import List, Optional, Tuple
|
| 4 |
from uuid import uuid4
|
| 5 |
|
| 6 |
import numpy as np
|
| 7 |
+
from chromadb import Client
|
| 8 |
+
from chromadb.config import Settings
|
| 9 |
|
| 10 |
|
| 11 |
+
class ChromaVectorStore:
|
| 12 |
+
def __init__(self, embedding_dim: int, index_path: str = None, persist: bool = True):
|
| 13 |
self.embedding_dim = embedding_dim
|
| 14 |
+
self.collection_name = os.getenv("CHROMA_COLLECTION", "repo_qa_chunks")
|
| 15 |
+
self.upsert_batch_size = max(1, int(os.getenv("CHROMA_UPSERT_BATCH_SIZE", "64")))
|
| 16 |
+
self.persist_path = os.getenv("CHROMA_PATH", index_path or "./data/chroma")
|
| 17 |
+
self.persist = persist
|
|
|
|
| 18 |
self.client = self._create_client()
|
| 19 |
+
self.collection = self._ensure_collection()
|
| 20 |
|
| 21 |
def _create_client(self):
|
| 22 |
+
if self.persist:
|
| 23 |
+
Path(self.persist_path).mkdir(parents=True, exist_ok=True)
|
| 24 |
+
return Client(
|
| 25 |
+
Settings(
|
| 26 |
+
is_persistent=True,
|
| 27 |
+
persist_directory=self.persist_path,
|
| 28 |
+
anonymized_telemetry=False,
|
| 29 |
+
)
|
| 30 |
)
|
|
|
|
| 31 |
|
| 32 |
+
return Client(Settings(anonymized_telemetry=False))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
def _ensure_collection(self):
|
| 35 |
+
return self.client.get_or_create_collection(
|
| 36 |
+
name=self.collection_name,
|
| 37 |
+
embedding_function=None,
|
| 38 |
+
metadata={"hnsw:space": "cosine"},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
)
|
| 40 |
|
| 41 |
+
def add_embeddings(self, embeddings: np.ndarray, metadata: List[dict]) -> List[str]:
|
| 42 |
if embeddings.size == 0:
|
| 43 |
return []
|
| 44 |
|
|
|
|
| 47 |
embeddings = embeddings.reshape(1, -1)
|
| 48 |
|
| 49 |
ids = [uuid4().hex for _ in metadata]
|
| 50 |
+
total_points = len(ids)
|
| 51 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
for start in range(0, total_points, self.upsert_batch_size):
|
| 53 |
+
end = start + self.upsert_batch_size
|
| 54 |
+
batch_ids = ids[start:end]
|
| 55 |
+
batch_embeddings = embeddings[start:end].tolist()
|
| 56 |
+
batch_metadata = []
|
| 57 |
+
batch_documents = []
|
| 58 |
+
|
| 59 |
+
for idx, meta in zip(batch_ids, metadata[start:end]):
|
| 60 |
+
payload = self._sanitize_metadata(meta)
|
| 61 |
+
payload["id"] = idx
|
| 62 |
+
batch_metadata.append(payload)
|
| 63 |
+
batch_documents.append(str(meta.get("content") or ""))
|
| 64 |
+
|
| 65 |
batch_number = (start // self.upsert_batch_size) + 1
|
| 66 |
total_batches = (total_points + self.upsert_batch_size - 1) // self.upsert_batch_size
|
| 67 |
print(
|
| 68 |
+
f"[chroma] Adding batch {batch_number}/{total_batches} "
|
| 69 |
+
f"points={len(batch_ids)} progress={start}/{total_points}",
|
| 70 |
flush=True,
|
| 71 |
)
|
| 72 |
+
self.collection.add(
|
| 73 |
+
ids=batch_ids,
|
| 74 |
+
embeddings=batch_embeddings,
|
| 75 |
+
metadatas=batch_metadata,
|
| 76 |
+
documents=batch_documents,
|
| 77 |
)
|
| 78 |
|
| 79 |
return ids
|
|
|
|
| 88 |
query_embedding = query_embedding.reshape(1, -1)
|
| 89 |
query_embedding = query_embedding.astype("float32")
|
| 90 |
|
| 91 |
+
where = {"repository_id": repo_filter} if repo_filter is not None else None
|
| 92 |
+
results = self.collection.query(
|
| 93 |
+
query_embeddings=[query_embedding[0].tolist()],
|
| 94 |
+
n_results=k,
|
| 95 |
+
where=where,
|
| 96 |
+
include=["documents", "metadatas", "distances"],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
)
|
| 98 |
|
| 99 |
+
ids = (results.get("ids") or [[]])[0]
|
| 100 |
+
documents = (results.get("documents") or [[]])[0]
|
| 101 |
+
metadatas = (results.get("metadatas") or [[]])[0]
|
| 102 |
+
distances = (results.get("distances") or [[]])[0]
|
| 103 |
+
|
| 104 |
+
hits = []
|
| 105 |
+
for idx, document, meta, distance in zip(ids, documents, metadatas, distances):
|
| 106 |
+
payload = dict(meta or {})
|
| 107 |
+
payload["id"] = payload.get("id") or idx
|
| 108 |
+
payload["content"] = document or ""
|
| 109 |
+
hits.append((self._distance_to_score(distance), payload))
|
| 110 |
+
return hits
|
| 111 |
|
| 112 |
def remove_repository(self, repo_id: int):
|
| 113 |
+
self.collection.delete(where={"repository_id": repo_id})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
def clear(self):
|
| 116 |
+
try:
|
| 117 |
+
self.client.delete_collection(name=self.collection_name)
|
| 118 |
+
except Exception:
|
| 119 |
+
pass
|
| 120 |
+
self.collection = self._ensure_collection()
|
| 121 |
|
| 122 |
def save(self):
|
| 123 |
+
persist = getattr(self.client, "persist", None)
|
| 124 |
+
if callable(persist):
|
| 125 |
+
persist()
|
| 126 |
|
| 127 |
def load(self):
|
| 128 |
+
self.collection = self._ensure_collection()
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
def keep_alive(self) -> dict:
|
| 131 |
+
heartbeat = getattr(self.client, "heartbeat", None)
|
| 132 |
+
if callable(heartbeat):
|
| 133 |
+
heartbeat()
|
| 134 |
+
return self.get_stats()
|
|
|
|
| 135 |
|
| 136 |
def get_stats(self) -> dict:
|
|
|
|
| 137 |
return {
|
| 138 |
+
"total_vectors": self.collection.count(),
|
| 139 |
"embedding_dim": self.embedding_dim,
|
| 140 |
"collection_name": self.collection_name,
|
| 141 |
+
"persist_path": self.persist_path if self.persist else None,
|
| 142 |
}
|
| 143 |
+
|
| 144 |
+
@staticmethod
|
| 145 |
+
def _sanitize_metadata(meta: dict) -> dict:
|
| 146 |
+
sanitized = {}
|
| 147 |
+
for key, value in meta.items():
|
| 148 |
+
if key == "content":
|
| 149 |
+
continue
|
| 150 |
+
if value is None:
|
| 151 |
+
sanitized[key] = ""
|
| 152 |
+
elif isinstance(value, (str, int, float, bool)):
|
| 153 |
+
sanitized[key] = value
|
| 154 |
+
else:
|
| 155 |
+
sanitized[key] = str(value)
|
| 156 |
+
return sanitized
|
| 157 |
+
|
| 158 |
+
@staticmethod
|
| 159 |
+
def _distance_to_score(distance: float) -> float:
|
| 160 |
+
if distance is None:
|
| 161 |
+
return 0.0
|
| 162 |
+
return max(0.0, min(1.0, 1.0 - float(distance)))
|