| | """ |
| | Reranker Agent |
| | |
| | Cross-encoder based reranking for improved retrieval precision. |
| | Follows FAANG best practices for production RAG systems. |
| | |
| | Key Features: |
| | - LLM-based cross-encoder reranking |
| | - Relevance scoring with explanations |
| | - Diversity promotion to avoid redundancy |
| | - Quality filtering (removes low-quality chunks) |
| | - Chunk deduplication |
| | """ |
| |
|
| | from typing import List, Optional, Dict, Any, Tuple |
| | from pydantic import BaseModel, Field |
| | from loguru import logger |
| | from dataclasses import dataclass |
| | import json |
| | import re |
| | from difflib import SequenceMatcher |
| |
|
| | try: |
| | import httpx |
| | HTTPX_AVAILABLE = True |
| | except ImportError: |
| | HTTPX_AVAILABLE = False |
| |
|
| | from .retriever import RetrievalResult |
| |
|
| |
|
| | class RerankerConfig(BaseModel): |
| | """Configuration for reranking.""" |
| | |
| | model: str = Field(default="llama3.2:3b") |
| | base_url: str = Field(default="http://localhost:11434") |
| | temperature: float = Field(default=0.1) |
| |
|
| | |
| | top_k: int = Field(default=5, ge=1) |
| | min_relevance_score: float = Field(default=0.3, ge=0.0, le=1.0) |
| |
|
| | |
| | enable_diversity: bool = Field(default=True) |
| | diversity_threshold: float = Field(default=0.8, description="Max similarity between chunks") |
| |
|
| | |
| | dedup_threshold: float = Field(default=0.9, description="Similarity threshold for dedup") |
| |
|
| | |
| | use_llm_rerank: bool = Field(default=True) |
| |
|
| |
|
| | class RankedResult(BaseModel): |
| | """A reranked result with relevance score.""" |
| | chunk_id: str |
| | document_id: str |
| | text: str |
| | original_score: float |
| | relevance_score: float |
| | final_score: float |
| | relevance_explanation: Optional[str] = None |
| |
|
| | |
| | page: Optional[int] = None |
| | chunk_type: Optional[str] = None |
| | source_path: Optional[str] = None |
| | metadata: Dict[str, Any] = Field(default_factory=dict) |
| | bbox: Optional[Dict[str, float]] = None |
| |
|
| |
|
| | class RerankerAgent: |
| | """ |
| | Reranks retrieval results for improved precision. |
| | |
| | Capabilities: |
| | 1. Cross-encoder relevance scoring |
| | 2. Diversity-aware reranking (MMR-style) |
| | 3. Quality filtering |
| | 4. Chunk deduplication |
| | """ |
| |
|
| | RERANK_PROMPT = """Score the relevance of this text passage to the given query. |
| | |
| | Query: {query} |
| | |
| | Passage: {passage} |
| | |
| | Score the relevance on a scale of 0-10 where: |
| | - 0-2: Completely irrelevant, no useful information |
| | - 3-4: Marginally relevant, tangentially related |
| | - 5-6: Somewhat relevant, contains some useful information |
| | - 7-8: Highly relevant, directly addresses the query |
| | - 9-10: Perfectly relevant, comprehensive answer to query |
| | |
| | Respond with ONLY a JSON object: |
| | {{"score": <number>, "explanation": "<brief reason>"}}""" |
| |
|
| | def __init__(self, config: Optional[RerankerConfig] = None): |
| | """ |
| | Initialize Reranker Agent. |
| | |
| | Args: |
| | config: Reranker configuration |
| | """ |
| | self.config = config or RerankerConfig() |
| | logger.info(f"RerankerAgent initialized (model={self.config.model})") |
| |
|
| | def rerank( |
| | self, |
| | query: str, |
| | results: List[RetrievalResult], |
| | top_k: Optional[int] = None, |
| | ) -> List[RankedResult]: |
| | """ |
| | Rerank retrieval results by relevance to query. |
| | |
| | Args: |
| | query: Original search query |
| | results: Retrieval results to rerank |
| | top_k: Number of results to return |
| | |
| | Returns: |
| | Reranked results with relevance scores |
| | """ |
| | if not results: |
| | return [] |
| |
|
| | top_k = top_k or self.config.top_k |
| |
|
| | |
| | deduped = self._deduplicate(results) |
| |
|
| | |
| | if self.config.use_llm_rerank and HTTPX_AVAILABLE: |
| | scored = self._llm_rerank(query, deduped) |
| | else: |
| | scored = self._heuristic_rerank(query, deduped) |
| |
|
| | |
| | filtered = [ |
| | r for r in scored |
| | if r.relevance_score >= self.config.min_relevance_score |
| | ] |
| |
|
| | |
| | if self.config.enable_diversity: |
| | diverse = self._promote_diversity(filtered, top_k) |
| | else: |
| | diverse = sorted(filtered, key=lambda x: x.final_score, reverse=True)[:top_k] |
| |
|
| | return diverse |
| |
|
| | def _deduplicate(self, results: List[RetrievalResult]) -> List[RetrievalResult]: |
| | """Remove near-duplicate chunks.""" |
| | if not results: |
| | return [] |
| |
|
| | deduped = [results[0]] |
| |
|
| | for result in results[1:]: |
| | is_dup = False |
| | for existing in deduped: |
| | similarity = self._text_similarity(result.text, existing.text) |
| | if similarity > self.config.dedup_threshold: |
| | is_dup = True |
| | break |
| |
|
| | if not is_dup: |
| | deduped.append(result) |
| |
|
| | if len(results) != len(deduped): |
| | logger.debug(f"Deduplication: {len(results)} -> {len(deduped)} chunks") |
| |
|
| | return deduped |
| |
|
| | def _text_similarity(self, text1: str, text2: str) -> float: |
| | """Compute text similarity using SequenceMatcher.""" |
| | return SequenceMatcher(None, text1.lower(), text2.lower()).ratio() |
| |
|
| | def _llm_rerank( |
| | self, |
| | query: str, |
| | results: List[RetrievalResult], |
| | ) -> List[RankedResult]: |
| | """Use LLM for cross-encoder style reranking.""" |
| | ranked = [] |
| |
|
| | for result in results: |
| | try: |
| | relevance_score, explanation = self._score_passage(query, result.text) |
| |
|
| | |
| | |
| | final_score = 0.3 * result.score + 0.7 * (relevance_score / 10.0) |
| |
|
| | ranked.append(RankedResult( |
| | chunk_id=result.chunk_id, |
| | document_id=result.document_id, |
| | text=result.text, |
| | original_score=result.score, |
| | relevance_score=relevance_score / 10.0, |
| | final_score=final_score, |
| | relevance_explanation=explanation, |
| | page=result.page, |
| | chunk_type=result.chunk_type, |
| | source_path=result.source_path, |
| | metadata=result.metadata, |
| | bbox=result.bbox, |
| | )) |
| |
|
| | except Exception as e: |
| | logger.warning(f"Failed to score passage: {e}") |
| | |
| | ranked.append(RankedResult( |
| | chunk_id=result.chunk_id, |
| | document_id=result.document_id, |
| | text=result.text, |
| | original_score=result.score, |
| | relevance_score=result.score, |
| | final_score=result.score, |
| | page=result.page, |
| | chunk_type=result.chunk_type, |
| | source_path=result.source_path, |
| | metadata=result.metadata, |
| | bbox=result.bbox, |
| | )) |
| |
|
| | return ranked |
| |
|
| | def _score_passage(self, query: str, passage: str) -> Tuple[float, str]: |
| | """Score a single passage using LLM.""" |
| | prompt = self.RERANK_PROMPT.format( |
| | query=query, |
| | passage=passage[:1000], |
| | ) |
| |
|
| | with httpx.Client(timeout=30.0) as client: |
| | response = client.post( |
| | f"{self.config.base_url}/api/generate", |
| | json={ |
| | "model": self.config.model, |
| | "prompt": prompt, |
| | "stream": False, |
| | "options": { |
| | "temperature": self.config.temperature, |
| | "num_predict": 256, |
| | }, |
| | }, |
| | ) |
| | response.raise_for_status() |
| | result = response.json() |
| |
|
| | |
| | response_text = result.get("response", "") |
| | return self._parse_score_response(response_text) |
| |
|
| | def _parse_score_response(self, text: str) -> Tuple[float, str]: |
| | """Parse score and explanation from LLM response.""" |
| | try: |
| | |
| | json_match = re.search(r'\{[\s\S]*\}', text) |
| | if json_match: |
| | data = json.loads(json_match.group()) |
| | score = float(data.get("score", 5)) |
| | explanation = data.get("explanation", "") |
| | return min(max(score, 0), 10), explanation |
| | except Exception: |
| | pass |
| |
|
| | |
| | num_match = re.search(r'\b([0-9]|10)\b', text) |
| | if num_match: |
| | return float(num_match.group()), "" |
| |
|
| | |
| | return 5.0, "Could not parse score" |
| |
|
| | def _heuristic_rerank( |
| | self, |
| | query: str, |
| | results: List[RetrievalResult], |
| | ) -> List[RankedResult]: |
| | """Fast heuristic-based reranking.""" |
| | query_terms = set(query.lower().split()) |
| | ranked = [] |
| |
|
| | for result in results: |
| | |
| | text_lower = result.text.lower() |
| |
|
| | |
| | text_terms = set(text_lower.split()) |
| | overlap = len(query_terms & text_terms) / len(query_terms) if query_terms else 0 |
| |
|
| | |
| | phrase_bonus = 0.2 if query.lower() in text_lower else 0 |
| |
|
| | |
| | length = len(result.text) |
| | length_score = min(length, 500) / 500 |
| |
|
| | |
| | relevance = 0.5 * overlap + 0.3 * phrase_bonus + 0.2 * length_score |
| | final_score = 0.4 * result.score + 0.6 * relevance |
| |
|
| | ranked.append(RankedResult( |
| | chunk_id=result.chunk_id, |
| | document_id=result.document_id, |
| | text=result.text, |
| | original_score=result.score, |
| | relevance_score=relevance, |
| | final_score=final_score, |
| | page=result.page, |
| | chunk_type=result.chunk_type, |
| | source_path=result.source_path, |
| | metadata=result.metadata, |
| | bbox=result.bbox, |
| | )) |
| |
|
| | return ranked |
| |
|
| | def _promote_diversity( |
| | self, |
| | results: List[RankedResult], |
| | top_k: int, |
| | ) -> List[RankedResult]: |
| | """ |
| | Promote diversity using MMR-style selection. |
| | |
| | Maximal Marginal Relevance balances relevance with diversity. |
| | """ |
| | if not results: |
| | return [] |
| |
|
| | |
| | sorted_results = sorted(results, key=lambda x: x.final_score, reverse=True) |
| |
|
| | selected = [sorted_results[0]] |
| | remaining = sorted_results[1:] |
| |
|
| | while len(selected) < top_k and remaining: |
| | |
| | best_mmr = -1 |
| | best_idx = 0 |
| |
|
| | for i, candidate in enumerate(remaining): |
| | |
| | relevance = candidate.final_score |
| |
|
| | |
| | max_sim = max( |
| | self._text_similarity(candidate.text, s.text) |
| | for s in selected |
| | ) |
| |
|
| | |
| | |
| | mmr = 0.7 * relevance - 0.3 * max_sim |
| |
|
| | if mmr > best_mmr: |
| | best_mmr = mmr |
| | best_idx = i |
| |
|
| | selected.append(remaining.pop(best_idx)) |
| |
|
| | return selected |
| |
|