Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
composer-replication-framework / composer_replication /trainer /tests /test_killswitch_integration.py
Baladithya Balamurugan
Wave 3: close the HIGH review findings (kill-switch wiring, HeldoutSplit, EKS entrypoint bug)
bd0c358 | """R1 — HeldOutGuard wired into ComposerReplicationTrainer (the #2 safeguard). | |
| These tests close the "tripwire exists but never fires" gap: the run-level | |
| collapse kill-switch (``composer_replication.safety.HeldOutGuard``) must actually | |
| be folded by the trainer at its logging cadence, fed the in-loop GRPO reward and | |
| an INJECTED held-out eval, and halt the run on a fired verdict. | |
| Acceptance gates: | |
| 1. BACKWARD-COMPAT: no ``heldout_guard`` => ``_maybe_update_killswitch`` is a | |
| pure no-op (never touches the eval fn, never logs) — identical behavior. | |
| 2. THE WIRING: a fake ``heldout_eval_fn`` that DECLINES while the in-loop | |
| reward RISES drives the guard to fire; strict mode raises | |
| ``CollapseStopError`` and the verdict carries the reward-hacking signature. | |
| 3. Soft stop: ``strict_killswitch=False`` => no raise, the HF loop is asked to | |
| stop (``control.should_training_stop``). | |
| 4. Healthy run (held-out tracks reward) never fires. | |
| 5. KL-to-init is read from TRL's logged metric and reaches the guard. | |
| 6. Constructor contract: ``heldout_guard`` without ``heldout_eval_fn`` raises. | |
| 7. ``_compute_loss`` only folds the guard at the ``logging_steps`` cadence | |
| (not every micro-step), mirroring the loss-component logging. | |
| CPU-only, no model download, no full GRPOTrainer init (stub instance via | |
| ``__new__`` + manual attribute wiring, the pattern used by the SDPO tests). | |
| """ | |
| from __future__ import annotations | |
| import pytest | |
| from composer_replication.safety import CollapseStopError, HeldOutGuard | |
| from composer_replication.trainer.composer_trainer import ComposerReplicationTrainer | |
| # --------------------------------------------------------------------------- | |
| # Stubs — mirror _make_sdpo_trainer in test_sdpo_alignment_indices.py: build the | |
| # trainer via __new__ so we never run GRPOTrainer.__init__ (no TRL setup, no | |
| # model download), then wire only the attributes the kill-switch path reads. | |
| # --------------------------------------------------------------------------- | |
| class _State: | |
| def __init__(self, global_step: int = 0) -> None: | |
| self.global_step = global_step | |
| class _Args: | |
| def __init__(self, logging_steps: int = 1) -> None: | |
| self.logging_steps = logging_steps | |
| class _Control: | |
| def __init__(self) -> None: | |
| self.should_training_stop = False | |
| def _make_killswitch_trainer( | |
| guard: HeldOutGuard | None, | |
| eval_fn, | |
| *, | |
| strict: bool = True, | |
| reward: float | None = 0.40, | |
| kl: float | None = None, | |
| ): | |
| """A ComposerReplicationTrainer stub exposing only what the kill-switch reads. | |
| ``reward`` / ``kl`` seed TRL's per-step metric series (the trainer reads the | |
| tail of ``self._metrics["train"][name]``). Pass reward=None to simulate "no | |
| reward aggregated yet". | |
| """ | |
| obj = ComposerReplicationTrainer.__new__(ComposerReplicationTrainer) | |
| obj.heldout_guard = guard | |
| obj.heldout_eval_fn = eval_fn | |
| obj.strict_killswitch = strict | |
| obj.state = _State(global_step=0) | |
| obj.args = _Args(logging_steps=1) | |
| obj.control = _Control() | |
| train_metrics: dict[str, list] = {} | |
| if reward is not None: | |
| train_metrics["reward"] = [reward] | |
| if kl is not None: | |
| train_metrics["kl"] = [kl] | |
| obj._metrics = {"train": train_metrics} | |
| obj.logged: list[dict] = [] | |
| # capture self.log(...) instead of routing through HF Trainer.log | |
| obj.log = obj.logged.append # type: ignore[assignment] | |
| return obj | |
| def _set_step_reward(obj, step: int, reward: float, kl: float | None = None) -> None: | |
| obj.state.global_step = step | |
| obj._metrics["train"].setdefault("reward", []).append(reward) | |
| if kl is not None: | |
| obj._metrics["train"].setdefault("kl", []).append(kl) | |
| # --------------------------------------------------------------------------- | |
| # Gate 1 — backward compatibility: absent guard is a pure no-op | |
| # --------------------------------------------------------------------------- | |
| def test_absent_guard_is_noop(): | |
| """No heldout_guard => the kill-switch path does nothing: the eval fn is | |
| never called, nothing is logged, no exception. This is the backward-compat | |
| guarantee (no kwarg => identical behavior).""" | |
| calls = {"n": 0} | |
| def eval_fn() -> float: | |
| calls["n"] += 1 | |
| return 0.0 | |
| # Pass eval_fn but NO guard — the helper must never reach the eval fn. | |
| obj = _make_killswitch_trainer(guard=None, eval_fn=eval_fn) | |
| for step in range(50): | |
| _set_step_reward(obj, step, reward=0.40 + 0.05 * step) | |
| obj._maybe_update_killswitch() # must be a no-op | |
| assert calls["n"] == 0, "held-out eval fn was called even though no guard set" | |
| assert obj.logged == [], "kill-switch logged even though no guard configured" | |
| assert obj.control.should_training_stop is False | |
| def test_constructor_defaults_leave_killswitch_off(): | |
| """The constructor defaults: a trainer built without the kwargs has | |
| heldout_guard=None / strict_killswitch defaulting on but inert.""" | |
| obj = ComposerReplicationTrainer.__new__(ComposerReplicationTrainer) | |
| # Simulate the default-kwarg assignment __init__ performs. | |
| obj.heldout_guard = None | |
| obj.heldout_eval_fn = None | |
| obj.strict_killswitch = True | |
| obj._maybe_update_killswitch() # no state needed: returns immediately on None | |
| # --------------------------------------------------------------------------- | |
| # Gate 2 — THE WIRING: declining held-out + rising reward => guard fires & raises | |
| # --------------------------------------------------------------------------- | |
| def test_guard_fires_and_raises_on_reward_hacking_signature(): | |
| """Fake heldout_eval_fn DECLINES while the in-loop reward RISES — the | |
| canonical reward-hacking signature. The wired guard must fire and (strict | |
| mode) raise CollapseStopError with the reward-hacking reason.""" | |
| # min_steps small so the test is fast; isolate the decline-streak path. | |
| guard = HeldOutGuard( | |
| min_steps=3, decline_patience=3, ema_alpha=0.5, max_proxy_real_gap=10.0 | |
| ) | |
| # held-out declines every call; reward (TRL metric) rises every step. | |
| heldout = {"v": 0.80} | |
| def declining_eval() -> float: | |
| heldout["v"] -= 0.05 | |
| return heldout["v"] | |
| obj = _make_killswitch_trainer(guard, declining_eval, strict=True) | |
| raised = None | |
| for step in range(1, 30): | |
| # rising in-loop reward fed via the TRL metric tail | |
| _set_step_reward(obj, step, reward=0.30 + 0.03 * step) | |
| try: | |
| obj._maybe_update_killswitch() | |
| except CollapseStopError as exc: | |
| raised = exc | |
| break | |
| assert raised is not None, "guard never fired on the reward-hacking signature" | |
| assert guard.should_halt() | |
| assert raised.status.fire | |
| # The fired verdict must be the held-out-declines-while-reward-rises signature. | |
| assert "held-out" in raised.status.reason | |
| assert raised.status.proxy_real_gap > 0.0 # proxy gained while real lost | |
| # And the kill-switch logged the verdict before raising. | |
| assert any("killswitch/fire" in d for d in obj.logged) | |
| def test_soft_stop_sets_control_instead_of_raising(): | |
| """strict_killswitch=False => a fired verdict does NOT raise; it sets the HF | |
| loop's control.should_training_stop so training ends gracefully.""" | |
| guard = HeldOutGuard( | |
| min_steps=3, decline_patience=3, ema_alpha=0.5, max_proxy_real_gap=10.0 | |
| ) | |
| heldout = {"v": 0.80} | |
| def declining_eval() -> float: | |
| heldout["v"] -= 0.05 | |
| return heldout["v"] | |
| obj = _make_killswitch_trainer(guard, declining_eval, strict=False) | |
| for step in range(1, 30): | |
| _set_step_reward(obj, step, reward=0.30 + 0.03 * step) | |
| obj._maybe_update_killswitch() # must NOT raise | |
| if obj.control.should_training_stop: | |
| break | |
| assert obj.control.should_training_stop is True, ( | |
| "soft-stop guard fired but did not request training stop" | |
| ) | |
| assert guard.should_halt() | |
| # --------------------------------------------------------------------------- | |
| # Gate 4 — healthy run never fires | |
| # --------------------------------------------------------------------------- | |
| def test_healthy_run_never_fires(): | |
| """Held-out tracks the in-loop reward (both rise together), KL in band => | |
| the wired guard never fires and training is never asked to stop.""" | |
| guard = HeldOutGuard( | |
| min_steps=3, decline_patience=3, ema_alpha=0.5, kl_hard_stop=0.08, | |
| max_proxy_real_gap=10.0, | |
| ) | |
| heldout = {"v": 0.28} | |
| def rising_eval() -> float: | |
| heldout["v"] += 0.01 | |
| return heldout["v"] | |
| obj = _make_killswitch_trainer(guard, rising_eval, strict=True, kl=0.03) | |
| for step in range(1, 40): | |
| _set_step_reward(obj, step, reward=0.30 + 0.01 * step, kl=0.03) | |
| obj._maybe_update_killswitch() # must never raise on a healthy run | |
| assert not guard.should_halt() | |
| assert obj.control.should_training_stop is False | |
| # --------------------------------------------------------------------------- | |
| # Gate 5 — KL-to-init from TRL's logged metric reaches the guard | |
| # --------------------------------------------------------------------------- | |
| def test_kl_to_init_is_forwarded_to_guard(): | |
| """The KL the trainer reads from TRL's "kl" metric must reach the guard's | |
| kl_ema (proves kl_to_init wiring), and a KL breach fires via the KL path.""" | |
| guard = HeldOutGuard( | |
| min_steps=3, decline_patience=100, ema_alpha=0.5, kl_hard_stop=0.08, | |
| max_proxy_real_gap=10.0, # isolate the KL path | |
| ) | |
| obj = _make_killswitch_trainer(guard, lambda: 0.40, strict=True, kl=0.04) | |
| # Warm-up with healthy KL; metrics flat so only the KL path can fire. | |
| for step in range(1, 5): | |
| _set_step_reward(obj, step, reward=0.40, kl=0.04) | |
| obj._maybe_update_killswitch() | |
| assert guard.last_status is not None and guard.last_status.kl_ema is not None, ( | |
| "kl_to_init never reached the guard — KL wiring is broken" | |
| ) | |
| # KL spikes above the hard stop; EMA climbs and crosses => fire via KL path. | |
| raised = None | |
| for step in range(5, 20): | |
| _set_step_reward(obj, step, reward=0.40, kl=0.30) | |
| try: | |
| obj._maybe_update_killswitch() | |
| except CollapseStopError as exc: | |
| raised = exc | |
| break | |
| assert raised is not None, "KL hard-stop never fired through the wired guard" | |
| assert "kl_to_init" in raised.status.reason | |
| def test_no_reward_metric_yet_skips_cleanly(): | |
| """Before TRL has aggregated any reward (empty metric series), the helper | |
| skips the fold rather than feeding a fabricated 0.0 into the guard's EMA.""" | |
| guard = HeldOutGuard(min_steps=3, ema_alpha=0.5) | |
| calls = {"n": 0} | |
| def eval_fn() -> float: | |
| calls["n"] += 1 | |
| return 0.40 | |
| obj = _make_killswitch_trainer(guard, eval_fn, reward=None) | |
| obj._maybe_update_killswitch() # no reward series => skip | |
| assert calls["n"] == 0, "eval fn called despite no in-loop reward yet" | |
| assert guard.last_status is None, "guard advanced despite no reward metric" | |
| # --------------------------------------------------------------------------- | |
| # Gate 6 — constructor contract | |
| # --------------------------------------------------------------------------- | |
| def test_guard_without_eval_fn_raises_at_construction(monkeypatch): | |
| """A guard with no held-out eval is meaningless (the tripwire needs the | |
| held-out signal) => the REAL __init__ must reject it loudly. We stub the | |
| GRPOTrainer parent __init__ so the validation clause runs without a full | |
| TRL/model setup.""" | |
| parent = ComposerReplicationTrainer.__bases__[0] | |
| monkeypatch.setattr(parent, "__init__", lambda self, *a, **k: None, raising=False) | |
| guard = HeldOutGuard(min_steps=3) | |
| with pytest.raises(ValueError, match="heldout_eval_fn"): | |
| ComposerReplicationTrainer(heldout_guard=guard) # no heldout_eval_fn | |
| def test_guard_with_eval_fn_constructs_and_stays_off_when_absent(monkeypatch): | |
| """The real __init__ wires the kill-switch attributes; with both provided it | |
| constructs cleanly, and with neither provided the guard stays None (the | |
| default = OFF backward-compat path).""" | |
| parent = ComposerReplicationTrainer.__bases__[0] | |
| monkeypatch.setattr(parent, "__init__", lambda self, *a, **k: None, raising=False) | |
| # Both provided => constructs, guard wired. | |
| guard = HeldOutGuard(min_steps=3) | |
| t = ComposerReplicationTrainer(heldout_guard=guard, heldout_eval_fn=lambda: 0.4) | |
| assert t.heldout_guard is guard | |
| assert t.strict_killswitch is True # strict default | |
| # Neither provided => guard stays None (OFF). | |
| t2 = ComposerReplicationTrainer() | |
| assert t2.heldout_guard is None | |
| assert t2.heldout_eval_fn is None | |
| # --------------------------------------------------------------------------- | |
| # Gate 7 — _compute_loss only folds the guard at the logging cadence | |
| # --------------------------------------------------------------------------- | |
| def test_compute_loss_folds_guard_only_at_logging_cadence(monkeypatch): | |
| """Drive _compute_loss end-to-end (with the GRPO parent loss + SDPO/replay | |
| channels stubbed) and assert the guard is folded ONLY on logging-cadence | |
| steps — i.e. the kill-switch fold sits inside the same cadence gate as the | |
| loss-component logging, not on every micro-step.""" | |
| import torch | |
| folds = {"n": 0} | |
| guard = HeldOutGuard(min_steps=10_000, ema_alpha=0.5) # never fires in this test | |
| def counting_eval() -> float: | |
| folds["n"] += 1 | |
| return 0.40 | |
| obj = ComposerReplicationTrainer.__new__(ComposerReplicationTrainer) | |
| obj.alpha_sdpo = 0.0 | |
| obj.beta_replay = 0.0 | |
| obj.heldout_guard = guard | |
| obj.heldout_eval_fn = counting_eval | |
| obj.strict_killswitch = True | |
| obj.state = _State(global_step=0) | |
| obj.args = _Args(logging_steps=10) | |
| obj.control = _Control() | |
| obj._metrics = {"train": {"reward": [0.40]}} | |
| obj.logged = [] | |
| obj.log = obj.logged.append # type: ignore[assignment] | |
| # Stub the GRPO parent loss (the real `super()._compute_loss` would need a | |
| # full TRL trainer) and the SDPO / replay channels to zero. We patch the | |
| # PARENT class's _compute_loss so `super()._compute_loss(...)` resolves to it. | |
| parent = ComposerReplicationTrainer.__bases__[0] | |
| monkeypatch.setattr( | |
| parent, "_compute_loss", | |
| lambda self, model, inputs: torch.tensor(1.0), | |
| raising=False, | |
| ) | |
| monkeypatch.setattr( | |
| ComposerReplicationTrainer, "_compute_sdpo_loss", | |
| lambda self, model, inputs: torch.tensor(0.0), | |
| raising=True, | |
| ) | |
| monkeypatch.setattr( | |
| ComposerReplicationTrainer, "_compute_trace_replay_loss", | |
| lambda self, model, inputs: torch.tensor(0.0), | |
| raising=True, | |
| ) | |
| for step in range(0, 35): | |
| obj.state.global_step = step | |
| obj._metrics["train"]["reward"].append(0.40 + 0.001 * step) | |
| total = obj._compute_loss(model=object(), inputs={}) | |
| assert float(total.detach()) == pytest.approx(1.0) | |
| # steps 0, 10, 20, 30 are the only cadence hits => exactly 4 guard folds. | |
| assert folds["n"] == 4, f"expected 4 cadence folds, got {folds['n']}" | |