technophyle commited on
Commit
24e05bd
·
verified ·
1 Parent(s): 26349ea

Sync from GitHub via hub-sync

Browse files
README.md CHANGED
@@ -23,8 +23,8 @@ FastAPI backend for Code Compass, a personal full-stack RAG project that indexes
23
 
24
  - Clone a public GitHub repository into temporary storage
25
  - Filter and chunk source files for retrieval
26
- - Generate embeddings and store chunks in Qdrant
27
- - Maintain lightweight repository and session metadata in SQLite
28
  - Run indexing as a background task
29
  - Retrieve evidence with semantic search, lexical search, fusion, and reranking
30
  - Generate answers from the selected context and return citations to the UI
@@ -45,21 +45,17 @@ Production is configured for lower-cost hosting:
45
  - `EMBEDDING_PROVIDER=local`
46
  - Groq-hosted Llama for answer generation
47
  - Local sentence-transformer embeddings for retrieval
48
- - Qdrant Cloud for vector storage
49
 
50
- ## Qdrant Keepalive
51
 
52
- The backend starts a lightweight Qdrant keepalive scheduler when `QDRANT_URL` is configured. It calls the configured collection every 12 hours by default so a free-tier Qdrant cluster does not become inactive while the backend process is running.
53
 
54
  Configuration:
55
 
56
- - `QDRANT_URL`
57
- - `QDRANT_API_KEY`
58
- - `QDRANT_COLLECTION=repo_qa_chunks`
59
- - `QDRANT_KEEPALIVE_ENABLED=true`
60
- - `QDRANT_KEEPALIVE_INTERVAL_SECONDS=43200`
61
-
62
- The main repository also includes a GitHub Actions keepalive workflow for cases where the backend host is asleep.
63
 
64
  ## Metrics
65
 
 
23
 
24
  - Clone a public GitHub repository into temporary storage
25
  - Filter and chunk source files for retrieval
26
+ - Generate embeddings and store chunks in Chroma DB
27
+ - Maintain lightweight repository and session metadata in memory
28
  - Run indexing as a background task
29
  - Retrieve evidence with semantic search, lexical search, fusion, and reranking
30
  - Generate answers from the selected context and return citations to the UI
 
45
  - `EMBEDDING_PROVIDER=local`
46
  - Groq-hosted Llama for answer generation
47
  - Local sentence-transformer embeddings for retrieval
48
+ - Chroma DB for vector storage
49
 
50
+ ## Chroma Storage
51
 
52
+ The backend uses Chroma DB for vector storage in both local development and production. By default it stores the collection under `./data/chroma`, and you can point it somewhere else with `CHROMA_PATH`.
53
 
54
  Configuration:
55
 
56
+ - `CHROMA_PATH=./data/chroma`
57
+ - `CHROMA_COLLECTION=repo_qa_chunks`
58
+ - `CHROMA_UPSERT_BATCH_SIZE=64`
 
 
 
 
59
 
60
  ## Metrics
61
 
requirements.txt CHANGED
@@ -1,6 +1,5 @@
1
  fastapi==0.109.2
2
  uvicorn[standard]==0.27.1
3
- sqlalchemy==2.0.25
4
  pydantic==2.6.1
5
  python-dotenv==1.0.1
6
 
@@ -11,7 +10,7 @@ google-genai==1.12.1
11
  httpx==0.28.1
12
  numpy==1.26.4
13
  rank-bm25==0.2.2
14
- qdrant-client==1.15.1
15
  sentence-transformers==2.7.0
16
  einops==0.8.1
17
  tree-sitter==0.21.3
 
1
  fastapi==0.109.2
2
  uvicorn[standard]==0.27.1
 
3
  pydantic==2.6.1
4
  python-dotenv==1.0.1
5
 
 
10
  httpx==0.28.1
11
  numpy==1.26.4
12
  rank-bm25==0.2.2
13
+ chromadb>=0.5.23
14
  sentence-transformers==2.7.0
15
  einops==0.8.1
16
  tree-sitter==0.21.3
server_app.py CHANGED
@@ -8,7 +8,6 @@ from pydantic import BaseModel, Field, HttpUrl
8
  from dotenv import load_dotenv
9
 
10
  from src.bedrock_claude import BedrockTransientError, is_bedrock_retryable_error
11
- from src.qdrant_keepalive import QdrantKeepAliveScheduler
12
  from src.rag_system import CodebaseRAGSystem
13
 
14
  load_dotenv(Path(__file__).with_name(".env"))
@@ -35,7 +34,6 @@ app.add_middleware(
35
  )
36
 
37
  rag_system: Optional[CodebaseRAGSystem] = None
38
- qdrant_keepalive: Optional[QdrantKeepAliveScheduler] = None
39
 
40
 
41
  class RepoIndexRequest(BaseModel):
