from smolagents.tools import Tool import asyncio from collections import Counter import hashlib import logging import math import os from pathlib import Path from typing import Iterable, List, Optional import re from dotenv import load_dotenv import chromadb from chromadb.errors import NotFoundError from pypdf import PdfReader from llama_index.core import StorageContext, VectorStoreIndex from llama_index.core.schema import Document, BaseNode, NodeWithScore, TextNode from llama_index.core.node_parser import SentenceSplitter from llama_index.vector_stores.chroma import ChromaVectorStore load_dotenv() BASE_DIR = Path(__file__).resolve().parent PROJECT_ROOT = BASE_DIR.parent KNOWLEDGE_BASE_DIR = BASE_DIR / "knowledge_base" LEGACY_KNOWLEDGE_BASE_DIR = BASE_DIR / "knowledge_base" KNOWLEDGE_BASE_DIR = PROJECT_ROOT / "knowledge_base" RAW_DIR = KNOWLEDGE_BASE_DIR / "raw" CHROMA_DB_DIR = KNOWLEDGE_BASE_DIR / "chroma_db" HF_CACHE_DIR = PROJECT_ROOT / "hf_cache" COLLECTION_NAME = "options_knowledge" EMBED_MODEL_NAME = os.getenv("RAG_EMBED_MODEL", "BAAI/bge-small-en-v1.5") RERANKER_MODEL_NAME = os.getenv( "RAG_RERANKER_MODEL", "cross-encoder/ms-marco-MiniLM-L-6-v2") RERANKER_BATCH_SIZE = int(os.getenv("RAG_RERANKER_BATCH_SIZE", "16")) EMBED_MODEL_METADATA_KEY = "embedding_model" BM25_METADATA_KEY = "bm25_score" VECTOR_METADATA_KEY = "vector_score" CHUNK_SIZE = 1000 CHUNK_OVERLAP = 150 PDF_REPEATED_LINE_MIN_PAGES = 3 PDF_BOUNDARY_LINE_COUNT = 4 PDF_EXTRACTION_METHOD = "pymupdf_formula_blocks_v5" PDF_LINE_Y_TOLERANCE = 3.0 PDF_MIN_SECTION_CHARS = 240 PDF_STRONG_MATH_SYMBOLS = set("=∂∫∑∏√∞≈≠≤≥±×÷^_σΣΔδθΘλΛμρπΠφΦτν𝜎𝜇𝜌𝜃𝜕") PDF_WEAK_MATH_SYMBOLS = set("+-−*/∕<>") PDF_MATH_SYMBOLS = PDF_STRONG_MATH_SYMBOLS | PDF_WEAK_MATH_SYMBOLS PDF_OPERATOR_MATH_SYMBOLS = set("=∂∫∑∏√∞≈≠≤≥±×÷^_+-−*/∕<>") PDF_FORMULA_TRIGGER_SYMBOLS = set("=∂∫∑∏√∞≈≠≤≥±×÷^_∕<>") logging.getLogger("pypdf").setLevel(logging.ERROR) def load_pymupdf(): try: import fitz except ImportError: return None return fitz REQUIRED_METADATA = [ "source_file", "file_name", "file_type", "document_title", "file_hash", "chunk_id", "chunk_index", ] def configure_model_cache() -> None: HF_CACHE_DIR.mkdir(parents=True, exist_ok=True) os.environ.setdefault("HF_HOME", str(HF_CACHE_DIR)) os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", str( HF_CACHE_DIR / "sentence_transformers")) os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") if local_model_snapshot(EMBED_MODEL_NAME): os.environ.setdefault("HF_HUB_OFFLINE", "1") os.environ.setdefault("TRANSFORMERS_OFFLINE", "1") def local_model_snapshot(model_name: str) -> Optional[Path]: cached_model_dir = ( HF_CACHE_DIR / "sentence_transformers" / f"models--{model_name.replace('/', '--')}" ) snapshots_dir = cached_model_dir / "snapshots" if snapshots_dir.exists(): snapshots = sorted(path for path in snapshots_dir.iterdir() if path.is_dir()) for snapshot in reversed(snapshots): if (snapshot / "config.json").exists(): return snapshot return None def resolve_embed_model_name() -> str: snapshot = local_model_snapshot(EMBED_MODEL_NAME) if snapshot: return str(snapshot) return EMBED_MODEL_NAME def resolve_reranker_model_name(model_name: str = RERANKER_MODEL_NAME) -> str: snapshot = local_model_snapshot(model_name) if snapshot: return str(snapshot) return model_name def env_flag(name: str, default: bool = False) -> bool: value = os.getenv(name) if value is None: return default return value.strip().lower() in {"1", "true", "yes", "on"} def effective_raw_dir(raw_dir: Path = RAW_DIR) -> Path: if any(iter_source_files(raw_dir)): return raw_dir legacy_raw_dir = LEGACY_KNOWLEDGE_BASE_DIR / "raw" if any(iter_source_files(legacy_raw_dir)): logging.warning( "Using legacy knowledge base path %s. Move files to %s when convenient.", legacy_raw_dir, raw_dir, ) return legacy_raw_dir return raw_dir class CrossEncoderReranker: def __init__( self, model_name: str = RERANKER_MODEL_NAME, batch_size: int = RERANKER_BATCH_SIZE, ): self.model_name = model_name self.batch_size = batch_size self._model = None def _load_model(self): if self._model is not None: return self._model from sentence_transformers import CrossEncoder self._model = CrossEncoder( resolve_reranker_model_name(self.model_name), max_length=512, cache_folder=str(HF_CACHE_DIR / "sentence_transformers"), ) return self._model def rerank( self, query: str, results: list[NodeWithScore], top_n: Optional[int] = None, ) -> list[NodeWithScore]: if not results: return [] pairs = [ (query, result.node.get_content()) for result in results ] model = self._load_model() scores = model.predict( pairs, batch_size=self.batch_size, show_progress_bar=False, ) reranked = [ NodeWithScore(node=result.node, score=float(score)) for result, score in zip(results, scores) ] reranked.sort(key=lambda item: item.score or float("-inf"), reverse=True) return reranked[:top_n] if top_n else reranked class BM25Retriever: def __init__(self, nodes: list[TextNode]): self.nodes = nodes self.tokenized_docs = [self.tokenize(node.get_content()) for node in nodes] self.doc_freqs: Counter[str] = Counter() for tokens in self.tokenized_docs: self.doc_freqs.update(set(tokens)) self.avg_doc_len = ( sum(len(tokens) for tokens in self.tokenized_docs) / len(self.tokenized_docs) if self.tokenized_docs else 0.0 ) @staticmethod def tokenize(text: str) -> list[str]: return [ token.lower() for token in re.findall(r"[A-Za-z]+(?:[-'][A-Za-z]+)*|\d+(?:\.\d+)*|[^\sA-Za-z0-9]", text) if token.strip() ] def score(self, query_tokens: list[str], doc_tokens: list[str]) -> float: if not query_tokens or not doc_tokens: return 0.0 token_counts = Counter(doc_tokens) doc_len = len(doc_tokens) total_docs = len(self.tokenized_docs) k1 = 1.5 b = 0.75 score = 0.0 for token in query_tokens: term_freq = token_counts.get(token, 0) if term_freq == 0: continue doc_freq = self.doc_freqs.get(token, 0) idf = math.log(1 + (total_docs - doc_freq + 0.5) / (doc_freq + 0.5)) denominator = term_freq + k1 * ( 1 - b + b * doc_len / max(self.avg_doc_len, 1.0) ) score += idf * (term_freq * (k1 + 1)) / denominator return score def retrieve(self, query: str, top_k: int) -> list[NodeWithScore]: query_tokens = self.tokenize(query) scored: list[NodeWithScore] = [] for node, doc_tokens in zip(self.nodes, self.tokenized_docs): score = self.score(query_tokens, doc_tokens) if score <= 0: continue node.metadata[BM25_METADATA_KEY] = score scored.append(NodeWithScore(node=node, score=score)) scored.sort(key=lambda item: item.score or float("-inf"), reverse=True) return scored[:top_k] def file_sha256(path: Path) -> str: digest = hashlib.sha256() with path.open("rb") as file: for block in iter(lambda: file.read(1024 * 1024), b""): digest.update(block) return digest.hexdigest() def load_md_file(path: Path) -> Document: text = path.read_text(encoding="utf-8") return Document( text=text, metadata={ "source_file": str(path.resolve()), "file_name": path.name, "file_type": "md", "document_title": path.stem, "file_hash": file_sha256(path), }, ) def load_md_documents(path: Path) -> List[Document]: text = path.read_text(encoding="utf-8") file_hash = file_sha256(path) documents: List[Document] = [] current_heading = "" current_lines: List[str] = [] def flush() -> None: nonlocal current_lines section_text = "\n".join(current_lines).strip() if not section_text: current_lines = [] return documents.append( Document( text=section_text, metadata={ "source_file": str(path.resolve()), "file_name": path.name, "file_type": path.suffix.lower().lstrip("."), "document_title": path.stem, "file_hash": file_hash, "content_type": "markdown_section", "chapter_title": "", "section_title": current_heading, "section_path": current_heading, "char_count": len(section_text), }, ) ) current_lines = [] for line in text.splitlines(): heading_match = re.match(r"^(#{1,6})\s+(.+?)\s*$", line) if heading_match: flush() current_heading = heading_match.group(2).strip() current_lines.append(line) flush() return documents or [load_md_file(path)] def append_visual_fragment(line_parts: List[str], text: str, baseline_y: float, item: dict) -> None: if not text: return stripped = text.strip() if not stripped: return font_size = item["font_size"] y_offset = item["y"] - baseline_y is_small = font_size < item["line_font_size"] * 0.82 if is_small and y_offset > max(1.5, item["line_font_size"] * 0.18): line_parts.append(f"^{{{stripped}}}") elif is_small and y_offset < -max(1.5, item["line_font_size"] * 0.18): line_parts.append(f"_{{{stripped}}}") else: line_parts.append(stripped) def join_visual_line(items: List[dict]) -> str: if not items: return "" items = sorted(items, key=lambda value: value["x"]) baseline_y = sorted(item["y"] for item in items)[len(items) // 2] line_font_size = max(item["font_size"] for item in items) previous_right = None line_parts: List[str] = [] for item in items: item["line_font_size"] = line_font_size if previous_right is not None: gap = item["x"] - previous_right if gap > max(2.5, line_font_size * 0.28): line_parts.append(" ") append_visual_fragment(line_parts, item["text"], baseline_y, item) previous_right = max(previous_right or item["x"], item["x"] + item["width"]) return normalize_pdf_line("".join(line_parts)) def extract_pdf_text_by_position(page) -> str: fragments: List[dict] = [] def visitor_text(text, cm, tm, font_dict, font_size): if not text or not text.strip(): return x = float(tm[4]) y = float(tm[5]) width = max(len(text.strip()) * float(font_size) * 0.45, float(font_size)) fragments.append( { "text": text, "x": x, "y": y, "width": width, "font_size": float(font_size or 1.0), } ) try: page.extract_text(visitor_text=visitor_text) except Exception: return "" if not fragments: return "" lines: List[List[dict]] = [] for fragment in sorted(fragments, key=lambda value: (-value["y"], value["x"])): for line in lines: if abs(line[0]["y"] - fragment["y"]) <= PDF_LINE_Y_TOLERANCE: line.append(fragment) break else: lines.append([fragment]) return "\n".join(join_visual_line(line) for line in lines) def math_text_score(text: str) -> float: if not text.strip(): return 0.0 lines = [line for line in text.splitlines() if line.strip()] compact_length = len(re.sub(r"\s+", "", text)) math_symbol_count = sum(1 for char in text if char in PDF_MATH_SYMBOLS) superscript_markers = text.count("^{") + text.count("_{") multiline_bonus = sum(1 for line in lines if is_formula_like(line)) * 8 equation_block_bonus = sum( 1 for index, line in enumerate(lines) if is_formula_like(line) and ( index > 0 and is_formula_like(lines[index - 1]) or index + 1 < len(lines) and is_formula_like(lines[index + 1]) ) ) * 12 return ( compact_length + math_symbol_count * 12 + superscript_markers * 20 + multiline_bonus + equation_block_bonus ) def extract_pdf_text(page) -> str: positioned_text = extract_pdf_text_by_position(page) try: layout_text = page.extract_text(extraction_mode="layout") or "" except Exception: layout_text = "" try: plain_text = page.extract_text() or "" except Exception: plain_text = "" candidates = [positioned_text, layout_text, plain_text] candidates = [candidate for candidate in candidates if candidate.strip()] if not candidates: return "" return max(candidates, key=math_text_score) def pymupdf_span_text(span: dict) -> str: return normalize_pdf_line(span.get("text", "")) def pymupdf_line_text(line: dict) -> str: return normalize_pdf_line("".join(pymupdf_span_text(span) for span in line.get("spans", []))) def pymupdf_block_text(block: dict) -> str: lines = [ pymupdf_line_text(line) for line in block.get("lines", []) ] return "\n".join(line for line in lines if line) def pymupdf_span_has_math_font(span: dict) -> bool: font_name = span.get("font", "").lower() return any( marker in font_name for marker in ("math", "symbol", "cmmi", "cmsy", "cmex", "stix") ) def is_formula_block_line(line: str) -> bool: stripped = line.strip() if not stripped: return False trigger_math_count = sum(1 for char in stripped if char in PDF_FORMULA_TRIGGER_SYMBOLS) digit_count = sum(1 for char in stripped if char.isdigit()) alpha_count = sum(1 for char in stripped if char.isalpha()) alpha_words = [ word for word in re.findall(r"[A-Za-z]+", stripped) if word.lower() not in {"and", "or", "the", "where", "then", "with", "for"} ] compact_length = len(re.sub(r"\s+", "", stripped)) if compact_length < 3: return False if re.fullmatch(r"\(?\d+(\.\d+)?\)?", stripped): return False if re.search(r"\(\d+(\.\d+)+[a-z]?\)$", stripped) and compact_length <= 240: return True if "=" in stripped and compact_length <= 260 and len(alpha_words) <= 12: return True if any(char in stripped for char in "∂∫∑∏√∞≈≠≤≥±×÷") and compact_length <= 220 and len(alpha_words) <= 10: return True if trigger_math_count >= 2 and compact_length <= 120 and len(alpha_words) <= 6: return True if trigger_math_count >= 1 and digit_count >= 1 and alpha_count <= 18 and compact_length <= 100: return True return False def is_formula_block(block: dict) -> bool: text = pymupdf_block_text(block) if not text: return False lines = [line for line in text.splitlines() if line.strip()] if any(is_formula_block_line(line) for line in lines): return True spans = [ span for line in block.get("lines", []) for span in line.get("spans", []) if pymupdf_span_text(span) ] if not spans: return False math_font_count = sum(1 for span in spans if pymupdf_span_has_math_font(span)) strong_math_count = sum(1 for char in text if char in PDF_STRONG_MATH_SYMBOLS) alpha_count = sum(1 for char in text if char.isalpha()) digit_count = sum(1 for char in text if char.isdigit()) compact_length = len(re.sub(r"\s+", "", text)) if math_font_count >= 2 and compact_length <= 220: return True if strong_math_count >= 3 and compact_length <= 260: return True if strong_math_count >= 1 and digit_count >= 1 and alpha_count <= 20 and compact_length <= 160: return True return False def block_bbox_string(block: dict) -> str: bbox = block.get("bbox") or [] if len(bbox) != 4: return "" return ",".join(f"{float(value):.2f}" for value in bbox) def line_bbox_string(line: dict) -> str: bbox = line.get("bbox") or [] if len(bbox) != 4: return "" return ",".join(f"{float(value):.2f}" for value in bbox) def pymupdf_line_has_math_font(line: dict) -> bool: return any( pymupdf_span_has_math_font(span) for span in line.get("spans", []) if pymupdf_span_text(span) ) def should_extract_formula_line(line: dict) -> bool: text = pymupdf_line_text(line) if not text: return False if is_formula_block_line(text): return True compact_length = len(re.sub(r"\s+", "", text)) trigger_math_count = sum(1 for char in text if char in PDF_FORMULA_TRIGGER_SYMBOLS) alpha_words = re.findall(r"[A-Za-z]+", text) if ( pymupdf_line_has_math_font(line) and trigger_math_count >= 1 and compact_length <= 180 and len(alpha_words) <= 6 ): return True return False def is_formula_continuation_line(text: str) -> bool: stripped = text.strip() if not stripped: return False compact = re.sub(r"\s+", "", stripped) if len(compact) > 90: return False if compact in {"(", ")", "[", "]", "{", "}", "√"}: return True alpha_words = re.findall(r"[A-Za-z]+", stripped) math_count = sum(1 for char in stripped if char in PDF_MATH_SYMBOLS) digit_count = sum(1 for char in stripped if char.isdigit()) if len(alpha_words) <= 4 and (math_count >= 1 or digit_count >= 1): return True return False def append_formula_block( formula_blocks: List[dict], body_blocks: List[str], page_number: int, formula_index: int, formula_lines: List[str], formula_bboxes: List[str], ) -> int: formula_text = clean_formula_text("\n".join(formula_lines)) if not is_useful_formula_text(formula_text): return formula_index formula_id = f"formula-{page_number}-{formula_index}" formula_bbox = merge_bbox_strings(formula_bboxes) formula_blocks.append( { "id": formula_id, "text": formula_text, "bbox": formula_bbox, } ) body_blocks.append(f"[FORMULA id={formula_id}]\n{formula_text}\n[/FORMULA]") return formula_index + 1 def merge_bbox_strings(bbox_strings: List[str]) -> str: boxes = [] for bbox_string in bbox_strings: if not bbox_string: continue values = bbox_string.split(",") if len(values) != 4: continue try: boxes.append([float(value) for value in values]) except ValueError: continue if not boxes: return "" x0 = min(box[0] for box in boxes) y0 = min(box[1] for box in boxes) x1 = max(box[2] for box in boxes) y1 = max(box[3] for box in boxes) return f"{x0:.2f},{y0:.2f},{x1:.2f},{y1:.2f}" def is_useful_formula_text(text: str) -> bool: stripped = text.strip() if not stripped: return False compact_length = len(re.sub(r"\s+", "", stripped)) if compact_length < 6: return False lines = [line.strip() for line in stripped.splitlines() if line.strip()] if re.search(r"\(\d+(\.\d+)+[a-z]?\)", stripped): return True if any(char in stripped for char in "∂∫∑∏∞≈≠≤≥±×÷"): alpha_words = re.findall(r"[A-Za-z]+", stripped) return len(alpha_words) <= 12 or "=" in stripped for line in lines: if "=" not in line: continue alpha_words = [ word for word in re.findall(r"[A-Za-z]+", line) if word.lower() not in {"and", "or", "the", "where", "then", "with", "for"} ] if len(alpha_words) <= 12 and len(line) <= 260: return True return False def extract_pymupdf_page(page) -> dict: page_dict = page.get_text("dict", sort=True) body_blocks: List[str] = [] formula_blocks: List[dict] = [] formula_lines: List[str] = [] formula_bboxes: List[str] = [] formula_index = 0 page_number = page.number + 1 for block in page_dict.get("blocks", []): if block.get("type") != 0: continue normal_lines: List[str] = [] for line in block.get("lines", []): line_text = pymupdf_line_text(line) if not line_text: continue if should_extract_formula_line(line) or ( formula_lines and is_formula_continuation_line(line_text) ): if normal_lines: body_blocks.append("\n".join(normal_lines)) normal_lines = [] formula_lines.append(line_text) formula_bboxes.append(line_bbox_string(line)) else: if formula_lines: formula_index = append_formula_block( formula_blocks=formula_blocks, body_blocks=body_blocks, page_number=page_number, formula_index=formula_index, formula_lines=formula_lines, formula_bboxes=formula_bboxes, ) formula_lines = [] formula_bboxes = [] normal_lines.append(line_text) if normal_lines: body_blocks.append("\n".join(normal_lines)) if formula_lines: append_formula_block( formula_blocks=formula_blocks, body_blocks=body_blocks, page_number=page_number, formula_index=formula_index, formula_lines=formula_lines, formula_bboxes=formula_bboxes, ) return { "text": "\n".join(body_blocks), "formula_blocks": formula_blocks, "backend": "pymupdf", } def extract_pdf_pages_with_pymupdf(path: Path) -> Optional[List[dict]]: fitz = load_pymupdf() if fitz is None: return None try: document = fitz.open(str(path)) except Exception: return None try: return [extract_pymupdf_page(page) for page in document] finally: document.close() def clean_formula_text(text: str) -> str: lines = page_lines(text) if not lines: return "" text = "\n".join(lines) text = re.sub(r"[ \t]+", " ", text) text = re.sub(r"\n{3,}", "\n\n", text) return text.strip() def normalize_pdf_line(line: str) -> str: line = line.replace("\x00", " ") line = line.replace("\ufb00", "ff") line = line.replace("\ufb01", "fi") line = line.replace("\ufb02", "fl") line = line.replace("\ufb03", "ffi") line = line.replace("\ufb04", "ffl") line = re.sub(r"[ \t]+", " ", line) return line.strip() def is_noise_line(line: str) -> bool: if not line: return True if re.fullmatch(r"\d+", line): return True if re.fullmatch(r"page\s+\d+(\s+of\s+\d+)?", line, flags=re.IGNORECASE): return True if re.fullmatch(r"[-_=\s]{3,}", line): return True return False def is_formula_like(line: str) -> bool: stripped = line.strip() if not stripped: return False strong_math_count = sum(1 for char in stripped if char in PDF_STRONG_MATH_SYMBOLS) weak_math_count = sum(1 for char in stripped if char in PDF_WEAK_MATH_SYMBOLS) alpha_count = sum(1 for char in stripped if char.isalpha()) digit_count = sum(1 for char in stripped if char.isdigit()) compact = stripped.replace(" ", "") if "={" in compact or "^{" in compact or "_{" in compact: return True if compact in {"(", ")", "[", "]", "{", "}"}: return True if len(compact) <= 40 and any(char in compact for char in PDF_MATH_SYMBOLS): return True if strong_math_count >= 2 and len(stripped) <= 180: return True if strong_math_count >= 1 and weak_math_count >= 1 and len(stripped) <= 180: return True if "=" in stripped and (alpha_count + digit_count) >= 2 and len(stripped) <= 220: return True if re.search(r"\b(d|D|exp|ln|sqrt|max|min|var|cov)\s*[\(\[]", stripped): return True if alpha_count <= 4 and (strong_math_count + weak_math_count) >= 1 and digit_count >= 1: return True return False def normalized_line_key(line: str) -> str: return re.sub(r"\d+", "#", line.lower()).strip() def page_lines(text: str) -> List[str]: lines = [] for line in text.replace("\r\n", "\n").replace("\r", "\n").split("\n"): normalized = normalize_pdf_line(line) if not is_noise_line(normalized): lines.append(normalized) return lines def find_repeated_boundary_lines(raw_pages: List[str]) -> set[str]: counter: Counter[str] = Counter() for raw_text in raw_pages: lines = page_lines(raw_text) boundary_lines = lines[:PDF_BOUNDARY_LINE_COUNT] + lines[-PDF_BOUNDARY_LINE_COUNT:] counter.update( normalized_line_key(line) for line in boundary_lines if 3 <= len(line) <= 140 ) min_count = min( PDF_REPEATED_LINE_MIN_PAGES, max(2, len(raw_pages) // 3), ) return {line for line, count in counter.items() if count >= min_count} def clean_pdf_text(text: str, repeated_boundary_lines: set[str]) -> str: lines = page_lines(text) cleaned_lines = [] for index, line in enumerate(lines): is_boundary = ( index < PDF_BOUNDARY_LINE_COUNT or index >= len(lines) - PDF_BOUNDARY_LINE_COUNT ) if is_boundary and normalized_line_key(line) in repeated_boundary_lines: continue cleaned_lines.append(line) merged_lines = [] for line in cleaned_lines: if merged_lines and merged_lines[-1].endswith("-") and line[:1].islower(): merged_lines[-1] = merged_lines[-1][:-1] + line else: merged_lines.append(line) text = "\n".join(merged_lines) text = preserve_math_line_breaks(text) text = re.sub(r"[ \t]+", " ", text) text = re.sub(r"\n{3,}", "\n\n", text) return text.strip() def preserve_math_line_breaks(text: str) -> str: lines = text.split("\n") if not lines: return "" output = [lines[0]] in_formula_block = is_formula_like(lines[0]) for line in lines[1:]: previous = output[-1] line_is_formula = is_formula_like(line) previous_is_formula = is_formula_like(previous) if previous_is_formula or line_is_formula or in_formula_block: output.append(line) in_formula_block = line_is_formula or ( in_formula_block and not line.endswith((".", ";", ":", "?", "!")) ) elif previous.endswith((".", ":", ";", "?", "!", ")")): output.append(line) in_formula_block = False else: output[-1] = f"{previous} {line}" in_formula_block = False return "\n".join(output) def is_chapter_heading(line: str) -> bool: return bool(re.fullmatch( r"(chapter|appendix)\s+([0-9]+|[ivxlcdm]+|[a-z])", line.strip(), flags=re.IGNORECASE, )) def titlecase_word_ratio(words: List[str]) -> float: candidate_words = [ word.strip("()[]{}:;,.") for word in words if any(char.isalpha() for char in word) ] if not candidate_words: return 0.0 titlecase_words = [ word for word in candidate_words if word[:1].isupper() or word.lower() in {"a", "an", "and", "for", "in", "of", "on", "or", "the", "to", "with"} ] return len(titlecase_words) / len(candidate_words) def uppercase_letter_ratio(text: str) -> float: letters = [char for char in text if char.isalpha()] if not letters: return 0.0 return sum(1 for char in letters if char.isupper()) / len(letters) def is_section_heading(line: str) -> bool: stripped = line.strip() if not 4 <= len(stripped) <= 150: return False letters = [char for char in stripped if char.isalpha()] digit_count = sum(1 for char in stripped if char.isdigit()) alpha_words = [ word.strip("()[]{}:;,.") for word in stripped.split() if any(char.isalpha() for char in word) ] if len(letters) < 6 or len(alpha_words) < 2: return False if digit_count > max(4, len(letters)): return False if "%" in stripped and digit_count >= len(letters) / 2: return False numbered_heading = bool(re.match(r"^\d+(\.\d+)+\s+", stripped)) if stripped[:1].isdigit() and not numbered_heading: return False if re.match( r"^(in|from|where|thus|then|now|let|because|while|figure|table|for)\b", stripped, flags=re.IGNORECASE, ): return False if is_formula_like(stripped): return False if stripped.endswith((".", ",", ";")): return False if re.match(r"^(figure|table)\s+\d", stripped, flags=re.IGNORECASE): return False if numbered_heading: return True words = stripped.split() if len(words) > 16: return False if uppercase_letter_ratio(stripped) >= 0.72 and len(words) >= 2: return True if len(words) >= 4 and titlecase_word_ratio(words) >= 0.68: return True return False def make_section_path(chapter_title: str, section_title: str) -> str: if chapter_title and section_title and section_title != chapter_title: return f"{chapter_title} > {section_title}" return section_title or chapter_title def split_pdf_page_into_sections( path: Path, page_index: int, text: str, file_hash: str, section_state: dict, extraction_backend: str, formula_count: int, ) -> List[Document]: documents = [] lines = text.splitlines() pending_lines: List[str] = [] pending_metadata = { "chapter_title": section_state.get("chapter_title", ""), "section_title": section_state.get("section_title", ""), } def flush_pending() -> None: nonlocal pending_lines, pending_metadata section_text = "\n".join(line for line in pending_lines if line.strip()).strip() if not section_text: pending_lines = [] return chapter_title = pending_metadata.get("chapter_title", "") section_title = pending_metadata.get("section_title", "") documents.append( Document( text=section_text, metadata={ "source_file": str(path.resolve()), "file_name": path.name, "file_type": "pdf", "document_title": path.stem, "file_hash": file_hash, "page_number": page_index, "extraction_method": PDF_EXTRACTION_METHOD, "extraction_backend": extraction_backend, "char_count": len(section_text), "formula_count": formula_count, "content_type": "text", "chapter_title": chapter_title, "section_title": section_title, "section_path": make_section_path(chapter_title, section_title), }, ) ) pending_lines = [] for line in lines: stripped = line.strip() if not stripped: continue if is_chapter_heading(stripped): if len("\n".join(pending_lines)) >= PDF_MIN_SECTION_CHARS: flush_pending() section_state["pending_chapter_label"] = stripped.title() section_state["chapter_title"] = stripped.title() section_state["section_title"] = stripped.title() pending_metadata = { "chapter_title": section_state["chapter_title"], "section_title": section_state["section_title"], } pending_lines.append(stripped) continue if section_state.get("pending_chapter_label") and is_section_heading(stripped): if pending_lines == [section_state["pending_chapter_label"]]: pending_lines[0] = f"{section_state['pending_chapter_label']}: {stripped}" else: pending_lines.append(stripped) section_state["chapter_title"] = pending_lines[-1] section_state["section_title"] = pending_lines[-1] section_state["pending_chapter_label"] = "" pending_metadata = { "chapter_title": section_state["chapter_title"], "section_title": section_state["section_title"], } continue if is_section_heading(stripped): if len("\n".join(pending_lines)) >= PDF_MIN_SECTION_CHARS: flush_pending() section_state["section_title"] = stripped section_state["pending_chapter_label"] = "" pending_metadata = { "chapter_title": section_state.get("chapter_title", ""), "section_title": section_state["section_title"], } pending_lines.append(stripped) flush_pending() return documents def make_formula_documents( path: Path, page_index: int, formula_blocks: List[dict], file_hash: str, extraction_backend: str, ) -> List[Document]: documents = [] for formula_index, formula in enumerate(formula_blocks): formula_text = formula.get("text", "").strip() if not formula_text: continue documents.append( Document( text=f"[FORMULA]\n{formula_text}\n[/FORMULA]", metadata={ "source_file": str(path.resolve()), "file_name": path.name, "file_type": "pdf", "document_title": path.stem, "file_hash": file_hash, "page_number": page_index, "extraction_method": PDF_EXTRACTION_METHOD, "extraction_backend": extraction_backend, "char_count": len(formula_text), "content_type": "formula", "formula_id": formula.get("id", f"formula-{page_index}-{formula_index}"), "formula_index": formula_index, "formula_bbox": formula.get("bbox", ""), "formula_count": 1, "chapter_title": "", "section_title": "", "section_path": "", }, ) ) return documents def load_pdf_file(path: Path) -> List[Document]: reader = PdfReader(str(path)) documents = [] pymupdf_pages = extract_pdf_pages_with_pymupdf(path) if pymupdf_pages: page_payloads = pymupdf_pages else: page_payloads = [ { "text": extract_pdf_text(page), "formula_blocks": [], "backend": "pypdf", } for page in reader.pages ] raw_pages = [payload["text"] for payload in page_payloads] repeated_boundary_lines = find_repeated_boundary_lines(raw_pages) file_hash = file_sha256(path) section_state: dict = { "chapter_title": "", "section_title": "", "pending_chapter_label": "", } for page_index, payload in enumerate(page_payloads, start=1): raw_text = payload["text"] text = clean_pdf_text(raw_text, repeated_boundary_lines) formula_blocks = payload.get("formula_blocks", []) extraction_backend = payload.get("backend", "pypdf") if not text.strip(): documents.extend( make_formula_documents( path=path, page_index=page_index, formula_blocks=formula_blocks, file_hash=file_hash, extraction_backend=extraction_backend, ) ) continue documents.extend( split_pdf_page_into_sections( path=path, page_index=page_index, text=text, file_hash=file_hash, section_state=section_state, extraction_backend=extraction_backend, formula_count=len(formula_blocks), ) ) documents.extend( make_formula_documents( path=path, page_index=page_index, formula_blocks=formula_blocks, file_hash=file_hash, extraction_backend=extraction_backend, ) ) return documents def load_txt_file(path: Path) -> List[Document]: text = path.read_text(encoding="utf-8") return [ Document( text=text, metadata={ "source_file": str(path.resolve()), "file_name": path.name, "file_type": "txt", "document_title": path.stem, "file_hash": file_sha256(path), "content_type": "text", "chapter_title": "", "section_title": "", "section_path": "", "char_count": len(text), }, ) ] def iter_source_files(raw_dir: Path) -> Iterable[Path]: supported_suffixes = {".md", ".markdown", ".pdf", ".txt"} for path in sorted(raw_dir.rglob("*")): if path.is_file() and path.suffix.lower() in supported_suffixes: yield path def load_docs(raw_dir: Path = RAW_DIR) -> List[Document]: documents: List[Document] = [] raw_dir = effective_raw_dir(raw_dir) for path in iter_source_files(raw_dir): suffix = path.suffix.lower() if suffix in {".md", ".markdown"}: documents.extend(load_md_documents(path)) elif suffix == ".pdf": documents.extend(load_pdf_file(path)) elif suffix == ".txt": documents.extend(load_txt_file(path)) if not documents: raise ValueError(f"No supported documents found under {raw_dir}") return documents def add_chunk_metadata(nodes: List[BaseNode]) -> List[BaseNode]: counters: dict[str, int] = {} for node in nodes: source_file = node.metadata["source_file"] chunk_index = counters.get(source_file, 0) counters[source_file] = chunk_index + 1 file_hash = node.metadata["file_hash"][:12] page_number = node.metadata.get("page_number", "na") chunk_id = f"{Path(source_file).stem}-{file_hash}-p{page_number}-c{chunk_index}" node.metadata["chunk_id"] = chunk_id node.metadata["chunk_index"] = chunk_index node.metadata[EMBED_MODEL_METADATA_KEY] = EMBED_MODEL_NAME node.id_ = chunk_id return nodes def validate_nodes(nodes: List[BaseNode]) -> None: if not nodes: raise ValueError("No chunks were created from the source documents.") for node in nodes: missing = [key for key in REQUIRED_METADATA if key not in node.metadata] if missing: raise ValueError( f"Node {node.node_id} is missing metadata fields: {missing}") if node.metadata["file_type"] == "pdf" and "page_number" not in node.metadata: raise ValueError( f"PDF node {node.node_id} is missing page_number metadata.") def split_documents(documents: List[Document]) -> List[BaseNode]: splitter = SentenceSplitter( chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP, ) nodes = splitter.get_nodes_from_documents(documents) add_chunk_metadata(nodes) validate_nodes(nodes) return nodes def build_nodes(raw_dir: Path = RAW_DIR) -> List[BaseNode]: documents = load_docs(raw_dir) return split_documents(documents) def load_source_file(path: Path) -> List[Document]: suffix = path.suffix.lower() if suffix in {".md", ".markdown"}: return load_md_documents(path) if suffix == ".pdf": return load_pdf_file(path) if suffix == ".txt": return load_txt_file(path) return [] def list_current_sources(raw_dir: Path = RAW_DIR) -> dict[str, dict[str, str]]: raw_dir = effective_raw_dir(raw_dir) sources = {} for path in iter_source_files(raw_dir): resolved = str(path.resolve()) sources[resolved] = { "file_hash": file_sha256(path), "file_type": path.suffix.lower().lstrip("."), } return sources def existing_source_metadata(chroma_collection) -> dict[str, dict[str, str]]: existing: dict[str, dict[str, str]] = {} if chroma_collection.count() == 0: return existing offset = 0 limit = 500 while True: batch = chroma_collection.get( limit=limit, offset=offset, include=["metadatas"], ) metadatas = batch.get("metadatas") or [] if not metadatas: break for metadata in metadatas: source_file = metadata.get("source_file") if not source_file: continue existing[source_file] = { "file_hash": metadata.get("file_hash", ""), "file_type": metadata.get("file_type", ""), "embedding_model": metadata.get(EMBED_MODEL_METADATA_KEY, ""), "extraction_method": metadata.get("extraction_method", ""), } if len(metadatas) < limit: break offset += limit return existing def source_needs_update(current: dict[str, str], existing: dict[str, str] | None) -> bool: if not existing: return True if existing.get("file_hash") != current["file_hash"]: return True if existing.get("embedding_model") != EMBED_MODEL_NAME: return True if current["file_type"] == "pdf" and existing.get("extraction_method") != PDF_EXTRACTION_METHOD: return True return False def incremental_update_index( raw_dir: Path, chroma_collection, storage_context: StorageContext, embed_model, ) -> bool: current_sources = list_current_sources(raw_dir) existing_sources = existing_source_metadata(chroma_collection) deleted_sources = sorted(set(existing_sources) - set(current_sources)) changed_sources = sorted( source_file for source_file, current in current_sources.items() if source_needs_update(current, existing_sources.get(source_file)) ) for source_file in deleted_sources + changed_sources: try: chroma_collection.delete(where={"source_file": source_file}) except Exception as exc: logging.warning("Could not delete stale chunks for %s: %s", source_file, exc) if not changed_sources: if deleted_sources: print(f"Removed {len(deleted_sources)} stale source(s) from collection '{COLLECTION_NAME}'.") return bool(deleted_sources) documents: List[Document] = [] for source_file in changed_sources: documents.extend(load_source_file(Path(source_file))) nodes = split_documents(documents) VectorStoreIndex( nodes, storage_context=storage_context, embed_model=embed_model, show_progress=True, ) print( f"Incrementally indexed {len(nodes)} chunk(s) from {len(changed_sources)} source file(s)." ) return True def collection_needs_rebuild(chroma_collection) -> bool: if chroma_collection.count() == 0: return True try: sample = chroma_collection.peek(limit=min(chroma_collection.count(), 20)) except Exception: return False for metadata in sample.get("metadatas") or []: if metadata.get(EMBED_MODEL_METADATA_KEY) != EMBED_MODEL_NAME: return True if metadata.get("file_type") == "pdf": return metadata.get("extraction_method") != PDF_EXTRACTION_METHOD return False async def build_index(raw_dir: Path = RAW_DIR, rebuild: bool = False) -> VectorStoreIndex: configure_model_cache() from llama_index.embeddings.huggingface import HuggingFaceEmbedding load_dotenv() raw_dir = effective_raw_dir(raw_dir) CHROMA_DB_DIR.mkdir(parents=True, exist_ok=True) db = chromadb.PersistentClient(path=str(CHROMA_DB_DIR)) if rebuild: try: db.delete_collection(COLLECTION_NAME) except (NotFoundError, ValueError): pass chroma_collection = db.get_or_create_collection(COLLECTION_NAME) vector_store = ChromaVectorStore(chroma_collection=chroma_collection) storage_context = StorageContext.from_defaults(vector_store=vector_store) embed_model = HuggingFaceEmbedding( model_name=resolve_embed_model_name(), cache_folder=str(HF_CACHE_DIR / "sentence_transformers"), ) if rebuild or chroma_collection.count() == 0: nodes = build_nodes(raw_dir) index = VectorStoreIndex( nodes, storage_context=storage_context, embed_model=embed_model, show_progress=True, ) print( f"Indexed {len(nodes)} chunks into collection '{COLLECTION_NAME}'") return index incremental_update_index( raw_dir=raw_dir, chroma_collection=chroma_collection, storage_context=storage_context, embed_model=embed_model, ) print( f"Loaded existing collection '{COLLECTION_NAME}' with {chroma_collection.count()} chunks.") return VectorStoreIndex.from_vector_store(vector_store, embed_model=embed_model) class QueryKnowledgeTool(Tool): name = "query_knowledge" description = ( "Searches the local options trading knowledge base. Use this for option " "concepts, volatility, Greeks, strategies, formulas, equation numbers, " "and citations from reference books." ) inputs = {'query': {'type': 'string', 'description': 'The search query to perform.'}} output_type = "string" @staticmethod def format_results(results, max_chars: int = 800): output = [] for result in results: metadata = result.node.metadata source = metadata.get("file_name", "unknown") page = metadata.get("page_number", "n/a") section = metadata.get("section_path") or metadata.get("section_title") or "n/a" content_type = metadata.get("content_type", "text") formula_id = metadata.get("formula_id", "") score = result.score text = result.node.get_content() if len(text) > max_chars: text = f"{text[:max_chars].rstrip()}..." output.append( f"source:{source}\n" f"page:{page}\n" f"section:{section}\n" f"content_type:{content_type}\n" f"formula_id:{formula_id or 'n/a'}\n" f"score:{score:.4f}\n" f"vector_score:{metadata.get(VECTOR_METADATA_KEY, 'n/a')}\n" f"bm25_score:{metadata.get(BM25_METADATA_KEY, 'n/a')}\n" f"content:{text}" ) return "\n\n---\n\n".join(output) @staticmethod def load_bm25_nodes(collection_name: str = COLLECTION_NAME) -> list[TextNode]: db = chromadb.PersistentClient(path=str(CHROMA_DB_DIR)) try: collection = db.get_collection(collection_name) except Exception: return [] nodes: list[TextNode] = [] offset = 0 limit = 500 while True: batch = collection.get( limit=limit, offset=offset, include=["documents", "metadatas"], ) documents = batch.get("documents") or [] metadatas = batch.get("metadatas") or [] ids = batch.get("ids") or [] if not documents: break for index, text in enumerate(documents): metadata = dict(metadatas[index] or {}) node_id = ids[index] if index < len(ids) else metadata.get("chunk_id", "") nodes.append(TextNode(id_=node_id, text=text or "", metadata=metadata)) if len(documents) < limit: break offset += limit return nodes @staticmethod def merge_results( vector_results: list[NodeWithScore], bm25_results: list[NodeWithScore], top_k: int, ) -> list[NodeWithScore]: merged: dict[str, NodeWithScore] = {} for rank, result in enumerate(vector_results): node_id = result.node.node_id result.node.metadata[VECTOR_METADATA_KEY] = result.score merged[node_id] = NodeWithScore( node=result.node, score=1.0 / (rank + 1), ) for rank, result in enumerate(bm25_results): node_id = result.node.node_id result.node.metadata[BM25_METADATA_KEY] = result.score reciprocal_rank_score = 1.0 / (rank + 1) if node_id in merged: merged[node_id].score = (merged[node_id].score or 0.0) + reciprocal_rank_score merged[node_id].node.metadata[BM25_METADATA_KEY] = result.score else: merged[node_id] = NodeWithScore( node=result.node, score=reciprocal_rank_score, ) results = list(merged.values()) results.sort(key=lambda item: item.score or float("-inf"), reverse=True) return results[:top_k] def __init__( self, max_results=20, top_k=5, use_reranker: Optional[bool] = None, use_hybrid: Optional[bool] = None, reranker_top_n: Optional[int] = None, reranker_model_name: Optional[str] = None, **kwargs, ): super().__init__() self.max_results = max_results self.top_k = top_k self.use_reranker = ( env_flag("RAG_USE_RERANKER", True) if use_reranker is None else use_reranker ) self.use_hybrid = ( env_flag("RAG_USE_HYBRID", True) if use_hybrid is None else use_hybrid ) self.reranker_top_n = reranker_top_n or top_k self.reranker = ( CrossEncoderReranker(reranker_model_name or RERANKER_MODEL_NAME) if self.use_reranker else None ) index = asyncio.run(build_index(rebuild=False)) retrieve_top_k = max(max_results, top_k) if self.use_reranker else top_k self.retriever = index.as_retriever(similarity_top_k=retrieve_top_k) self.bm25_retriever = ( BM25Retriever(self.load_bm25_nodes()) if self.use_hybrid else None ) def forward(self, query: str) -> str: vector_results = self.retriever.retrieve(query) results = vector_results if self.bm25_retriever: bm25_results = self.bm25_retriever.retrieve(query, self.max_results) results = self.merge_results( vector_results=vector_results, bm25_results=bm25_results, top_k=max(self.max_results, self.top_k), ) if self.reranker: try: results = self.reranker.rerank( query, results, top_n=self.reranker_top_n, ) except Exception as exc: logging.warning("Reranker failed; falling back to vector ranking: %s", exc) results = results[:self.top_k] return QueryKnowledgeTool.format_results(results[:self.top_k]) if __name__ == "__main__": query_tool = QueryKnowledgeTool() res: str = query_tool.forward("What is option?") print(res)