Baladithya Balamurugan
Wave 3: close the HIGH review findings (kill-switch wiring, HeldoutSplit, EKS entrypoint bug)
bd0c358
Raw
History Blame Contribute Delete
15.6 kB
"""Tests for the held-out collapse kill-switch (HeldOutGuard).
CPU-only, pure-Python — no torch, no cloud. Mirrors the
``datagen/tests/test_feature_deletion.py`` style (small helpers, behavioral
asserts). Covers:
- no-halt on a healthy co-rising run (the held-out-twin "within noise" case);
- HALT on the canonical signature: held-out declines while in-loop rises;
- HALT on KL-to-init hard-stop breach;
- HALT on a fast proxy-real Hacking-Gap blowout;
- window / patience behavior (min_steps warm-up; decline_patience streak);
- calibration tightens-only;
- idempotent query + latched-fire edge cases.
"""
from __future__ import annotations
import pytest
from composer_replication.safety import (
CollapseStopError,
HeldOutGuard,
TripwireStatus,
kl_token_trust_filter,
)
def _guard(**kw) -> HeldOutGuard:
# Small min_steps keeps tests fast while still exercising the warm-up gate.
base = dict(min_steps=3, decline_patience=3, ema_alpha=0.5, kl_hard_stop=0.08)
base.update(kw)
return HeldOutGuard(**base)
# --- healthy run: both rise => never halt -----------------------------------
def test_no_halt_when_both_rise():
"""Clean run: in-loop and held-out rise together, KL stays in band. The
held-out twin scores within noise of the proxy => no fire (the well-behaved
case the literature says a clean model exhibits)."""
g = _guard()
status = None
for i in range(30):
status = g.update(
i,
in_loop_reward=0.30 + 0.01 * i,
heldout_score=0.28 + 0.01 * i, # tracks proxy within noise
kl_to_init=0.03,
)
assert not status.fire, f"fired unexpectedly at step {i}: {status.reason}"
assert not g.should_halt()
# Gap stays near zero because both gained equally.
assert abs(g.proxy_real_gap()) < 0.05
# --- canonical signature: held-out declines while in-loop rises -------------
def test_halt_on_heldout_declines_while_reward_rises():
g = _guard(max_proxy_real_gap=10.0) # disable gap-blowout path to isolate (a)
# Warm up past min_steps with a stable healthy stretch.
for i in range(6):
s = g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.03)
assert not s.fire
# Now: proxy reward climbs, held-out eval falls — the reward-hacking
# fingerprint. Should fire once the decline streak hits decline_patience (3).
fired_at = None
for j, i in enumerate(range(6, 12)):
s = g.update(
i,
in_loop_reward=0.40 + 0.05 * (j + 1), # rising
heldout_score=0.40 - 0.05 * (j + 1), # declining
kl_to_init=0.03, # KL stays in band
)
if s.fire:
fired_at = i
break
assert fired_at is not None, "tripwire never fired on the collapse signature"
assert g.should_halt()
s = g.last_status
assert "held-out" in s.reason and "consecutive" in s.reason
assert s.proxy_real_gap > 0.0 # proxy gained while real lost
def test_does_not_fire_before_patience_window():
"""Held-out declining while in-loop rises for FEWER than decline_patience
checkpoints must NOT fire (window behavior)."""
g = _guard(decline_patience=3, max_proxy_real_gap=10.0)
for i in range(6):
g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.03)
# Only 2 divergent checkpoints (< patience of 3) => no fire.
s1 = g.update(6, in_loop_reward=0.45, heldout_score=0.35, kl_to_init=0.03)
s2 = g.update(7, in_loop_reward=0.50, heldout_score=0.30, kl_to_init=0.03)
assert not s1.fire and not s2.fire
def test_decline_streak_resets_on_recovery():
"""A clean checkpoint (held-out recovers) resets the decline streak, so a
later short divergence does not inherit prior declines."""
g = _guard(decline_patience=3, max_proxy_real_gap=10.0)
for i in range(6):
g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.03)
# 2 declines...
g.update(6, in_loop_reward=0.45, heldout_score=0.35, kl_to_init=0.03)
g.update(7, in_loop_reward=0.50, heldout_score=0.30, kl_to_init=0.03)
# ...then held-out recovers (resets streak)...
s = g.update(8, in_loop_reward=0.50, heldout_score=0.45, kl_to_init=0.03)
assert not s.fire
# ...one more decline is only streak=1, still below patience.
s = g.update(9, in_loop_reward=0.55, heldout_score=0.40, kl_to_init=0.03)
assert not s.fire
# --- KL hard-stop ------------------------------------------------------------
def test_halt_on_kl_hard_stop_breach():
g = _guard(kl_hard_stop=0.08, max_proxy_real_gap=10.0)
# Healthy KL through the warm-up; both metrics flat so only KL can fire.
for i in range(5):
s = g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.04)
assert not s.fire
# KL spikes well above 0.08; EMA climbs across a couple steps then crosses.
fired = False
for i in range(5, 12):
s = g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.20)
if s.fire:
fired = True
assert "kl_to_init" in s.reason and "hard stop" in s.reason
break
assert fired, "KL hard-stop never fired despite KL EMA crossing the ceiling"
def test_kl_none_never_fires_kl_path():
"""If the caller never supplies kl_to_init, the KL path must be inert (and
kl_ema stays None) — KL is an optional float."""
g = _guard(max_proxy_real_gap=10.0)
s = None
for i in range(20):
s = g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=None)
assert s is not None and not s.fire
assert s.kl_ema is None
# --- proxy-real gap blowout (fast divergence) -------------------------------
def test_halt_on_proxy_real_gap_blowout():
"""A large single-generation divergence (proxy jumps, real stays flat) fires
via the gap-blowout path even before the decline streak reaches patience."""
g = _guard(max_proxy_real_gap=0.10, decline_patience=100) # disable (a)
for i in range(5):
g.update(i, in_loop_reward=0.30, heldout_score=0.30, kl_to_init=0.03)
# Proxy blows up; held-out flat. With ema_alpha=0.5 the gap crosses 0.10 fast.
fired = False
for i in range(5, 12):
s = g.update(i, in_loop_reward=0.90, heldout_score=0.30, kl_to_init=0.03)
if s.fire:
fired = True
assert "Hacking Gap" in s.reason
assert s.proxy_real_gap > 0.10
break
assert fired, "gap-blowout tripwire never fired"
# --- warm-up window (min_steps) ---------------------------------------------
def test_respects_min_steps_no_early_fire():
"""Even with every signal tripped, no fire before min_steps (avoids halting
on early-run noise)."""
g = _guard(min_steps=10, decline_patience=2, kl_hard_stop=0.08,
max_proxy_real_gap=0.01)
# Egregiously bad signals from step 0: KL huge, proxy up, held-out down.
for i in range(9): # 9 updates, all < min_steps=10
s = g.update(i, in_loop_reward=0.10 + 0.1 * i, heldout_score=0.90 - 0.1 * i,
kl_to_init=0.9)
assert not s.fire, f"fired during warm-up at step {i}: {s.reason}"
# The 10th update (n==10, not < min_steps) is now allowed to fire.
s = g.update(9, in_loop_reward=1.5, heldout_score=0.0, kl_to_init=0.9)
assert s.fire
# --- calibration -------------------------------------------------------------
def test_calibrate_kl_threshold_tightens_only():
g = _guard(kl_hard_stop=0.08)
# Baseline mean 0.01 => 3x = 0.03 < 0.08 => tightens to 0.03.
new = g.calibrate_kl_threshold([0.008, 0.010, 0.012], factor=3.0)
assert new == pytest.approx(0.03, abs=1e-9)
assert g.kl_hard_stop == pytest.approx(0.03, abs=1e-9)
def test_calibrate_never_loosens_past_band():
g = _guard(kl_hard_stop=0.08)
# A drifting baseline (mean 0.05 => 3x = 0.15) must NOT loosen past 0.08.
new = g.calibrate_kl_threshold([0.05, 0.05, 0.05], factor=3.0)
assert new == pytest.approx(0.08, abs=1e-9)
assert g.kl_hard_stop == pytest.approx(0.08, abs=1e-9)
def test_calibrate_empty_raises():
g = _guard()
with pytest.raises(ValueError, match="non-empty"):
g.calibrate_kl_threshold([])
# --- proxy_real_gap definition ----------------------------------------------
def test_proxy_real_gap_is_gain_difference():
g = _guard(min_steps=100, max_proxy_real_gap=10.0) # disable firing
g.update(0, in_loop_reward=0.20, heldout_score=0.20, kl_to_init=0.02) # baseline
# With ema_alpha=0.5 the second sample moves each EMA halfway.
g.update(1, in_loop_reward=0.60, heldout_score=0.30, kl_to_init=0.02)
# in_loop EMA: 0.5*0.20 + 0.5*0.60 = 0.40; gain = 0.40-0.20 = 0.20
# heldout EMA: 0.5*0.20 + 0.5*0.30 = 0.25; gain = 0.25-0.20 = 0.05
# gap = 0.20 - 0.05 = 0.15
assert g.proxy_real_gap() == pytest.approx(0.15, abs=1e-9)
def test_proxy_real_gap_zero_before_update():
g = _guard()
assert g.proxy_real_gap() == 0.0
# --- idempotency / edge cases -----------------------------------------------
def test_should_halt_is_idempotent_query():
g = _guard(max_proxy_real_gap=10.0)
for i in range(6):
g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.03)
# Querying repeatedly must not advance state or change the verdict.
snap_gap = g.proxy_real_gap()
assert g.should_halt() is False
assert g.should_halt() is False
assert g.proxy_real_gap() == snap_gap # unchanged by querying
assert g.last_status is not None and not g.last_status.fire
def test_fire_is_latched():
"""Once fired, a subsequent recovery cannot silently un-halt the run."""
g = _guard(kl_hard_stop=0.08, max_proxy_real_gap=10.0)
for i in range(5):
g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.04)
# Drive a KL breach.
fired = False
for i in range(5, 12):
s = g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.30)
if s.fire:
fired = True
break
assert fired
# Now KL recovers to healthy — verdict must stay fired (latched).
s = g.update(99, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.01)
assert s.fire and s.reason.startswith("latched:")
assert g.should_halt()
def test_raise_if_fired_raises_typed_exception():
g = _guard(kl_hard_stop=0.08, max_proxy_real_gap=10.0)
for i in range(5):
g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.04)
status = None
for i in range(5, 12):
status = g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.30)
if status.fire:
break
assert status is not None and status.fire
with pytest.raises(CollapseStopError) as exc:
g.raise_if_fired(status)
assert exc.value.status is status
assert isinstance(str(exc.value), str) and str(exc.value)
def test_raise_if_fired_noop_when_clean():
g = _guard(max_proxy_real_gap=10.0)
s = g.update(0, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.03)
# No fire => no raise (uses last_status when arg omitted).
g.raise_if_fired(s)
g.raise_if_fired()
def test_status_halt_alias_matches_fire():
g = _guard(max_proxy_real_gap=10.0)
s = g.update(0, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.03)
assert s.halt == s.fire is False
assert isinstance(s, TripwireStatus)
def test_non_contiguous_round_idx_uses_internal_counter():
"""min_steps gates on the internal update counter, not round_idx, so a caller
that logs sparse / non-contiguous round indices still warms up correctly."""
g = _guard(min_steps=3, max_proxy_real_gap=0.01, decline_patience=1)
# Pass huge round_idx values; only the 3rd UPDATE clears warm-up.
g.update(1000, in_loop_reward=0.10, heldout_score=0.90, kl_to_init=0.9)
g.update(2000, in_loop_reward=0.50, heldout_score=0.50, kl_to_init=0.9)
s = g.update(3000, in_loop_reward=0.90, heldout_score=0.10, kl_to_init=0.9)
assert s.fire # 3rd update, n==3 not < min_steps
# --- config validation -------------------------------------------------------
def test_bad_ema_alpha_rejected():
with pytest.raises(ValueError, match="ema_alpha"):
HeldOutGuard(ema_alpha=1.0)
with pytest.raises(ValueError, match="ema_alpha"):
HeldOutGuard(ema_alpha=-0.1)
def test_bad_kl_hard_stop_rejected():
with pytest.raises(ValueError, match="kl_hard_stop"):
HeldOutGuard(kl_hard_stop=0.0)
def test_bad_decline_patience_rejected():
with pytest.raises(ValueError, match="decline_patience"):
HeldOutGuard(decline_patience=0)
# --- kl_token_trust_filter helper -------------------------------------------
def test_kl_token_trust_filter_masks_above_threshold():
# 0.5 * logratio^2; mask when it exceeds the per-token KL ceiling.
assert kl_token_trust_filter(0.20, threshold=0.08) is True # too large -> mask
assert kl_token_trust_filter(0.05, threshold=0.08) is False # within trust region
assert kl_token_trust_filter(0.08, threshold=0.08) is False # boundary, not masked
# --- R4: calibrate_kl_threshold input guards (negative factor / baseline) -----
def test_calibrate_rejects_nonpositive_factor():
"""R4: a factor<=0 would make calibrated<=0 and min(<=0, 0.08)<=0, after
which the KL tripwire fires on every healthy step. Reject it loudly."""
g = _guard()
with pytest.raises(ValueError, match="factor must be > 0"):
g.calibrate_kl_threshold([0.01, 0.02], factor=-3.0)
with pytest.raises(ValueError, match="factor must be > 0"):
g.calibrate_kl_threshold([0.01, 0.02], factor=0.0)
def test_calibrate_rejects_negative_baseline_kl():
"""R4: KL is non-negative by definition; a negative baseline is nonsensical
and could invert the ceiling. Reject it."""
g = _guard()
with pytest.raises(ValueError, match="non-negative"):
g.calibrate_kl_threshold([0.01, -0.5, 0.02])
def test_calibrate_never_yields_nonpositive_threshold():
"""R4: even an all-zero baseline (mean 0) must leave a positive ceiling so a
later positive KL doesn't fire spuriously."""
g = _guard()
out = g.calibrate_kl_threshold([0.0, 0.0, 0.0], factor=3.0)
assert out > 0.0
assert g.kl_hard_stop > 0.0
# --- R10: path-(c) gap-blowout is a divergence-RATE gate, not a real-decline --
def test_gap_blowout_fires_even_when_real_still_rising():
"""R10: path (c) fires when the proxy gain outpaces the real gain beyond the
ceiling EVEN WHILE the held-out (real) score is still genuinely RISING. This
is INTENTIONAL — path (c) is a divergence-RATE gate (fast single-generation
hacking), distinct from path (a)'s real-decline streak. Locking the intended
behavior so a future change can't silently turn it into a real-decline gate."""
g = _guard(max_proxy_real_gap=0.1, decline_patience=99) # isolate path (c) from (a)
status = None
for i in range(8):
status = g.update(
i,
in_loop_reward=0.30 + 0.20 * i, # proxy sprints
heldout_score=0.30 + 0.01 * i, # real still rising, but slowly
kl_to_init=0.02,
)
assert status.fire, "path (c) should fire on a fast proxy/real divergence"
assert "gap" in status.reason.lower()
# And the real score WAS rising the whole time (not a decline-driven fire).
assert status.heldout_ema > g._fold(None, 0.30) # type: ignore[attr-defined]