Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
Baladithya Balamurugan
Wave 3: close the HIGH review findings (kill-switch wiring, HeldoutSplit, EKS entrypoint bug)
bd0c358 | """Tests for the held-out collapse kill-switch (HeldOutGuard). | |
| CPU-only, pure-Python — no torch, no cloud. Mirrors the | |
| ``datagen/tests/test_feature_deletion.py`` style (small helpers, behavioral | |
| asserts). Covers: | |
| - no-halt on a healthy co-rising run (the held-out-twin "within noise" case); | |
| - HALT on the canonical signature: held-out declines while in-loop rises; | |
| - HALT on KL-to-init hard-stop breach; | |
| - HALT on a fast proxy-real Hacking-Gap blowout; | |
| - window / patience behavior (min_steps warm-up; decline_patience streak); | |
| - calibration tightens-only; | |
| - idempotent query + latched-fire edge cases. | |
| """ | |
| from __future__ import annotations | |
| import pytest | |
| from composer_replication.safety import ( | |
| CollapseStopError, | |
| HeldOutGuard, | |
| TripwireStatus, | |
| kl_token_trust_filter, | |
| ) | |
| def _guard(**kw) -> HeldOutGuard: | |
| # Small min_steps keeps tests fast while still exercising the warm-up gate. | |
| base = dict(min_steps=3, decline_patience=3, ema_alpha=0.5, kl_hard_stop=0.08) | |
| base.update(kw) | |
| return HeldOutGuard(**base) | |
| # --- healthy run: both rise => never halt ----------------------------------- | |
| def test_no_halt_when_both_rise(): | |
| """Clean run: in-loop and held-out rise together, KL stays in band. The | |
| held-out twin scores within noise of the proxy => no fire (the well-behaved | |
| case the literature says a clean model exhibits).""" | |
| g = _guard() | |
| status = None | |
| for i in range(30): | |
| status = g.update( | |
| i, | |
| in_loop_reward=0.30 + 0.01 * i, | |
| heldout_score=0.28 + 0.01 * i, # tracks proxy within noise | |
| kl_to_init=0.03, | |
| ) | |
| assert not status.fire, f"fired unexpectedly at step {i}: {status.reason}" | |
| assert not g.should_halt() | |
| # Gap stays near zero because both gained equally. | |
| assert abs(g.proxy_real_gap()) < 0.05 | |
| # --- canonical signature: held-out declines while in-loop rises ------------- | |
| def test_halt_on_heldout_declines_while_reward_rises(): | |
| g = _guard(max_proxy_real_gap=10.0) # disable gap-blowout path to isolate (a) | |
| # Warm up past min_steps with a stable healthy stretch. | |
| for i in range(6): | |
| s = g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.03) | |
| assert not s.fire | |
| # Now: proxy reward climbs, held-out eval falls — the reward-hacking | |
| # fingerprint. Should fire once the decline streak hits decline_patience (3). | |
| fired_at = None | |
| for j, i in enumerate(range(6, 12)): | |
| s = g.update( | |
| i, | |
| in_loop_reward=0.40 + 0.05 * (j + 1), # rising | |
| heldout_score=0.40 - 0.05 * (j + 1), # declining | |
| kl_to_init=0.03, # KL stays in band | |
| ) | |
| if s.fire: | |
| fired_at = i | |
| break | |
| assert fired_at is not None, "tripwire never fired on the collapse signature" | |
| assert g.should_halt() | |
| s = g.last_status | |
| assert "held-out" in s.reason and "consecutive" in s.reason | |
| assert s.proxy_real_gap > 0.0 # proxy gained while real lost | |
| def test_does_not_fire_before_patience_window(): | |
| """Held-out declining while in-loop rises for FEWER than decline_patience | |
| checkpoints must NOT fire (window behavior).""" | |
| g = _guard(decline_patience=3, max_proxy_real_gap=10.0) | |
| for i in range(6): | |
| g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.03) | |
| # Only 2 divergent checkpoints (< patience of 3) => no fire. | |
| s1 = g.update(6, in_loop_reward=0.45, heldout_score=0.35, kl_to_init=0.03) | |
| s2 = g.update(7, in_loop_reward=0.50, heldout_score=0.30, kl_to_init=0.03) | |
| assert not s1.fire and not s2.fire | |
| def test_decline_streak_resets_on_recovery(): | |
| """A clean checkpoint (held-out recovers) resets the decline streak, so a | |
| later short divergence does not inherit prior declines.""" | |
| g = _guard(decline_patience=3, max_proxy_real_gap=10.0) | |
| for i in range(6): | |
| g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.03) | |
| # 2 declines... | |
| g.update(6, in_loop_reward=0.45, heldout_score=0.35, kl_to_init=0.03) | |
| g.update(7, in_loop_reward=0.50, heldout_score=0.30, kl_to_init=0.03) | |
| # ...then held-out recovers (resets streak)... | |
| s = g.update(8, in_loop_reward=0.50, heldout_score=0.45, kl_to_init=0.03) | |
| assert not s.fire | |
| # ...one more decline is only streak=1, still below patience. | |
| s = g.update(9, in_loop_reward=0.55, heldout_score=0.40, kl_to_init=0.03) | |
| assert not s.fire | |
| # --- KL hard-stop ------------------------------------------------------------ | |
| def test_halt_on_kl_hard_stop_breach(): | |
| g = _guard(kl_hard_stop=0.08, max_proxy_real_gap=10.0) | |
| # Healthy KL through the warm-up; both metrics flat so only KL can fire. | |
| for i in range(5): | |
| s = g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.04) | |
| assert not s.fire | |
| # KL spikes well above 0.08; EMA climbs across a couple steps then crosses. | |
| fired = False | |
| for i in range(5, 12): | |
| s = g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.20) | |
| if s.fire: | |
| fired = True | |
| assert "kl_to_init" in s.reason and "hard stop" in s.reason | |
| break | |
| assert fired, "KL hard-stop never fired despite KL EMA crossing the ceiling" | |
| def test_kl_none_never_fires_kl_path(): | |
| """If the caller never supplies kl_to_init, the KL path must be inert (and | |
| kl_ema stays None) — KL is an optional float.""" | |
| g = _guard(max_proxy_real_gap=10.0) | |
| s = None | |
| for i in range(20): | |
| s = g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=None) | |
| assert s is not None and not s.fire | |
| assert s.kl_ema is None | |
| # --- proxy-real gap blowout (fast divergence) ------------------------------- | |
| def test_halt_on_proxy_real_gap_blowout(): | |
| """A large single-generation divergence (proxy jumps, real stays flat) fires | |
| via the gap-blowout path even before the decline streak reaches patience.""" | |
| g = _guard(max_proxy_real_gap=0.10, decline_patience=100) # disable (a) | |
| for i in range(5): | |
| g.update(i, in_loop_reward=0.30, heldout_score=0.30, kl_to_init=0.03) | |
| # Proxy blows up; held-out flat. With ema_alpha=0.5 the gap crosses 0.10 fast. | |
| fired = False | |
| for i in range(5, 12): | |
| s = g.update(i, in_loop_reward=0.90, heldout_score=0.30, kl_to_init=0.03) | |
| if s.fire: | |
| fired = True | |
| assert "Hacking Gap" in s.reason | |
| assert s.proxy_real_gap > 0.10 | |
| break | |
| assert fired, "gap-blowout tripwire never fired" | |
| # --- warm-up window (min_steps) --------------------------------------------- | |
| def test_respects_min_steps_no_early_fire(): | |
| """Even with every signal tripped, no fire before min_steps (avoids halting | |
| on early-run noise).""" | |
| g = _guard(min_steps=10, decline_patience=2, kl_hard_stop=0.08, | |
| max_proxy_real_gap=0.01) | |
| # Egregiously bad signals from step 0: KL huge, proxy up, held-out down. | |
| for i in range(9): # 9 updates, all < min_steps=10 | |
| s = g.update(i, in_loop_reward=0.10 + 0.1 * i, heldout_score=0.90 - 0.1 * i, | |
| kl_to_init=0.9) | |
| assert not s.fire, f"fired during warm-up at step {i}: {s.reason}" | |
| # The 10th update (n==10, not < min_steps) is now allowed to fire. | |
| s = g.update(9, in_loop_reward=1.5, heldout_score=0.0, kl_to_init=0.9) | |
| assert s.fire | |
| # --- calibration ------------------------------------------------------------- | |
| def test_calibrate_kl_threshold_tightens_only(): | |
| g = _guard(kl_hard_stop=0.08) | |
| # Baseline mean 0.01 => 3x = 0.03 < 0.08 => tightens to 0.03. | |
| new = g.calibrate_kl_threshold([0.008, 0.010, 0.012], factor=3.0) | |
| assert new == pytest.approx(0.03, abs=1e-9) | |
| assert g.kl_hard_stop == pytest.approx(0.03, abs=1e-9) | |
| def test_calibrate_never_loosens_past_band(): | |
| g = _guard(kl_hard_stop=0.08) | |
| # A drifting baseline (mean 0.05 => 3x = 0.15) must NOT loosen past 0.08. | |
| new = g.calibrate_kl_threshold([0.05, 0.05, 0.05], factor=3.0) | |
| assert new == pytest.approx(0.08, abs=1e-9) | |
| assert g.kl_hard_stop == pytest.approx(0.08, abs=1e-9) | |
| def test_calibrate_empty_raises(): | |
| g = _guard() | |
| with pytest.raises(ValueError, match="non-empty"): | |
| g.calibrate_kl_threshold([]) | |
| # --- proxy_real_gap definition ---------------------------------------------- | |
| def test_proxy_real_gap_is_gain_difference(): | |
| g = _guard(min_steps=100, max_proxy_real_gap=10.0) # disable firing | |
| g.update(0, in_loop_reward=0.20, heldout_score=0.20, kl_to_init=0.02) # baseline | |
| # With ema_alpha=0.5 the second sample moves each EMA halfway. | |
| g.update(1, in_loop_reward=0.60, heldout_score=0.30, kl_to_init=0.02) | |
| # in_loop EMA: 0.5*0.20 + 0.5*0.60 = 0.40; gain = 0.40-0.20 = 0.20 | |
| # heldout EMA: 0.5*0.20 + 0.5*0.30 = 0.25; gain = 0.25-0.20 = 0.05 | |
| # gap = 0.20 - 0.05 = 0.15 | |
| assert g.proxy_real_gap() == pytest.approx(0.15, abs=1e-9) | |
| def test_proxy_real_gap_zero_before_update(): | |
| g = _guard() | |
| assert g.proxy_real_gap() == 0.0 | |
| # --- idempotency / edge cases ----------------------------------------------- | |
| def test_should_halt_is_idempotent_query(): | |
| g = _guard(max_proxy_real_gap=10.0) | |
| for i in range(6): | |
| g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.03) | |
| # Querying repeatedly must not advance state or change the verdict. | |
| snap_gap = g.proxy_real_gap() | |
| assert g.should_halt() is False | |
| assert g.should_halt() is False | |
| assert g.proxy_real_gap() == snap_gap # unchanged by querying | |
| assert g.last_status is not None and not g.last_status.fire | |
| def test_fire_is_latched(): | |
| """Once fired, a subsequent recovery cannot silently un-halt the run.""" | |
| g = _guard(kl_hard_stop=0.08, max_proxy_real_gap=10.0) | |
| for i in range(5): | |
| g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.04) | |
| # Drive a KL breach. | |
| fired = False | |
| for i in range(5, 12): | |
| s = g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.30) | |
| if s.fire: | |
| fired = True | |
| break | |
| assert fired | |
| # Now KL recovers to healthy — verdict must stay fired (latched). | |
| s = g.update(99, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.01) | |
| assert s.fire and s.reason.startswith("latched:") | |
| assert g.should_halt() | |
| def test_raise_if_fired_raises_typed_exception(): | |
| g = _guard(kl_hard_stop=0.08, max_proxy_real_gap=10.0) | |
| for i in range(5): | |
| g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.04) | |
| status = None | |
| for i in range(5, 12): | |
| status = g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.30) | |
| if status.fire: | |
| break | |
| assert status is not None and status.fire | |
| with pytest.raises(CollapseStopError) as exc: | |
| g.raise_if_fired(status) | |
| assert exc.value.status is status | |
| assert isinstance(str(exc.value), str) and str(exc.value) | |
| def test_raise_if_fired_noop_when_clean(): | |
| g = _guard(max_proxy_real_gap=10.0) | |
| s = g.update(0, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.03) | |
| # No fire => no raise (uses last_status when arg omitted). | |
| g.raise_if_fired(s) | |
| g.raise_if_fired() | |
| def test_status_halt_alias_matches_fire(): | |
| g = _guard(max_proxy_real_gap=10.0) | |
| s = g.update(0, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.03) | |
| assert s.halt == s.fire is False | |
| assert isinstance(s, TripwireStatus) | |
| def test_non_contiguous_round_idx_uses_internal_counter(): | |
| """min_steps gates on the internal update counter, not round_idx, so a caller | |
| that logs sparse / non-contiguous round indices still warms up correctly.""" | |
| g = _guard(min_steps=3, max_proxy_real_gap=0.01, decline_patience=1) | |
| # Pass huge round_idx values; only the 3rd UPDATE clears warm-up. | |
| g.update(1000, in_loop_reward=0.10, heldout_score=0.90, kl_to_init=0.9) | |
| g.update(2000, in_loop_reward=0.50, heldout_score=0.50, kl_to_init=0.9) | |
| s = g.update(3000, in_loop_reward=0.90, heldout_score=0.10, kl_to_init=0.9) | |
| assert s.fire # 3rd update, n==3 not < min_steps | |
| # --- config validation ------------------------------------------------------- | |
| def test_bad_ema_alpha_rejected(): | |
| with pytest.raises(ValueError, match="ema_alpha"): | |
| HeldOutGuard(ema_alpha=1.0) | |
| with pytest.raises(ValueError, match="ema_alpha"): | |
| HeldOutGuard(ema_alpha=-0.1) | |
| def test_bad_kl_hard_stop_rejected(): | |
| with pytest.raises(ValueError, match="kl_hard_stop"): | |
| HeldOutGuard(kl_hard_stop=0.0) | |
| def test_bad_decline_patience_rejected(): | |
| with pytest.raises(ValueError, match="decline_patience"): | |
| HeldOutGuard(decline_patience=0) | |
| # --- kl_token_trust_filter helper ------------------------------------------- | |
| def test_kl_token_trust_filter_masks_above_threshold(): | |
| # 0.5 * logratio^2; mask when it exceeds the per-token KL ceiling. | |
| assert kl_token_trust_filter(0.20, threshold=0.08) is True # too large -> mask | |
| assert kl_token_trust_filter(0.05, threshold=0.08) is False # within trust region | |
| assert kl_token_trust_filter(0.08, threshold=0.08) is False # boundary, not masked | |
| # --- R4: calibrate_kl_threshold input guards (negative factor / baseline) ----- | |
| def test_calibrate_rejects_nonpositive_factor(): | |
| """R4: a factor<=0 would make calibrated<=0 and min(<=0, 0.08)<=0, after | |
| which the KL tripwire fires on every healthy step. Reject it loudly.""" | |
| g = _guard() | |
| with pytest.raises(ValueError, match="factor must be > 0"): | |
| g.calibrate_kl_threshold([0.01, 0.02], factor=-3.0) | |
| with pytest.raises(ValueError, match="factor must be > 0"): | |
| g.calibrate_kl_threshold([0.01, 0.02], factor=0.0) | |
| def test_calibrate_rejects_negative_baseline_kl(): | |
| """R4: KL is non-negative by definition; a negative baseline is nonsensical | |
| and could invert the ceiling. Reject it.""" | |
| g = _guard() | |
| with pytest.raises(ValueError, match="non-negative"): | |
| g.calibrate_kl_threshold([0.01, -0.5, 0.02]) | |
| def test_calibrate_never_yields_nonpositive_threshold(): | |
| """R4: even an all-zero baseline (mean 0) must leave a positive ceiling so a | |
| later positive KL doesn't fire spuriously.""" | |
| g = _guard() | |
| out = g.calibrate_kl_threshold([0.0, 0.0, 0.0], factor=3.0) | |
| assert out > 0.0 | |
| assert g.kl_hard_stop > 0.0 | |
| # --- R10: path-(c) gap-blowout is a divergence-RATE gate, not a real-decline -- | |
| def test_gap_blowout_fires_even_when_real_still_rising(): | |
| """R10: path (c) fires when the proxy gain outpaces the real gain beyond the | |
| ceiling EVEN WHILE the held-out (real) score is still genuinely RISING. This | |
| is INTENTIONAL — path (c) is a divergence-RATE gate (fast single-generation | |
| hacking), distinct from path (a)'s real-decline streak. Locking the intended | |
| behavior so a future change can't silently turn it into a real-decline gate.""" | |
| g = _guard(max_proxy_real_gap=0.1, decline_patience=99) # isolate path (c) from (a) | |
| status = None | |
| for i in range(8): | |
| status = g.update( | |
| i, | |
| in_loop_reward=0.30 + 0.20 * i, # proxy sprints | |
| heldout_score=0.30 + 0.01 * i, # real still rising, but slowly | |
| kl_to_init=0.02, | |
| ) | |
| assert status.fire, "path (c) should fire on a fast proxy/real divergence" | |
| assert "gap" in status.reason.lower() | |
| # And the real score WAS rising the whole time (not a decline-driven fire). | |
| assert status.heldout_ema > g._fold(None, 0.30) # type: ignore[attr-defined] | |