First_agent_template / tools /query_knowledge.py
mathidot's picture
build option trading agent modules
8f1601b
from smolagents.tools import Tool
import asyncio
from collections import Counter
import hashlib
import logging
import math
import os
from pathlib import Path
from typing import Iterable, List, Optional
import re
from dotenv import load_dotenv
import chromadb
from chromadb.errors import NotFoundError
from pypdf import PdfReader
from llama_index.core import StorageContext, VectorStoreIndex
from llama_index.core.schema import Document, BaseNode, NodeWithScore, TextNode
from llama_index.core.node_parser import SentenceSplitter
from llama_index.vector_stores.chroma import ChromaVectorStore
load_dotenv()
BASE_DIR = Path(__file__).resolve().parent
PROJECT_ROOT = BASE_DIR.parent
KNOWLEDGE_BASE_DIR = BASE_DIR / "knowledge_base"
LEGACY_KNOWLEDGE_BASE_DIR = BASE_DIR / "knowledge_base"
KNOWLEDGE_BASE_DIR = PROJECT_ROOT / "knowledge_base"
RAW_DIR = KNOWLEDGE_BASE_DIR / "raw"
CHROMA_DB_DIR = KNOWLEDGE_BASE_DIR / "chroma_db"
HF_CACHE_DIR = PROJECT_ROOT / "hf_cache"
COLLECTION_NAME = "options_knowledge"
EMBED_MODEL_NAME = os.getenv("RAG_EMBED_MODEL", "BAAI/bge-small-en-v1.5")
RERANKER_MODEL_NAME = os.getenv(
"RAG_RERANKER_MODEL", "cross-encoder/ms-marco-MiniLM-L-6-v2")
RERANKER_BATCH_SIZE = int(os.getenv("RAG_RERANKER_BATCH_SIZE", "16"))
EMBED_MODEL_METADATA_KEY = "embedding_model"
BM25_METADATA_KEY = "bm25_score"
VECTOR_METADATA_KEY = "vector_score"
CHUNK_SIZE = 1000
CHUNK_OVERLAP = 150
PDF_REPEATED_LINE_MIN_PAGES = 3
PDF_BOUNDARY_LINE_COUNT = 4
PDF_EXTRACTION_METHOD = "pymupdf_formula_blocks_v5"
PDF_LINE_Y_TOLERANCE = 3.0
PDF_MIN_SECTION_CHARS = 240
PDF_STRONG_MATH_SYMBOLS = set("=∂∫∑∏√∞≈≠≤≥±×÷^_σΣΔδθΘλΛμρπΠφΦτν𝜎𝜇𝜌𝜃𝜕")
PDF_WEAK_MATH_SYMBOLS = set("+-−*/∕<>")
PDF_MATH_SYMBOLS = PDF_STRONG_MATH_SYMBOLS | PDF_WEAK_MATH_SYMBOLS
PDF_OPERATOR_MATH_SYMBOLS = set("=∂∫∑∏√∞≈≠≤≥±×÷^_+-−*/∕<>")
PDF_FORMULA_TRIGGER_SYMBOLS = set("=∂∫∑∏√∞≈≠≤≥±×÷^_∕<>")
logging.getLogger("pypdf").setLevel(logging.ERROR)
def load_pymupdf():
try:
import fitz
except ImportError:
return None
return fitz
REQUIRED_METADATA = [
"source_file",
"file_name",
"file_type",
"document_title",
"file_hash",
"chunk_id",
"chunk_index",
]
def configure_model_cache() -> None:
HF_CACHE_DIR.mkdir(parents=True, exist_ok=True)
os.environ.setdefault("HF_HOME", str(HF_CACHE_DIR))
os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", str(
HF_CACHE_DIR / "sentence_transformers"))
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
if local_model_snapshot(EMBED_MODEL_NAME):
os.environ.setdefault("HF_HUB_OFFLINE", "1")
os.environ.setdefault("TRANSFORMERS_OFFLINE", "1")
def local_model_snapshot(model_name: str) -> Optional[Path]:
cached_model_dir = (
HF_CACHE_DIR
/ "sentence_transformers"
/ f"models--{model_name.replace('/', '--')}"
)
snapshots_dir = cached_model_dir / "snapshots"
if snapshots_dir.exists():
snapshots = sorted(path for path in snapshots_dir.iterdir() if path.is_dir())
for snapshot in reversed(snapshots):
if (snapshot / "config.json").exists():
return snapshot
return None
def resolve_embed_model_name() -> str:
snapshot = local_model_snapshot(EMBED_MODEL_NAME)
if snapshot:
return str(snapshot)
return EMBED_MODEL_NAME
def resolve_reranker_model_name(model_name: str = RERANKER_MODEL_NAME) -> str:
snapshot = local_model_snapshot(model_name)
if snapshot:
return str(snapshot)
return model_name
def env_flag(name: str, default: bool = False) -> bool:
value = os.getenv(name)
if value is None:
return default
return value.strip().lower() in {"1", "true", "yes", "on"}
def effective_raw_dir(raw_dir: Path = RAW_DIR) -> Path:
if any(iter_source_files(raw_dir)):
return raw_dir
legacy_raw_dir = LEGACY_KNOWLEDGE_BASE_DIR / "raw"
if any(iter_source_files(legacy_raw_dir)):
logging.warning(
"Using legacy knowledge base path %s. Move files to %s when convenient.",
legacy_raw_dir,
raw_dir,
)
return legacy_raw_dir
return raw_dir
class CrossEncoderReranker:
def __init__(
self,
model_name: str = RERANKER_MODEL_NAME,
batch_size: int = RERANKER_BATCH_SIZE,
):
self.model_name = model_name
self.batch_size = batch_size
self._model = None
def _load_model(self):
if self._model is not None:
return self._model
from sentence_transformers import CrossEncoder
self._model = CrossEncoder(
resolve_reranker_model_name(self.model_name),
max_length=512,
cache_folder=str(HF_CACHE_DIR / "sentence_transformers"),
)
return self._model
def rerank(
self,
query: str,
results: list[NodeWithScore],
top_n: Optional[int] = None,
) -> list[NodeWithScore]:
if not results:
return []
pairs = [
(query, result.node.get_content())
for result in results
]
model = self._load_model()
scores = model.predict(
pairs,
batch_size=self.batch_size,
show_progress_bar=False,
)
reranked = [
NodeWithScore(node=result.node, score=float(score))
for result, score in zip(results, scores)
]
reranked.sort(key=lambda item: item.score or float("-inf"), reverse=True)
return reranked[:top_n] if top_n else reranked
class BM25Retriever:
def __init__(self, nodes: list[TextNode]):
self.nodes = nodes
self.tokenized_docs = [self.tokenize(node.get_content()) for node in nodes]
self.doc_freqs: Counter[str] = Counter()
for tokens in self.tokenized_docs:
self.doc_freqs.update(set(tokens))
self.avg_doc_len = (
sum(len(tokens) for tokens in self.tokenized_docs) / len(self.tokenized_docs)
if self.tokenized_docs
else 0.0
)
@staticmethod
def tokenize(text: str) -> list[str]:
return [
token.lower()
for token in re.findall(r"[A-Za-z]+(?:[-'][A-Za-z]+)*|\d+(?:\.\d+)*|[^\sA-Za-z0-9]", text)
if token.strip()
]
def score(self, query_tokens: list[str], doc_tokens: list[str]) -> float:
if not query_tokens or not doc_tokens:
return 0.0
token_counts = Counter(doc_tokens)
doc_len = len(doc_tokens)
total_docs = len(self.tokenized_docs)
k1 = 1.5
b = 0.75
score = 0.0
for token in query_tokens:
term_freq = token_counts.get(token, 0)
if term_freq == 0:
continue
doc_freq = self.doc_freqs.get(token, 0)
idf = math.log(1 + (total_docs - doc_freq + 0.5) / (doc_freq + 0.5))
denominator = term_freq + k1 * (
1 - b + b * doc_len / max(self.avg_doc_len, 1.0)
)
score += idf * (term_freq * (k1 + 1)) / denominator
return score
def retrieve(self, query: str, top_k: int) -> list[NodeWithScore]:
query_tokens = self.tokenize(query)
scored: list[NodeWithScore] = []
for node, doc_tokens in zip(self.nodes, self.tokenized_docs):
score = self.score(query_tokens, doc_tokens)
if score <= 0:
continue
node.metadata[BM25_METADATA_KEY] = score
scored.append(NodeWithScore(node=node, score=score))
scored.sort(key=lambda item: item.score or float("-inf"), reverse=True)
return scored[:top_k]
def file_sha256(path: Path) -> str:
digest = hashlib.sha256()
with path.open("rb") as file:
for block in iter(lambda: file.read(1024 * 1024), b""):
digest.update(block)
return digest.hexdigest()
def load_md_file(path: Path) -> Document:
text = path.read_text(encoding="utf-8")
return Document(
text=text,
metadata={
"source_file": str(path.resolve()),
"file_name": path.name,
"file_type": "md",
"document_title": path.stem,
"file_hash": file_sha256(path),
},
)
def load_md_documents(path: Path) -> List[Document]:
text = path.read_text(encoding="utf-8")
file_hash = file_sha256(path)
documents: List[Document] = []
current_heading = ""
current_lines: List[str] = []
def flush() -> None:
nonlocal current_lines
section_text = "\n".join(current_lines).strip()
if not section_text:
current_lines = []
return
documents.append(
Document(
text=section_text,
metadata={
"source_file": str(path.resolve()),
"file_name": path.name,
"file_type": path.suffix.lower().lstrip("."),
"document_title": path.stem,
"file_hash": file_hash,
"content_type": "markdown_section",
"chapter_title": "",
"section_title": current_heading,
"section_path": current_heading,
"char_count": len(section_text),
},
)
)
current_lines = []
for line in text.splitlines():
heading_match = re.match(r"^(#{1,6})\s+(.+?)\s*$", line)
if heading_match:
flush()
current_heading = heading_match.group(2).strip()
current_lines.append(line)
flush()
return documents or [load_md_file(path)]
def append_visual_fragment(line_parts: List[str], text: str, baseline_y: float, item: dict) -> None:
if not text:
return
stripped = text.strip()
if not stripped:
return
font_size = item["font_size"]
y_offset = item["y"] - baseline_y
is_small = font_size < item["line_font_size"] * 0.82
if is_small and y_offset > max(1.5, item["line_font_size"] * 0.18):
line_parts.append(f"^{{{stripped}}}")
elif is_small and y_offset < -max(1.5, item["line_font_size"] * 0.18):
line_parts.append(f"_{{{stripped}}}")
else:
line_parts.append(stripped)
def join_visual_line(items: List[dict]) -> str:
if not items:
return ""
items = sorted(items, key=lambda value: value["x"])
baseline_y = sorted(item["y"] for item in items)[len(items) // 2]
line_font_size = max(item["font_size"] for item in items)
previous_right = None
line_parts: List[str] = []
for item in items:
item["line_font_size"] = line_font_size
if previous_right is not None:
gap = item["x"] - previous_right
if gap > max(2.5, line_font_size * 0.28):
line_parts.append(" ")
append_visual_fragment(line_parts, item["text"], baseline_y, item)
previous_right = max(previous_right or item["x"], item["x"] + item["width"])
return normalize_pdf_line("".join(line_parts))
def extract_pdf_text_by_position(page) -> str:
fragments: List[dict] = []
def visitor_text(text, cm, tm, font_dict, font_size):
if not text or not text.strip():
return
x = float(tm[4])
y = float(tm[5])
width = max(len(text.strip()) * float(font_size) * 0.45, float(font_size))
fragments.append(
{
"text": text,
"x": x,
"y": y,
"width": width,
"font_size": float(font_size or 1.0),
}
)
try:
page.extract_text(visitor_text=visitor_text)
except Exception:
return ""
if not fragments:
return ""
lines: List[List[dict]] = []
for fragment in sorted(fragments, key=lambda value: (-value["y"], value["x"])):
for line in lines:
if abs(line[0]["y"] - fragment["y"]) <= PDF_LINE_Y_TOLERANCE:
line.append(fragment)
break
else:
lines.append([fragment])
return "\n".join(join_visual_line(line) for line in lines)
def math_text_score(text: str) -> float:
if not text.strip():
return 0.0
lines = [line for line in text.splitlines() if line.strip()]
compact_length = len(re.sub(r"\s+", "", text))
math_symbol_count = sum(1 for char in text if char in PDF_MATH_SYMBOLS)
superscript_markers = text.count("^{") + text.count("_{")
multiline_bonus = sum(1 for line in lines if is_formula_like(line)) * 8
equation_block_bonus = sum(
1
for index, line in enumerate(lines)
if is_formula_like(line)
and (
index > 0
and is_formula_like(lines[index - 1])
or index + 1 < len(lines)
and is_formula_like(lines[index + 1])
)
) * 12
return (
compact_length
+ math_symbol_count * 12
+ superscript_markers * 20
+ multiline_bonus
+ equation_block_bonus
)
def extract_pdf_text(page) -> str:
positioned_text = extract_pdf_text_by_position(page)
try:
layout_text = page.extract_text(extraction_mode="layout") or ""
except Exception:
layout_text = ""
try:
plain_text = page.extract_text() or ""
except Exception:
plain_text = ""
candidates = [positioned_text, layout_text, plain_text]
candidates = [candidate for candidate in candidates if candidate.strip()]
if not candidates:
return ""
return max(candidates, key=math_text_score)
def pymupdf_span_text(span: dict) -> str:
return normalize_pdf_line(span.get("text", ""))
def pymupdf_line_text(line: dict) -> str:
return normalize_pdf_line("".join(pymupdf_span_text(span) for span in line.get("spans", [])))
def pymupdf_block_text(block: dict) -> str:
lines = [
pymupdf_line_text(line)
for line in block.get("lines", [])
]
return "\n".join(line for line in lines if line)
def pymupdf_span_has_math_font(span: dict) -> bool:
font_name = span.get("font", "").lower()
return any(
marker in font_name
for marker in ("math", "symbol", "cmmi", "cmsy", "cmex", "stix")
)
def is_formula_block_line(line: str) -> bool:
stripped = line.strip()
if not stripped:
return False
trigger_math_count = sum(1 for char in stripped if char in PDF_FORMULA_TRIGGER_SYMBOLS)
digit_count = sum(1 for char in stripped if char.isdigit())
alpha_count = sum(1 for char in stripped if char.isalpha())
alpha_words = [
word
for word in re.findall(r"[A-Za-z]+", stripped)
if word.lower() not in {"and", "or", "the", "where", "then", "with", "for"}
]
compact_length = len(re.sub(r"\s+", "", stripped))
if compact_length < 3:
return False
if re.fullmatch(r"\(?\d+(\.\d+)?\)?", stripped):
return False
if re.search(r"\(\d+(\.\d+)+[a-z]?\)$", stripped) and compact_length <= 240:
return True
if "=" in stripped and compact_length <= 260 and len(alpha_words) <= 12:
return True
if any(char in stripped for char in "∂∫∑∏√∞≈≠≤≥±×÷") and compact_length <= 220 and len(alpha_words) <= 10:
return True
if trigger_math_count >= 2 and compact_length <= 120 and len(alpha_words) <= 6:
return True
if trigger_math_count >= 1 and digit_count >= 1 and alpha_count <= 18 and compact_length <= 100:
return True
return False
def is_formula_block(block: dict) -> bool:
text = pymupdf_block_text(block)
if not text:
return False
lines = [line for line in text.splitlines() if line.strip()]
if any(is_formula_block_line(line) for line in lines):
return True
spans = [
span
for line in block.get("lines", [])
for span in line.get("spans", [])
if pymupdf_span_text(span)
]
if not spans:
return False
math_font_count = sum(1 for span in spans if pymupdf_span_has_math_font(span))
strong_math_count = sum(1 for char in text if char in PDF_STRONG_MATH_SYMBOLS)
alpha_count = sum(1 for char in text if char.isalpha())
digit_count = sum(1 for char in text if char.isdigit())
compact_length = len(re.sub(r"\s+", "", text))
if math_font_count >= 2 and compact_length <= 220:
return True
if strong_math_count >= 3 and compact_length <= 260:
return True
if strong_math_count >= 1 and digit_count >= 1 and alpha_count <= 20 and compact_length <= 160:
return True
return False
def block_bbox_string(block: dict) -> str:
bbox = block.get("bbox") or []
if len(bbox) != 4:
return ""
return ",".join(f"{float(value):.2f}" for value in bbox)
def line_bbox_string(line: dict) -> str:
bbox = line.get("bbox") or []
if len(bbox) != 4:
return ""
return ",".join(f"{float(value):.2f}" for value in bbox)
def pymupdf_line_has_math_font(line: dict) -> bool:
return any(
pymupdf_span_has_math_font(span)
for span in line.get("spans", [])
if pymupdf_span_text(span)
)
def should_extract_formula_line(line: dict) -> bool:
text = pymupdf_line_text(line)
if not text:
return False
if is_formula_block_line(text):
return True
compact_length = len(re.sub(r"\s+", "", text))
trigger_math_count = sum(1 for char in text if char in PDF_FORMULA_TRIGGER_SYMBOLS)
alpha_words = re.findall(r"[A-Za-z]+", text)
if (
pymupdf_line_has_math_font(line)
and trigger_math_count >= 1
and compact_length <= 180
and len(alpha_words) <= 6
):
return True
return False
def is_formula_continuation_line(text: str) -> bool:
stripped = text.strip()
if not stripped:
return False
compact = re.sub(r"\s+", "", stripped)
if len(compact) > 90:
return False
if compact in {"(", ")", "[", "]", "{", "}", "√"}:
return True
alpha_words = re.findall(r"[A-Za-z]+", stripped)
math_count = sum(1 for char in stripped if char in PDF_MATH_SYMBOLS)
digit_count = sum(1 for char in stripped if char.isdigit())
if len(alpha_words) <= 4 and (math_count >= 1 or digit_count >= 1):
return True
return False
def append_formula_block(
formula_blocks: List[dict],
body_blocks: List[str],
page_number: int,
formula_index: int,
formula_lines: List[str],
formula_bboxes: List[str],
) -> int:
formula_text = clean_formula_text("\n".join(formula_lines))
if not is_useful_formula_text(formula_text):
return formula_index
formula_id = f"formula-{page_number}-{formula_index}"
formula_bbox = merge_bbox_strings(formula_bboxes)
formula_blocks.append(
{
"id": formula_id,
"text": formula_text,
"bbox": formula_bbox,
}
)
body_blocks.append(f"[FORMULA id={formula_id}]\n{formula_text}\n[/FORMULA]")
return formula_index + 1
def merge_bbox_strings(bbox_strings: List[str]) -> str:
boxes = []
for bbox_string in bbox_strings:
if not bbox_string:
continue
values = bbox_string.split(",")
if len(values) != 4:
continue
try:
boxes.append([float(value) for value in values])
except ValueError:
continue
if not boxes:
return ""
x0 = min(box[0] for box in boxes)
y0 = min(box[1] for box in boxes)
x1 = max(box[2] for box in boxes)
y1 = max(box[3] for box in boxes)
return f"{x0:.2f},{y0:.2f},{x1:.2f},{y1:.2f}"
def is_useful_formula_text(text: str) -> bool:
stripped = text.strip()
if not stripped:
return False
compact_length = len(re.sub(r"\s+", "", stripped))
if compact_length < 6:
return False
lines = [line.strip() for line in stripped.splitlines() if line.strip()]
if re.search(r"\(\d+(\.\d+)+[a-z]?\)", stripped):
return True
if any(char in stripped for char in "∂∫∑∏∞≈≠≤≥±×÷"):
alpha_words = re.findall(r"[A-Za-z]+", stripped)
return len(alpha_words) <= 12 or "=" in stripped
for line in lines:
if "=" not in line:
continue
alpha_words = [
word
for word in re.findall(r"[A-Za-z]+", line)
if word.lower() not in {"and", "or", "the", "where", "then", "with", "for"}
]
if len(alpha_words) <= 12 and len(line) <= 260:
return True
return False
def extract_pymupdf_page(page) -> dict:
page_dict = page.get_text("dict", sort=True)
body_blocks: List[str] = []
formula_blocks: List[dict] = []
formula_lines: List[str] = []
formula_bboxes: List[str] = []
formula_index = 0
page_number = page.number + 1
for block in page_dict.get("blocks", []):
if block.get("type") != 0:
continue
normal_lines: List[str] = []
for line in block.get("lines", []):
line_text = pymupdf_line_text(line)
if not line_text:
continue
if should_extract_formula_line(line) or (
formula_lines and is_formula_continuation_line(line_text)
):
if normal_lines:
body_blocks.append("\n".join(normal_lines))
normal_lines = []
formula_lines.append(line_text)
formula_bboxes.append(line_bbox_string(line))
else:
if formula_lines:
formula_index = append_formula_block(
formula_blocks=formula_blocks,
body_blocks=body_blocks,
page_number=page_number,
formula_index=formula_index,
formula_lines=formula_lines,
formula_bboxes=formula_bboxes,
)
formula_lines = []
formula_bboxes = []
normal_lines.append(line_text)
if normal_lines:
body_blocks.append("\n".join(normal_lines))
if formula_lines:
append_formula_block(
formula_blocks=formula_blocks,
body_blocks=body_blocks,
page_number=page_number,
formula_index=formula_index,
formula_lines=formula_lines,
formula_bboxes=formula_bboxes,
)
return {
"text": "\n".join(body_blocks),
"formula_blocks": formula_blocks,
"backend": "pymupdf",
}
def extract_pdf_pages_with_pymupdf(path: Path) -> Optional[List[dict]]:
fitz = load_pymupdf()
if fitz is None:
return None
try:
document = fitz.open(str(path))
except Exception:
return None
try:
return [extract_pymupdf_page(page) for page in document]
finally:
document.close()
def clean_formula_text(text: str) -> str:
lines = page_lines(text)
if not lines:
return ""
text = "\n".join(lines)
text = re.sub(r"[ \t]+", " ", text)
text = re.sub(r"\n{3,}", "\n\n", text)
return text.strip()
def normalize_pdf_line(line: str) -> str:
line = line.replace("\x00", " ")
line = line.replace("\ufb00", "ff")
line = line.replace("\ufb01", "fi")
line = line.replace("\ufb02", "fl")
line = line.replace("\ufb03", "ffi")
line = line.replace("\ufb04", "ffl")
line = re.sub(r"[ \t]+", " ", line)
return line.strip()
def is_noise_line(line: str) -> bool:
if not line:
return True
if re.fullmatch(r"\d+", line):
return True
if re.fullmatch(r"page\s+\d+(\s+of\s+\d+)?", line, flags=re.IGNORECASE):
return True
if re.fullmatch(r"[-_=\s]{3,}", line):
return True
return False
def is_formula_like(line: str) -> bool:
stripped = line.strip()
if not stripped:
return False
strong_math_count = sum(1 for char in stripped if char in PDF_STRONG_MATH_SYMBOLS)
weak_math_count = sum(1 for char in stripped if char in PDF_WEAK_MATH_SYMBOLS)
alpha_count = sum(1 for char in stripped if char.isalpha())
digit_count = sum(1 for char in stripped if char.isdigit())
compact = stripped.replace(" ", "")
if "={" in compact or "^{" in compact or "_{" in compact:
return True
if compact in {"(", ")", "[", "]", "{", "}"}:
return True
if len(compact) <= 40 and any(char in compact for char in PDF_MATH_SYMBOLS):
return True
if strong_math_count >= 2 and len(stripped) <= 180:
return True
if strong_math_count >= 1 and weak_math_count >= 1 and len(stripped) <= 180:
return True
if "=" in stripped and (alpha_count + digit_count) >= 2 and len(stripped) <= 220:
return True
if re.search(r"\b(d|D|exp|ln|sqrt|max|min|var|cov)\s*[\(\[]", stripped):
return True
if alpha_count <= 4 and (strong_math_count + weak_math_count) >= 1 and digit_count >= 1:
return True
return False
def normalized_line_key(line: str) -> str:
return re.sub(r"\d+", "#", line.lower()).strip()
def page_lines(text: str) -> List[str]:
lines = []
for line in text.replace("\r\n", "\n").replace("\r", "\n").split("\n"):
normalized = normalize_pdf_line(line)
if not is_noise_line(normalized):
lines.append(normalized)
return lines
def find_repeated_boundary_lines(raw_pages: List[str]) -> set[str]:
counter: Counter[str] = Counter()
for raw_text in raw_pages:
lines = page_lines(raw_text)
boundary_lines = lines[:PDF_BOUNDARY_LINE_COUNT] + lines[-PDF_BOUNDARY_LINE_COUNT:]
counter.update(
normalized_line_key(line)
for line in boundary_lines
if 3 <= len(line) <= 140
)
min_count = min(
PDF_REPEATED_LINE_MIN_PAGES,
max(2, len(raw_pages) // 3),
)
return {line for line, count in counter.items() if count >= min_count}
def clean_pdf_text(text: str, repeated_boundary_lines: set[str]) -> str:
lines = page_lines(text)
cleaned_lines = []
for index, line in enumerate(lines):
is_boundary = (
index < PDF_BOUNDARY_LINE_COUNT
or index >= len(lines) - PDF_BOUNDARY_LINE_COUNT
)
if is_boundary and normalized_line_key(line) in repeated_boundary_lines:
continue
cleaned_lines.append(line)
merged_lines = []
for line in cleaned_lines:
if merged_lines and merged_lines[-1].endswith("-") and line[:1].islower():
merged_lines[-1] = merged_lines[-1][:-1] + line
else:
merged_lines.append(line)
text = "\n".join(merged_lines)
text = preserve_math_line_breaks(text)
text = re.sub(r"[ \t]+", " ", text)
text = re.sub(r"\n{3,}", "\n\n", text)
return text.strip()
def preserve_math_line_breaks(text: str) -> str:
lines = text.split("\n")
if not lines:
return ""
output = [lines[0]]
in_formula_block = is_formula_like(lines[0])
for line in lines[1:]:
previous = output[-1]
line_is_formula = is_formula_like(line)
previous_is_formula = is_formula_like(previous)
if previous_is_formula or line_is_formula or in_formula_block:
output.append(line)
in_formula_block = line_is_formula or (
in_formula_block
and not line.endswith((".", ";", ":", "?", "!"))
)
elif previous.endswith((".", ":", ";", "?", "!", ")")):
output.append(line)
in_formula_block = False
else:
output[-1] = f"{previous} {line}"
in_formula_block = False
return "\n".join(output)
def is_chapter_heading(line: str) -> bool:
return bool(re.fullmatch(
r"(chapter|appendix)\s+([0-9]+|[ivxlcdm]+|[a-z])",
line.strip(),
flags=re.IGNORECASE,
))
def titlecase_word_ratio(words: List[str]) -> float:
candidate_words = [
word.strip("()[]{}:;,.")
for word in words
if any(char.isalpha() for char in word)
]
if not candidate_words:
return 0.0
titlecase_words = [
word
for word in candidate_words
if word[:1].isupper()
or word.lower() in {"a", "an", "and", "for", "in", "of", "on", "or", "the", "to", "with"}
]
return len(titlecase_words) / len(candidate_words)
def uppercase_letter_ratio(text: str) -> float:
letters = [char for char in text if char.isalpha()]
if not letters:
return 0.0
return sum(1 for char in letters if char.isupper()) / len(letters)
def is_section_heading(line: str) -> bool:
stripped = line.strip()
if not 4 <= len(stripped) <= 150:
return False
letters = [char for char in stripped if char.isalpha()]
digit_count = sum(1 for char in stripped if char.isdigit())
alpha_words = [
word.strip("()[]{}:;,.")
for word in stripped.split()
if any(char.isalpha() for char in word)
]
if len(letters) < 6 or len(alpha_words) < 2:
return False
if digit_count > max(4, len(letters)):
return False
if "%" in stripped and digit_count >= len(letters) / 2:
return False
numbered_heading = bool(re.match(r"^\d+(\.\d+)+\s+", stripped))
if stripped[:1].isdigit() and not numbered_heading:
return False
if re.match(
r"^(in|from|where|thus|then|now|let|because|while|figure|table|for)\b",
stripped,
flags=re.IGNORECASE,
):
return False
if is_formula_like(stripped):
return False
if stripped.endswith((".", ",", ";")):
return False
if re.match(r"^(figure|table)\s+\d", stripped, flags=re.IGNORECASE):
return False
if numbered_heading:
return True
words = stripped.split()
if len(words) > 16:
return False
if uppercase_letter_ratio(stripped) >= 0.72 and len(words) >= 2:
return True
if len(words) >= 4 and titlecase_word_ratio(words) >= 0.68:
return True
return False
def make_section_path(chapter_title: str, section_title: str) -> str:
if chapter_title and section_title and section_title != chapter_title:
return f"{chapter_title} > {section_title}"
return section_title or chapter_title
def split_pdf_page_into_sections(
path: Path,
page_index: int,
text: str,
file_hash: str,
section_state: dict,
extraction_backend: str,
formula_count: int,
) -> List[Document]:
documents = []
lines = text.splitlines()
pending_lines: List[str] = []
pending_metadata = {
"chapter_title": section_state.get("chapter_title", ""),
"section_title": section_state.get("section_title", ""),
}
def flush_pending() -> None:
nonlocal pending_lines, pending_metadata
section_text = "\n".join(line for line in pending_lines if line.strip()).strip()
if not section_text:
pending_lines = []
return
chapter_title = pending_metadata.get("chapter_title", "")
section_title = pending_metadata.get("section_title", "")
documents.append(
Document(
text=section_text,
metadata={
"source_file": str(path.resolve()),
"file_name": path.name,
"file_type": "pdf",
"document_title": path.stem,
"file_hash": file_hash,
"page_number": page_index,
"extraction_method": PDF_EXTRACTION_METHOD,
"extraction_backend": extraction_backend,
"char_count": len(section_text),
"formula_count": formula_count,
"content_type": "text",
"chapter_title": chapter_title,
"section_title": section_title,
"section_path": make_section_path(chapter_title, section_title),
},
)
)
pending_lines = []
for line in lines:
stripped = line.strip()
if not stripped:
continue
if is_chapter_heading(stripped):
if len("\n".join(pending_lines)) >= PDF_MIN_SECTION_CHARS:
flush_pending()
section_state["pending_chapter_label"] = stripped.title()
section_state["chapter_title"] = stripped.title()
section_state["section_title"] = stripped.title()
pending_metadata = {
"chapter_title": section_state["chapter_title"],
"section_title": section_state["section_title"],
}
pending_lines.append(stripped)
continue
if section_state.get("pending_chapter_label") and is_section_heading(stripped):
if pending_lines == [section_state["pending_chapter_label"]]:
pending_lines[0] = f"{section_state['pending_chapter_label']}: {stripped}"
else:
pending_lines.append(stripped)
section_state["chapter_title"] = pending_lines[-1]
section_state["section_title"] = pending_lines[-1]
section_state["pending_chapter_label"] = ""
pending_metadata = {
"chapter_title": section_state["chapter_title"],
"section_title": section_state["section_title"],
}
continue
if is_section_heading(stripped):
if len("\n".join(pending_lines)) >= PDF_MIN_SECTION_CHARS:
flush_pending()
section_state["section_title"] = stripped
section_state["pending_chapter_label"] = ""
pending_metadata = {
"chapter_title": section_state.get("chapter_title", ""),
"section_title": section_state["section_title"],
}
pending_lines.append(stripped)
flush_pending()
return documents
def make_formula_documents(
path: Path,
page_index: int,
formula_blocks: List[dict],
file_hash: str,
extraction_backend: str,
) -> List[Document]:
documents = []
for formula_index, formula in enumerate(formula_blocks):
formula_text = formula.get("text", "").strip()
if not formula_text:
continue
documents.append(
Document(
text=f"[FORMULA]\n{formula_text}\n[/FORMULA]",
metadata={
"source_file": str(path.resolve()),
"file_name": path.name,
"file_type": "pdf",
"document_title": path.stem,
"file_hash": file_hash,
"page_number": page_index,
"extraction_method": PDF_EXTRACTION_METHOD,
"extraction_backend": extraction_backend,
"char_count": len(formula_text),
"content_type": "formula",
"formula_id": formula.get("id", f"formula-{page_index}-{formula_index}"),
"formula_index": formula_index,
"formula_bbox": formula.get("bbox", ""),
"formula_count": 1,
"chapter_title": "",
"section_title": "",
"section_path": "",
},
)
)
return documents
def load_pdf_file(path: Path) -> List[Document]:
reader = PdfReader(str(path))
documents = []
pymupdf_pages = extract_pdf_pages_with_pymupdf(path)
if pymupdf_pages:
page_payloads = pymupdf_pages
else:
page_payloads = [
{
"text": extract_pdf_text(page),
"formula_blocks": [],
"backend": "pypdf",
}
for page in reader.pages
]
raw_pages = [payload["text"] for payload in page_payloads]
repeated_boundary_lines = find_repeated_boundary_lines(raw_pages)
file_hash = file_sha256(path)
section_state: dict = {
"chapter_title": "",
"section_title": "",
"pending_chapter_label": "",
}
for page_index, payload in enumerate(page_payloads, start=1):
raw_text = payload["text"]
text = clean_pdf_text(raw_text, repeated_boundary_lines)
formula_blocks = payload.get("formula_blocks", [])
extraction_backend = payload.get("backend", "pypdf")
if not text.strip():
documents.extend(
make_formula_documents(
path=path,
page_index=page_index,
formula_blocks=formula_blocks,
file_hash=file_hash,
extraction_backend=extraction_backend,
)
)
continue
documents.extend(
split_pdf_page_into_sections(
path=path,
page_index=page_index,
text=text,
file_hash=file_hash,
section_state=section_state,
extraction_backend=extraction_backend,
formula_count=len(formula_blocks),
)
)
documents.extend(
make_formula_documents(
path=path,
page_index=page_index,
formula_blocks=formula_blocks,
file_hash=file_hash,
extraction_backend=extraction_backend,
)
)
return documents
def load_txt_file(path: Path) -> List[Document]:
text = path.read_text(encoding="utf-8")
return [
Document(
text=text,
metadata={
"source_file": str(path.resolve()),
"file_name": path.name,
"file_type": "txt",
"document_title": path.stem,
"file_hash": file_sha256(path),
"content_type": "text",
"chapter_title": "",
"section_title": "",
"section_path": "",
"char_count": len(text),
},
)
]
def iter_source_files(raw_dir: Path) -> Iterable[Path]:
supported_suffixes = {".md", ".markdown", ".pdf", ".txt"}
for path in sorted(raw_dir.rglob("*")):
if path.is_file() and path.suffix.lower() in supported_suffixes:
yield path
def load_docs(raw_dir: Path = RAW_DIR) -> List[Document]:
documents: List[Document] = []
raw_dir = effective_raw_dir(raw_dir)
for path in iter_source_files(raw_dir):
suffix = path.suffix.lower()
if suffix in {".md", ".markdown"}:
documents.extend(load_md_documents(path))
elif suffix == ".pdf":
documents.extend(load_pdf_file(path))
elif suffix == ".txt":
documents.extend(load_txt_file(path))
if not documents:
raise ValueError(f"No supported documents found under {raw_dir}")
return documents
def add_chunk_metadata(nodes: List[BaseNode]) -> List[BaseNode]:
counters: dict[str, int] = {}
for node in nodes:
source_file = node.metadata["source_file"]
chunk_index = counters.get(source_file, 0)
counters[source_file] = chunk_index + 1
file_hash = node.metadata["file_hash"][:12]
page_number = node.metadata.get("page_number", "na")
chunk_id = f"{Path(source_file).stem}-{file_hash}-p{page_number}-c{chunk_index}"
node.metadata["chunk_id"] = chunk_id
node.metadata["chunk_index"] = chunk_index
node.metadata[EMBED_MODEL_METADATA_KEY] = EMBED_MODEL_NAME
node.id_ = chunk_id
return nodes
def validate_nodes(nodes: List[BaseNode]) -> None:
if not nodes:
raise ValueError("No chunks were created from the source documents.")
for node in nodes:
missing = [key for key in REQUIRED_METADATA if key not in node.metadata]
if missing:
raise ValueError(
f"Node {node.node_id} is missing metadata fields: {missing}")
if node.metadata["file_type"] == "pdf" and "page_number" not in node.metadata:
raise ValueError(
f"PDF node {node.node_id} is missing page_number metadata.")
def split_documents(documents: List[Document]) -> List[BaseNode]:
splitter = SentenceSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=CHUNK_OVERLAP,
)
nodes = splitter.get_nodes_from_documents(documents)
add_chunk_metadata(nodes)
validate_nodes(nodes)
return nodes
def build_nodes(raw_dir: Path = RAW_DIR) -> List[BaseNode]:
documents = load_docs(raw_dir)
return split_documents(documents)
def load_source_file(path: Path) -> List[Document]:
suffix = path.suffix.lower()
if suffix in {".md", ".markdown"}:
return load_md_documents(path)
if suffix == ".pdf":
return load_pdf_file(path)
if suffix == ".txt":
return load_txt_file(path)
return []
def list_current_sources(raw_dir: Path = RAW_DIR) -> dict[str, dict[str, str]]:
raw_dir = effective_raw_dir(raw_dir)
sources = {}
for path in iter_source_files(raw_dir):
resolved = str(path.resolve())
sources[resolved] = {
"file_hash": file_sha256(path),
"file_type": path.suffix.lower().lstrip("."),
}
return sources
def existing_source_metadata(chroma_collection) -> dict[str, dict[str, str]]:
existing: dict[str, dict[str, str]] = {}
if chroma_collection.count() == 0:
return existing
offset = 0
limit = 500
while True:
batch = chroma_collection.get(
limit=limit,
offset=offset,
include=["metadatas"],
)
metadatas = batch.get("metadatas") or []
if not metadatas:
break
for metadata in metadatas:
source_file = metadata.get("source_file")
if not source_file:
continue
existing[source_file] = {
"file_hash": metadata.get("file_hash", ""),
"file_type": metadata.get("file_type", ""),
"embedding_model": metadata.get(EMBED_MODEL_METADATA_KEY, ""),
"extraction_method": metadata.get("extraction_method", ""),
}
if len(metadatas) < limit:
break
offset += limit
return existing
def source_needs_update(current: dict[str, str], existing: dict[str, str] | None) -> bool:
if not existing:
return True
if existing.get("file_hash") != current["file_hash"]:
return True
if existing.get("embedding_model") != EMBED_MODEL_NAME:
return True
if current["file_type"] == "pdf" and existing.get("extraction_method") != PDF_EXTRACTION_METHOD:
return True
return False
def incremental_update_index(
raw_dir: Path,
chroma_collection,
storage_context: StorageContext,
embed_model,
) -> bool:
current_sources = list_current_sources(raw_dir)
existing_sources = existing_source_metadata(chroma_collection)
deleted_sources = sorted(set(existing_sources) - set(current_sources))
changed_sources = sorted(
source_file
for source_file, current in current_sources.items()
if source_needs_update(current, existing_sources.get(source_file))
)
for source_file in deleted_sources + changed_sources:
try:
chroma_collection.delete(where={"source_file": source_file})
except Exception as exc:
logging.warning("Could not delete stale chunks for %s: %s", source_file, exc)
if not changed_sources:
if deleted_sources:
print(f"Removed {len(deleted_sources)} stale source(s) from collection '{COLLECTION_NAME}'.")
return bool(deleted_sources)
documents: List[Document] = []
for source_file in changed_sources:
documents.extend(load_source_file(Path(source_file)))
nodes = split_documents(documents)
VectorStoreIndex(
nodes,
storage_context=storage_context,
embed_model=embed_model,
show_progress=True,
)
print(
f"Incrementally indexed {len(nodes)} chunk(s) from {len(changed_sources)} source file(s)."
)
return True
def collection_needs_rebuild(chroma_collection) -> bool:
if chroma_collection.count() == 0:
return True
try:
sample = chroma_collection.peek(limit=min(chroma_collection.count(), 20))
except Exception:
return False
for metadata in sample.get("metadatas") or []:
if metadata.get(EMBED_MODEL_METADATA_KEY) != EMBED_MODEL_NAME:
return True
if metadata.get("file_type") == "pdf":
return metadata.get("extraction_method") != PDF_EXTRACTION_METHOD
return False
async def build_index(raw_dir: Path = RAW_DIR, rebuild: bool = False) -> VectorStoreIndex:
configure_model_cache()
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
load_dotenv()
raw_dir = effective_raw_dir(raw_dir)
CHROMA_DB_DIR.mkdir(parents=True, exist_ok=True)
db = chromadb.PersistentClient(path=str(CHROMA_DB_DIR))
if rebuild:
try:
db.delete_collection(COLLECTION_NAME)
except (NotFoundError, ValueError):
pass
chroma_collection = db.get_or_create_collection(COLLECTION_NAME)
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
embed_model = HuggingFaceEmbedding(
model_name=resolve_embed_model_name(),
cache_folder=str(HF_CACHE_DIR / "sentence_transformers"),
)
if rebuild or chroma_collection.count() == 0:
nodes = build_nodes(raw_dir)
index = VectorStoreIndex(
nodes,
storage_context=storage_context,
embed_model=embed_model,
show_progress=True,
)
print(
f"Indexed {len(nodes)} chunks into collection '{COLLECTION_NAME}'")
return index
incremental_update_index(
raw_dir=raw_dir,
chroma_collection=chroma_collection,
storage_context=storage_context,
embed_model=embed_model,
)
print(
f"Loaded existing collection '{COLLECTION_NAME}' with {chroma_collection.count()} chunks.")
return VectorStoreIndex.from_vector_store(vector_store, embed_model=embed_model)
class QueryKnowledgeTool(Tool):
name = "query_knowledge"
description = (
"Searches the local options trading knowledge base. Use this for option "
"concepts, volatility, Greeks, strategies, formulas, equation numbers, "
"and citations from reference books."
)
inputs = {'query': {'type': 'string',
'description': 'The search query to perform.'}}
output_type = "string"
@staticmethod
def format_results(results, max_chars: int = 800):
output = []
for result in results:
metadata = result.node.metadata
source = metadata.get("file_name", "unknown")
page = metadata.get("page_number", "n/a")
section = metadata.get("section_path") or metadata.get("section_title") or "n/a"
content_type = metadata.get("content_type", "text")
formula_id = metadata.get("formula_id", "")
score = result.score
text = result.node.get_content()
if len(text) > max_chars:
text = f"{text[:max_chars].rstrip()}..."
output.append(
f"source:{source}\n"
f"page:{page}\n"
f"section:{section}\n"
f"content_type:{content_type}\n"
f"formula_id:{formula_id or 'n/a'}\n"
f"score:{score:.4f}\n"
f"vector_score:{metadata.get(VECTOR_METADATA_KEY, 'n/a')}\n"
f"bm25_score:{metadata.get(BM25_METADATA_KEY, 'n/a')}\n"
f"content:{text}"
)
return "\n\n---\n\n".join(output)
@staticmethod
def load_bm25_nodes(collection_name: str = COLLECTION_NAME) -> list[TextNode]:
db = chromadb.PersistentClient(path=str(CHROMA_DB_DIR))
try:
collection = db.get_collection(collection_name)
except Exception:
return []
nodes: list[TextNode] = []
offset = 0
limit = 500
while True:
batch = collection.get(
limit=limit,
offset=offset,
include=["documents", "metadatas"],
)
documents = batch.get("documents") or []
metadatas = batch.get("metadatas") or []
ids = batch.get("ids") or []
if not documents:
break
for index, text in enumerate(documents):
metadata = dict(metadatas[index] or {})
node_id = ids[index] if index < len(ids) else metadata.get("chunk_id", "")
nodes.append(TextNode(id_=node_id, text=text or "", metadata=metadata))
if len(documents) < limit:
break
offset += limit
return nodes
@staticmethod
def merge_results(
vector_results: list[NodeWithScore],
bm25_results: list[NodeWithScore],
top_k: int,
) -> list[NodeWithScore]:
merged: dict[str, NodeWithScore] = {}
for rank, result in enumerate(vector_results):
node_id = result.node.node_id
result.node.metadata[VECTOR_METADATA_KEY] = result.score
merged[node_id] = NodeWithScore(
node=result.node,
score=1.0 / (rank + 1),
)
for rank, result in enumerate(bm25_results):
node_id = result.node.node_id
result.node.metadata[BM25_METADATA_KEY] = result.score
reciprocal_rank_score = 1.0 / (rank + 1)
if node_id in merged:
merged[node_id].score = (merged[node_id].score or 0.0) + reciprocal_rank_score
merged[node_id].node.metadata[BM25_METADATA_KEY] = result.score
else:
merged[node_id] = NodeWithScore(
node=result.node,
score=reciprocal_rank_score,
)
results = list(merged.values())
results.sort(key=lambda item: item.score or float("-inf"), reverse=True)
return results[:top_k]
def __init__(
self,
max_results=20,
top_k=5,
use_reranker: Optional[bool] = None,
use_hybrid: Optional[bool] = None,
reranker_top_n: Optional[int] = None,
reranker_model_name: Optional[str] = None,
**kwargs,
):
super().__init__()
self.max_results = max_results
self.top_k = top_k
self.use_reranker = (
env_flag("RAG_USE_RERANKER", True)
if use_reranker is None
else use_reranker
)
self.use_hybrid = (
env_flag("RAG_USE_HYBRID", True)
if use_hybrid is None
else use_hybrid
)
self.reranker_top_n = reranker_top_n or top_k
self.reranker = (
CrossEncoderReranker(reranker_model_name or RERANKER_MODEL_NAME)
if self.use_reranker
else None
)
index = asyncio.run(build_index(rebuild=False))
retrieve_top_k = max(max_results, top_k) if self.use_reranker else top_k
self.retriever = index.as_retriever(similarity_top_k=retrieve_top_k)
self.bm25_retriever = (
BM25Retriever(self.load_bm25_nodes())
if self.use_hybrid
else None
)
def forward(self, query: str) -> str:
vector_results = self.retriever.retrieve(query)
results = vector_results
if self.bm25_retriever:
bm25_results = self.bm25_retriever.retrieve(query, self.max_results)
results = self.merge_results(
vector_results=vector_results,
bm25_results=bm25_results,
top_k=max(self.max_results, self.top_k),
)
if self.reranker:
try:
results = self.reranker.rerank(
query,
results,
top_n=self.reranker_top_n,
)
except Exception as exc:
logging.warning("Reranker failed; falling back to vector ranking: %s", exc)
results = results[:self.top_k]
return QueryKnowledgeTool.format_results(results[:self.top_k])
if __name__ == "__main__":
query_tool = QueryKnowledgeTool()
res: str = query_tool.forward("What is option?")
print(res)