"""kill_switch.py — held-out collapse tripwire (the #2 collapse safeguard). This is the missing RUN-LEVEL / across-generation control for the self-evolving RL flywheel. The per-task controls already exist in ``composer_replication.datagen`` (the 4-gate solvability validator, the ``HackMonitor`` provenance check, and the sandbox denylist); this module sits ABOVE them and watches the whole run. Rationale (the literature is unambiguous that a held-out eval + hard stop is the load-bearing control, not a nice-to-have): - **Reward hacking rises monotonically with optimization depth.** Zhao et al., "Reward Hacking in Self-Improving Code Agents" (ICLR 2026 Workshop on RSI, OpenReview ``ikrQWGgxYg``) show that going from 10 -> 100 optimization steps drives the hacking rate from 26.4% to 57.8% (+31.4 points), and that 73.8% of KernelBench / 46.8% of ALE-Bench optimizations show *proxy gains without real gains*. They define **Hacking Gap = proxy gain - real gain**; this module's ``proxy_real_gap()`` is exactly that quantity. They label an optimization reward-hacking when it "improves the public metric WITHOUT improving the private metric" — the canonical signature this tripwire fires on. - **Self-critique alone is insufficient.** The same paper's "retrospection" self-critique sometimes *increased* hacking; their conclusion: "mitigating reward hacking likely requires stronger evaluations and constraints beyond self-critique alone." So we build a genuinely disjoint held-out eval plus a hard stop, not a critique hook. - **Held-out eval is necessary but NOT sufficient by itself.** EvilGenie (arXiv 2511.21654) found "only minimal improvement from the use of held out test cases" in isolation and that "holdout tests have many surprising failure modes." This module is therefore explicitly *defense-in-depth*, layered ON TOP of ``HackMonitor`` (provenance) — neither is sufficient alone, matching the repo's existing defense-in-depth framing in ``datagen/monitor.py``. - **Closed-loop RL on self-generated data collapses.** The self-evolving-agents survey (Gao et al., TMLR 2026; arXiv 2507.21046 v4) §8.3 names "model collapse from closed-loop RL on static synthetic data" and prescribes "continuous monitoring ... to detect long-horizon value drift" — i.e. a per-generation online tripwire, not a one-time eval. Shumailov et al. (Nature 2024, "AI models collapse when trained on recursively generated data") show self-training first loses the distribution tails, then converges to a low-variance point estimate; the mitigation that matters here is that the held-out eval must stay anchored to REAL tasks that are NEVER fed back to the generator (see ``HeldoutSplit``), otherwise the eval drifts with the train set. - **KL-to-init hard stop.** The GRPO "healthy progression" band (Orchestra Research GRPO SKILL) climbs 0.02 -> 0.05 -> 0.08 -> 0.12 nats/token over a run, with 0.08 the top of the "good progression" band and just below the code-generation drift zone (0.05-0.15 per-token); >0.5 is "diverging too much." So 0.08 nats/token is a sound HARD-STOP default. Catastrophic Goodhart (OpenReview ``UXuBzWoZGK``) proves KL regularization alone does NOT prevent heavy-tailed reward misspecification, so the KL hard stop is ONE tripwire among several, never the sole control. UNITS GOTCHA (load-bearing): the ``kl_to_init`` this module consumes is **token-mean KL in nats/token**, matching the repo convention in ``composer_replication.integrations.altered_minds.kl_logging.token_mean_kl``. A token-mean KL is NOT comparable to a sequence-level / sequence-summed KL (whose healthy band is ~0.05-10). The 0.08 default is per-token. Do not pass a sequence-summed KL into the per-token hard stop — it will fire instantly. This module is pure-Python: no torch, no cloud deps. ``kl_to_init`` is just a float the caller passes (computed upstream by ``token_mean_kl``). It is fully CPU-testable. """ from __future__ import annotations from dataclasses import dataclass, field class CollapseStopError(RuntimeError): """Raised (by the caller, optionally) when the tripwire fires a hard stop. The trainer loop can either check ``TripwireStatus.fire`` and stop softly, or call ``HeldOutGuard.raise_if_fired(status)`` to convert a fired verdict into this typed exception. Carries the structured verdict for logging. """ def __init__(self, status: TripwireStatus) -> None: super().__init__(status.reason) self.status = status @dataclass(frozen=True) class TripwireStatus: """Structured verdict returned by every ``HeldOutGuard.update(...)`` call. Attributes: fire: True => the run should HALT (collapse / reward-hacking detected). reason: human-readable WHY (empty string when ``fire`` is False), so the trainer can log exactly which tripwire tripped, mirroring how ``datagen/monitor.py`` logs suspected hacks for review. step: the round/generation index this verdict was computed at. proxy_real_gap: the RSI "Hacking Gap" at this step = (in-loop reward gain since baseline) - (held-out score gain since baseline). Positive and widening => proxy improving faster than (or while) real declines. in_loop_ema: EMA of the in-loop / proxy reward at this step. heldout_ema: EMA of the held-out / real eval score at this step. kl_ema: EMA of ``kl_to_init`` (nats/token), or None if never supplied. """ fire: bool reason: str step: int proxy_real_gap: float in_loop_ema: float heldout_ema: float kl_ema: float | None = None # `halt` is a documented alias for `fire` — the task spec describes a # `should_halt()` / verdict with a `halt` field; expose both names so callers # reading either convention work. @property def halt(self) -> bool: return self.fire @dataclass class HeldOutGuard: """Across-generation collapse / reward-hacking kill-switch (HeldOutGuard). Tracks, per generation/round: in-loop (proxy) oracle reward, held-out (real) eval score, and optional KL-to-init / entropy / reward-std. Computes the proxy-minus-real "Hacking Gap" tripwire and fires a structured ``halt`` verdict when collapse is caught in the act. The guard is **stateful**: call ``update(round_idx, ...)`` once per checkpoint in the trainer loop (the same cadence at which ``DifficultyCurriculum.update`` is called). It maintains denoised EMAs of every metric (raw single-step values are too noisy to threshold — theneuralbase early-stopping guidance) and returns a ``TripwireStatus``. Fires (``fire=True``) when ANY of: (a) **collapse-caught-in-the-act** — the in-loop reward EMA is RISING while the held-out score EMA has DECLINED for >= ``decline_patience`` consecutive checkpoints (default 3, matching the "monotone for >=3 checkpoints" rule). This is the canonical reward-hacking signature. (b) **KL breach** — the ``kl_to_init`` EMA exceeds ``kl_hard_stop`` (default 0.08 nats/token) on/after ``min_steps``. (c) **proxy-real gap blowout** — the Hacking Gap (proxy gain - real gain since baseline) widens beyond ``max_proxy_real_gap``, even if held-out has not strictly declined for the full patience window (a fast single-generation divergence). No tripwire fires before ``min_steps`` (avoids halting on early-run noise, when both signals are still warming up). The guard is idempotent in the sense that re-querying ``last_status`` or calling ``should_halt()`` does not advance state — only ``update`` does. """ # --- thresholds (calibratable; see calibrate_kl_threshold) --------------- kl_hard_stop: float = 0.08 # nats/token; top of GRPO "good" band max_proxy_real_gap: float = 0.10 # absolute Hacking-Gap blowout ceiling # --- temporal gates ------------------------------------------------------ min_steps: int = 20 # no fire before this many updates decline_patience: int = 3 # consecutive held-out declines to fire (a) # --- denoising ----------------------------------------------------------- ema_alpha: float = 0.9 # EMA weight on the PRIOR (0.9 => slow) rise_eps: float = 1e-4 # min EMA delta to count as "rising"/"declining" # --- internal state (do not set directly) -------------------------------- _n: int = field(default=0, init=False) _in_loop_ema: float | None = field(default=None, init=False) _heldout_ema: float | None = field(default=None, init=False) _kl_ema: float | None = field(default=None, init=False) _entropy_ema: float | None = field(default=None, init=False) _reward_std_ema: float | None = field(default=None, init=False) _in_loop_baseline: float | None = field(default=None, init=False) _heldout_baseline: float | None = field(default=None, init=False) _prev_in_loop_ema: float | None = field(default=None, init=False) _prev_heldout_ema: float | None = field(default=None, init=False) _heldout_decline_streak: int = field(default=0, init=False) _last_status: TripwireStatus | None = field(default=None, init=False) _fired: bool = field(default=False, init=False) def __post_init__(self) -> None: if not (0.0 <= self.ema_alpha < 1.0): raise ValueError( f"ema_alpha must be in [0, 1), got {self.ema_alpha!r} " "(it is the weight on the PRIOR EMA)." ) if self.kl_hard_stop <= 0.0: raise ValueError(f"kl_hard_stop must be > 0, got {self.kl_hard_stop!r}") if self.decline_patience < 1: raise ValueError( f"decline_patience must be >= 1, got {self.decline_patience!r}" ) # ------------------------------------------------------------------------ # core API # ------------------------------------------------------------------------ def update( self, round_idx: int, in_loop_reward: float, heldout_score: float, kl_to_init: float | None = None, entropy: float | None = None, reward_std: float | None = None, ) -> TripwireStatus: """Fold one checkpoint's metrics in and return the current verdict. Args: round_idx: the generation / round index (for logging; not used for gating — the internal update counter ``_n`` drives ``min_steps`` so the guard is robust to non-contiguous round indices). in_loop_reward: mean in-loop (proxy / oracle) reward this round. This is what the policy is optimizing against. heldout_score: mean score on the DISJOINT held-out eval pool this round — REAL tasks the generator never trains on. See ``composer_replication.safety.holdout`` design notes / the ``HeldoutSplit`` discipline; if held-out drifts with the train set the gap signal is meaningless. kl_to_init: optional token-mean KL(policy || init) in nats/token (this repo's ``token_mean_kl`` convention). NOT sequence-level KL. entropy: optional policy entropy (early-warning of entropy collapse, "the silent killer of RLVR generalization"). Tracked + exposed, not currently a hard gate. reward_std: optional std of the reward distribution (tracked; a collapsing std is an early collapse signal). Returns: A ``TripwireStatus``. Once the guard has fired, every subsequent ``update`` keeps ``fire=True`` (latched) so a transient recovery after a detected collapse cannot silently un-halt the run. """ self._n += 1 # --- EMA folds (alpha on the prior; first sample seeds the EMA) ------- self._in_loop_ema = self._fold(self._in_loop_ema, float(in_loop_reward)) self._heldout_ema = self._fold(self._heldout_ema, float(heldout_score)) if kl_to_init is not None: self._kl_ema = self._fold(self._kl_ema, float(kl_to_init)) if entropy is not None: self._entropy_ema = self._fold(self._entropy_ema, float(entropy)) if reward_std is not None: self._reward_std_ema = self._fold(self._reward_std_ema, float(reward_std)) # --- baselines: seed on the first update so gains are measured from # run start (the RSI Hacking-Gap is a gain-since-baseline quantity). - if self._in_loop_baseline is None: self._in_loop_baseline = self._in_loop_ema if self._heldout_baseline is None: self._heldout_baseline = self._heldout_ema # --- track the held-out decline streak (uses EMA deltas, denoised) ---- in_loop_rising = ( self._prev_in_loop_ema is not None and (self._in_loop_ema - self._prev_in_loop_ema) > self.rise_eps ) heldout_declining = ( self._prev_heldout_ema is not None and (self._heldout_ema - self._prev_heldout_ema) < -self.rise_eps ) # The collapse signature is held-out DOWN while in-loop UP. We only count # a decline toward the streak when in-loop is simultaneously rising — a # held-out dip during an in-loop dip is just noise / a hard batch, not # reward hacking. if heldout_declining and in_loop_rising: self._heldout_decline_streak += 1 elif not heldout_declining: self._heldout_decline_streak = 0 # (if held-out declines but in-loop is flat/down we neither grow nor reset # the streak immediately — but the elif above resets on any non-decline, # so a single clean checkpoint clears it.) gap = self.proxy_real_gap() status = self._evaluate(round_idx, gap) # advance "previous EMA" trackers AFTER evaluation self._prev_in_loop_ema = self._in_loop_ema self._prev_heldout_ema = self._heldout_ema self._last_status = status if status.fire: self._fired = True return status def _evaluate(self, round_idx: int, gap: float) -> TripwireStatus: """Decide the verdict from current state. Pure (no state mutation).""" assert self._in_loop_ema is not None and self._heldout_ema is not None base = dict( step=round_idx, proxy_real_gap=gap, in_loop_ema=self._in_loop_ema, heldout_ema=self._heldout_ema, kl_ema=self._kl_ema, ) # Latched: once fired, stay fired (cannot silently un-halt). if self._fired: prev_reason = self._last_status.reason if self._last_status else "collapse" return TripwireStatus(fire=True, reason=f"latched: {prev_reason}", **base) # Warm-up guard: never fire on early-run noise. if self._n < self.min_steps: return TripwireStatus(fire=False, reason="", **base) # (b) KL hard stop — checked first; it's the cheapest unambiguous breach. if self._kl_ema is not None and self._kl_ema > self.kl_hard_stop: return TripwireStatus( fire=True, reason=( f"kl_to_init EMA {self._kl_ema:.4f} nats/token exceeds hard " f"stop {self.kl_hard_stop:.4f} (policy drifting from init)" ), **base, ) # (a) collapse caught in the act — held-out declines while in-loop rises. if self._heldout_decline_streak >= self.decline_patience: return TripwireStatus( fire=True, reason=( f"reward-hacking signature: held-out score declined while " f"in-loop reward rose for {self._heldout_decline_streak} " f"consecutive checkpoints (Hacking Gap {gap:.4f})" ), **base, ) # (c) proxy-real gap blowout — fast single-generation divergence. if gap > self.max_proxy_real_gap: return TripwireStatus( fire=True, reason=( f"proxy-real Hacking Gap {gap:.4f} exceeds ceiling " f"{self.max_proxy_real_gap:.4f} (proxy reward improving far " f"faster than real held-out eval)" ), **base, ) return TripwireStatus(fire=False, reason="", **base) # ------------------------------------------------------------------------ # query helpers (do NOT advance state — idempotent) # ------------------------------------------------------------------------ def should_halt(self) -> bool: """True if the most recent ``update`` produced a halt verdict. Idempotent: querying does not advance the EMA state. """ return self._last_status is not None and self._last_status.fire @property def last_status(self) -> TripwireStatus | None: """The most recent verdict, or None if ``update`` was never called.""" return self._last_status def raise_if_fired(self, status: TripwireStatus | None = None) -> None: """Convert a fired verdict into a typed ``CollapseStopError`` exception. Pass the status returned by ``update`` (or omit to use ``last_status``). Trainer loops that prefer exception-based control flow call this right after ``update``; loops that prefer flag-checking just read ``status.fire`` / ``should_halt()``. """ st = status if status is not None else self._last_status if st is not None and st.fire: raise CollapseStopError(st) def proxy_real_gap(self) -> float: """The RSI Hacking Gap = (in-loop gain) - (held-out gain), both measured as EMA-minus-baseline since run start. Returns 0.0 before the first ``update`` (no baseline yet). A positive, widening value is the reward-hacking fingerprint: the proxy the policy optimizes is improving more than the real held-out objective. """ if ( self._in_loop_ema is None or self._heldout_ema is None or self._in_loop_baseline is None or self._heldout_baseline is None ): return 0.0 in_loop_gain = self._in_loop_ema - self._in_loop_baseline heldout_gain = self._heldout_ema - self._heldout_baseline return in_loop_gain - heldout_gain # ------------------------------------------------------------------------ # calibration # ------------------------------------------------------------------------ def calibrate_kl_threshold( self, baseline_kls: list[float], factor: float = 3.0 ) -> float: """Set ``kl_hard_stop`` to ``factor`` x the mean of early-run baseline KLs. theneuralbase guidance: "Record baseline KL during the first ~100 steps, set max to 3x that." Single fixed thresholds are dataset-dependent; this adapts to the run's own KL scale. SAFETY CLAMP: calibration may only ever TIGHTEN the hard stop, never loosen it past the documented collapse band. The returned (and stored) threshold is ``min(3x baseline, current kl_hard_stop)`` — so a noisy / already-drifting baseline cannot raise the ceiling above 0.08 nats/token. Args: baseline_kls: per-step token-mean KL values from early in the run. KL is non-negative by definition, so every value must be >= 0. factor: multiplier on the baseline mean. Must be > 0. Returns: The new ``kl_hard_stop`` (also stored on the instance), always > 0. Raises: ValueError: if ``baseline_kls`` is empty, ``factor <= 0``, or any baseline KL is negative. """ if not baseline_kls: raise ValueError("baseline_kls must be non-empty to calibrate") # GUARD (R4): a non-positive factor or a negative baseline would make # `calibrated` <= 0, and min(<=0, 0.08) = a NON-POSITIVE kl_hard_stop — # after which the KL tripwire fires on EVERY healthy step (any positive # KL EMA exceeds a non-positive ceiling). KL is non-negative by # definition, so these inputs are nonsensical; reject them loudly rather # than silently disarm-by-inverting the guard. if factor <= 0: raise ValueError(f"factor must be > 0, got {factor!r}") if any(k < 0 for k in baseline_kls): raise ValueError( f"baseline_kls must all be >= 0 (KL is non-negative); got a " f"negative value in {baseline_kls!r}" ) mean_kl = sum(baseline_kls) / len(baseline_kls) calibrated = factor * mean_kl # Only tighten: never let calibration loosen past the current ceiling. # Floor at a small positive epsilon so an all-zero baseline (mean_kl==0) # can't drive the ceiling to exactly 0 and fire on the first positive KL. self.kl_hard_stop = max(min(calibrated, self.kl_hard_stop), 1e-6) return self.kl_hard_stop # ------------------------------------------------------------------------ # internals # ------------------------------------------------------------------------ def _fold(self, prev: float | None, x: float) -> float: """EMA fold; the first observation seeds the EMA (no warm-up bias).""" if prev is None: return x return self.ema_alpha * prev + (1.0 - self.ema_alpha) * x def kl_token_trust_filter(logratio_sq_half: float, threshold: float = 0.08) -> bool: """Per-token KL trust-region mask, mirroring torchrl's GRPO "KL-Mask". torchrl masks any TOKEN whose ``0.5 * (log pi/pi_ref)^2`` (the Schulman k2 estimator of per-token KL) exceeds a threshold, forming a per-token trust region. This helper returns True when the token should be MASKED OUT (its KL contribution is too large), so it can be wired into a loss later without pulling torch into this module — the caller computes ``0.5 * logratio**2``. Args: logratio_sq_half: ``0.5 * (log pi/pi_ref)^2`` for one token (nats). threshold: per-token KL ceiling (default 0.08 nats, the same band as the run-level hard stop). Returns: True if the token exceeds the trust region and should be masked. """ return logratio_sq_half > threshold