"""R1 — HeldOutGuard wired into ComposerReplicationTrainer (the #2 safeguard). These tests close the "tripwire exists but never fires" gap: the run-level collapse kill-switch (``composer_replication.safety.HeldOutGuard``) must actually be folded by the trainer at its logging cadence, fed the in-loop GRPO reward and an INJECTED held-out eval, and halt the run on a fired verdict. Acceptance gates: 1. BACKWARD-COMPAT: no ``heldout_guard`` => ``_maybe_update_killswitch`` is a pure no-op (never touches the eval fn, never logs) — identical behavior. 2. THE WIRING: a fake ``heldout_eval_fn`` that DECLINES while the in-loop reward RISES drives the guard to fire; strict mode raises ``CollapseStopError`` and the verdict carries the reward-hacking signature. 3. Soft stop: ``strict_killswitch=False`` => no raise, the HF loop is asked to stop (``control.should_training_stop``). 4. Healthy run (held-out tracks reward) never fires. 5. KL-to-init is read from TRL's logged metric and reaches the guard. 6. Constructor contract: ``heldout_guard`` without ``heldout_eval_fn`` raises. 7. ``_compute_loss`` only folds the guard at the ``logging_steps`` cadence (not every micro-step), mirroring the loss-component logging. CPU-only, no model download, no full GRPOTrainer init (stub instance via ``__new__`` + manual attribute wiring, the pattern used by the SDPO tests). """ from __future__ import annotations import pytest from composer_replication.safety import CollapseStopError, HeldOutGuard from composer_replication.trainer.composer_trainer import ComposerReplicationTrainer # --------------------------------------------------------------------------- # Stubs — mirror _make_sdpo_trainer in test_sdpo_alignment_indices.py: build the # trainer via __new__ so we never run GRPOTrainer.__init__ (no TRL setup, no # model download), then wire only the attributes the kill-switch path reads. # --------------------------------------------------------------------------- class _State: def __init__(self, global_step: int = 0) -> None: self.global_step = global_step class _Args: def __init__(self, logging_steps: int = 1) -> None: self.logging_steps = logging_steps class _Control: def __init__(self) -> None: self.should_training_stop = False def _make_killswitch_trainer( guard: HeldOutGuard | None, eval_fn, *, strict: bool = True, reward: float | None = 0.40, kl: float | None = None, ): """A ComposerReplicationTrainer stub exposing only what the kill-switch reads. ``reward`` / ``kl`` seed TRL's per-step metric series (the trainer reads the tail of ``self._metrics["train"][name]``). Pass reward=None to simulate "no reward aggregated yet". """ obj = ComposerReplicationTrainer.__new__(ComposerReplicationTrainer) obj.heldout_guard = guard obj.heldout_eval_fn = eval_fn obj.strict_killswitch = strict obj.state = _State(global_step=0) obj.args = _Args(logging_steps=1) obj.control = _Control() train_metrics: dict[str, list] = {} if reward is not None: train_metrics["reward"] = [reward] if kl is not None: train_metrics["kl"] = [kl] obj._metrics = {"train": train_metrics} obj.logged: list[dict] = [] # capture self.log(...) instead of routing through HF Trainer.log obj.log = obj.logged.append # type: ignore[assignment] return obj def _set_step_reward(obj, step: int, reward: float, kl: float | None = None) -> None: obj.state.global_step = step obj._metrics["train"].setdefault("reward", []).append(reward) if kl is not None: obj._metrics["train"].setdefault("kl", []).append(kl) # --------------------------------------------------------------------------- # Gate 1 — backward compatibility: absent guard is a pure no-op # --------------------------------------------------------------------------- def test_absent_guard_is_noop(): """No heldout_guard => the kill-switch path does nothing: the eval fn is never called, nothing is logged, no exception. This is the backward-compat guarantee (no kwarg => identical behavior).""" calls = {"n": 0} def eval_fn() -> float: calls["n"] += 1 return 0.0 # Pass eval_fn but NO guard — the helper must never reach the eval fn. obj = _make_killswitch_trainer(guard=None, eval_fn=eval_fn) for step in range(50): _set_step_reward(obj, step, reward=0.40 + 0.05 * step) obj._maybe_update_killswitch() # must be a no-op assert calls["n"] == 0, "held-out eval fn was called even though no guard set" assert obj.logged == [], "kill-switch logged even though no guard configured" assert obj.control.should_training_stop is False def test_constructor_defaults_leave_killswitch_off(): """The constructor defaults: a trainer built without the kwargs has heldout_guard=None / strict_killswitch defaulting on but inert.""" obj = ComposerReplicationTrainer.__new__(ComposerReplicationTrainer) # Simulate the default-kwarg assignment __init__ performs. obj.heldout_guard = None obj.heldout_eval_fn = None obj.strict_killswitch = True obj._maybe_update_killswitch() # no state needed: returns immediately on None # --------------------------------------------------------------------------- # Gate 2 — THE WIRING: declining held-out + rising reward => guard fires & raises # --------------------------------------------------------------------------- def test_guard_fires_and_raises_on_reward_hacking_signature(): """Fake heldout_eval_fn DECLINES while the in-loop reward RISES — the canonical reward-hacking signature. The wired guard must fire and (strict mode) raise CollapseStopError with the reward-hacking reason.""" # min_steps small so the test is fast; isolate the decline-streak path. guard = HeldOutGuard( min_steps=3, decline_patience=3, ema_alpha=0.5, max_proxy_real_gap=10.0 ) # held-out declines every call; reward (TRL metric) rises every step. heldout = {"v": 0.80} def declining_eval() -> float: heldout["v"] -= 0.05 return heldout["v"] obj = _make_killswitch_trainer(guard, declining_eval, strict=True) raised = None for step in range(1, 30): # rising in-loop reward fed via the TRL metric tail _set_step_reward(obj, step, reward=0.30 + 0.03 * step) try: obj._maybe_update_killswitch() except CollapseStopError as exc: raised = exc break assert raised is not None, "guard never fired on the reward-hacking signature" assert guard.should_halt() assert raised.status.fire # The fired verdict must be the held-out-declines-while-reward-rises signature. assert "held-out" in raised.status.reason assert raised.status.proxy_real_gap > 0.0 # proxy gained while real lost # And the kill-switch logged the verdict before raising. assert any("killswitch/fire" in d for d in obj.logged) def test_soft_stop_sets_control_instead_of_raising(): """strict_killswitch=False => a fired verdict does NOT raise; it sets the HF loop's control.should_training_stop so training ends gracefully.""" guard = HeldOutGuard( min_steps=3, decline_patience=3, ema_alpha=0.5, max_proxy_real_gap=10.0 ) heldout = {"v": 0.80} def declining_eval() -> float: heldout["v"] -= 0.05 return heldout["v"] obj = _make_killswitch_trainer(guard, declining_eval, strict=False) for step in range(1, 30): _set_step_reward(obj, step, reward=0.30 + 0.03 * step) obj._maybe_update_killswitch() # must NOT raise if obj.control.should_training_stop: break assert obj.control.should_training_stop is True, ( "soft-stop guard fired but did not request training stop" ) assert guard.should_halt() # --------------------------------------------------------------------------- # Gate 4 — healthy run never fires # --------------------------------------------------------------------------- def test_healthy_run_never_fires(): """Held-out tracks the in-loop reward (both rise together), KL in band => the wired guard never fires and training is never asked to stop.""" guard = HeldOutGuard( min_steps=3, decline_patience=3, ema_alpha=0.5, kl_hard_stop=0.08, max_proxy_real_gap=10.0, ) heldout = {"v": 0.28} def rising_eval() -> float: heldout["v"] += 0.01 return heldout["v"] obj = _make_killswitch_trainer(guard, rising_eval, strict=True, kl=0.03) for step in range(1, 40): _set_step_reward(obj, step, reward=0.30 + 0.01 * step, kl=0.03) obj._maybe_update_killswitch() # must never raise on a healthy run assert not guard.should_halt() assert obj.control.should_training_stop is False # --------------------------------------------------------------------------- # Gate 5 — KL-to-init from TRL's logged metric reaches the guard # --------------------------------------------------------------------------- def test_kl_to_init_is_forwarded_to_guard(): """The KL the trainer reads from TRL's "kl" metric must reach the guard's kl_ema (proves kl_to_init wiring), and a KL breach fires via the KL path.""" guard = HeldOutGuard( min_steps=3, decline_patience=100, ema_alpha=0.5, kl_hard_stop=0.08, max_proxy_real_gap=10.0, # isolate the KL path ) obj = _make_killswitch_trainer(guard, lambda: 0.40, strict=True, kl=0.04) # Warm-up with healthy KL; metrics flat so only the KL path can fire. for step in range(1, 5): _set_step_reward(obj, step, reward=0.40, kl=0.04) obj._maybe_update_killswitch() assert guard.last_status is not None and guard.last_status.kl_ema is not None, ( "kl_to_init never reached the guard — KL wiring is broken" ) # KL spikes above the hard stop; EMA climbs and crosses => fire via KL path. raised = None for step in range(5, 20): _set_step_reward(obj, step, reward=0.40, kl=0.30) try: obj._maybe_update_killswitch() except CollapseStopError as exc: raised = exc break assert raised is not None, "KL hard-stop never fired through the wired guard" assert "kl_to_init" in raised.status.reason def test_no_reward_metric_yet_skips_cleanly(): """Before TRL has aggregated any reward (empty metric series), the helper skips the fold rather than feeding a fabricated 0.0 into the guard's EMA.""" guard = HeldOutGuard(min_steps=3, ema_alpha=0.5) calls = {"n": 0} def eval_fn() -> float: calls["n"] += 1 return 0.40 obj = _make_killswitch_trainer(guard, eval_fn, reward=None) obj._maybe_update_killswitch() # no reward series => skip assert calls["n"] == 0, "eval fn called despite no in-loop reward yet" assert guard.last_status is None, "guard advanced despite no reward metric" # --------------------------------------------------------------------------- # Gate 6 — constructor contract # --------------------------------------------------------------------------- def test_guard_without_eval_fn_raises_at_construction(monkeypatch): """A guard with no held-out eval is meaningless (the tripwire needs the held-out signal) => the REAL __init__ must reject it loudly. We stub the GRPOTrainer parent __init__ so the validation clause runs without a full TRL/model setup.""" parent = ComposerReplicationTrainer.__bases__[0] monkeypatch.setattr(parent, "__init__", lambda self, *a, **k: None, raising=False) guard = HeldOutGuard(min_steps=3) with pytest.raises(ValueError, match="heldout_eval_fn"): ComposerReplicationTrainer(heldout_guard=guard) # no heldout_eval_fn def test_guard_with_eval_fn_constructs_and_stays_off_when_absent(monkeypatch): """The real __init__ wires the kill-switch attributes; with both provided it constructs cleanly, and with neither provided the guard stays None (the default = OFF backward-compat path).""" parent = ComposerReplicationTrainer.__bases__[0] monkeypatch.setattr(parent, "__init__", lambda self, *a, **k: None, raising=False) # Both provided => constructs, guard wired. guard = HeldOutGuard(min_steps=3) t = ComposerReplicationTrainer(heldout_guard=guard, heldout_eval_fn=lambda: 0.4) assert t.heldout_guard is guard assert t.strict_killswitch is True # strict default # Neither provided => guard stays None (OFF). t2 = ComposerReplicationTrainer() assert t2.heldout_guard is None assert t2.heldout_eval_fn is None # --------------------------------------------------------------------------- # Gate 7 — _compute_loss only folds the guard at the logging cadence # --------------------------------------------------------------------------- def test_compute_loss_folds_guard_only_at_logging_cadence(monkeypatch): """Drive _compute_loss end-to-end (with the GRPO parent loss + SDPO/replay channels stubbed) and assert the guard is folded ONLY on logging-cadence steps — i.e. the kill-switch fold sits inside the same cadence gate as the loss-component logging, not on every micro-step.""" import torch folds = {"n": 0} guard = HeldOutGuard(min_steps=10_000, ema_alpha=0.5) # never fires in this test def counting_eval() -> float: folds["n"] += 1 return 0.40 obj = ComposerReplicationTrainer.__new__(ComposerReplicationTrainer) obj.alpha_sdpo = 0.0 obj.beta_replay = 0.0 obj.heldout_guard = guard obj.heldout_eval_fn = counting_eval obj.strict_killswitch = True obj.state = _State(global_step=0) obj.args = _Args(logging_steps=10) obj.control = _Control() obj._metrics = {"train": {"reward": [0.40]}} obj.logged = [] obj.log = obj.logged.append # type: ignore[assignment] # Stub the GRPO parent loss (the real `super()._compute_loss` would need a # full TRL trainer) and the SDPO / replay channels to zero. We patch the # PARENT class's _compute_loss so `super()._compute_loss(...)` resolves to it. parent = ComposerReplicationTrainer.__bases__[0] monkeypatch.setattr( parent, "_compute_loss", lambda self, model, inputs: torch.tensor(1.0), raising=False, ) monkeypatch.setattr( ComposerReplicationTrainer, "_compute_sdpo_loss", lambda self, model, inputs: torch.tensor(0.0), raising=True, ) monkeypatch.setattr( ComposerReplicationTrainer, "_compute_trace_replay_loss", lambda self, model, inputs: torch.tensor(0.0), raising=True, ) for step in range(0, 35): obj.state.global_step = step obj._metrics["train"]["reward"].append(0.40 + 0.001 * step) total = obj._compute_loss(model=object(), inputs={}) assert float(total.detach()) == pytest.approx(1.0) # steps 0, 10, 20, 30 are the only cadence hits => exactly 4 guard folds. assert folds["n"] == 4, f"expected 4 cadence folds, got {folds['n']}"