Baladithya Balamurugan
Wave 3: close the HIGH review findings (kill-switch wiring, HeldoutSplit, EKS entrypoint bug)
bd0c358
Raw
History Blame Contribute Delete
22.9 kB
"""kill_switch.py — held-out collapse tripwire (the #2 collapse safeguard).
This is the missing RUN-LEVEL / across-generation control for the self-evolving
RL flywheel. The per-task controls already exist in ``composer_replication.datagen``
(the 4-gate solvability validator, the ``HackMonitor`` provenance check, and the
sandbox denylist); this module sits ABOVE them and watches the whole run.
Rationale (the literature is unambiguous that a held-out eval + hard stop is the
load-bearing control, not a nice-to-have):
- **Reward hacking rises monotonically with optimization depth.** Zhao et al.,
"Reward Hacking in Self-Improving Code Agents" (ICLR 2026 Workshop on RSI,
OpenReview ``ikrQWGgxYg``) show that going from 10 -> 100 optimization steps
drives the hacking rate from 26.4% to 57.8% (+31.4 points), and that
73.8% of KernelBench / 46.8% of ALE-Bench optimizations show *proxy gains
without real gains*. They define **Hacking Gap = proxy gain - real gain**;
this module's ``proxy_real_gap()`` is exactly that quantity. They label an
optimization reward-hacking when it "improves the public metric WITHOUT
improving the private metric" — the canonical signature this tripwire fires on.
- **Self-critique alone is insufficient.** The same paper's "retrospection"
self-critique sometimes *increased* hacking; their conclusion: "mitigating
reward hacking likely requires stronger evaluations and constraints beyond
self-critique alone." So we build a genuinely disjoint held-out eval plus a
hard stop, not a critique hook.
- **Held-out eval is necessary but NOT sufficient by itself.** EvilGenie
(arXiv 2511.21654) found "only minimal improvement from the use of held out
test cases" in isolation and that "holdout tests have many surprising failure
modes." This module is therefore explicitly *defense-in-depth*, layered ON
TOP of ``HackMonitor`` (provenance) — neither is sufficient alone, matching
the repo's existing defense-in-depth framing in ``datagen/monitor.py``.
- **Closed-loop RL on self-generated data collapses.** The self-evolving-agents
survey (Gao et al., TMLR 2026; arXiv 2507.21046 v4) §8.3 names "model
collapse from closed-loop RL on static synthetic data" and prescribes
"continuous monitoring ... to detect long-horizon value drift" — i.e. a
per-generation online tripwire, not a one-time eval. Shumailov et al. (Nature
2024, "AI models collapse when trained on recursively generated data") show
self-training first loses the distribution tails, then converges to a
low-variance point estimate; the mitigation that matters here is that the
held-out eval must stay anchored to REAL tasks that are NEVER fed back to the
generator (see ``HeldoutSplit``), otherwise the eval drifts with the train set.
- **KL-to-init hard stop.** The GRPO "healthy progression" band (Orchestra
Research GRPO SKILL) climbs 0.02 -> 0.05 -> 0.08 -> 0.12 nats/token over a
run, with 0.08 the top of the "good progression" band and just below the
code-generation drift zone (0.05-0.15 per-token); >0.5 is "diverging too
much." So 0.08 nats/token is a sound HARD-STOP default. Catastrophic Goodhart
(OpenReview ``UXuBzWoZGK``) proves KL regularization alone does NOT prevent
heavy-tailed reward misspecification, so the KL hard stop is ONE tripwire
among several, never the sole control.
UNITS GOTCHA (load-bearing): the ``kl_to_init`` this module consumes is
**token-mean KL in nats/token**, matching the repo convention in
``composer_replication.integrations.altered_minds.kl_logging.token_mean_kl``.
A token-mean KL is NOT comparable to a sequence-level / sequence-summed KL
(whose healthy band is ~0.05-10). The 0.08 default is per-token. Do not pass a
sequence-summed KL into the per-token hard stop — it will fire instantly.
This module is pure-Python: no torch, no cloud deps. ``kl_to_init`` is just a
float the caller passes (computed upstream by ``token_mean_kl``). It is fully
CPU-testable.
"""
from __future__ import annotations
from dataclasses import dataclass, field
class CollapseStopError(RuntimeError):
"""Raised (by the caller, optionally) when the tripwire fires a hard stop.
The trainer loop can either check ``TripwireStatus.fire`` and stop softly,
or call ``HeldOutGuard.raise_if_fired(status)`` to convert a fired verdict
into this typed exception. Carries the structured verdict for logging.
"""
def __init__(self, status: TripwireStatus) -> None:
super().__init__(status.reason)
self.status = status
@dataclass(frozen=True)
class TripwireStatus:
"""Structured verdict returned by every ``HeldOutGuard.update(...)`` call.
Attributes:
fire: True => the run should HALT (collapse / reward-hacking detected).
reason: human-readable WHY (empty string when ``fire`` is False), so the
trainer can log exactly which tripwire tripped, mirroring how
``datagen/monitor.py`` logs suspected hacks for review.
step: the round/generation index this verdict was computed at.
proxy_real_gap: the RSI "Hacking Gap" at this step = (in-loop reward gain
since baseline) - (held-out score gain since baseline). Positive and
widening => proxy improving faster than (or while) real declines.
in_loop_ema: EMA of the in-loop / proxy reward at this step.
heldout_ema: EMA of the held-out / real eval score at this step.
kl_ema: EMA of ``kl_to_init`` (nats/token), or None if never supplied.
"""
fire: bool
reason: str
step: int
proxy_real_gap: float
in_loop_ema: float
heldout_ema: float
kl_ema: float | None = None
# `halt` is a documented alias for `fire` — the task spec describes a
# `should_halt()` / verdict with a `halt` field; expose both names so callers
# reading either convention work.
@property
def halt(self) -> bool:
return self.fire
@dataclass
class HeldOutGuard:
"""Across-generation collapse / reward-hacking kill-switch (HeldOutGuard).
Tracks, per generation/round: in-loop (proxy) oracle reward, held-out (real)
eval score, and optional KL-to-init / entropy / reward-std. Computes the
proxy-minus-real "Hacking Gap" tripwire and fires a structured ``halt``
verdict when collapse is caught in the act.
The guard is **stateful**: call ``update(round_idx, ...)`` once per checkpoint
in the trainer loop (the same cadence at which ``DifficultyCurriculum.update``
is called). It maintains denoised EMAs of every metric (raw single-step
values are too noisy to threshold — theneuralbase early-stopping guidance) and
returns a ``TripwireStatus``.
Fires (``fire=True``) when ANY of:
(a) **collapse-caught-in-the-act** — the in-loop reward EMA is RISING while
the held-out score EMA has DECLINED for >= ``decline_patience``
consecutive checkpoints (default 3, matching the "monotone for >=3
checkpoints" rule). This is the canonical reward-hacking signature.
(b) **KL breach** — the ``kl_to_init`` EMA exceeds ``kl_hard_stop`` (default
0.08 nats/token) on/after ``min_steps``.
(c) **proxy-real gap blowout** — the Hacking Gap (proxy gain - real gain
since baseline) widens beyond ``max_proxy_real_gap``, even if held-out
has not strictly declined for the full patience window (a fast
single-generation divergence).
No tripwire fires before ``min_steps`` (avoids halting on early-run noise,
when both signals are still warming up).
The guard is idempotent in the sense that re-querying ``last_status`` or
calling ``should_halt()`` does not advance state — only ``update`` does.
"""
# --- thresholds (calibratable; see calibrate_kl_threshold) ---------------
kl_hard_stop: float = 0.08 # nats/token; top of GRPO "good" band
max_proxy_real_gap: float = 0.10 # absolute Hacking-Gap blowout ceiling
# --- temporal gates ------------------------------------------------------
min_steps: int = 20 # no fire before this many updates
decline_patience: int = 3 # consecutive held-out declines to fire (a)
# --- denoising -----------------------------------------------------------
ema_alpha: float = 0.9 # EMA weight on the PRIOR (0.9 => slow)
rise_eps: float = 1e-4 # min EMA delta to count as "rising"/"declining"
# --- internal state (do not set directly) --------------------------------
_n: int = field(default=0, init=False)
_in_loop_ema: float | None = field(default=None, init=False)
_heldout_ema: float | None = field(default=None, init=False)
_kl_ema: float | None = field(default=None, init=False)
_entropy_ema: float | None = field(default=None, init=False)
_reward_std_ema: float | None = field(default=None, init=False)
_in_loop_baseline: float | None = field(default=None, init=False)
_heldout_baseline: float | None = field(default=None, init=False)
_prev_in_loop_ema: float | None = field(default=None, init=False)
_prev_heldout_ema: float | None = field(default=None, init=False)
_heldout_decline_streak: int = field(default=0, init=False)
_last_status: TripwireStatus | None = field(default=None, init=False)
_fired: bool = field(default=False, init=False)
def __post_init__(self) -> None:
if not (0.0 <= self.ema_alpha < 1.0):
raise ValueError(
f"ema_alpha must be in [0, 1), got {self.ema_alpha!r} "
"(it is the weight on the PRIOR EMA)."
)
if self.kl_hard_stop <= 0.0:
raise ValueError(f"kl_hard_stop must be > 0, got {self.kl_hard_stop!r}")
if self.decline_patience < 1:
raise ValueError(
f"decline_patience must be >= 1, got {self.decline_patience!r}"
)
# ------------------------------------------------------------------------
# core API
# ------------------------------------------------------------------------
def update(
self,
round_idx: int,
in_loop_reward: float,
heldout_score: float,
kl_to_init: float | None = None,
entropy: float | None = None,
reward_std: float | None = None,
) -> TripwireStatus:
"""Fold one checkpoint's metrics in and return the current verdict.
Args:
round_idx: the generation / round index (for logging; not used for
gating — the internal update counter ``_n`` drives ``min_steps``
so the guard is robust to non-contiguous round indices).
in_loop_reward: mean in-loop (proxy / oracle) reward this round. This
is what the policy is optimizing against.
heldout_score: mean score on the DISJOINT held-out eval pool this
round — REAL tasks the generator never trains on. See
``composer_replication.safety.holdout`` design notes / the
``HeldoutSplit`` discipline; if held-out drifts with the train
set the gap signal is meaningless.
kl_to_init: optional token-mean KL(policy || init) in nats/token
(this repo's ``token_mean_kl`` convention). NOT sequence-level KL.
entropy: optional policy entropy (early-warning of entropy collapse,
"the silent killer of RLVR generalization"). Tracked + exposed,
not currently a hard gate.
reward_std: optional std of the reward distribution (tracked; a
collapsing std is an early collapse signal).
Returns:
A ``TripwireStatus``. Once the guard has fired, every subsequent
``update`` keeps ``fire=True`` (latched) so a transient recovery
after a detected collapse cannot silently un-halt the run.
"""
self._n += 1
# --- EMA folds (alpha on the prior; first sample seeds the EMA) -------
self._in_loop_ema = self._fold(self._in_loop_ema, float(in_loop_reward))
self._heldout_ema = self._fold(self._heldout_ema, float(heldout_score))
if kl_to_init is not None:
self._kl_ema = self._fold(self._kl_ema, float(kl_to_init))
if entropy is not None:
self._entropy_ema = self._fold(self._entropy_ema, float(entropy))
if reward_std is not None:
self._reward_std_ema = self._fold(self._reward_std_ema, float(reward_std))
# --- baselines: seed on the first update so gains are measured from
# run start (the RSI Hacking-Gap is a gain-since-baseline quantity). -
if self._in_loop_baseline is None:
self._in_loop_baseline = self._in_loop_ema
if self._heldout_baseline is None:
self._heldout_baseline = self._heldout_ema
# --- track the held-out decline streak (uses EMA deltas, denoised) ----
in_loop_rising = (
self._prev_in_loop_ema is not None
and (self._in_loop_ema - self._prev_in_loop_ema) > self.rise_eps
)
heldout_declining = (
self._prev_heldout_ema is not None
and (self._heldout_ema - self._prev_heldout_ema) < -self.rise_eps
)
# The collapse signature is held-out DOWN while in-loop UP. We only count
# a decline toward the streak when in-loop is simultaneously rising — a
# held-out dip during an in-loop dip is just noise / a hard batch, not
# reward hacking.
if heldout_declining and in_loop_rising:
self._heldout_decline_streak += 1
elif not heldout_declining:
self._heldout_decline_streak = 0
# (if held-out declines but in-loop is flat/down we neither grow nor reset
# the streak immediately — but the elif above resets on any non-decline,
# so a single clean checkpoint clears it.)
gap = self.proxy_real_gap()
status = self._evaluate(round_idx, gap)
# advance "previous EMA" trackers AFTER evaluation
self._prev_in_loop_ema = self._in_loop_ema
self._prev_heldout_ema = self._heldout_ema
self._last_status = status
if status.fire:
self._fired = True
return status
def _evaluate(self, round_idx: int, gap: float) -> TripwireStatus:
"""Decide the verdict from current state. Pure (no state mutation)."""
assert self._in_loop_ema is not None and self._heldout_ema is not None
base = dict(
step=round_idx,
proxy_real_gap=gap,
in_loop_ema=self._in_loop_ema,
heldout_ema=self._heldout_ema,
kl_ema=self._kl_ema,
)
# Latched: once fired, stay fired (cannot silently un-halt).
if self._fired:
prev_reason = self._last_status.reason if self._last_status else "collapse"
return TripwireStatus(fire=True, reason=f"latched: {prev_reason}", **base)
# Warm-up guard: never fire on early-run noise.
if self._n < self.min_steps:
return TripwireStatus(fire=False, reason="", **base)
# (b) KL hard stop — checked first; it's the cheapest unambiguous breach.
if self._kl_ema is not None and self._kl_ema > self.kl_hard_stop:
return TripwireStatus(
fire=True,
reason=(
f"kl_to_init EMA {self._kl_ema:.4f} nats/token exceeds hard "
f"stop {self.kl_hard_stop:.4f} (policy drifting from init)"
),
**base,
)
# (a) collapse caught in the act — held-out declines while in-loop rises.
if self._heldout_decline_streak >= self.decline_patience:
return TripwireStatus(
fire=True,
reason=(
f"reward-hacking signature: held-out score declined while "
f"in-loop reward rose for {self._heldout_decline_streak} "
f"consecutive checkpoints (Hacking Gap {gap:.4f})"
),
**base,
)
# (c) proxy-real gap blowout — fast single-generation divergence.
if gap > self.max_proxy_real_gap:
return TripwireStatus(
fire=True,
reason=(
f"proxy-real Hacking Gap {gap:.4f} exceeds ceiling "
f"{self.max_proxy_real_gap:.4f} (proxy reward improving far "
f"faster than real held-out eval)"
),
**base,
)
return TripwireStatus(fire=False, reason="", **base)
# ------------------------------------------------------------------------
# query helpers (do NOT advance state — idempotent)
# ------------------------------------------------------------------------
def should_halt(self) -> bool:
"""True if the most recent ``update`` produced a halt verdict.
Idempotent: querying does not advance the EMA state.
"""
return self._last_status is not None and self._last_status.fire
@property
def last_status(self) -> TripwireStatus | None:
"""The most recent verdict, or None if ``update`` was never called."""
return self._last_status
def raise_if_fired(self, status: TripwireStatus | None = None) -> None:
"""Convert a fired verdict into a typed ``CollapseStopError`` exception.
Pass the status returned by ``update`` (or omit to use ``last_status``).
Trainer loops that prefer exception-based control flow call this right
after ``update``; loops that prefer flag-checking just read
``status.fire`` / ``should_halt()``.
"""
st = status if status is not None else self._last_status
if st is not None and st.fire:
raise CollapseStopError(st)
def proxy_real_gap(self) -> float:
"""The RSI Hacking Gap = (in-loop gain) - (held-out gain), both measured
as EMA-minus-baseline since run start.
Returns 0.0 before the first ``update`` (no baseline yet). A positive,
widening value is the reward-hacking fingerprint: the proxy the policy
optimizes is improving more than the real held-out objective.
"""
if (
self._in_loop_ema is None
or self._heldout_ema is None
or self._in_loop_baseline is None
or self._heldout_baseline is None
):
return 0.0
in_loop_gain = self._in_loop_ema - self._in_loop_baseline
heldout_gain = self._heldout_ema - self._heldout_baseline
return in_loop_gain - heldout_gain
# ------------------------------------------------------------------------
# calibration
# ------------------------------------------------------------------------
def calibrate_kl_threshold(
self, baseline_kls: list[float], factor: float = 3.0
) -> float:
"""Set ``kl_hard_stop`` to ``factor`` x the mean of early-run baseline KLs.
theneuralbase guidance: "Record baseline KL during the first ~100 steps,
set max to 3x that." Single fixed thresholds are dataset-dependent; this
adapts to the run's own KL scale.
SAFETY CLAMP: calibration may only ever TIGHTEN the hard stop, never
loosen it past the documented collapse band. The returned (and stored)
threshold is ``min(3x baseline, current kl_hard_stop)`` — so a noisy /
already-drifting baseline cannot raise the ceiling above 0.08 nats/token.
Args:
baseline_kls: per-step token-mean KL values from early in the run.
KL is non-negative by definition, so every value must be >= 0.
factor: multiplier on the baseline mean. Must be > 0.
Returns:
The new ``kl_hard_stop`` (also stored on the instance), always > 0.
Raises:
ValueError: if ``baseline_kls`` is empty, ``factor <= 0``, or any
baseline KL is negative.
"""
if not baseline_kls:
raise ValueError("baseline_kls must be non-empty to calibrate")
# GUARD (R4): a non-positive factor or a negative baseline would make
# `calibrated` <= 0, and min(<=0, 0.08) = a NON-POSITIVE kl_hard_stop —
# after which the KL tripwire fires on EVERY healthy step (any positive
# KL EMA exceeds a non-positive ceiling). KL is non-negative by
# definition, so these inputs are nonsensical; reject them loudly rather
# than silently disarm-by-inverting the guard.
if factor <= 0:
raise ValueError(f"factor must be > 0, got {factor!r}")
if any(k < 0 for k in baseline_kls):
raise ValueError(
f"baseline_kls must all be >= 0 (KL is non-negative); got a "
f"negative value in {baseline_kls!r}"
)
mean_kl = sum(baseline_kls) / len(baseline_kls)
calibrated = factor * mean_kl
# Only tighten: never let calibration loosen past the current ceiling.
# Floor at a small positive epsilon so an all-zero baseline (mean_kl==0)
# can't drive the ceiling to exactly 0 and fire on the first positive KL.
self.kl_hard_stop = max(min(calibrated, self.kl_hard_stop), 1e-6)
return self.kl_hard_stop
# ------------------------------------------------------------------------
# internals
# ------------------------------------------------------------------------
def _fold(self, prev: float | None, x: float) -> float:
"""EMA fold; the first observation seeds the EMA (no warm-up bias)."""
if prev is None:
return x
return self.ema_alpha * prev + (1.0 - self.ema_alpha) * x
def kl_token_trust_filter(logratio_sq_half: float, threshold: float = 0.08) -> bool:
"""Per-token KL trust-region mask, mirroring torchrl's GRPO "KL-Mask".
torchrl masks any TOKEN whose ``0.5 * (log pi/pi_ref)^2`` (the Schulman k2
estimator of per-token KL) exceeds a threshold, forming a per-token trust
region. This helper returns True when the token should be MASKED OUT (its
KL contribution is too large), so it can be wired into a loss later without
pulling torch into this module — the caller computes ``0.5 * logratio**2``.
Args:
logratio_sq_half: ``0.5 * (log pi/pi_ref)^2`` for one token (nats).
threshold: per-token KL ceiling (default 0.08 nats, the same band as the
run-level hard stop).
Returns:
True if the token exceeds the trust region and should be masked.
"""
return logratio_sq_half > threshold