| from __future__ import annotations |
|
|
| from typing import List, Dict, Any, Optional, Tuple |
| import math |
|
|
| from ...models.model_configs import LLMConfig |
| from ...agents.customize_agent import CustomizeAgent |
|
|
|
|
| def _tokenize(text: str) -> List[str]: |
| return [t for t in text.lower().split() if t.strip()] |
|
|
|
|
| def _tf_vector(tokens: List[str]) -> Dict[str, float]: |
| vec: Dict[str, float] = {} |
| for t in tokens: |
| vec[t] = vec.get(t, 0.0) + 1.0 |
| |
| norm = math.sqrt(sum(v * v for v in vec.values())) or 1.0 |
| for k in list(vec.keys()): |
| vec[k] /= norm |
| return vec |
|
|
|
|
| def _cosine_sim(a: Dict[str, float], b: Dict[str, float]) -> float: |
| if len(a) < len(b): |
| a, b = b, a |
| return sum(v * b.get(k, 0.0) for k, v in a.items()) |
|
|
|
|
| def _js_divergence(p: Dict[str, float], q: Dict[str, float]) -> float: |
| |
| vocab = set(p.keys()) | set(q.keys()) |
| eps = 1e-9 |
| def _norm(d: Dict[str, float]) -> Dict[str, float]: |
| s = sum(d.get(w, 0.0) for w in vocab) or 1.0 |
| return {w: (d.get(w, 0.0) + eps) / (s + eps * len(vocab)) for w in vocab} |
| P = _norm(p) |
| Q = _norm(q) |
| M = {w: 0.5 * (P[w] + Q[w]) for w in vocab} |
| def _kl(X, Y): |
| return sum(X[w] * math.log((X[w] + eps) / (Y[w] + eps)) for w in vocab) |
| return 0.5 * _kl(P, M) + 0.5 * _kl(Q, M) |
|
|
|
|
| class PruningPipeline: |
| """可插拔剪枝流水线:质量剪枝(QP) → 多样性剪枝(DP) → 误解反驳(MR)。 |
| |
| 候选输入格式:List[{"agent_id": int, "text": str}] |
| 输出保留相同结构,并在条目中填充可选指标:qp_score、dup_removed 等。 |
| """ |
|
|
| def __init__( |
| self, |
| enable_qp: bool = True, |
| enable_dp: bool = True, |
| enable_mr: bool = False, |
| qp_threshold: float = 0.15, |
| qp_top_k: Optional[int] = None, |
| dp_similarity_threshold: float = 0.92, |
| dp_max_candidates: Optional[int] = None, |
| mr_llm_config: Optional[LLMConfig] = None, |
| min_keep_count: Optional[int] = None, |
| ) -> None: |
| self.enable_qp = enable_qp |
| self.enable_dp = enable_dp |
| self.enable_mr = enable_mr |
| self.qp_threshold = qp_threshold |
| self.qp_top_k = qp_top_k |
| self.dp_similarity_threshold = dp_similarity_threshold |
| self.dp_max_candidates = dp_max_candidates |
| self.mr_llm_config = mr_llm_config |
| |
| self.min_keep_count = min_keep_count |
|
|
| |
| def _qp_score(self, problem: str, text: str) -> float: |
| |
| qv = _tf_vector(_tokenize(problem)) |
| tv = _tf_vector(_tokenize(text)) |
| return _cosine_sim(qv, tv) |
|
|
| def _quality_prune(self, problem: str, candidates: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
| if not self.enable_qp or len(candidates) <= 1: |
| return candidates |
| scored: List[Tuple[float, Dict[str, Any]]] = [] |
| for c in candidates: |
| s = self._qp_score(problem, c.get("text", "")) |
| c = dict(c) |
| c["qp_score"] = s |
| scored.append((s, c)) |
| scored.sort(key=lambda x: x[0], reverse=True) |
| if self.qp_top_k is not None and self.qp_top_k > 0: |
| scored = scored[: self.qp_top_k] |
| kept = [c for s, c in scored if s >= self.qp_threshold] |
| |
| if not kept: |
| kept = [scored[0][1]] |
| if self.min_keep_count and len(kept) < self.min_keep_count: |
| |
| existing_ids = set(id(obj) for obj in kept) |
| for _, c in scored: |
| if id(c) not in existing_ids: |
| kept.append(c) |
| if len(kept) >= self.min_keep_count: |
| break |
| return kept |
|
|
| |
| def _diversity_prune(self, candidates: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
| if not self.enable_dp or len(candidates) <= 1: |
| return candidates |
| vecs = [_tf_vector(_tokenize(c.get("text", ""))) for c in candidates] |
| kept: List[int] = [] |
| for i, v in enumerate(vecs): |
| diverse = True |
| for j in kept: |
| sim = _cosine_sim(v, vecs[j]) |
| if sim >= self.dp_similarity_threshold: |
| diverse = False |
| break |
| if diverse: |
| kept.append(i) |
| if self.dp_max_candidates and len(kept) >= self.dp_max_candidates: |
| break |
| pruned = [candidates[i] for i in kept] |
| |
| if self.min_keep_count and len(pruned) < self.min_keep_count: |
| |
| ranked = sorted( |
| range(len(candidates)), |
| key=lambda idx: float(candidates[idx].get("qp_score") or 0.0), |
| reverse=True, |
| ) |
| chosen = set(kept) |
| for idx in ranked: |
| if idx in chosen: |
| continue |
| pruned.append(candidates[idx]) |
| chosen.add(idx) |
| if len(pruned) >= self.min_keep_count: |
| break |
| return pruned |
|
|
| |
| def _build_critic(self) -> Optional[CustomizeAgent]: |
| if not self.mr_llm_config: |
| return None |
| prompt = ( |
| """ |
| You are a critical reviewer. Given a problem and a set of condensed candidate answers, identify common misunderstandings or mistakes, and propose a corrected consolidated answer. |
| |
| Problem: |
| {problem} |
| |
| Candidates: |
| {candidates_text} |
| |
| Return XML: |
| <response> |
| <issues>Common mistakes found</issues> |
| <rebuttal>How to fix them</rebuttal> |
| <corrected>Single corrected final answer</corrected> |
| </response> |
| """ |
| ).strip() |
| inputs = [ |
| {"name": "problem", "type": "str", "description": "Problem statement"}, |
| {"name": "candidates_text", "type": "str", "description": "Concatenated candidates"}, |
| ] |
| outputs = [ |
| {"name": "issues", "type": "str", "description": "Common mistakes", "required": True}, |
| {"name": "rebuttal", "type": "str", "description": "Corrections", "required": True}, |
| {"name": "corrected", "type": "str", "description": "Corrected final answer", "required": True}, |
| ] |
| return CustomizeAgent( |
| name="CriticAgent", |
| description="Detects misunderstandings and proposes corrected answer", |
| prompt=prompt, |
| llm_config=self.mr_llm_config, |
| inputs=inputs, |
| outputs=outputs, |
| parse_mode="xml", |
| ) |
|
|
| def _misunderstanding_rebuttal(self, problem: str, candidates: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], Optional[Dict[str, str]]]: |
| if not self.enable_mr: |
| return candidates, None |
| critic = self._build_critic() |
| if critic is None: |
| return candidates, None |
| concat = "\n\n".join(f"#{c.get('agent_id')}: {c.get('text','').strip()}" for c in candidates) |
| msg = critic(inputs={"problem": problem, "candidates_text": concat}) |
| st = msg.content.get_structured_data() |
| |
| for c in candidates: |
| c["mr_issues"] = st.get("issues", "") |
| c["mr_rebuttal"] = st.get("rebuttal", "") |
| suggested = { |
| "issues": st.get("issues", ""), |
| "rebuttal": st.get("rebuttal", ""), |
| "corrected": st.get("corrected", ""), |
| } |
| return candidates, suggested |
|
|
| |
| def apply(self, problem: str, candidates: List[Dict[str, Any]]) -> Dict[str, Any]: |
| """返回 {"candidates": pruned, "mr_suggested": optional}。""" |
| step1 = self._quality_prune(problem, candidates) |
| step2 = self._diversity_prune(step1) |
| step3, suggested = self._misunderstanding_rebuttal(problem, step2) |
| return {"candidates": step3, "mr_suggested": suggested} |
|
|
|
|
|
|