@@ -62,17 +60,15 @@ def require_session_id(x_session_id: Optional[str] = Header(None, alias="X-Sessi
62
 
63
  @app.on_event("startup")
64
  def startup():
65
- global qdrant_keepalive, rag_system
66
  Path("./data").mkdir(exist_ok=True)
67
  rag_system = CodebaseRAGSystem()
68
- qdrant_keepalive = QdrantKeepAliveScheduler(rag_system.vector_store)
69
- qdrant_keepalive.start()
70
 
71
 
72
  @app.on_event("shutdown")
73
  def shutdown():
74
- if qdrant_keepalive is not None:
75
- qdrant_keepalive.stop()
76
 
77
 
78
  @app.get("/")
 
8
  from dotenv import load_dotenv
9
 
10
  from src.bedrock_claude import BedrockTransientError, is_bedrock_retryable_error
 
11
  from src.rag_system import CodebaseRAGSystem
12
 
13
  load_dotenv(Path(__file__).with_name(".env"))
 
34
  )
35
 
36
  rag_system: Optional[CodebaseRAGSystem] = None
 
37
 
38
 
39
  class RepoIndexRequest(BaseModel):
 
60
 
61
  @app.on_event("startup")
62
  def startup():
63
+ global rag_system
64
  Path("./data").mkdir(exist_ok=True)
65
  rag_system = CodebaseRAGSystem()
 
 
66
 
67
 
68
  @app.on_event("shutdown")
69
  def shutdown():
70
+ if rag_system is not None:
71
+ rag_system.vector_store.save()
72
 
73
 
74
  @app.get("/")
src/__init__.py CHANGED
@@ -7,14 +7,14 @@ from .embeddings import EmbeddingGenerator
7
  from .hybrid_search import HybridSearchEngine
8
  from .rag_system import CodebaseRAGSystem
9
  from .repo_fetcher import RepoFetcher
10
- from .vector_store import QdrantVectorStore
11
 
12
  __version__ = "2.0.0"
13
  __all__ = [
14
  "CodeParser",
15
  "CodebaseRAGSystem",
16
  "EmbeddingGenerator",
17
- "QdrantVectorStore",
18
  "HybridSearchEngine",
19
  "RepoFetcher",
20
  ]
 
7
  from .hybrid_search import HybridSearchEngine
8
  from .rag_system import CodebaseRAGSystem
9
  from .repo_fetcher import RepoFetcher
10
+ from .vector_store import ChromaVectorStore
11
 
12
  __version__ = "2.0.0"
13
  __all__ = [
14
  "CodeParser",
15
  "CodebaseRAGSystem",
16
  "EmbeddingGenerator",
17
+ "ChromaVectorStore",
18
  "HybridSearchEngine",
19
  "RepoFetcher",
20
  ]
src/database.py DELETED
@@ -1,143 +0,0 @@
1
- import os
2
- from datetime import datetime
3
- from pathlib import Path
4
-
5
- from sqlalchemy import (
6
- JSON,
7
- Column,
8
- DateTime,
9
- Float,
10
- ForeignKey,
11
- Integer,
12
- String,
13
- Text,
14
- create_engine,
15
- inspect,
16
- text,
17
- )
18
- from sqlalchemy.orm import declarative_base, relationship, sessionmaker
19
-
20
- Base = declarative_base()
21
- _ENGINE_CACHE = {}
22
- _SESSION_FACTORY_CACHE = {}
23
- SERVER_DIR = Path(__file__).resolve().parents[1]
24
-
25
-
26
- class Repository(Base):
27
- __tablename__ = "repositories"
28
-
29
- id = Column(Integer, primary_key=True)
30
- github_url = Column(String(1024), nullable=False, unique=True)
31
- source_url = Column(String(1024))
32
- session_key = Column(String(255), index=True)
33
- session_expires_at = Column(DateTime)
34
- owner = Column(String(255), nullable=False)
35
- name = Column(String(255), nullable=False)
36
- branch = Column(String(255), nullable=False, default="main")
37
- local_path = Column(String(1024))
38
- status = Column(String(64), nullable=False, default="queued")
39
- error_message = Column(Text)
40
- file_count = Column(Integer, nullable=False, default=0)
41
- chunk_count = Column(Integer, nullable=False, default=0)
42
- indexed_at = Column(DateTime)
43
- created_at = Column(DateTime, default=datetime.utcnow)
44
- updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
45
-
46
- chunks = relationship(
47
- "CodeChunk", back_populates="repository", cascade="all, delete-orphan"
48
- )
49
- chat_turns = relationship(
50
- "ChatTurn", back_populates="repository", cascade="all, delete-orphan"
51
- )
52
-
53
-
54
- class CodeChunk(Base):
55
- __tablename__ = "code_chunks"
56
-
57
- id = Column(Integer, primary_key=True)
58
- repository_id = Column(Integer, ForeignKey("repositories.id"), nullable=False)
59
- file_path = Column(String(1024), nullable=False)
60
- language = Column(String(64), nullable=False)
61
- symbol_name = Column(String(255))
62
- symbol_type = Column(String(128), nullable=False, default="chunk")
63
- line_start = Column(Integer, nullable=False)
64
- line_end = Column(Integer, nullable=False)
65
- signature = Column(Text)
66
- content = Column(Text, nullable=False)
67
- searchable_text = Column(Text, nullable=False)
68
- metadata_json = Column(JSON, nullable=False, default=dict)
69
- embedding_id = Column(Integer)
70
- rerank_score = Column(Float)
71
- created_at = Column(DateTime, default=datetime.utcnow)
72
-
73
- repository = relationship("Repository", back_populates="chunks")
74
-
75
-
76
- class ChatTurn(Base):
77
- __tablename__ = "chat_turns"
78
-
79
- id = Column(Integer, primary_key=True)
80
- repository_id = Column(Integer, ForeignKey("repositories.id"), nullable=False)
81
- role = Column(String(32), nullable=False)
82
- content = Column(Text, nullable=False)
83
- answer_json = Column(JSON)
84
- created_at = Column(DateTime, default=datetime.utcnow)
85
-
86
- repository = relationship("Repository", back_populates="chat_turns")
87
-
88
-
89
- def init_db(database_url: str = None):
90
- if database_url is None:
91
- database_url = os.getenv("DATABASE_URL", "sqlite:///./codebase_rag.db")
92
-
93
- database_url = resolve_database_url(database_url)
94
- if database_url in _ENGINE_CACHE:
95
- return _ENGINE_CACHE[database_url], _SESSION_FACTORY_CACHE[database_url]
96
-
97
- connect_args = {"check_same_thread": False} if database_url.startswith("sqlite") else {}
98
- engine = create_engine(database_url, echo=False, connect_args=connect_args)
99
- Base.metadata.create_all(engine)
100
- _ensure_runtime_columns(engine)
101
- session_local = sessionmaker(bind=engine)
102
- _ENGINE_CACHE[database_url] = engine
103
- _SESSION_FACTORY_CACHE[database_url] = session_local
104
- return engine, session_local
105
-
106
-
107
- def resolve_database_url(database_url: str) -> str:
108
- if not database_url.startswith("sqlite:///"):
109
- return database_url
110
-
111
- sqlite_path = database_url.removeprefix("sqlite:///")
112
- if sqlite_path == ":memory:":
113
- return database_url
114
-
115
- path = Path(sqlite_path)
116
- if not path.is_absolute():
117
- path = SERVER_DIR / path
118
- path.parent.mkdir(parents=True, exist_ok=True)
119
- path.touch(exist_ok=True)
120
- return f"sqlite:///{path.resolve()}"
121
-
122
-
123
- def _ensure_runtime_columns(engine):
124
- inspector = inspect(engine)
125
- if "repositories" not in inspector.get_table_names():
126
- return
127
-
128
- existing = {column["name"] for column in inspector.get_columns("repositories")}
129
- alterations = {
130
- "source_url": "ALTER TABLE repositories ADD COLUMN source_url VARCHAR(1024)",
131
- "session_key": "ALTER TABLE repositories ADD COLUMN session_key VARCHAR(255)",
132
- "session_expires_at": "ALTER TABLE repositories ADD COLUMN session_expires_at DATETIME",
133
- }
134
-
135
- with engine.begin() as connection:
136
- for column_name, statement in alterations.items():
137
- if column_name not in existing:
138
- connection.execute(text(statement))
139
-
140
-
141
- def get_db_session(database_url: str = None):
142
- _, session_local = init_db(database_url)
143
- return session_local()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/qdrant_keepalive.py DELETED
@@ -1,83 +0,0 @@
1
- import os
2
- import threading
3
- from typing import Optional
4
-
5
- from src.vector_store import QdrantVectorStore
6
-
7
-
8
- class QdrantKeepAliveScheduler:
9
- def __init__(self, vector_store: QdrantVectorStore):
10
- self.vector_store = vector_store
11
- self.interval_seconds = self._interval_seconds()
12
- self.run_on_start = self._env_flag("QDRANT_KEEPALIVE_RUN_ON_START", True)
13
- self.keepalive_enabled = self._env_flag("QDRANT_KEEPALIVE_ENABLED", True)
14
- self.enabled = self.keepalive_enabled and self.vector_store.is_remote()
15
- self._stop_event = threading.Event()
16
- self._thread: Optional[threading.Thread] = None
17
-
18
- def start(self):
19
- if not self.enabled:
20
- reason = (
21
- "disabled by QDRANT_KEEPALIVE_ENABLED"
22
- if not self.keepalive_enabled
23
- else "set QDRANT_URL to enable remote Qdrant pings"
24
- )
25
- print(
26
- f"[qdrant-keepalive] Disabled; {reason}",
27
- flush=True,
28
- )
29
- return
30
- if self._thread and self._thread.is_alive():
31
- return
32
-
33
- self._stop_event.clear()
34
- self._thread = threading.Thread(
35
- target=self._run,
36
- name="qdrant-keepalive",
37
- daemon=True,
38
- )
39
- self._thread.start()
40
- print(
41
- f"[qdrant-keepalive] Started interval_seconds={self.interval_seconds}",
42
- flush=True,
43
- )
44
-
45
- def stop(self):
46
- self._stop_event.set()
47
- if self._thread and self._thread.is_alive():
48
- self._thread.join(timeout=5)
49
- self._thread = None
50
-
51
- def _run(self):
52
- if self.run_on_start:
53
- self._ping()
54
-
55
- while not self._stop_event.wait(self.interval_seconds):
56
- self._ping()
57
-
58
- def _ping(self):
59
- try:
60
- stats = self.vector_store.keep_alive()
61
- print(
62
- "[qdrant-keepalive] Ping succeeded "
63
- f"collection={stats['collection_name']} "
64
- f"points={stats['total_vectors']}",
65
- flush=True,
66
- )
67
- except Exception as exc:
68
- print(f"[qdrant-keepalive] Ping failed: {exc}", flush=True)
69
-
70
- @staticmethod
71
- def _env_flag(name: str, default: bool) -> bool:
72
- value = os.getenv(name)
73
- if value is None:
74
- return default
75
- return value.strip().lower() not in {"0", "false", "no", "off"}
76
-
77
- @staticmethod
78
- def _interval_seconds() -> int:
79
- value = os.getenv("QDRANT_KEEPALIVE_INTERVAL_SECONDS", "43200")
80
- try:
81
- return max(60, int(value))
82
- except ValueError:
83
- return 43200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/rag_system.py CHANGED
@@ -1,44 +1,57 @@
1
  import os
2
  import re
 
3
  from datetime import datetime, timedelta
 
4
  from typing import Dict, List, Optional
5
 
6
  from openai import OpenAI
7
 
8
  from src.code_parser import CodeParser
9
  from src.bedrock_claude import create_bedrock_runtime_client, generate_bedrock_claude_text
10
- from src.database import Repository, get_db_session, init_db, resolve_database_url
11
  from src.embeddings import EmbeddingGenerator
12
  from src.hybrid_search import HybridSearchEngine
13
  from src.repo_fetcher import RepoFetcher
14
- from src.vector_store import QdrantVectorStore
15
 
16
 
17
  class SessionCancelledError(RuntimeError):
18
  pass
19
 
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  class CodebaseRAGSystem:
22
  def __init__(
23
  self,
24
- database_url: str = None,
25
  repo_dir: str = None,
26
  index_path: str = None,
27
  ):
28
- self.database_url = database_url or os.getenv(
29
- "DATABASE_URL", "sqlite:///./codebase_rag.db"
30
- )
31
- self.database_url = resolve_database_url(self.database_url)
32
- init_db(self.database_url)
33
- print(f"[database] Using database_url={self.database_url}", flush=True)
34
-
35
  self.repo_fetcher = RepoFetcher(base_dir=repo_dir)
36
  self.parser = CodeParser()
37
  self.embedder = EmbeddingGenerator()
38
- self.vector_store = QdrantVectorStore(
39
  embedding_dim=self.embedder.get_embedding_dim(),
40
- index_path=index_path or "./data/faiss/codebase_index",
41
- persist=False,
42
  )
43
  self.hybrid_search = HybridSearchEngine(
44
  reranker_model=os.getenv(
@@ -51,34 +64,35 @@ class CodebaseRAGSystem:
51
  self.llm_model = ""
52
  self._configure_llm()
53
  self.session_ttl_minutes = int(os.getenv("SESSION_TTL_MINUTES", "120"))
 
 
 
 
54
  self.indexing_progress: Dict[int, dict] = {}
55
  self.repo_chunks: Dict[int, List[dict]] = {}
56
  self.cancelled_repo_ids = set()
57
  self.rebuild_indexes()
58
 
59
  def rebuild_indexes(self):
60
- session = get_db_session(self.database_url)
61
- try:
62
  self.vector_store.clear()
 
 
 
63
  self.repo_chunks.clear()
64
  self.indexing_progress.clear()
65
  self.cancelled_repo_ids.clear()
66
- repos = session.query(Repository).all()
67
- self._delete_repositories(session, repos, track_cancellation=False)
68
- self.cancelled_repo_ids.clear()
69
- session.commit()
70
- finally:
71
- session.close()
72
 
73
  def create_or_reset_repository(self, github_url: str, session_key: str) -> Repository:
74
  info = self.repo_fetcher.parse_github_url(github_url)
75
  registry_key = self._build_registry_key(session_key, github_url)
76
- session = get_db_session(self.database_url)
77
- try:
78
- self._cleanup_expired_sessions(session)
79
- repo = session.query(Repository).filter_by(github_url=registry_key).first()
80
  if repo is None:
81
  repo = Repository(
 
82
  github_url=registry_key,
83
  source_url=github_url,
84
  session_key=session_key,
@@ -88,8 +102,9 @@ class CodebaseRAGSystem:
88
  branch=info["branch"],
89
  status="queued",
90
  )
91
- session.add(repo)
92
- session.flush()
 
93
  self.cancelled_repo_ids.discard(repo.id)
94
  else:
95
  repo.source_url = github_url
@@ -103,37 +118,39 @@ class CodebaseRAGSystem:
103
  repo.file_count = 0
104
  repo.chunk_count = 0
105
  repo.indexed_at = None
 
106
  self.cancelled_repo_ids.discard(repo.id)
107
  self.hybrid_search.remove_repository(repo.id)
108
  self.vector_store.remove_repository(repo.id)
109
  self.repo_chunks.pop(repo.id, None)
110
 
111
- session.commit()
112
- session.refresh(repo)
113
  return repo
114
- finally:
115
- session.close()
116
 
117
  def index_repository(self, repo_id: int):
118
- session = get_db_session(self.database_url)
119
  try:
120
- self._cleanup_expired_sessions(session)
121
- repo = session.query(Repository).filter_by(id=repo_id).first()
122
- if repo is None:
123
- raise ValueError("Repository not found")
124
- self._ensure_repo_not_cancelled(repo.id)
125
- print(f"[indexing] Starting repository index repo_id={repo.id}", flush=True)
 
 
 
 
 
 
126
 
127
- repo.status = "indexing"
128
- repo.error_message = None
129
- repo.session_expires_at = self._session_expiry()
130
- session.commit()
131
  self._set_progress(repo.id, phase="cloning", message="Cloning repository")
132
 
133
  clone_info = self.repo_fetcher.clone_repository(repo.source_url or repo.github_url)
134
  self._ensure_repo_not_cancelled(repo.id)
135
- repo.local_path = None
136
- repo.branch = clone_info["branch"]
 
 
 
137
  print(
138
  f"[indexing] Repository cloned repo_id={repo.id} branch={repo.branch} "
139
  f"path={clone_info['local_path']}",
@@ -250,82 +267,74 @@ class CodebaseRAGSystem:
250
  }
251
  created_rows.append(row)
252
 
253
- repo.status = "indexed"
254
- repo.file_count = file_count
255
- repo.chunk_count = len(created_rows)
256
- repo.indexed_at = datetime.utcnow()
257
- repo.session_expires_at = self._session_expiry()
258
- self._ensure_repo_still_exists(session, repo.id)
259
- self._ensure_repo_not_cancelled(repo.id)
260
- session.commit()
261
-
262
  serialized = [self._serialize_chunk(chunk) for chunk in created_rows]
263
- self.repo_chunks[repo.id] = serialized
 
 
 
 
 
 
 
 
 
264
  self.vector_store.save()
265
- self.indexing_progress.pop(repo.id, None)
266
- self.cancelled_repo_ids.discard(repo.id)
 
267
  self.repo_fetcher.cleanup_repository(clone_info["local_path"])
268
  print(f"[indexing] Repository index complete repo_id={repo.id}", flush=True)
269
  except Exception as exc:
270
  print(f"[indexing] Repository index failed repo_id={repo_id} error={exc}", flush=True)
271
- session.rollback()
272
  self.vector_store.remove_repository(repo_id)
273
- self.repo_chunks.pop(repo_id, None)
274
  self.hybrid_search.remove_repository(repo_id)
275
- repo = session.query(Repository).filter_by(id=repo_id).first()
276
- if repo:
277
- if repo_id in self.cancelled_repo_ids:
278
- session.delete(repo)
279
- else:
280
- repo.status = "failed"
281
- repo.error_message = str(exc)
282
- session.commit()
 
 
283
  try:
284
- if "clone_info" in locals():
285
  self.repo_fetcher.cleanup_repository(clone_info["local_path"])
286
  except Exception:
287
  pass
288
- self.indexing_progress.pop(repo_id, None)
 
289
  if isinstance(exc, SessionCancelledError):
290
  return
291
  raise
292
- finally:
293
- session.close()
294
 
295
  def list_repositories(self) -> List[dict]:
296
  raise NotImplementedError
297
 
298
  def list_repositories_for_session(self, session_key: str) -> List[dict]:
299
- session = get_db_session(self.database_url)
300
- try:
301
- self._cleanup_expired_sessions(session)
302
- repos = (
303
- session.query(Repository)
304
- .filter_by(session_key=session_key)
305
- .order_by(Repository.updated_at.desc())
306
- .all()
307
- )
308
- self._touch_session(session, session_key)
309
  return [self._serialize_repo(repo) for repo in repos]
310
- finally:
311
- session.close()
312
 
313
  def get_repository(self, repo_id: int) -> Optional[dict]:
314
  raise NotImplementedError
315
 
316
  def get_repository_for_session(self, repo_id: int, session_key: str) -> Optional[dict]:
317
- session = get_db_session(self.database_url)
318
- try:
319
- self._cleanup_expired_sessions(session)
320
- repo = (
321
- session.query(Repository)
322
- .filter_by(id=repo_id, session_key=session_key)
323
- .first()
324
- )
325
- self._touch_session(session, session_key)
326
  return self._serialize_repo(repo) if repo else None
327
- finally:
328
- session.close()
329
 
330
  def answer_question(
331
  self,
@@ -335,96 +344,92 @@ class CodebaseRAGSystem:
335
  top_k: int = 8,
336
  history=None,
337
  ) -> dict:
338
- session = get_db_session(self.database_url)
339
- try:
340
- self._cleanup_expired_sessions(session)
341
- repo = (
342
- session.query(Repository)
343
- .filter_by(id=repo_id, session_key=session_key)
344
- .first()
345
- )
346
  if repo is None:
347
  raise ValueError("Repository not found")
348
  if repo.status != "indexed":
349
  raise ValueError("Repository is not ready for questions yet")
350
  if repo_id not in self.repo_chunks:
351
  raise ValueError("Session cache expired. Re-index the repository and try again.")
352
- self._touch_session(session, session_key)
 
353
 
354
- normalized_history = self._normalize_history(history or [])
355
- question_intent = self._question_intent(question)
356
- deep_search_intents = {
357
- "api",
358
- "implementation",
359
- "cross_file",
360
- "error_handling",
361
- "setup",
362
- "tests",
363
- }
364
- deep_multiplier = int(os.getenv("RAG_DEEP_SEARCH_MULTIPLIER", "8"))
365
- shallow_multiplier = int(os.getenv("RAG_SEARCH_MULTIPLIER", "4"))
366
- search_depth = (
367
- top_k * deep_multiplier
368
- if question_intent in deep_search_intents
369
- else top_k * shallow_multiplier
370
- )
371
- search_depth = max(top_k, min(search_depth, 120))
372
 
373
- retrieval_query = self._build_retrieval_query(question, normalized_history)
374
- query_embedding = self.embedder.embed_text(retrieval_query)
375
 
376
- semantic_hits = []
377
- for score, meta in self.vector_store.search(query_embedding, k=search_depth, repo_filter=repo_id):
378
- serialized = dict(meta)
379
- serialized["semantic_score"] = score
380
- semantic_hits.append(serialized)
381
 
382
- lexical_hits = self.hybrid_search.bm25_search(
383
- self.repo_chunks[repo_id],
384
- retrieval_query,
385
- top_k=search_depth,
386
- )
387
- semantic_hits = self.hybrid_search.normalize_semantic_results(semantic_hits)
388
- fused = self.hybrid_search.reciprocal_rank_fusion(lexical_hits, semantic_hits, top_k=search_depth)
389
-
390
- path_hits = self._path_intent_search(
391
- self.repo_chunks[repo_id],
392
- question,
393
- retrieval_query,
394
- top_k=search_depth,
395
- )
396
- fused = self._merge_ranked_candidates(fused, path_hits, top_k=search_depth)
397
 
398
- rerank_query = retrieval_query if question_intent in deep_search_intents else question
399
 
400
- # FIX: rerank to a small candidate pool first (20), then let
401
- # _prioritize_results and _select_answer_sources trim to final top_k.
402
- # Previously rerank was called with search_depth (up to 120), meaning
403
- # the LLM received far too many chunks and faithfulness dropped.
404
- rerank_pool = min(search_depth, 20)
405
- reranked = self.hybrid_search.rerank(rerank_query, fused, top_k=rerank_pool)
406
 
407
- reranked = self._prioritize_results(question, retrieval_query, reranked, top_k=top_k)
408
 
409
- # FIX: cap final sources at 5 instead of top_k (8).
410
- # 5 sources × 1500 chars = ~7500 chars context, which the LLM handles well.
411
- # 8 sources × 2500 chars = ~20000 chars, which causes lost-in-the-middle issues.
412
- final_top_k = min(top_k, 5)
413
- reranked = self._select_answer_sources(question, reranked, top_k=final_top_k)
414
 
415
- answer = self._generate_answer(repo, question, reranked, normalized_history)
416
- return answer
417
- finally:
418
- session.close()
419
 
420
  def end_session(self, session_key: str):
421
- session = get_db_session(self.database_url)
422
- try:
423
- repos = session.query(Repository).filter_by(session_key=session_key).all()
424
- self._delete_repositories(session, repos)
425
- session.commit()
426
- finally:
427
- session.close()
428
 
429
 
430
  def _generate_answer(
@@ -800,35 +805,33 @@ Do not leave the answer unfinished.
800
  return payload
801
 
802
  def _set_progress(self, repo_id: int, **progress):
803
- self.indexing_progress[repo_id] = {
804
- **self.indexing_progress.get(repo_id, {}),
805
- **progress,
806
- "updated_at": datetime.utcnow().isoformat(),
807
- }
 
808
 
809
- def _touch_session(self, session, session_key: str):
810
  expiry = self._session_expiry()
811
- repos = session.query(Repository).filter_by(session_key=session_key).all()
812
- for repo in repos:
813
- repo.session_expires_at = expiry
814
- session.commit()
815
 
816
- def _cleanup_expired_sessions(self, session):
817
  now = datetime.utcnow()
818
- expired = (
819
- session.query(Repository)
820
- .filter(Repository.session_expires_at.is_not(None))
821
- .filter(Repository.session_expires_at < now)
822
- .all()
823
- )
824
  if not expired:
825
  return
826
- self._delete_repositories(session, expired)
827
- session.commit()
828
 
829
  def _delete_repositories(
830
  self,
831
- session,
832
  repos: List[Repository],
833
  track_cancellation: bool = True,
834
  ):
@@ -840,13 +843,18 @@ Do not leave the answer unfinished.
840
  self.vector_store.remove_repository(repo_id)
841
  self.repo_chunks.pop(repo_id, None)
842
  self.indexing_progress.pop(repo_id, None)
843
- for repo in repos:
844
- session.delete(repo)
 
845
 
846
  def _ensure_repo_not_cancelled(self, repo_id: int):
847
  if repo_id in self.cancelled_repo_ids:
848
  raise SessionCancelledError("Session ended before indexing completed.")
849
 
 
 
 
 
850
  def _build_retrieval_query(self, question: str, history: List[dict]) -> str:
851
  normalized = " ".join(question.strip().split())
852
  if self._is_repo_overview_question(normalized):
@@ -1660,9 +1668,8 @@ Do not leave the answer unfinished.
1660
  lines.append(f"{role}: {content[:400]}")
1661
  return "\n".join(lines) if lines else "None"
1662
 
1663
- @staticmethod
1664
- def _ensure_repo_still_exists(session, repo_id: int):
1665
- if session.query(Repository.id).filter_by(id=repo_id).first() is None:
1666
  raise RuntimeError("Repository was removed before indexing completed.")
1667
 
1668
  def _session_expiry(self) -> datetime:
 
1
  import os
2
  import re
3
+ from dataclasses import dataclass, field
4
  from datetime import datetime, timedelta
5
+ from threading import RLock
6
  from typing import Dict, List, Optional
7
 
8
  from openai import OpenAI
9
 
10
  from src.code_parser import CodeParser
11
  from src.bedrock_claude import create_bedrock_runtime_client, generate_bedrock_claude_text
 
12
  from src.embeddings import EmbeddingGenerator
13
  from src.hybrid_search import HybridSearchEngine
14
  from src.repo_fetcher import RepoFetcher
15
+ from src.vector_store import ChromaVectorStore
16
 
17
 
18
  class SessionCancelledError(RuntimeError):
19
  pass
20
 
21
 
22
+ @dataclass
23
+ class Repository:
24
+ id: int
25
+ github_url: str
26
+ source_url: str
27
+ session_key: str
28
+ session_expires_at: datetime
29
+ owner: str
30
+ name: str
31
+ branch: str = "main"
32
+ local_path: Optional[str] = None
33
+ status: str = "queued"
34
+ error_message: Optional[str] = None
35
+ file_count: int = 0
36
+ chunk_count: int = 0
37
+ indexed_at: Optional[datetime] = None
38
+ created_at: datetime = field(default_factory=datetime.utcnow)
39
+ updated_at: datetime = field(default_factory=datetime.utcnow)
40
+
41
+
42
  class CodebaseRAGSystem:
43
  def __init__(
44
  self,
 
45
  repo_dir: str = None,
46
  index_path: str = None,
47
  ):
 
 
 
 
 
 
 
48
  self.repo_fetcher = RepoFetcher(base_dir=repo_dir)
49
  self.parser = CodeParser()
50
  self.embedder = EmbeddingGenerator()
51
+ self.vector_store = ChromaVectorStore(
52
  embedding_dim=self.embedder.get_embedding_dim(),
53
+ index_path=index_path or "./data/chroma",
54
+ persist=True,
55
  )
56
  self.hybrid_search = HybridSearchEngine(
57
  reranker_model=os.getenv(
 
64
  self.llm_model = ""
65
  self._configure_llm()
66
  self.session_ttl_minutes = int(os.getenv("SESSION_TTL_MINUTES", "120"))
67
+ self.repo_lock = RLock()
68
+ self.repositories: Dict[int, Repository] = {}
69
+ self.repository_registry: Dict[str, int] = {}
70
+ self.next_repo_id = 1
71
  self.indexing_progress: Dict[int, dict] = {}
72
  self.repo_chunks: Dict[int, List[dict]] = {}
73
  self.cancelled_repo_ids = set()
74
  self.rebuild_indexes()
75
 
76
  def rebuild_indexes(self):
77
+ with self.repo_lock:
 
78
  self.vector_store.clear()
79
+ self.repositories.clear()
80
+ self.repository_registry.clear()
81
+ self.next_repo_id = 1
82
  self.repo_chunks.clear()
83
  self.indexing_progress.clear()
84
  self.cancelled_repo_ids.clear()
 
 
 
 
 
 
85
 
86
  def create_or_reset_repository(self, github_url: str, session_key: str) -> Repository:
87
  info = self.repo_fetcher.parse_github_url(github_url)
88
  registry_key = self._build_registry_key(session_key, github_url)
89
+ with self.repo_lock:
90
+ self._cleanup_expired_sessions()
91
+ repo_id = self.repository_registry.get(registry_key)
92
+ repo = self.repositories.get(repo_id) if repo_id else None
93
  if repo is None:
94
  repo = Repository(
95
+ id=self.next_repo_id,
96
  github_url=registry_key,
97
  source_url=github_url,
98
  session_key=session_key,
 
102
  branch=info["branch"],
103
  status="queued",
104
  )
105
+ self.next_repo_id += 1
106
+ self.repositories[repo.id] = repo
107
+ self.repository_registry[registry_key] = repo.id
108
  self.cancelled_repo_ids.discard(repo.id)
109
  else:
110
  repo.source_url = github_url
 
118
  repo.file_count = 0
119
  repo.chunk_count = 0
120
  repo.indexed_at = None
121
+ self._mark_repo_updated(repo)
122
  self.cancelled_repo_ids.discard(repo.id)
123
  self.hybrid_search.remove_repository(repo.id)
124
  self.vector_store.remove_repository(repo.id)
125
  self.repo_chunks.pop(repo.id, None)
126
 
 
 
127
  return repo
 
 
128
 
129
  def index_repository(self, repo_id: int):
130
+ clone_info = None
131
  try:
132
+ with self.repo_lock:
133
+ self._cleanup_expired_sessions()
134
+ repo = self.repositories.get(repo_id)
135
+ if repo is None:
136
+ raise ValueError("Repository not found")
137
+ self._ensure_repo_not_cancelled(repo.id)
138
+ print(f"[indexing] Starting repository index repo_id={repo.id}", flush=True)
139
+
140
+ repo.status = "indexing"
141
+ repo.error_message = None
142
+ repo.session_expires_at = self._session_expiry()
143
+ self._mark_repo_updated(repo)
144
 
 
 
 
 
145
  self._set_progress(repo.id, phase="cloning", message="Cloning repository")
146
 
147
  clone_info = self.repo_fetcher.clone_repository(repo.source_url or repo.github_url)
148
  self._ensure_repo_not_cancelled(repo.id)
149
+ with self.repo_lock:
150
+ self._ensure_repo_still_exists(repo.id)
151
+ repo.branch = clone_info["branch"]
152
+ repo.local_path = None
153
+ self._mark_repo_updated(repo)
154
  print(
155
  f"[indexing] Repository cloned repo_id={repo.id} branch={repo.branch} "
156
  f"path={clone_info['local_path']}",
 
267
  }
268
  created_rows.append(row)
269
 
 
 
 
 
 
 
 
 
 
270
  serialized = [self._serialize_chunk(chunk) for chunk in created_rows]
271
+ with self.repo_lock:
272
+ self._ensure_repo_still_exists(repo.id)
273
+ self._ensure_repo_not_cancelled(repo.id)
274
+ repo.status = "indexed"
275
+ repo.file_count = file_count
276
+ repo.chunk_count = len(created_rows)
277
+ repo.indexed_at = datetime.utcnow()
278
+ repo.session_expires_at = self._session_expiry()
279
+ self._mark_repo_updated(repo)
280
+ self.repo_chunks[repo.id] = serialized
281
  self.vector_store.save()
282
+ with self.repo_lock:
283
+ self.indexing_progress.pop(repo.id, None)
284
+ self.cancelled_repo_ids.discard(repo.id)
285
  self.repo_fetcher.cleanup_repository(clone_info["local_path"])
286
  print(f"[indexing] Repository index complete repo_id={repo.id}", flush=True)
287
  except Exception as exc:
288
  print(f"[indexing] Repository index failed repo_id={repo_id} error={exc}", flush=True)
 
289
  self.vector_store.remove_repository(repo_id)
 
290
  self.hybrid_search.remove_repository(repo_id)
291
+ with self.repo_lock:
292
+ self.repo_chunks.pop(repo_id, None)
293
+ repo = self.repositories.get(repo_id)
294
+ if repo:
295
+ if repo_id in self.cancelled_repo_ids:
296
+ self._delete_repositories([repo], track_cancellation=False)
297
+ else:
298
+ repo.status = "failed"
299
+ repo.error_message = str(exc)
300
+ self._mark_repo_updated(repo)
301
  try:
302
+ if clone_info:
303
  self.repo_fetcher.cleanup_repository(clone_info["local_path"])
304
  except Exception:
305
  pass
306
+ with self.repo_lock:
307
+ self.indexing_progress.pop(repo_id, None)
308
  if isinstance(exc, SessionCancelledError):
309
  return
310
  raise
 
 
311
 
312
  def list_repositories(self) -> List[dict]:
313
  raise NotImplementedError
314
 
315
  def list_repositories_for_session(self, session_key: str) -> List[dict]:
316
+ with self.repo_lock:
317
+ self._cleanup_expired_sessions()
318
+ repos = [
319
+ repo
320
+ for repo in self.repositories.values()
321
+ if repo.session_key == session_key
322
+ ]
323
+ repos.sort(key=lambda repo: repo.updated_at, reverse=True)
324
+ self._touch_session(session_key)
 
325
  return [self._serialize_repo(repo) for repo in repos]
 
 
326
 
327
  def get_repository(self, repo_id: int) -> Optional[dict]:
328
  raise NotImplementedError
329
 
330
  def get_repository_for_session(self, repo_id: int, session_key: str) -> Optional[dict]:
331
+ with self.repo_lock:
332
+ self._cleanup_expired_sessions()
333
+ repo = self.repositories.get(repo_id)
334
+ if repo and repo.session_key != session_key:
335
+ repo = None
336
+ self._touch_session(session_key)
 
 
 
337
  return self._serialize_repo(repo) if repo else None
 
 
338
 
339
  def answer_question(
340
  self,
 
344
  top_k: int = 8,
345
  history=None,
346
  ) -> dict:
347
+ with self.repo_lock:
348
+ self._cleanup_expired_sessions()
349
+ repo = self.repositories.get(repo_id)
350
+ if repo and repo.session_key != session_key:
351
+ repo = None
 
 
 
352
  if repo is None:
353
  raise ValueError("Repository not found")
354
  if repo.status != "indexed":
355
  raise ValueError("Repository is not ready for questions yet")
356
  if repo_id not in self.repo_chunks:
357
  raise ValueError("Session cache expired. Re-index the repository and try again.")
358
+ repo_chunks = list(self.repo_chunks[repo_id])
359
+ self._touch_session(session_key)
360
 
361
+ normalized_history = self._normalize_history(history or [])
362
+ question_intent = self._question_intent(question)
363
+ deep_search_intents = {
364
+ "api",
365
+ "implementation",
366
+ "cross_file",
367
+ "error_handling",
368
+ "setup",
369
+ "tests",
370
+ }
371
+ deep_multiplier = int(os.getenv("RAG_DEEP_SEARCH_MULTIPLIER", "8"))
372
+ shallow_multiplier = int(os.getenv("RAG_SEARCH_MULTIPLIER", "4"))
373
+ search_depth = (
374
+ top_k * deep_multiplier
375
+ if question_intent in deep_search_intents
376
+ else top_k * shallow_multiplier
377
+ )
378
+ search_depth = max(top_k, min(search_depth, 120))
379
 
380
+ retrieval_query = self._build_retrieval_query(question, normalized_history)
381
+ query_embedding = self.embedder.embed_text(retrieval_query)
382
 
383
+ semantic_hits = []
384
+ for score, meta in self.vector_store.search(query_embedding, k=search_depth, repo_filter=repo_id):
385
+ serialized = dict(meta)
386
+ serialized["semantic_score"] = score
387
+ semantic_hits.append(serialized)
388
 
389
+ lexical_hits = self.hybrid_search.bm25_search(
390
+ repo_chunks,
391
+ retrieval_query,
392
+ top_k=search_depth,
393
+ )
394
+ semantic_hits = self.hybrid_search.normalize_semantic_results(semantic_hits)
395
+ fused = self.hybrid_search.reciprocal_rank_fusion(lexical_hits, semantic_hits, top_k=search_depth)
396
+
397
+ path_hits = self._path_intent_search(
398
+ repo_chunks,
399
+ question,
400
+ retrieval_query,
401
+ top_k=search_depth,
402
+ )
403
+ fused = self._merge_ranked_candidates(fused, path_hits, top_k=search_depth)
404
 
405
+ rerank_query = retrieval_query if question_intent in deep_search_intents else question
406
 
407
+ # FIX: rerank to a small candidate pool first (20), then let
408
+ # _prioritize_results and _select_answer_sources trim to final top_k.
409
+ # Previously rerank was called with search_depth (up to 120), meaning
410
+ # the LLM received far too many chunks and faithfulness dropped.
411
+ rerank_pool = min(search_depth, 20)
412
+ reranked = self.hybrid_search.rerank(rerank_query, fused, top_k=rerank_pool)
413
 
414
+ reranked = self._prioritize_results(question, retrieval_query, reranked, top_k=top_k)
415
 
416
+ # FIX: cap final sources at 5 instead of top_k (8).
417
+ # 5 sources × 1500 chars = ~7500 chars context, which the LLM handles well.
418
+ # 8 sources × 2500 chars = ~20000 chars, which causes lost-in-the-middle issues.
419
+ final_top_k = min(top_k, 5)
420
+ reranked = self._select_answer_sources(question, reranked, top_k=final_top_k)
421
 
422
+ answer = self._generate_answer(repo, question, reranked, normalized_history)
423
+ return answer
 
 
424
 
425
  def end_session(self, session_key: str):
426
+ with self.repo_lock:
427
+ repos = [
428
+ repo
429
+ for repo in self.repositories.values()
430
+ if repo.session_key == session_key
431
+ ]
432
+ self._delete_repositories(repos)
433
 
434
 
435
  def _generate_answer(
 
805
  return payload
806
 
807
  def _set_progress(self, repo_id: int, **progress):
808
+ with self.repo_lock:
809
+ self.indexing_progress[repo_id] = {
810
+ **self.indexing_progress.get(repo_id, {}),
811
+ **progress,
812
+ "updated_at": datetime.utcnow().isoformat(),
813
+ }
814
 
815
+ def _touch_session(self, session_key: str):
816
  expiry = self._session_expiry()
817
+ for repo in self.repositories.values():
818
+ if repo.session_key == session_key:
819
+ repo.session_expires_at = expiry
820
+ self._mark_repo_updated(repo)
821
 
822
+ def _cleanup_expired_sessions(self):
823
  now = datetime.utcnow()
824
+ expired = [
825
+ repo
826
+ for repo in self.repositories.values()
827
+ if repo.session_expires_at is not None and repo.session_expires_at < now
828
+ ]
 
829
  if not expired:
830
  return
831
+ self._delete_repositories(expired)
 
832
 
833
  def _delete_repositories(
834
  self,
 
835
  repos: List[Repository],
836
  track_cancellation: bool = True,
837
  ):
 
843
  self.vector_store.remove_repository(repo_id)
844
  self.repo_chunks.pop(repo_id, None)
845
  self.indexing_progress.pop(repo_id, None)
846
+ repo = self.repositories.pop(repo_id, None)
847
+ if repo:
848
+ self.repository_registry.pop(repo.github_url, None)
849
 
850
  def _ensure_repo_not_cancelled(self, repo_id: int):
851
  if repo_id in self.cancelled_repo_ids:
852
  raise SessionCancelledError("Session ended before indexing completed.")
853
 
854
+ @staticmethod
855
+ def _mark_repo_updated(repo: Repository):
856
+ repo.updated_at = datetime.utcnow()
857
+
858
  def _build_retrieval_query(self, question: str, history: List[dict]) -> str:
859
  normalized = " ".join(question.strip().split())
860
  if self._is_repo_overview_question(normalized):
 
1668
  lines.append(f"{role}: {content[:400]}")
1669
  return "\n".join(lines) if lines else "None"
1670
 
1671
+ def _ensure_repo_still_exists(self, repo_id: int):
1672
+ if repo_id not in self.repositories:
 
1673
  raise RuntimeError("Repository was removed before indexing completed.")
1674
 
1675
  def _session_expiry(self) -> datetime:
src/vector_store.py CHANGED
@@ -1,60 +1,44 @@
1
  import os
 
2
  from typing import List, Optional, Tuple
3
  from uuid import uuid4
4
 
5
  import numpy as np
6
- from qdrant_client import QdrantClient, models
 
7
 
8
 
9
- class QdrantVectorStore:
10
- def __init__(self, embedding_dim: int, index_path: str = None, persist: bool = False):
11
  self.embedding_dim = embedding_dim
12
- self.collection_name = os.getenv("QDRANT_COLLECTION", "repo_qa_chunks")
13
- self.upsert_batch_size = max(1, int(os.getenv("QDRANT_UPSERT_BATCH_SIZE", "64")))
14
- self.qdrant_url = self._clean_env("QDRANT_URL")
15
- self.qdrant_api_key = self._clean_env("QDRANT_API_KEY")
16
- self.timeout = int(os.getenv("QDRANT_TIMEOUT_SECONDS", "120"))
17
  self.client = self._create_client()
18
- self._ensure_collection()
19
 
20
  def _create_client(self):
21
- if self.qdrant_url:
22
- return QdrantClient(
23
- url=self.qdrant_url,
24
- api_key=self.qdrant_api_key,
25
- timeout=self.timeout,
26
- check_compatibility=False,
 
 
27
  )
28
- return QdrantClient(":memory:")
29
 
30
- @staticmethod
31
- def _clean_env(name: str) -> Optional[str]:
32
- value = os.getenv(name)
33
- if value is None:
34
- return None
35
- cleaned = value.strip()
36
- return cleaned or None
37
 
38
  def _ensure_collection(self):
39
- if not self.client.collection_exists(self.collection_name):
40
- self.client.create_collection(
41
- collection_name=self.collection_name,
42
- vectors_config=models.VectorParams(
43
- size=self.embedding_dim,
44
- distance=models.Distance.COSINE,
45
- ),
46
- )
47
- self._ensure_payload_indexes()
48
-
49
- def _ensure_payload_indexes(self):
50
- self.client.create_payload_index(
51
- collection_name=self.collection_name,
52
- field_name="repository_id",
53
- field_schema=models.PayloadSchemaType.INTEGER,
54
- wait=True,
55
  )
56
 
57
- def add_embeddings(self, embeddings: np.ndarray, metadata: List[dict]) -> List[int]:
58
  if embeddings.size == 0:
59
  return []
60
 
@@ -63,31 +47,33 @@ class QdrantVectorStore:
63
  embeddings = embeddings.reshape(1, -1)
64
 
65
  ids = [uuid4().hex for _ in metadata]
66
- points = []
67
- for idx, meta, embedding in zip(ids, metadata, embeddings):
68
- payload = dict(meta)
69
- payload["id"] = idx
70
- points.append(
71
- models.PointStruct(
72
- id=idx,
73
- vector=embedding.tolist(),
74
- payload=payload,
75
- )
76
- )
77
- total_points = len(points)
78
  for start in range(0, total_points, self.upsert_batch_size):
79
- batch = points[start : start + self.upsert_batch_size]
 
 
 
 
 
 
 
 
 
 
 
80
  batch_number = (start // self.upsert_batch_size) + 1
81
  total_batches = (total_points + self.upsert_batch_size - 1) // self.upsert_batch_size
82
  print(
83
- f"[qdrant] Upserting batch {batch_number}/{total_batches} "
84
- f"points={len(batch)} progress={start}/{total_points}",
85
  flush=True,
86
  )
87
- self.client.upsert(
88
- collection_name=self.collection_name,
89
- wait=True,
90
- points=batch,
 
91
  )
92
 
93
  return ids
@@ -102,67 +88,75 @@ class QdrantVectorStore:
102
  query_embedding = query_embedding.reshape(1, -1)
103
  query_embedding = query_embedding.astype("float32")
104
 
105
- query_filter = None
106
- if repo_filter is not None:
107
- query_filter = models.Filter(
108
- must=[
109
- models.FieldCondition(
110
- key="repository_id",
111
- match=models.MatchValue(value=repo_filter),
112
- )
113
- ]
114
- )
115
-
116
- hits = self.client.search(
117
- collection_name=self.collection_name,
118
- query_vector=query_embedding[0].tolist(),
119
- query_filter=query_filter,
120
- limit=k,
121
  )
122
 
123
- return [(float(hit.score), dict(hit.payload or {})) for hit in hits]
 
 
 
 
 
 
 
 
 
 
 
124
 
125
  def remove_repository(self, repo_id: int):
126
- self.client.delete(
127
- collection_name=self.collection_name,
128
- wait=True,
129
- points_selector=models.FilterSelector(
130
- filter=models.Filter(
131
- must=[
132
- models.FieldCondition(
133
- key="repository_id",
134
- match=models.MatchValue(value=repo_id),
135
- )
136
- ]
137
- )
138
- ),
139
- )
140
 
141
  def clear(self):
142
- if self.client.collection_exists(self.collection_name):
143
- self.client.delete_collection(self.collection_name)
144
- self._ensure_collection()
 
 
145
 
146
  def save(self):
147
- return None
 
 
148
 
149
  def load(self):
150
- self._ensure_collection()
151
-
152
- def is_remote(self) -> bool:
153
- return self.qdrant_url is not None
154
 
155
  def keep_alive(self) -> dict:
156
- info = self.client.get_collection(self.collection_name)
157
- return {
158
- "total_vectors": info.points_count or 0,
159
- "collection_name": self.collection_name,
160
- }
161
 
162
  def get_stats(self) -> dict:
163
- info = self.client.get_collection(self.collection_name)
164
  return {
165
- "total_vectors": info.points_count or 0,
166
  "embedding_dim": self.embedding_dim,
167
  "collection_name": self.collection_name,
 
168
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ from pathlib import Path
3
  from typing import List, Optional, Tuple
4
  from uuid import uuid4
5
 
6
  import numpy as np
7
+ from chromadb import Client
8
+ from chromadb.config import Settings
9
 
10
 
11
+ class ChromaVectorStore:
12
+ def __init__(self, embedding_dim: int, index_path: str = None, persist: bool = True):
13
  self.embedding_dim = embedding_dim
14
+ self.collection_name = os.getenv("CHROMA_COLLECTION", "repo_qa_chunks")
15
+ self.upsert_batch_size = max(1, int(os.getenv("CHROMA_UPSERT_BATCH_SIZE", "64")))
16
+ self.persist_path = os.getenv("CHROMA_PATH", index_path or "./data/chroma")
17
+ self.persist = persist
 
18
  self.client = self._create_client()
19
+ self.collection = self._ensure_collection()
20
 
21
  def _create_client(self):
22
+ if self.persist:
23
+ Path(self.persist_path).mkdir(parents=True, exist_ok=True)
24
+ return Client(
25
+ Settings(
26
+ is_persistent=True,
27
+ persist_directory=self.persist_path,
28
+ anonymized_telemetry=False,
29
+ )
30
  )
 
31
 
32
+ return Client(Settings(anonymized_telemetry=False))
 
 
 
 
 
 
33
 
34
  def _ensure_collection(self):
35
+ return self.client.get_or_create_collection(
36
+ name=self.collection_name,
37
+ embedding_function=None,
38
+ metadata={"hnsw:space": "cosine"},
 
 
 
 
 
 
 
 
 
 
 
 
39
  )
40
 
41
+ def add_embeddings(self, embeddings: np.ndarray, metadata: List[dict]) -> List[str]:
42
  if embeddings.size == 0:
43
  return []
44
 
 
47
  embeddings = embeddings.reshape(1, -1)
48
 
49
  ids = [uuid4().hex for _ in metadata]
50
+ total_points = len(ids)
51
+
 
 
 
 
 
 
 
 
 
 
52
  for start in range(0, total_points, self.upsert_batch_size):
53
+ end = start + self.upsert_batch_size
54
+ batch_ids = ids[start:end]
55
+ batch_embeddings = embeddings[start:end].tolist()
56
+ batch_metadata = []
57
+ batch_documents = []
58
+
59
+ for idx, meta in zip(batch_ids, metadata[start:end]):
60
+ payload = self._sanitize_metadata(meta)
61
+ payload["id"] = idx
62
+ batch_metadata.append(payload)
63
+ batch_documents.append(str(meta.get("content") or ""))
64
+
65
  batch_number = (start // self.upsert_batch_size) + 1
66
  total_batches = (total_points + self.upsert_batch_size - 1) // self.upsert_batch_size
67
  print(
68
+ f"[chroma] Adding batch {batch_number}/{total_batches} "
69
+ f"points={len(batch_ids)} progress={start}/{total_points}",
70
  flush=True,
71
  )
72
+ self.collection.add(
73
+ ids=batch_ids,
74
+ embeddings=batch_embeddings,
75
+ metadatas=batch_metadata,
76
+ documents=batch_documents,
77
  )
78
 
79
  return ids
 
88
  query_embedding = query_embedding.reshape(1, -1)
89
  query_embedding = query_embedding.astype("float32")
90
 
91
+ where = {"repository_id": repo_filter} if repo_filter is not None else None
92
+ results = self.collection.query(
93
+ query_embeddings=[query_embedding[0].tolist()],
94
+ n_results=k,
95
+ where=where,
96
+ include=["documents", "metadatas", "distances"],
 
 
 
 
 
 
 
 
 
 
97
  )
98
 
99
+ ids = (results.get("ids") or [[]])[0]
100
+ documents = (results.get("documents") or [[]])[0]
101
+ metadatas = (results.get("metadatas") or [[]])[0]
102
+ distances = (results.get("distances") or [[]])[0]
103
+
104
+ hits = []
105
+ for idx, document, meta, distance in zip(ids, documents, metadatas, distances):
106
+ payload = dict(meta or {})
107
+ payload["id"] = payload.get("id") or idx
108
+ payload["content"] = document or ""
109
+ hits.append((self._distance_to_score(distance), payload))
110
+ return hits
111
 
112
  def remove_repository(self, repo_id: int):
113
+ self.collection.delete(where={"repository_id": repo_id})
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  def clear(self):
116
+ try:
117
+ self.client.delete_collection(name=self.collection_name)
118
+ except Exception:
119
+ pass
120
+ self.collection = self._ensure_collection()
121
 
122
  def save(self):
123
+ persist = getattr(self.client, "persist", None)
124
+ if callable(persist):
125
+ persist()
126
 
127
  def load(self):
128
+ self.collection = self._ensure_collection()
 
 
 
129
 
130
  def keep_alive(self) -> dict:
131
+ heartbeat = getattr(self.client, "heartbeat", None)
132
+ if callable(heartbeat):
133
+ heartbeat()
134
+ return self.get_stats()
 
135
 
136
  def get_stats(self) -> dict:
 
137
  return {
138
+ "total_vectors": self.collection.count(),
139
  "embedding_dim": self.embedding_dim,
140
  "collection_name": self.collection_name,
141
+ "persist_path": self.persist_path if self.persist else None,
142
  }
143
+
144
+ @staticmethod
145
+ def _sanitize_metadata(meta: dict) -> dict:
146
+ sanitized = {}
147
+ for key, value in meta.items():
148
+ if key == "content":
149
+ continue
150
+ if value is None:
151
+ sanitized[key] = ""
152
+ elif isinstance(value, (str, int, float, bool)):
153
+ sanitized[key] = value
154
+ else:
155
+ sanitized[key] = str(value)
156
+ return sanitized
157
+
158
+ @staticmethod
159
+ def _distance_to_score(distance: float) -> float:
160
+ if distance is None:
161
+ return 0.0
162
+ return max(0.0, min(1.0, 1.0 - float(distance)))