Baladithya Balamurugan
Wave 3: close the HIGH review findings (kill-switch wiring, HeldoutSplit, EKS entrypoint bug)
bd0c358
Raw
History Blame Contribute Delete
15.2 kB
"""R1 — HeldOutGuard wired into ComposerReplicationTrainer (the #2 safeguard).
These tests close the "tripwire exists but never fires" gap: the run-level
collapse kill-switch (``composer_replication.safety.HeldOutGuard``) must actually
be folded by the trainer at its logging cadence, fed the in-loop GRPO reward and
an INJECTED held-out eval, and halt the run on a fired verdict.
Acceptance gates:
1. BACKWARD-COMPAT: no ``heldout_guard`` => ``_maybe_update_killswitch`` is a
pure no-op (never touches the eval fn, never logs) — identical behavior.
2. THE WIRING: a fake ``heldout_eval_fn`` that DECLINES while the in-loop
reward RISES drives the guard to fire; strict mode raises
``CollapseStopError`` and the verdict carries the reward-hacking signature.
3. Soft stop: ``strict_killswitch=False`` => no raise, the HF loop is asked to
stop (``control.should_training_stop``).
4. Healthy run (held-out tracks reward) never fires.
5. KL-to-init is read from TRL's logged metric and reaches the guard.
6. Constructor contract: ``heldout_guard`` without ``heldout_eval_fn`` raises.
7. ``_compute_loss`` only folds the guard at the ``logging_steps`` cadence
(not every micro-step), mirroring the loss-component logging.
CPU-only, no model download, no full GRPOTrainer init (stub instance via
``__new__`` + manual attribute wiring, the pattern used by the SDPO tests).
"""
from __future__ import annotations
import pytest
from composer_replication.safety import CollapseStopError, HeldOutGuard
from composer_replication.trainer.composer_trainer import ComposerReplicationTrainer
# ---------------------------------------------------------------------------
# Stubs — mirror _make_sdpo_trainer in test_sdpo_alignment_indices.py: build the
# trainer via __new__ so we never run GRPOTrainer.__init__ (no TRL setup, no
# model download), then wire only the attributes the kill-switch path reads.
# ---------------------------------------------------------------------------
class _State:
def __init__(self, global_step: int = 0) -> None:
self.global_step = global_step
class _Args:
def __init__(self, logging_steps: int = 1) -> None:
self.logging_steps = logging_steps
class _Control:
def __init__(self) -> None:
self.should_training_stop = False
def _make_killswitch_trainer(
guard: HeldOutGuard | None,
eval_fn,
*,
strict: bool = True,
reward: float | None = 0.40,
kl: float | None = None,
):
"""A ComposerReplicationTrainer stub exposing only what the kill-switch reads.
``reward`` / ``kl`` seed TRL's per-step metric series (the trainer reads the
tail of ``self._metrics["train"][name]``). Pass reward=None to simulate "no
reward aggregated yet".
"""
obj = ComposerReplicationTrainer.__new__(ComposerReplicationTrainer)
obj.heldout_guard = guard
obj.heldout_eval_fn = eval_fn
obj.strict_killswitch = strict
obj.state = _State(global_step=0)
obj.args = _Args(logging_steps=1)
obj.control = _Control()
train_metrics: dict[str, list] = {}
if reward is not None:
train_metrics["reward"] = [reward]
if kl is not None:
train_metrics["kl"] = [kl]
obj._metrics = {"train": train_metrics}
obj.logged: list[dict] = []
# capture self.log(...) instead of routing through HF Trainer.log
obj.log = obj.logged.append # type: ignore[assignment]
return obj
def _set_step_reward(obj, step: int, reward: float, kl: float | None = None) -> None:
obj.state.global_step = step
obj._metrics["train"].setdefault("reward", []).append(reward)
if kl is not None:
obj._metrics["train"].setdefault("kl", []).append(kl)
# ---------------------------------------------------------------------------
# Gate 1 — backward compatibility: absent guard is a pure no-op
# ---------------------------------------------------------------------------
def test_absent_guard_is_noop():
"""No heldout_guard => the kill-switch path does nothing: the eval fn is
never called, nothing is logged, no exception. This is the backward-compat
guarantee (no kwarg => identical behavior)."""
calls = {"n": 0}
def eval_fn() -> float:
calls["n"] += 1
return 0.0
# Pass eval_fn but NO guard — the helper must never reach the eval fn.
obj = _make_killswitch_trainer(guard=None, eval_fn=eval_fn)
for step in range(50):
_set_step_reward(obj, step, reward=0.40 + 0.05 * step)
obj._maybe_update_killswitch() # must be a no-op
assert calls["n"] == 0, "held-out eval fn was called even though no guard set"
assert obj.logged == [], "kill-switch logged even though no guard configured"
assert obj.control.should_training_stop is False
def test_constructor_defaults_leave_killswitch_off():
"""The constructor defaults: a trainer built without the kwargs has
heldout_guard=None / strict_killswitch defaulting on but inert."""
obj = ComposerReplicationTrainer.__new__(ComposerReplicationTrainer)
# Simulate the default-kwarg assignment __init__ performs.
obj.heldout_guard = None
obj.heldout_eval_fn = None
obj.strict_killswitch = True
obj._maybe_update_killswitch() # no state needed: returns immediately on None
# ---------------------------------------------------------------------------
# Gate 2 — THE WIRING: declining held-out + rising reward => guard fires & raises
# ---------------------------------------------------------------------------
def test_guard_fires_and_raises_on_reward_hacking_signature():
"""Fake heldout_eval_fn DECLINES while the in-loop reward RISES — the
canonical reward-hacking signature. The wired guard must fire and (strict
mode) raise CollapseStopError with the reward-hacking reason."""
# min_steps small so the test is fast; isolate the decline-streak path.
guard = HeldOutGuard(
min_steps=3, decline_patience=3, ema_alpha=0.5, max_proxy_real_gap=10.0
)
# held-out declines every call; reward (TRL metric) rises every step.
heldout = {"v": 0.80}
def declining_eval() -> float:
heldout["v"] -= 0.05
return heldout["v"]
obj = _make_killswitch_trainer(guard, declining_eval, strict=True)
raised = None
for step in range(1, 30):
# rising in-loop reward fed via the TRL metric tail
_set_step_reward(obj, step, reward=0.30 + 0.03 * step)
try:
obj._maybe_update_killswitch()
except CollapseStopError as exc:
raised = exc
break
assert raised is not None, "guard never fired on the reward-hacking signature"
assert guard.should_halt()
assert raised.status.fire
# The fired verdict must be the held-out-declines-while-reward-rises signature.
assert "held-out" in raised.status.reason
assert raised.status.proxy_real_gap > 0.0 # proxy gained while real lost
# And the kill-switch logged the verdict before raising.
assert any("killswitch/fire" in d for d in obj.logged)
def test_soft_stop_sets_control_instead_of_raising():
"""strict_killswitch=False => a fired verdict does NOT raise; it sets the HF
loop's control.should_training_stop so training ends gracefully."""
guard = HeldOutGuard(
min_steps=3, decline_patience=3, ema_alpha=0.5, max_proxy_real_gap=10.0
)
heldout = {"v": 0.80}
def declining_eval() -> float:
heldout["v"] -= 0.05
return heldout["v"]
obj = _make_killswitch_trainer(guard, declining_eval, strict=False)
for step in range(1, 30):
_set_step_reward(obj, step, reward=0.30 + 0.03 * step)
obj._maybe_update_killswitch() # must NOT raise
if obj.control.should_training_stop:
break
assert obj.control.should_training_stop is True, (
"soft-stop guard fired but did not request training stop"
)
assert guard.should_halt()
# ---------------------------------------------------------------------------
# Gate 4 — healthy run never fires
# ---------------------------------------------------------------------------
def test_healthy_run_never_fires():
"""Held-out tracks the in-loop reward (both rise together), KL in band =>
the wired guard never fires and training is never asked to stop."""
guard = HeldOutGuard(
min_steps=3, decline_patience=3, ema_alpha=0.5, kl_hard_stop=0.08,
max_proxy_real_gap=10.0,
)
heldout = {"v": 0.28}
def rising_eval() -> float:
heldout["v"] += 0.01
return heldout["v"]
obj = _make_killswitch_trainer(guard, rising_eval, strict=True, kl=0.03)
for step in range(1, 40):
_set_step_reward(obj, step, reward=0.30 + 0.01 * step, kl=0.03)
obj._maybe_update_killswitch() # must never raise on a healthy run
assert not guard.should_halt()
assert obj.control.should_training_stop is False
# ---------------------------------------------------------------------------
# Gate 5 — KL-to-init from TRL's logged metric reaches the guard
# ---------------------------------------------------------------------------
def test_kl_to_init_is_forwarded_to_guard():
"""The KL the trainer reads from TRL's "kl" metric must reach the guard's
kl_ema (proves kl_to_init wiring), and a KL breach fires via the KL path."""
guard = HeldOutGuard(
min_steps=3, decline_patience=100, ema_alpha=0.5, kl_hard_stop=0.08,
max_proxy_real_gap=10.0, # isolate the KL path
)
obj = _make_killswitch_trainer(guard, lambda: 0.40, strict=True, kl=0.04)
# Warm-up with healthy KL; metrics flat so only the KL path can fire.
for step in range(1, 5):
_set_step_reward(obj, step, reward=0.40, kl=0.04)
obj._maybe_update_killswitch()
assert guard.last_status is not None and guard.last_status.kl_ema is not None, (
"kl_to_init never reached the guard — KL wiring is broken"
)
# KL spikes above the hard stop; EMA climbs and crosses => fire via KL path.
raised = None
for step in range(5, 20):
_set_step_reward(obj, step, reward=0.40, kl=0.30)
try:
obj._maybe_update_killswitch()
except CollapseStopError as exc:
raised = exc
break
assert raised is not None, "KL hard-stop never fired through the wired guard"
assert "kl_to_init" in raised.status.reason
def test_no_reward_metric_yet_skips_cleanly():
"""Before TRL has aggregated any reward (empty metric series), the helper
skips the fold rather than feeding a fabricated 0.0 into the guard's EMA."""
guard = HeldOutGuard(min_steps=3, ema_alpha=0.5)
calls = {"n": 0}
def eval_fn() -> float:
calls["n"] += 1
return 0.40
obj = _make_killswitch_trainer(guard, eval_fn, reward=None)
obj._maybe_update_killswitch() # no reward series => skip
assert calls["n"] == 0, "eval fn called despite no in-loop reward yet"
assert guard.last_status is None, "guard advanced despite no reward metric"
# ---------------------------------------------------------------------------
# Gate 6 — constructor contract
# ---------------------------------------------------------------------------
def test_guard_without_eval_fn_raises_at_construction(monkeypatch):
"""A guard with no held-out eval is meaningless (the tripwire needs the
held-out signal) => the REAL __init__ must reject it loudly. We stub the
GRPOTrainer parent __init__ so the validation clause runs without a full
TRL/model setup."""
parent = ComposerReplicationTrainer.__bases__[0]
monkeypatch.setattr(parent, "__init__", lambda self, *a, **k: None, raising=False)
guard = HeldOutGuard(min_steps=3)
with pytest.raises(ValueError, match="heldout_eval_fn"):
ComposerReplicationTrainer(heldout_guard=guard) # no heldout_eval_fn
def test_guard_with_eval_fn_constructs_and_stays_off_when_absent(monkeypatch):
"""The real __init__ wires the kill-switch attributes; with both provided it
constructs cleanly, and with neither provided the guard stays None (the
default = OFF backward-compat path)."""
parent = ComposerReplicationTrainer.__bases__[0]
monkeypatch.setattr(parent, "__init__", lambda self, *a, **k: None, raising=False)
# Both provided => constructs, guard wired.
guard = HeldOutGuard(min_steps=3)
t = ComposerReplicationTrainer(heldout_guard=guard, heldout_eval_fn=lambda: 0.4)
assert t.heldout_guard is guard
assert t.strict_killswitch is True # strict default
# Neither provided => guard stays None (OFF).
t2 = ComposerReplicationTrainer()
assert t2.heldout_guard is None
assert t2.heldout_eval_fn is None
# ---------------------------------------------------------------------------
# Gate 7 — _compute_loss only folds the guard at the logging cadence
# ---------------------------------------------------------------------------
def test_compute_loss_folds_guard_only_at_logging_cadence(monkeypatch):
"""Drive _compute_loss end-to-end (with the GRPO parent loss + SDPO/replay
channels stubbed) and assert the guard is folded ONLY on logging-cadence
steps — i.e. the kill-switch fold sits inside the same cadence gate as the
loss-component logging, not on every micro-step."""
import torch
folds = {"n": 0}
guard = HeldOutGuard(min_steps=10_000, ema_alpha=0.5) # never fires in this test
def counting_eval() -> float:
folds["n"] += 1
return 0.40
obj = ComposerReplicationTrainer.__new__(ComposerReplicationTrainer)
obj.alpha_sdpo = 0.0
obj.beta_replay = 0.0
obj.heldout_guard = guard
obj.heldout_eval_fn = counting_eval
obj.strict_killswitch = True
obj.state = _State(global_step=0)
obj.args = _Args(logging_steps=10)
obj.control = _Control()
obj._metrics = {"train": {"reward": [0.40]}}
obj.logged = []
obj.log = obj.logged.append # type: ignore[assignment]
# Stub the GRPO parent loss (the real `super()._compute_loss` would need a
# full TRL trainer) and the SDPO / replay channels to zero. We patch the
# PARENT class's _compute_loss so `super()._compute_loss(...)` resolves to it.
parent = ComposerReplicationTrainer.__bases__[0]
monkeypatch.setattr(
parent, "_compute_loss",
lambda self, model, inputs: torch.tensor(1.0),
raising=False,
)
monkeypatch.setattr(
ComposerReplicationTrainer, "_compute_sdpo_loss",
lambda self, model, inputs: torch.tensor(0.0),
raising=True,
)
monkeypatch.setattr(
ComposerReplicationTrainer, "_compute_trace_replay_loss",
lambda self, model, inputs: torch.tensor(0.0),
raising=True,
)
for step in range(0, 35):
obj.state.global_step = step
obj._metrics["train"]["reward"].append(0.40 + 0.001 * step)
total = obj._compute_loss(model=object(), inputs={})
assert float(total.detach()) == pytest.approx(1.0)
# steps 0, 10, 20, 30 are the only cadence hits => exactly 4 guard folds.
assert folds["n"] == 4, f"expected 4 cadence folds, got {folds['n']}"