Spaces:
Runtime error
Runtime error
| from smolagents.tools import Tool | |
| import asyncio | |
| from collections import Counter | |
| import hashlib | |
| import logging | |
| import math | |
| import os | |
| from pathlib import Path | |
| from typing import Iterable, List, Optional | |
| import re | |
| from dotenv import load_dotenv | |
| import chromadb | |
| from chromadb.errors import NotFoundError | |
| from pypdf import PdfReader | |
| from llama_index.core import StorageContext, VectorStoreIndex | |
| from llama_index.core.schema import Document, BaseNode, NodeWithScore, TextNode | |
| from llama_index.core.node_parser import SentenceSplitter | |
| from llama_index.vector_stores.chroma import ChromaVectorStore | |
| load_dotenv() | |
| BASE_DIR = Path(__file__).resolve().parent | |
| PROJECT_ROOT = BASE_DIR.parent | |
| KNOWLEDGE_BASE_DIR = BASE_DIR / "knowledge_base" | |
| LEGACY_KNOWLEDGE_BASE_DIR = BASE_DIR / "knowledge_base" | |
| KNOWLEDGE_BASE_DIR = PROJECT_ROOT / "knowledge_base" | |
| RAW_DIR = KNOWLEDGE_BASE_DIR / "raw" | |
| CHROMA_DB_DIR = KNOWLEDGE_BASE_DIR / "chroma_db" | |
| HF_CACHE_DIR = PROJECT_ROOT / "hf_cache" | |
| COLLECTION_NAME = "options_knowledge" | |
| EMBED_MODEL_NAME = os.getenv("RAG_EMBED_MODEL", "BAAI/bge-small-en-v1.5") | |
| RERANKER_MODEL_NAME = os.getenv( | |
| "RAG_RERANKER_MODEL", "cross-encoder/ms-marco-MiniLM-L-6-v2") | |
| RERANKER_BATCH_SIZE = int(os.getenv("RAG_RERANKER_BATCH_SIZE", "16")) | |
| EMBED_MODEL_METADATA_KEY = "embedding_model" | |
| BM25_METADATA_KEY = "bm25_score" | |
| VECTOR_METADATA_KEY = "vector_score" | |
| CHUNK_SIZE = 1000 | |
| CHUNK_OVERLAP = 150 | |
| PDF_REPEATED_LINE_MIN_PAGES = 3 | |
| PDF_BOUNDARY_LINE_COUNT = 4 | |
| PDF_EXTRACTION_METHOD = "pymupdf_formula_blocks_v5" | |
| PDF_LINE_Y_TOLERANCE = 3.0 | |
| PDF_MIN_SECTION_CHARS = 240 | |
| PDF_STRONG_MATH_SYMBOLS = set("=∂∫∑∏√∞≈≠≤≥±×÷^_σΣΔδθΘλΛμρπΠφΦτν𝜎𝜇𝜌𝜃𝜕") | |
| PDF_WEAK_MATH_SYMBOLS = set("+-−*/∕<>") | |
| PDF_MATH_SYMBOLS = PDF_STRONG_MATH_SYMBOLS | PDF_WEAK_MATH_SYMBOLS | |
| PDF_OPERATOR_MATH_SYMBOLS = set("=∂∫∑∏√∞≈≠≤≥±×÷^_+-−*/∕<>") | |
| PDF_FORMULA_TRIGGER_SYMBOLS = set("=∂∫∑∏√∞≈≠≤≥±×÷^_∕<>") | |
| logging.getLogger("pypdf").setLevel(logging.ERROR) | |
| def load_pymupdf(): | |
| try: | |
| import fitz | |
| except ImportError: | |
| return None | |
| return fitz | |
| REQUIRED_METADATA = [ | |
| "source_file", | |
| "file_name", | |
| "file_type", | |
| "document_title", | |
| "file_hash", | |
| "chunk_id", | |
| "chunk_index", | |
| ] | |
| def configure_model_cache() -> None: | |
| HF_CACHE_DIR.mkdir(parents=True, exist_ok=True) | |
| os.environ.setdefault("HF_HOME", str(HF_CACHE_DIR)) | |
| os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", str( | |
| HF_CACHE_DIR / "sentence_transformers")) | |
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") | |
| if local_model_snapshot(EMBED_MODEL_NAME): | |
| os.environ.setdefault("HF_HUB_OFFLINE", "1") | |
| os.environ.setdefault("TRANSFORMERS_OFFLINE", "1") | |
| def local_model_snapshot(model_name: str) -> Optional[Path]: | |
| cached_model_dir = ( | |
| HF_CACHE_DIR | |
| / "sentence_transformers" | |
| / f"models--{model_name.replace('/', '--')}" | |
| ) | |
| snapshots_dir = cached_model_dir / "snapshots" | |
| if snapshots_dir.exists(): | |
| snapshots = sorted(path for path in snapshots_dir.iterdir() if path.is_dir()) | |
| for snapshot in reversed(snapshots): | |
| if (snapshot / "config.json").exists(): | |
| return snapshot | |
| return None | |
| def resolve_embed_model_name() -> str: | |
| snapshot = local_model_snapshot(EMBED_MODEL_NAME) | |
| if snapshot: | |
| return str(snapshot) | |
| return EMBED_MODEL_NAME | |
| def resolve_reranker_model_name(model_name: str = RERANKER_MODEL_NAME) -> str: | |
| snapshot = local_model_snapshot(model_name) | |
| if snapshot: | |
| return str(snapshot) | |
| return model_name | |
| def env_flag(name: str, default: bool = False) -> bool: | |
| value = os.getenv(name) | |
| if value is None: | |
| return default | |
| return value.strip().lower() in {"1", "true", "yes", "on"} | |
| def effective_raw_dir(raw_dir: Path = RAW_DIR) -> Path: | |
| if any(iter_source_files(raw_dir)): | |
| return raw_dir | |
| legacy_raw_dir = LEGACY_KNOWLEDGE_BASE_DIR / "raw" | |
| if any(iter_source_files(legacy_raw_dir)): | |
| logging.warning( | |
| "Using legacy knowledge base path %s. Move files to %s when convenient.", | |
| legacy_raw_dir, | |
| raw_dir, | |
| ) | |
| return legacy_raw_dir | |
| return raw_dir | |
| class CrossEncoderReranker: | |
| def __init__( | |
| self, | |
| model_name: str = RERANKER_MODEL_NAME, | |
| batch_size: int = RERANKER_BATCH_SIZE, | |
| ): | |
| self.model_name = model_name | |
| self.batch_size = batch_size | |
| self._model = None | |
| def _load_model(self): | |
| if self._model is not None: | |
| return self._model | |
| from sentence_transformers import CrossEncoder | |
| self._model = CrossEncoder( | |
| resolve_reranker_model_name(self.model_name), | |
| max_length=512, | |
| cache_folder=str(HF_CACHE_DIR / "sentence_transformers"), | |
| ) | |
| return self._model | |
| def rerank( | |
| self, | |
| query: str, | |
| results: list[NodeWithScore], | |
| top_n: Optional[int] = None, | |
| ) -> list[NodeWithScore]: | |
| if not results: | |
| return [] | |
| pairs = [ | |
| (query, result.node.get_content()) | |
| for result in results | |
| ] | |
| model = self._load_model() | |
| scores = model.predict( | |
| pairs, | |
| batch_size=self.batch_size, | |
| show_progress_bar=False, | |
| ) | |
| reranked = [ | |
| NodeWithScore(node=result.node, score=float(score)) | |
| for result, score in zip(results, scores) | |
| ] | |
| reranked.sort(key=lambda item: item.score or float("-inf"), reverse=True) | |
| return reranked[:top_n] if top_n else reranked | |
| class BM25Retriever: | |
| def __init__(self, nodes: list[TextNode]): | |
| self.nodes = nodes | |
| self.tokenized_docs = [self.tokenize(node.get_content()) for node in nodes] | |
| self.doc_freqs: Counter[str] = Counter() | |
| for tokens in self.tokenized_docs: | |
| self.doc_freqs.update(set(tokens)) | |
| self.avg_doc_len = ( | |
| sum(len(tokens) for tokens in self.tokenized_docs) / len(self.tokenized_docs) | |
| if self.tokenized_docs | |
| else 0.0 | |
| ) | |
| def tokenize(text: str) -> list[str]: | |
| return [ | |
| token.lower() | |
| for token in re.findall(r"[A-Za-z]+(?:[-'][A-Za-z]+)*|\d+(?:\.\d+)*|[^\sA-Za-z0-9]", text) | |
| if token.strip() | |
| ] | |
| def score(self, query_tokens: list[str], doc_tokens: list[str]) -> float: | |
| if not query_tokens or not doc_tokens: | |
| return 0.0 | |
| token_counts = Counter(doc_tokens) | |
| doc_len = len(doc_tokens) | |
| total_docs = len(self.tokenized_docs) | |
| k1 = 1.5 | |
| b = 0.75 | |
| score = 0.0 | |
| for token in query_tokens: | |
| term_freq = token_counts.get(token, 0) | |
| if term_freq == 0: | |
| continue | |
| doc_freq = self.doc_freqs.get(token, 0) | |
| idf = math.log(1 + (total_docs - doc_freq + 0.5) / (doc_freq + 0.5)) | |
| denominator = term_freq + k1 * ( | |
| 1 - b + b * doc_len / max(self.avg_doc_len, 1.0) | |
| ) | |
| score += idf * (term_freq * (k1 + 1)) / denominator | |
| return score | |
| def retrieve(self, query: str, top_k: int) -> list[NodeWithScore]: | |
| query_tokens = self.tokenize(query) | |
| scored: list[NodeWithScore] = [] | |
| for node, doc_tokens in zip(self.nodes, self.tokenized_docs): | |
| score = self.score(query_tokens, doc_tokens) | |
| if score <= 0: | |
| continue | |
| node.metadata[BM25_METADATA_KEY] = score | |
| scored.append(NodeWithScore(node=node, score=score)) | |
| scored.sort(key=lambda item: item.score or float("-inf"), reverse=True) | |
| return scored[:top_k] | |
| def file_sha256(path: Path) -> str: | |
| digest = hashlib.sha256() | |
| with path.open("rb") as file: | |
| for block in iter(lambda: file.read(1024 * 1024), b""): | |
| digest.update(block) | |
| return digest.hexdigest() | |
| def load_md_file(path: Path) -> Document: | |
| text = path.read_text(encoding="utf-8") | |
| return Document( | |
| text=text, | |
| metadata={ | |
| "source_file": str(path.resolve()), | |
| "file_name": path.name, | |
| "file_type": "md", | |
| "document_title": path.stem, | |
| "file_hash": file_sha256(path), | |
| }, | |
| ) | |
| def load_md_documents(path: Path) -> List[Document]: | |
| text = path.read_text(encoding="utf-8") | |
| file_hash = file_sha256(path) | |
| documents: List[Document] = [] | |
| current_heading = "" | |
| current_lines: List[str] = [] | |
| def flush() -> None: | |
| nonlocal current_lines | |
| section_text = "\n".join(current_lines).strip() | |
| if not section_text: | |
| current_lines = [] | |
| return | |
| documents.append( | |
| Document( | |
| text=section_text, | |
| metadata={ | |
| "source_file": str(path.resolve()), | |
| "file_name": path.name, | |
| "file_type": path.suffix.lower().lstrip("."), | |
| "document_title": path.stem, | |
| "file_hash": file_hash, | |
| "content_type": "markdown_section", | |
| "chapter_title": "", | |
| "section_title": current_heading, | |
| "section_path": current_heading, | |
| "char_count": len(section_text), | |
| }, | |
| ) | |
| ) | |
| current_lines = [] | |
| for line in text.splitlines(): | |
| heading_match = re.match(r"^(#{1,6})\s+(.+?)\s*$", line) | |
| if heading_match: | |
| flush() | |
| current_heading = heading_match.group(2).strip() | |
| current_lines.append(line) | |
| flush() | |
| return documents or [load_md_file(path)] | |
| def append_visual_fragment(line_parts: List[str], text: str, baseline_y: float, item: dict) -> None: | |
| if not text: | |
| return | |
| stripped = text.strip() | |
| if not stripped: | |
| return | |
| font_size = item["font_size"] | |
| y_offset = item["y"] - baseline_y | |
| is_small = font_size < item["line_font_size"] * 0.82 | |
| if is_small and y_offset > max(1.5, item["line_font_size"] * 0.18): | |
| line_parts.append(f"^{{{stripped}}}") | |
| elif is_small and y_offset < -max(1.5, item["line_font_size"] * 0.18): | |
| line_parts.append(f"_{{{stripped}}}") | |
| else: | |
| line_parts.append(stripped) | |
| def join_visual_line(items: List[dict]) -> str: | |
| if not items: | |
| return "" | |
| items = sorted(items, key=lambda value: value["x"]) | |
| baseline_y = sorted(item["y"] for item in items)[len(items) // 2] | |
| line_font_size = max(item["font_size"] for item in items) | |
| previous_right = None | |
| line_parts: List[str] = [] | |
| for item in items: | |
| item["line_font_size"] = line_font_size | |
| if previous_right is not None: | |
| gap = item["x"] - previous_right | |
| if gap > max(2.5, line_font_size * 0.28): | |
| line_parts.append(" ") | |
| append_visual_fragment(line_parts, item["text"], baseline_y, item) | |
| previous_right = max(previous_right or item["x"], item["x"] + item["width"]) | |
| return normalize_pdf_line("".join(line_parts)) | |
| def extract_pdf_text_by_position(page) -> str: | |
| fragments: List[dict] = [] | |
| def visitor_text(text, cm, tm, font_dict, font_size): | |
| if not text or not text.strip(): | |
| return | |
| x = float(tm[4]) | |
| y = float(tm[5]) | |
| width = max(len(text.strip()) * float(font_size) * 0.45, float(font_size)) | |
| fragments.append( | |
| { | |
| "text": text, | |
| "x": x, | |
| "y": y, | |
| "width": width, | |
| "font_size": float(font_size or 1.0), | |
| } | |
| ) | |
| try: | |
| page.extract_text(visitor_text=visitor_text) | |
| except Exception: | |
| return "" | |
| if not fragments: | |
| return "" | |
| lines: List[List[dict]] = [] | |
| for fragment in sorted(fragments, key=lambda value: (-value["y"], value["x"])): | |
| for line in lines: | |
| if abs(line[0]["y"] - fragment["y"]) <= PDF_LINE_Y_TOLERANCE: | |
| line.append(fragment) | |
| break | |
| else: | |
| lines.append([fragment]) | |
| return "\n".join(join_visual_line(line) for line in lines) | |
| def math_text_score(text: str) -> float: | |
| if not text.strip(): | |
| return 0.0 | |
| lines = [line for line in text.splitlines() if line.strip()] | |
| compact_length = len(re.sub(r"\s+", "", text)) | |
| math_symbol_count = sum(1 for char in text if char in PDF_MATH_SYMBOLS) | |
| superscript_markers = text.count("^{") + text.count("_{") | |
| multiline_bonus = sum(1 for line in lines if is_formula_like(line)) * 8 | |
| equation_block_bonus = sum( | |
| 1 | |
| for index, line in enumerate(lines) | |
| if is_formula_like(line) | |
| and ( | |
| index > 0 | |
| and is_formula_like(lines[index - 1]) | |
| or index + 1 < len(lines) | |
| and is_formula_like(lines[index + 1]) | |
| ) | |
| ) * 12 | |
| return ( | |
| compact_length | |
| + math_symbol_count * 12 | |
| + superscript_markers * 20 | |
| + multiline_bonus | |
| + equation_block_bonus | |
| ) | |
| def extract_pdf_text(page) -> str: | |
| positioned_text = extract_pdf_text_by_position(page) | |
| try: | |
| layout_text = page.extract_text(extraction_mode="layout") or "" | |
| except Exception: | |
| layout_text = "" | |
| try: | |
| plain_text = page.extract_text() or "" | |
| except Exception: | |
| plain_text = "" | |
| candidates = [positioned_text, layout_text, plain_text] | |
| candidates = [candidate for candidate in candidates if candidate.strip()] | |
| if not candidates: | |
| return "" | |
| return max(candidates, key=math_text_score) | |
| def pymupdf_span_text(span: dict) -> str: | |
| return normalize_pdf_line(span.get("text", "")) | |
| def pymupdf_line_text(line: dict) -> str: | |
| return normalize_pdf_line("".join(pymupdf_span_text(span) for span in line.get("spans", []))) | |
| def pymupdf_block_text(block: dict) -> str: | |
| lines = [ | |
| pymupdf_line_text(line) | |
| for line in block.get("lines", []) | |
| ] | |
| return "\n".join(line for line in lines if line) | |
| def pymupdf_span_has_math_font(span: dict) -> bool: | |
| font_name = span.get("font", "").lower() | |
| return any( | |
| marker in font_name | |
| for marker in ("math", "symbol", "cmmi", "cmsy", "cmex", "stix") | |
| ) | |
| def is_formula_block_line(line: str) -> bool: | |
| stripped = line.strip() | |
| if not stripped: | |
| return False | |
| trigger_math_count = sum(1 for char in stripped if char in PDF_FORMULA_TRIGGER_SYMBOLS) | |
| digit_count = sum(1 for char in stripped if char.isdigit()) | |
| alpha_count = sum(1 for char in stripped if char.isalpha()) | |
| alpha_words = [ | |
| word | |
| for word in re.findall(r"[A-Za-z]+", stripped) | |
| if word.lower() not in {"and", "or", "the", "where", "then", "with", "for"} | |
| ] | |
| compact_length = len(re.sub(r"\s+", "", stripped)) | |
| if compact_length < 3: | |
| return False | |
| if re.fullmatch(r"\(?\d+(\.\d+)?\)?", stripped): | |
| return False | |
| if re.search(r"\(\d+(\.\d+)+[a-z]?\)$", stripped) and compact_length <= 240: | |
| return True | |
| if "=" in stripped and compact_length <= 260 and len(alpha_words) <= 12: | |
| return True | |
| if any(char in stripped for char in "∂∫∑∏√∞≈≠≤≥±×÷") and compact_length <= 220 and len(alpha_words) <= 10: | |
| return True | |
| if trigger_math_count >= 2 and compact_length <= 120 and len(alpha_words) <= 6: | |
| return True | |
| if trigger_math_count >= 1 and digit_count >= 1 and alpha_count <= 18 and compact_length <= 100: | |
| return True | |
| return False | |
| def is_formula_block(block: dict) -> bool: | |
| text = pymupdf_block_text(block) | |
| if not text: | |
| return False | |
| lines = [line for line in text.splitlines() if line.strip()] | |
| if any(is_formula_block_line(line) for line in lines): | |
| return True | |
| spans = [ | |
| span | |
| for line in block.get("lines", []) | |
| for span in line.get("spans", []) | |
| if pymupdf_span_text(span) | |
| ] | |
| if not spans: | |
| return False | |
| math_font_count = sum(1 for span in spans if pymupdf_span_has_math_font(span)) | |
| strong_math_count = sum(1 for char in text if char in PDF_STRONG_MATH_SYMBOLS) | |
| alpha_count = sum(1 for char in text if char.isalpha()) | |
| digit_count = sum(1 for char in text if char.isdigit()) | |
| compact_length = len(re.sub(r"\s+", "", text)) | |
| if math_font_count >= 2 and compact_length <= 220: | |
| return True | |
| if strong_math_count >= 3 and compact_length <= 260: | |
| return True | |
| if strong_math_count >= 1 and digit_count >= 1 and alpha_count <= 20 and compact_length <= 160: | |
| return True | |
| return False | |
| def block_bbox_string(block: dict) -> str: | |
| bbox = block.get("bbox") or [] | |
| if len(bbox) != 4: | |
| return "" | |
| return ",".join(f"{float(value):.2f}" for value in bbox) | |
| def line_bbox_string(line: dict) -> str: | |
| bbox = line.get("bbox") or [] | |
| if len(bbox) != 4: | |
| return "" | |
| return ",".join(f"{float(value):.2f}" for value in bbox) | |
| def pymupdf_line_has_math_font(line: dict) -> bool: | |
| return any( | |
| pymupdf_span_has_math_font(span) | |
| for span in line.get("spans", []) | |
| if pymupdf_span_text(span) | |
| ) | |
| def should_extract_formula_line(line: dict) -> bool: | |
| text = pymupdf_line_text(line) | |
| if not text: | |
| return False | |
| if is_formula_block_line(text): | |
| return True | |
| compact_length = len(re.sub(r"\s+", "", text)) | |
| trigger_math_count = sum(1 for char in text if char in PDF_FORMULA_TRIGGER_SYMBOLS) | |
| alpha_words = re.findall(r"[A-Za-z]+", text) | |
| if ( | |
| pymupdf_line_has_math_font(line) | |
| and trigger_math_count >= 1 | |
| and compact_length <= 180 | |
| and len(alpha_words) <= 6 | |
| ): | |
| return True | |
| return False | |
| def is_formula_continuation_line(text: str) -> bool: | |
| stripped = text.strip() | |
| if not stripped: | |
| return False | |
| compact = re.sub(r"\s+", "", stripped) | |
| if len(compact) > 90: | |
| return False | |
| if compact in {"(", ")", "[", "]", "{", "}", "√"}: | |
| return True | |
| alpha_words = re.findall(r"[A-Za-z]+", stripped) | |
| math_count = sum(1 for char in stripped if char in PDF_MATH_SYMBOLS) | |
| digit_count = sum(1 for char in stripped if char.isdigit()) | |
| if len(alpha_words) <= 4 and (math_count >= 1 or digit_count >= 1): | |
| return True | |
| return False | |
| def append_formula_block( | |
| formula_blocks: List[dict], | |
| body_blocks: List[str], | |
| page_number: int, | |
| formula_index: int, | |
| formula_lines: List[str], | |
| formula_bboxes: List[str], | |
| ) -> int: | |
| formula_text = clean_formula_text("\n".join(formula_lines)) | |
| if not is_useful_formula_text(formula_text): | |
| return formula_index | |
| formula_id = f"formula-{page_number}-{formula_index}" | |
| formula_bbox = merge_bbox_strings(formula_bboxes) | |
| formula_blocks.append( | |
| { | |
| "id": formula_id, | |
| "text": formula_text, | |
| "bbox": formula_bbox, | |
| } | |
| ) | |
| body_blocks.append(f"[FORMULA id={formula_id}]\n{formula_text}\n[/FORMULA]") | |
| return formula_index + 1 | |
| def merge_bbox_strings(bbox_strings: List[str]) -> str: | |
| boxes = [] | |
| for bbox_string in bbox_strings: | |
| if not bbox_string: | |
| continue | |
| values = bbox_string.split(",") | |
| if len(values) != 4: | |
| continue | |
| try: | |
| boxes.append([float(value) for value in values]) | |
| except ValueError: | |
| continue | |
| if not boxes: | |
| return "" | |
| x0 = min(box[0] for box in boxes) | |
| y0 = min(box[1] for box in boxes) | |
| x1 = max(box[2] for box in boxes) | |
| y1 = max(box[3] for box in boxes) | |
| return f"{x0:.2f},{y0:.2f},{x1:.2f},{y1:.2f}" | |
| def is_useful_formula_text(text: str) -> bool: | |
| stripped = text.strip() | |
| if not stripped: | |
| return False | |
| compact_length = len(re.sub(r"\s+", "", stripped)) | |
| if compact_length < 6: | |
| return False | |
| lines = [line.strip() for line in stripped.splitlines() if line.strip()] | |
| if re.search(r"\(\d+(\.\d+)+[a-z]?\)", stripped): | |
| return True | |
| if any(char in stripped for char in "∂∫∑∏∞≈≠≤≥±×÷"): | |
| alpha_words = re.findall(r"[A-Za-z]+", stripped) | |
| return len(alpha_words) <= 12 or "=" in stripped | |
| for line in lines: | |
| if "=" not in line: | |
| continue | |
| alpha_words = [ | |
| word | |
| for word in re.findall(r"[A-Za-z]+", line) | |
| if word.lower() not in {"and", "or", "the", "where", "then", "with", "for"} | |
| ] | |
| if len(alpha_words) <= 12 and len(line) <= 260: | |
| return True | |
| return False | |
| def extract_pymupdf_page(page) -> dict: | |
| page_dict = page.get_text("dict", sort=True) | |
| body_blocks: List[str] = [] | |
| formula_blocks: List[dict] = [] | |
| formula_lines: List[str] = [] | |
| formula_bboxes: List[str] = [] | |
| formula_index = 0 | |
| page_number = page.number + 1 | |
| for block in page_dict.get("blocks", []): | |
| if block.get("type") != 0: | |
| continue | |
| normal_lines: List[str] = [] | |
| for line in block.get("lines", []): | |
| line_text = pymupdf_line_text(line) | |
| if not line_text: | |
| continue | |
| if should_extract_formula_line(line) or ( | |
| formula_lines and is_formula_continuation_line(line_text) | |
| ): | |
| if normal_lines: | |
| body_blocks.append("\n".join(normal_lines)) | |
| normal_lines = [] | |
| formula_lines.append(line_text) | |
| formula_bboxes.append(line_bbox_string(line)) | |
| else: | |
| if formula_lines: | |
| formula_index = append_formula_block( | |
| formula_blocks=formula_blocks, | |
| body_blocks=body_blocks, | |
| page_number=page_number, | |
| formula_index=formula_index, | |
| formula_lines=formula_lines, | |
| formula_bboxes=formula_bboxes, | |
| ) | |
| formula_lines = [] | |
| formula_bboxes = [] | |
| normal_lines.append(line_text) | |
| if normal_lines: | |
| body_blocks.append("\n".join(normal_lines)) | |
| if formula_lines: | |
| append_formula_block( | |
| formula_blocks=formula_blocks, | |
| body_blocks=body_blocks, | |
| page_number=page_number, | |
| formula_index=formula_index, | |
| formula_lines=formula_lines, | |
| formula_bboxes=formula_bboxes, | |
| ) | |
| return { | |
| "text": "\n".join(body_blocks), | |
| "formula_blocks": formula_blocks, | |
| "backend": "pymupdf", | |
| } | |
| def extract_pdf_pages_with_pymupdf(path: Path) -> Optional[List[dict]]: | |
| fitz = load_pymupdf() | |
| if fitz is None: | |
| return None | |
| try: | |
| document = fitz.open(str(path)) | |
| except Exception: | |
| return None | |
| try: | |
| return [extract_pymupdf_page(page) for page in document] | |
| finally: | |
| document.close() | |
| def clean_formula_text(text: str) -> str: | |
| lines = page_lines(text) | |
| if not lines: | |
| return "" | |
| text = "\n".join(lines) | |
| text = re.sub(r"[ \t]+", " ", text) | |
| text = re.sub(r"\n{3,}", "\n\n", text) | |
| return text.strip() | |
| def normalize_pdf_line(line: str) -> str: | |
| line = line.replace("\x00", " ") | |
| line = line.replace("\ufb00", "ff") | |
| line = line.replace("\ufb01", "fi") | |
| line = line.replace("\ufb02", "fl") | |
| line = line.replace("\ufb03", "ffi") | |
| line = line.replace("\ufb04", "ffl") | |
| line = re.sub(r"[ \t]+", " ", line) | |
| return line.strip() | |
| def is_noise_line(line: str) -> bool: | |
| if not line: | |
| return True | |
| if re.fullmatch(r"\d+", line): | |
| return True | |
| if re.fullmatch(r"page\s+\d+(\s+of\s+\d+)?", line, flags=re.IGNORECASE): | |
| return True | |
| if re.fullmatch(r"[-_=\s]{3,}", line): | |
| return True | |
| return False | |
| def is_formula_like(line: str) -> bool: | |
| stripped = line.strip() | |
| if not stripped: | |
| return False | |
| strong_math_count = sum(1 for char in stripped if char in PDF_STRONG_MATH_SYMBOLS) | |
| weak_math_count = sum(1 for char in stripped if char in PDF_WEAK_MATH_SYMBOLS) | |
| alpha_count = sum(1 for char in stripped if char.isalpha()) | |
| digit_count = sum(1 for char in stripped if char.isdigit()) | |
| compact = stripped.replace(" ", "") | |
| if "={" in compact or "^{" in compact or "_{" in compact: | |
| return True | |
| if compact in {"(", ")", "[", "]", "{", "}"}: | |
| return True | |
| if len(compact) <= 40 and any(char in compact for char in PDF_MATH_SYMBOLS): | |
| return True | |
| if strong_math_count >= 2 and len(stripped) <= 180: | |
| return True | |
| if strong_math_count >= 1 and weak_math_count >= 1 and len(stripped) <= 180: | |
| return True | |
| if "=" in stripped and (alpha_count + digit_count) >= 2 and len(stripped) <= 220: | |
| return True | |
| if re.search(r"\b(d|D|exp|ln|sqrt|max|min|var|cov)\s*[\(\[]", stripped): | |
| return True | |
| if alpha_count <= 4 and (strong_math_count + weak_math_count) >= 1 and digit_count >= 1: | |
| return True | |
| return False | |
| def normalized_line_key(line: str) -> str: | |
| return re.sub(r"\d+", "#", line.lower()).strip() | |
| def page_lines(text: str) -> List[str]: | |
| lines = [] | |
| for line in text.replace("\r\n", "\n").replace("\r", "\n").split("\n"): | |
| normalized = normalize_pdf_line(line) | |
| if not is_noise_line(normalized): | |
| lines.append(normalized) | |
| return lines | |
| def find_repeated_boundary_lines(raw_pages: List[str]) -> set[str]: | |
| counter: Counter[str] = Counter() | |
| for raw_text in raw_pages: | |
| lines = page_lines(raw_text) | |
| boundary_lines = lines[:PDF_BOUNDARY_LINE_COUNT] + lines[-PDF_BOUNDARY_LINE_COUNT:] | |
| counter.update( | |
| normalized_line_key(line) | |
| for line in boundary_lines | |
| if 3 <= len(line) <= 140 | |
| ) | |
| min_count = min( | |
| PDF_REPEATED_LINE_MIN_PAGES, | |
| max(2, len(raw_pages) // 3), | |
| ) | |
| return {line for line, count in counter.items() if count >= min_count} | |
| def clean_pdf_text(text: str, repeated_boundary_lines: set[str]) -> str: | |
| lines = page_lines(text) | |
| cleaned_lines = [] | |
| for index, line in enumerate(lines): | |
| is_boundary = ( | |
| index < PDF_BOUNDARY_LINE_COUNT | |
| or index >= len(lines) - PDF_BOUNDARY_LINE_COUNT | |
| ) | |
| if is_boundary and normalized_line_key(line) in repeated_boundary_lines: | |
| continue | |
| cleaned_lines.append(line) | |
| merged_lines = [] | |
| for line in cleaned_lines: | |
| if merged_lines and merged_lines[-1].endswith("-") and line[:1].islower(): | |
| merged_lines[-1] = merged_lines[-1][:-1] + line | |
| else: | |
| merged_lines.append(line) | |
| text = "\n".join(merged_lines) | |
| text = preserve_math_line_breaks(text) | |
| text = re.sub(r"[ \t]+", " ", text) | |
| text = re.sub(r"\n{3,}", "\n\n", text) | |
| return text.strip() | |
| def preserve_math_line_breaks(text: str) -> str: | |
| lines = text.split("\n") | |
| if not lines: | |
| return "" | |
| output = [lines[0]] | |
| in_formula_block = is_formula_like(lines[0]) | |
| for line in lines[1:]: | |
| previous = output[-1] | |
| line_is_formula = is_formula_like(line) | |
| previous_is_formula = is_formula_like(previous) | |
| if previous_is_formula or line_is_formula or in_formula_block: | |
| output.append(line) | |
| in_formula_block = line_is_formula or ( | |
| in_formula_block | |
| and not line.endswith((".", ";", ":", "?", "!")) | |
| ) | |
| elif previous.endswith((".", ":", ";", "?", "!", ")")): | |
| output.append(line) | |
| in_formula_block = False | |
| else: | |
| output[-1] = f"{previous} {line}" | |
| in_formula_block = False | |
| return "\n".join(output) | |
| def is_chapter_heading(line: str) -> bool: | |
| return bool(re.fullmatch( | |
| r"(chapter|appendix)\s+([0-9]+|[ivxlcdm]+|[a-z])", | |
| line.strip(), | |
| flags=re.IGNORECASE, | |
| )) | |
| def titlecase_word_ratio(words: List[str]) -> float: | |
| candidate_words = [ | |
| word.strip("()[]{}:;,.") | |
| for word in words | |
| if any(char.isalpha() for char in word) | |
| ] | |
| if not candidate_words: | |
| return 0.0 | |
| titlecase_words = [ | |
| word | |
| for word in candidate_words | |
| if word[:1].isupper() | |
| or word.lower() in {"a", "an", "and", "for", "in", "of", "on", "or", "the", "to", "with"} | |
| ] | |
| return len(titlecase_words) / len(candidate_words) | |
| def uppercase_letter_ratio(text: str) -> float: | |
| letters = [char for char in text if char.isalpha()] | |
| if not letters: | |
| return 0.0 | |
| return sum(1 for char in letters if char.isupper()) / len(letters) | |
| def is_section_heading(line: str) -> bool: | |
| stripped = line.strip() | |
| if not 4 <= len(stripped) <= 150: | |
| return False | |
| letters = [char for char in stripped if char.isalpha()] | |
| digit_count = sum(1 for char in stripped if char.isdigit()) | |
| alpha_words = [ | |
| word.strip("()[]{}:;,.") | |
| for word in stripped.split() | |
| if any(char.isalpha() for char in word) | |
| ] | |
| if len(letters) < 6 or len(alpha_words) < 2: | |
| return False | |
| if digit_count > max(4, len(letters)): | |
| return False | |
| if "%" in stripped and digit_count >= len(letters) / 2: | |
| return False | |
| numbered_heading = bool(re.match(r"^\d+(\.\d+)+\s+", stripped)) | |
| if stripped[:1].isdigit() and not numbered_heading: | |
| return False | |
| if re.match( | |
| r"^(in|from|where|thus|then|now|let|because|while|figure|table|for)\b", | |
| stripped, | |
| flags=re.IGNORECASE, | |
| ): | |
| return False | |
| if is_formula_like(stripped): | |
| return False | |
| if stripped.endswith((".", ",", ";")): | |
| return False | |
| if re.match(r"^(figure|table)\s+\d", stripped, flags=re.IGNORECASE): | |
| return False | |
| if numbered_heading: | |
| return True | |
| words = stripped.split() | |
| if len(words) > 16: | |
| return False | |
| if uppercase_letter_ratio(stripped) >= 0.72 and len(words) >= 2: | |
| return True | |
| if len(words) >= 4 and titlecase_word_ratio(words) >= 0.68: | |
| return True | |
| return False | |
| def make_section_path(chapter_title: str, section_title: str) -> str: | |
| if chapter_title and section_title and section_title != chapter_title: | |
| return f"{chapter_title} > {section_title}" | |
| return section_title or chapter_title | |
| def split_pdf_page_into_sections( | |
| path: Path, | |
| page_index: int, | |
| text: str, | |
| file_hash: str, | |
| section_state: dict, | |
| extraction_backend: str, | |
| formula_count: int, | |
| ) -> List[Document]: | |
| documents = [] | |
| lines = text.splitlines() | |
| pending_lines: List[str] = [] | |
| pending_metadata = { | |
| "chapter_title": section_state.get("chapter_title", ""), | |
| "section_title": section_state.get("section_title", ""), | |
| } | |
| def flush_pending() -> None: | |
| nonlocal pending_lines, pending_metadata | |
| section_text = "\n".join(line for line in pending_lines if line.strip()).strip() | |
| if not section_text: | |
| pending_lines = [] | |
| return | |
| chapter_title = pending_metadata.get("chapter_title", "") | |
| section_title = pending_metadata.get("section_title", "") | |
| documents.append( | |
| Document( | |
| text=section_text, | |
| metadata={ | |
| "source_file": str(path.resolve()), | |
| "file_name": path.name, | |
| "file_type": "pdf", | |
| "document_title": path.stem, | |
| "file_hash": file_hash, | |
| "page_number": page_index, | |
| "extraction_method": PDF_EXTRACTION_METHOD, | |
| "extraction_backend": extraction_backend, | |
| "char_count": len(section_text), | |
| "formula_count": formula_count, | |
| "content_type": "text", | |
| "chapter_title": chapter_title, | |
| "section_title": section_title, | |
| "section_path": make_section_path(chapter_title, section_title), | |
| }, | |
| ) | |
| ) | |
| pending_lines = [] | |
| for line in lines: | |
| stripped = line.strip() | |
| if not stripped: | |
| continue | |
| if is_chapter_heading(stripped): | |
| if len("\n".join(pending_lines)) >= PDF_MIN_SECTION_CHARS: | |
| flush_pending() | |
| section_state["pending_chapter_label"] = stripped.title() | |
| section_state["chapter_title"] = stripped.title() | |
| section_state["section_title"] = stripped.title() | |
| pending_metadata = { | |
| "chapter_title": section_state["chapter_title"], | |
| "section_title": section_state["section_title"], | |
| } | |
| pending_lines.append(stripped) | |
| continue | |
| if section_state.get("pending_chapter_label") and is_section_heading(stripped): | |
| if pending_lines == [section_state["pending_chapter_label"]]: | |
| pending_lines[0] = f"{section_state['pending_chapter_label']}: {stripped}" | |
| else: | |
| pending_lines.append(stripped) | |
| section_state["chapter_title"] = pending_lines[-1] | |
| section_state["section_title"] = pending_lines[-1] | |
| section_state["pending_chapter_label"] = "" | |
| pending_metadata = { | |
| "chapter_title": section_state["chapter_title"], | |
| "section_title": section_state["section_title"], | |
| } | |
| continue | |
| if is_section_heading(stripped): | |
| if len("\n".join(pending_lines)) >= PDF_MIN_SECTION_CHARS: | |
| flush_pending() | |
| section_state["section_title"] = stripped | |
| section_state["pending_chapter_label"] = "" | |
| pending_metadata = { | |
| "chapter_title": section_state.get("chapter_title", ""), | |
| "section_title": section_state["section_title"], | |
| } | |
| pending_lines.append(stripped) | |
| flush_pending() | |
| return documents | |
| def make_formula_documents( | |
| path: Path, | |
| page_index: int, | |
| formula_blocks: List[dict], | |
| file_hash: str, | |
| extraction_backend: str, | |
| ) -> List[Document]: | |
| documents = [] | |
| for formula_index, formula in enumerate(formula_blocks): | |
| formula_text = formula.get("text", "").strip() | |
| if not formula_text: | |
| continue | |
| documents.append( | |
| Document( | |
| text=f"[FORMULA]\n{formula_text}\n[/FORMULA]", | |
| metadata={ | |
| "source_file": str(path.resolve()), | |
| "file_name": path.name, | |
| "file_type": "pdf", | |
| "document_title": path.stem, | |
| "file_hash": file_hash, | |
| "page_number": page_index, | |
| "extraction_method": PDF_EXTRACTION_METHOD, | |
| "extraction_backend": extraction_backend, | |
| "char_count": len(formula_text), | |
| "content_type": "formula", | |
| "formula_id": formula.get("id", f"formula-{page_index}-{formula_index}"), | |
| "formula_index": formula_index, | |
| "formula_bbox": formula.get("bbox", ""), | |
| "formula_count": 1, | |
| "chapter_title": "", | |
| "section_title": "", | |
| "section_path": "", | |
| }, | |
| ) | |
| ) | |
| return documents | |
| def load_pdf_file(path: Path) -> List[Document]: | |
| reader = PdfReader(str(path)) | |
| documents = [] | |
| pymupdf_pages = extract_pdf_pages_with_pymupdf(path) | |
| if pymupdf_pages: | |
| page_payloads = pymupdf_pages | |
| else: | |
| page_payloads = [ | |
| { | |
| "text": extract_pdf_text(page), | |
| "formula_blocks": [], | |
| "backend": "pypdf", | |
| } | |
| for page in reader.pages | |
| ] | |
| raw_pages = [payload["text"] for payload in page_payloads] | |
| repeated_boundary_lines = find_repeated_boundary_lines(raw_pages) | |
| file_hash = file_sha256(path) | |
| section_state: dict = { | |
| "chapter_title": "", | |
| "section_title": "", | |
| "pending_chapter_label": "", | |
| } | |
| for page_index, payload in enumerate(page_payloads, start=1): | |
| raw_text = payload["text"] | |
| text = clean_pdf_text(raw_text, repeated_boundary_lines) | |
| formula_blocks = payload.get("formula_blocks", []) | |
| extraction_backend = payload.get("backend", "pypdf") | |
| if not text.strip(): | |
| documents.extend( | |
| make_formula_documents( | |
| path=path, | |
| page_index=page_index, | |
| formula_blocks=formula_blocks, | |
| file_hash=file_hash, | |
| extraction_backend=extraction_backend, | |
| ) | |
| ) | |
| continue | |
| documents.extend( | |
| split_pdf_page_into_sections( | |
| path=path, | |
| page_index=page_index, | |
| text=text, | |
| file_hash=file_hash, | |
| section_state=section_state, | |
| extraction_backend=extraction_backend, | |
| formula_count=len(formula_blocks), | |
| ) | |
| ) | |
| documents.extend( | |
| make_formula_documents( | |
| path=path, | |
| page_index=page_index, | |
| formula_blocks=formula_blocks, | |
| file_hash=file_hash, | |
| extraction_backend=extraction_backend, | |
| ) | |
| ) | |
| return documents | |
| def load_txt_file(path: Path) -> List[Document]: | |
| text = path.read_text(encoding="utf-8") | |
| return [ | |
| Document( | |
| text=text, | |
| metadata={ | |
| "source_file": str(path.resolve()), | |
| "file_name": path.name, | |
| "file_type": "txt", | |
| "document_title": path.stem, | |
| "file_hash": file_sha256(path), | |
| "content_type": "text", | |
| "chapter_title": "", | |
| "section_title": "", | |
| "section_path": "", | |
| "char_count": len(text), | |
| }, | |
| ) | |
| ] | |
| def iter_source_files(raw_dir: Path) -> Iterable[Path]: | |
| supported_suffixes = {".md", ".markdown", ".pdf", ".txt"} | |
| for path in sorted(raw_dir.rglob("*")): | |
| if path.is_file() and path.suffix.lower() in supported_suffixes: | |
| yield path | |
| def load_docs(raw_dir: Path = RAW_DIR) -> List[Document]: | |
| documents: List[Document] = [] | |
| raw_dir = effective_raw_dir(raw_dir) | |
| for path in iter_source_files(raw_dir): | |
| suffix = path.suffix.lower() | |
| if suffix in {".md", ".markdown"}: | |
| documents.extend(load_md_documents(path)) | |
| elif suffix == ".pdf": | |
| documents.extend(load_pdf_file(path)) | |
| elif suffix == ".txt": | |
| documents.extend(load_txt_file(path)) | |
| if not documents: | |
| raise ValueError(f"No supported documents found under {raw_dir}") | |
| return documents | |
| def add_chunk_metadata(nodes: List[BaseNode]) -> List[BaseNode]: | |
| counters: dict[str, int] = {} | |
| for node in nodes: | |
| source_file = node.metadata["source_file"] | |
| chunk_index = counters.get(source_file, 0) | |
| counters[source_file] = chunk_index + 1 | |
| file_hash = node.metadata["file_hash"][:12] | |
| page_number = node.metadata.get("page_number", "na") | |
| chunk_id = f"{Path(source_file).stem}-{file_hash}-p{page_number}-c{chunk_index}" | |
| node.metadata["chunk_id"] = chunk_id | |
| node.metadata["chunk_index"] = chunk_index | |
| node.metadata[EMBED_MODEL_METADATA_KEY] = EMBED_MODEL_NAME | |
| node.id_ = chunk_id | |
| return nodes | |
| def validate_nodes(nodes: List[BaseNode]) -> None: | |
| if not nodes: | |
| raise ValueError("No chunks were created from the source documents.") | |
| for node in nodes: | |
| missing = [key for key in REQUIRED_METADATA if key not in node.metadata] | |
| if missing: | |
| raise ValueError( | |
| f"Node {node.node_id} is missing metadata fields: {missing}") | |
| if node.metadata["file_type"] == "pdf" and "page_number" not in node.metadata: | |
| raise ValueError( | |
| f"PDF node {node.node_id} is missing page_number metadata.") | |
| def split_documents(documents: List[Document]) -> List[BaseNode]: | |
| splitter = SentenceSplitter( | |
| chunk_size=CHUNK_SIZE, | |
| chunk_overlap=CHUNK_OVERLAP, | |
| ) | |
| nodes = splitter.get_nodes_from_documents(documents) | |
| add_chunk_metadata(nodes) | |
| validate_nodes(nodes) | |
| return nodes | |
| def build_nodes(raw_dir: Path = RAW_DIR) -> List[BaseNode]: | |
| documents = load_docs(raw_dir) | |
| return split_documents(documents) | |
| def load_source_file(path: Path) -> List[Document]: | |
| suffix = path.suffix.lower() | |
| if suffix in {".md", ".markdown"}: | |
| return load_md_documents(path) | |
| if suffix == ".pdf": | |
| return load_pdf_file(path) | |
| if suffix == ".txt": | |
| return load_txt_file(path) | |
| return [] | |
| def list_current_sources(raw_dir: Path = RAW_DIR) -> dict[str, dict[str, str]]: | |
| raw_dir = effective_raw_dir(raw_dir) | |
| sources = {} | |
| for path in iter_source_files(raw_dir): | |
| resolved = str(path.resolve()) | |
| sources[resolved] = { | |
| "file_hash": file_sha256(path), | |
| "file_type": path.suffix.lower().lstrip("."), | |
| } | |
| return sources | |
| def existing_source_metadata(chroma_collection) -> dict[str, dict[str, str]]: | |
| existing: dict[str, dict[str, str]] = {} | |
| if chroma_collection.count() == 0: | |
| return existing | |
| offset = 0 | |
| limit = 500 | |
| while True: | |
| batch = chroma_collection.get( | |
| limit=limit, | |
| offset=offset, | |
| include=["metadatas"], | |
| ) | |
| metadatas = batch.get("metadatas") or [] | |
| if not metadatas: | |
| break | |
| for metadata in metadatas: | |
| source_file = metadata.get("source_file") | |
| if not source_file: | |
| continue | |
| existing[source_file] = { | |
| "file_hash": metadata.get("file_hash", ""), | |
| "file_type": metadata.get("file_type", ""), | |
| "embedding_model": metadata.get(EMBED_MODEL_METADATA_KEY, ""), | |
| "extraction_method": metadata.get("extraction_method", ""), | |
| } | |
| if len(metadatas) < limit: | |
| break | |
| offset += limit | |
| return existing | |
| def source_needs_update(current: dict[str, str], existing: dict[str, str] | None) -> bool: | |
| if not existing: | |
| return True | |
| if existing.get("file_hash") != current["file_hash"]: | |
| return True | |
| if existing.get("embedding_model") != EMBED_MODEL_NAME: | |
| return True | |
| if current["file_type"] == "pdf" and existing.get("extraction_method") != PDF_EXTRACTION_METHOD: | |
| return True | |
| return False | |
| def incremental_update_index( | |
| raw_dir: Path, | |
| chroma_collection, | |
| storage_context: StorageContext, | |
| embed_model, | |
| ) -> bool: | |
| current_sources = list_current_sources(raw_dir) | |
| existing_sources = existing_source_metadata(chroma_collection) | |
| deleted_sources = sorted(set(existing_sources) - set(current_sources)) | |
| changed_sources = sorted( | |
| source_file | |
| for source_file, current in current_sources.items() | |
| if source_needs_update(current, existing_sources.get(source_file)) | |
| ) | |
| for source_file in deleted_sources + changed_sources: | |
| try: | |
| chroma_collection.delete(where={"source_file": source_file}) | |
| except Exception as exc: | |
| logging.warning("Could not delete stale chunks for %s: %s", source_file, exc) | |
| if not changed_sources: | |
| if deleted_sources: | |
| print(f"Removed {len(deleted_sources)} stale source(s) from collection '{COLLECTION_NAME}'.") | |
| return bool(deleted_sources) | |
| documents: List[Document] = [] | |
| for source_file in changed_sources: | |
| documents.extend(load_source_file(Path(source_file))) | |
| nodes = split_documents(documents) | |
| VectorStoreIndex( | |
| nodes, | |
| storage_context=storage_context, | |
| embed_model=embed_model, | |
| show_progress=True, | |
| ) | |
| print( | |
| f"Incrementally indexed {len(nodes)} chunk(s) from {len(changed_sources)} source file(s)." | |
| ) | |
| return True | |
| def collection_needs_rebuild(chroma_collection) -> bool: | |
| if chroma_collection.count() == 0: | |
| return True | |
| try: | |
| sample = chroma_collection.peek(limit=min(chroma_collection.count(), 20)) | |
| except Exception: | |
| return False | |
| for metadata in sample.get("metadatas") or []: | |
| if metadata.get(EMBED_MODEL_METADATA_KEY) != EMBED_MODEL_NAME: | |
| return True | |
| if metadata.get("file_type") == "pdf": | |
| return metadata.get("extraction_method") != PDF_EXTRACTION_METHOD | |
| return False | |
| async def build_index(raw_dir: Path = RAW_DIR, rebuild: bool = False) -> VectorStoreIndex: | |
| configure_model_cache() | |
| from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
| load_dotenv() | |
| raw_dir = effective_raw_dir(raw_dir) | |
| CHROMA_DB_DIR.mkdir(parents=True, exist_ok=True) | |
| db = chromadb.PersistentClient(path=str(CHROMA_DB_DIR)) | |
| if rebuild: | |
| try: | |
| db.delete_collection(COLLECTION_NAME) | |
| except (NotFoundError, ValueError): | |
| pass | |
| chroma_collection = db.get_or_create_collection(COLLECTION_NAME) | |
| vector_store = ChromaVectorStore(chroma_collection=chroma_collection) | |
| storage_context = StorageContext.from_defaults(vector_store=vector_store) | |
| embed_model = HuggingFaceEmbedding( | |
| model_name=resolve_embed_model_name(), | |
| cache_folder=str(HF_CACHE_DIR / "sentence_transformers"), | |
| ) | |
| if rebuild or chroma_collection.count() == 0: | |
| nodes = build_nodes(raw_dir) | |
| index = VectorStoreIndex( | |
| nodes, | |
| storage_context=storage_context, | |
| embed_model=embed_model, | |
| show_progress=True, | |
| ) | |
| print( | |
| f"Indexed {len(nodes)} chunks into collection '{COLLECTION_NAME}'") | |
| return index | |
| incremental_update_index( | |
| raw_dir=raw_dir, | |
| chroma_collection=chroma_collection, | |
| storage_context=storage_context, | |
| embed_model=embed_model, | |
| ) | |
| print( | |
| f"Loaded existing collection '{COLLECTION_NAME}' with {chroma_collection.count()} chunks.") | |
| return VectorStoreIndex.from_vector_store(vector_store, embed_model=embed_model) | |
| class QueryKnowledgeTool(Tool): | |
| name = "query_knowledge" | |
| description = ( | |
| "Searches the local options trading knowledge base. Use this for option " | |
| "concepts, volatility, Greeks, strategies, formulas, equation numbers, " | |
| "and citations from reference books." | |
| ) | |
| inputs = {'query': {'type': 'string', | |
| 'description': 'The search query to perform.'}} | |
| output_type = "string" | |
| def format_results(results, max_chars: int = 800): | |
| output = [] | |
| for result in results: | |
| metadata = result.node.metadata | |
| source = metadata.get("file_name", "unknown") | |
| page = metadata.get("page_number", "n/a") | |
| section = metadata.get("section_path") or metadata.get("section_title") or "n/a" | |
| content_type = metadata.get("content_type", "text") | |
| formula_id = metadata.get("formula_id", "") | |
| score = result.score | |
| text = result.node.get_content() | |
| if len(text) > max_chars: | |
| text = f"{text[:max_chars].rstrip()}..." | |
| output.append( | |
| f"source:{source}\n" | |
| f"page:{page}\n" | |
| f"section:{section}\n" | |
| f"content_type:{content_type}\n" | |
| f"formula_id:{formula_id or 'n/a'}\n" | |
| f"score:{score:.4f}\n" | |
| f"vector_score:{metadata.get(VECTOR_METADATA_KEY, 'n/a')}\n" | |
| f"bm25_score:{metadata.get(BM25_METADATA_KEY, 'n/a')}\n" | |
| f"content:{text}" | |
| ) | |
| return "\n\n---\n\n".join(output) | |
| def load_bm25_nodes(collection_name: str = COLLECTION_NAME) -> list[TextNode]: | |
| db = chromadb.PersistentClient(path=str(CHROMA_DB_DIR)) | |
| try: | |
| collection = db.get_collection(collection_name) | |
| except Exception: | |
| return [] | |
| nodes: list[TextNode] = [] | |
| offset = 0 | |
| limit = 500 | |
| while True: | |
| batch = collection.get( | |
| limit=limit, | |
| offset=offset, | |
| include=["documents", "metadatas"], | |
| ) | |
| documents = batch.get("documents") or [] | |
| metadatas = batch.get("metadatas") or [] | |
| ids = batch.get("ids") or [] | |
| if not documents: | |
| break | |
| for index, text in enumerate(documents): | |
| metadata = dict(metadatas[index] or {}) | |
| node_id = ids[index] if index < len(ids) else metadata.get("chunk_id", "") | |
| nodes.append(TextNode(id_=node_id, text=text or "", metadata=metadata)) | |
| if len(documents) < limit: | |
| break | |
| offset += limit | |
| return nodes | |
| def merge_results( | |
| vector_results: list[NodeWithScore], | |
| bm25_results: list[NodeWithScore], | |
| top_k: int, | |
| ) -> list[NodeWithScore]: | |
| merged: dict[str, NodeWithScore] = {} | |
| for rank, result in enumerate(vector_results): | |
| node_id = result.node.node_id | |
| result.node.metadata[VECTOR_METADATA_KEY] = result.score | |
| merged[node_id] = NodeWithScore( | |
| node=result.node, | |
| score=1.0 / (rank + 1), | |
| ) | |
| for rank, result in enumerate(bm25_results): | |
| node_id = result.node.node_id | |
| result.node.metadata[BM25_METADATA_KEY] = result.score | |
| reciprocal_rank_score = 1.0 / (rank + 1) | |
| if node_id in merged: | |
| merged[node_id].score = (merged[node_id].score or 0.0) + reciprocal_rank_score | |
| merged[node_id].node.metadata[BM25_METADATA_KEY] = result.score | |
| else: | |
| merged[node_id] = NodeWithScore( | |
| node=result.node, | |
| score=reciprocal_rank_score, | |
| ) | |
| results = list(merged.values()) | |
| results.sort(key=lambda item: item.score or float("-inf"), reverse=True) | |
| return results[:top_k] | |
| def __init__( | |
| self, | |
| max_results=20, | |
| top_k=5, | |
| use_reranker: Optional[bool] = None, | |
| use_hybrid: Optional[bool] = None, | |
| reranker_top_n: Optional[int] = None, | |
| reranker_model_name: Optional[str] = None, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.max_results = max_results | |
| self.top_k = top_k | |
| self.use_reranker = ( | |
| env_flag("RAG_USE_RERANKER", True) | |
| if use_reranker is None | |
| else use_reranker | |
| ) | |
| self.use_hybrid = ( | |
| env_flag("RAG_USE_HYBRID", True) | |
| if use_hybrid is None | |
| else use_hybrid | |
| ) | |
| self.reranker_top_n = reranker_top_n or top_k | |
| self.reranker = ( | |
| CrossEncoderReranker(reranker_model_name or RERANKER_MODEL_NAME) | |
| if self.use_reranker | |
| else None | |
| ) | |
| index = asyncio.run(build_index(rebuild=False)) | |
| retrieve_top_k = max(max_results, top_k) if self.use_reranker else top_k | |
| self.retriever = index.as_retriever(similarity_top_k=retrieve_top_k) | |
| self.bm25_retriever = ( | |
| BM25Retriever(self.load_bm25_nodes()) | |
| if self.use_hybrid | |
| else None | |
| ) | |
| def forward(self, query: str) -> str: | |
| vector_results = self.retriever.retrieve(query) | |
| results = vector_results | |
| if self.bm25_retriever: | |
| bm25_results = self.bm25_retriever.retrieve(query, self.max_results) | |
| results = self.merge_results( | |
| vector_results=vector_results, | |
| bm25_results=bm25_results, | |
| top_k=max(self.max_results, self.top_k), | |
| ) | |
| if self.reranker: | |
| try: | |
| results = self.reranker.rerank( | |
| query, | |
| results, | |
| top_n=self.reranker_top_n, | |
| ) | |
| except Exception as exc: | |
| logging.warning("Reranker failed; falling back to vector ranking: %s", exc) | |
| results = results[:self.top_k] | |
| return QueryKnowledgeTool.format_results(results[:self.top_k]) | |
| if __name__ == "__main__": | |
| query_tool = QueryKnowledgeTool() | |
| res: str = query_tool.forward("What is option?") | |
| print(res) | |