Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
File size: 15,202 Bytes
bd0c358 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 | """R1 — HeldOutGuard wired into ComposerReplicationTrainer (the #2 safeguard).
These tests close the "tripwire exists but never fires" gap: the run-level
collapse kill-switch (``composer_replication.safety.HeldOutGuard``) must actually
be folded by the trainer at its logging cadence, fed the in-loop GRPO reward and
an INJECTED held-out eval, and halt the run on a fired verdict.
Acceptance gates:
1. BACKWARD-COMPAT: no ``heldout_guard`` => ``_maybe_update_killswitch`` is a
pure no-op (never touches the eval fn, never logs) — identical behavior.
2. THE WIRING: a fake ``heldout_eval_fn`` that DECLINES while the in-loop
reward RISES drives the guard to fire; strict mode raises
``CollapseStopError`` and the verdict carries the reward-hacking signature.
3. Soft stop: ``strict_killswitch=False`` => no raise, the HF loop is asked to
stop (``control.should_training_stop``).
4. Healthy run (held-out tracks reward) never fires.
5. KL-to-init is read from TRL's logged metric and reaches the guard.
6. Constructor contract: ``heldout_guard`` without ``heldout_eval_fn`` raises.
7. ``_compute_loss`` only folds the guard at the ``logging_steps`` cadence
(not every micro-step), mirroring the loss-component logging.
CPU-only, no model download, no full GRPOTrainer init (stub instance via
``__new__`` + manual attribute wiring, the pattern used by the SDPO tests).
"""
from __future__ import annotations
import pytest
from composer_replication.safety import CollapseStopError, HeldOutGuard
from composer_replication.trainer.composer_trainer import ComposerReplicationTrainer
# ---------------------------------------------------------------------------
# Stubs — mirror _make_sdpo_trainer in test_sdpo_alignment_indices.py: build the
# trainer via __new__ so we never run GRPOTrainer.__init__ (no TRL setup, no
# model download), then wire only the attributes the kill-switch path reads.
# ---------------------------------------------------------------------------
class _State:
def __init__(self, global_step: int = 0) -> None:
self.global_step = global_step
class _Args:
def __init__(self, logging_steps: int = 1) -> None:
self.logging_steps = logging_steps
class _Control:
def __init__(self) -> None:
self.should_training_stop = False
def _make_killswitch_trainer(
guard: HeldOutGuard | None,
eval_fn,
*,
strict: bool = True,
reward: float | None = 0.40,
kl: float | None = None,
):
"""A ComposerReplicationTrainer stub exposing only what the kill-switch reads.
``reward`` / ``kl`` seed TRL's per-step metric series (the trainer reads the
tail of ``self._metrics["train"][name]``). Pass reward=None to simulate "no
reward aggregated yet".
"""
obj = ComposerReplicationTrainer.__new__(ComposerReplicationTrainer)
obj.heldout_guard = guard
obj.heldout_eval_fn = eval_fn
obj.strict_killswitch = strict
obj.state = _State(global_step=0)
obj.args = _Args(logging_steps=1)
obj.control = _Control()
train_metrics: dict[str, list] = {}
if reward is not None:
train_metrics["reward"] = [reward]
if kl is not None:
train_metrics["kl"] = [kl]
obj._metrics = {"train": train_metrics}
obj.logged: list[dict] = []
# capture self.log(...) instead of routing through HF Trainer.log
obj.log = obj.logged.append # type: ignore[assignment]
return obj
def _set_step_reward(obj, step: int, reward: float, kl: float | None = None) -> None:
obj.state.global_step = step
obj._metrics["train"].setdefault("reward", []).append(reward)
if kl is not None:
obj._metrics["train"].setdefault("kl", []).append(kl)
# ---------------------------------------------------------------------------
# Gate 1 — backward compatibility: absent guard is a pure no-op
# ---------------------------------------------------------------------------
def test_absent_guard_is_noop():
"""No heldout_guard => the kill-switch path does nothing: the eval fn is
never called, nothing is logged, no exception. This is the backward-compat
guarantee (no kwarg => identical behavior)."""
calls = {"n": 0}
def eval_fn() -> float:
calls["n"] += 1
return 0.0
# Pass eval_fn but NO guard — the helper must never reach the eval fn.
obj = _make_killswitch_trainer(guard=None, eval_fn=eval_fn)
for step in range(50):
_set_step_reward(obj, step, reward=0.40 + 0.05 * step)
obj._maybe_update_killswitch() # must be a no-op
assert calls["n"] == 0, "held-out eval fn was called even though no guard set"
assert obj.logged == [], "kill-switch logged even though no guard configured"
assert obj.control.should_training_stop is False
def test_constructor_defaults_leave_killswitch_off():
"""The constructor defaults: a trainer built without the kwargs has
heldout_guard=None / strict_killswitch defaulting on but inert."""
obj = ComposerReplicationTrainer.__new__(ComposerReplicationTrainer)
# Simulate the default-kwarg assignment __init__ performs.
obj.heldout_guard = None
obj.heldout_eval_fn = None
obj.strict_killswitch = True
obj._maybe_update_killswitch() # no state needed: returns immediately on None
# ---------------------------------------------------------------------------
# Gate 2 — THE WIRING: declining held-out + rising reward => guard fires & raises
# ---------------------------------------------------------------------------
def test_guard_fires_and_raises_on_reward_hacking_signature():
"""Fake heldout_eval_fn DECLINES while the in-loop reward RISES — the
canonical reward-hacking signature. The wired guard must fire and (strict
mode) raise CollapseStopError with the reward-hacking reason."""
# min_steps small so the test is fast; isolate the decline-streak path.
guard = HeldOutGuard(
min_steps=3, decline_patience=3, ema_alpha=0.5, max_proxy_real_gap=10.0
)
# held-out declines every call; reward (TRL metric) rises every step.
heldout = {"v": 0.80}
def declining_eval() -> float:
heldout["v"] -= 0.05
return heldout["v"]
obj = _make_killswitch_trainer(guard, declining_eval, strict=True)
raised = None
for step in range(1, 30):
# rising in-loop reward fed via the TRL metric tail
_set_step_reward(obj, step, reward=0.30 + 0.03 * step)
try:
obj._maybe_update_killswitch()
except CollapseStopError as exc:
raised = exc
break
assert raised is not None, "guard never fired on the reward-hacking signature"
assert guard.should_halt()
assert raised.status.fire
# The fired verdict must be the held-out-declines-while-reward-rises signature.
assert "held-out" in raised.status.reason
assert raised.status.proxy_real_gap > 0.0 # proxy gained while real lost
# And the kill-switch logged the verdict before raising.
assert any("killswitch/fire" in d for d in obj.logged)
def test_soft_stop_sets_control_instead_of_raising():
"""strict_killswitch=False => a fired verdict does NOT raise; it sets the HF
loop's control.should_training_stop so training ends gracefully."""
guard = HeldOutGuard(
min_steps=3, decline_patience=3, ema_alpha=0.5, max_proxy_real_gap=10.0
)
heldout = {"v": 0.80}
def declining_eval() -> float:
heldout["v"] -= 0.05
return heldout["v"]
obj = _make_killswitch_trainer(guard, declining_eval, strict=False)
for step in range(1, 30):
_set_step_reward(obj, step, reward=0.30 + 0.03 * step)
obj._maybe_update_killswitch() # must NOT raise
if obj.control.should_training_stop:
break
assert obj.control.should_training_stop is True, (
"soft-stop guard fired but did not request training stop"
)
assert guard.should_halt()
# ---------------------------------------------------------------------------
# Gate 4 — healthy run never fires
# ---------------------------------------------------------------------------
def test_healthy_run_never_fires():
"""Held-out tracks the in-loop reward (both rise together), KL in band =>
the wired guard never fires and training is never asked to stop."""
guard = HeldOutGuard(
min_steps=3, decline_patience=3, ema_alpha=0.5, kl_hard_stop=0.08,
max_proxy_real_gap=10.0,
)
heldout = {"v": 0.28}
def rising_eval() -> float:
heldout["v"] += 0.01
return heldout["v"]
obj = _make_killswitch_trainer(guard, rising_eval, strict=True, kl=0.03)
for step in range(1, 40):
_set_step_reward(obj, step, reward=0.30 + 0.01 * step, kl=0.03)
obj._maybe_update_killswitch() # must never raise on a healthy run
assert not guard.should_halt()
assert obj.control.should_training_stop is False
# ---------------------------------------------------------------------------
# Gate 5 — KL-to-init from TRL's logged metric reaches the guard
# ---------------------------------------------------------------------------
def test_kl_to_init_is_forwarded_to_guard():
"""The KL the trainer reads from TRL's "kl" metric must reach the guard's
kl_ema (proves kl_to_init wiring), and a KL breach fires via the KL path."""
guard = HeldOutGuard(
min_steps=3, decline_patience=100, ema_alpha=0.5, kl_hard_stop=0.08,
max_proxy_real_gap=10.0, # isolate the KL path
)
obj = _make_killswitch_trainer(guard, lambda: 0.40, strict=True, kl=0.04)
# Warm-up with healthy KL; metrics flat so only the KL path can fire.
for step in range(1, 5):
_set_step_reward(obj, step, reward=0.40, kl=0.04)
obj._maybe_update_killswitch()
assert guard.last_status is not None and guard.last_status.kl_ema is not None, (
"kl_to_init never reached the guard — KL wiring is broken"
)
# KL spikes above the hard stop; EMA climbs and crosses => fire via KL path.
raised = None
for step in range(5, 20):
_set_step_reward(obj, step, reward=0.40, kl=0.30)
try:
obj._maybe_update_killswitch()
except CollapseStopError as exc:
raised = exc
break
assert raised is not None, "KL hard-stop never fired through the wired guard"
assert "kl_to_init" in raised.status.reason
def test_no_reward_metric_yet_skips_cleanly():
"""Before TRL has aggregated any reward (empty metric series), the helper
skips the fold rather than feeding a fabricated 0.0 into the guard's EMA."""
guard = HeldOutGuard(min_steps=3, ema_alpha=0.5)
calls = {"n": 0}
def eval_fn() -> float:
calls["n"] += 1
return 0.40
obj = _make_killswitch_trainer(guard, eval_fn, reward=None)
obj._maybe_update_killswitch() # no reward series => skip
assert calls["n"] == 0, "eval fn called despite no in-loop reward yet"
assert guard.last_status is None, "guard advanced despite no reward metric"
# ---------------------------------------------------------------------------
# Gate 6 — constructor contract
# ---------------------------------------------------------------------------
def test_guard_without_eval_fn_raises_at_construction(monkeypatch):
"""A guard with no held-out eval is meaningless (the tripwire needs the
held-out signal) => the REAL __init__ must reject it loudly. We stub the
GRPOTrainer parent __init__ so the validation clause runs without a full
TRL/model setup."""
parent = ComposerReplicationTrainer.__bases__[0]
monkeypatch.setattr(parent, "__init__", lambda self, *a, **k: None, raising=False)
guard = HeldOutGuard(min_steps=3)
with pytest.raises(ValueError, match="heldout_eval_fn"):
ComposerReplicationTrainer(heldout_guard=guard) # no heldout_eval_fn
def test_guard_with_eval_fn_constructs_and_stays_off_when_absent(monkeypatch):
"""The real __init__ wires the kill-switch attributes; with both provided it
constructs cleanly, and with neither provided the guard stays None (the
default = OFF backward-compat path)."""
parent = ComposerReplicationTrainer.__bases__[0]
monkeypatch.setattr(parent, "__init__", lambda self, *a, **k: None, raising=False)
# Both provided => constructs, guard wired.
guard = HeldOutGuard(min_steps=3)
t = ComposerReplicationTrainer(heldout_guard=guard, heldout_eval_fn=lambda: 0.4)
assert t.heldout_guard is guard
assert t.strict_killswitch is True # strict default
# Neither provided => guard stays None (OFF).
t2 = ComposerReplicationTrainer()
assert t2.heldout_guard is None
assert t2.heldout_eval_fn is None
# ---------------------------------------------------------------------------
# Gate 7 — _compute_loss only folds the guard at the logging cadence
# ---------------------------------------------------------------------------
def test_compute_loss_folds_guard_only_at_logging_cadence(monkeypatch):
"""Drive _compute_loss end-to-end (with the GRPO parent loss + SDPO/replay
channels stubbed) and assert the guard is folded ONLY on logging-cadence
steps — i.e. the kill-switch fold sits inside the same cadence gate as the
loss-component logging, not on every micro-step."""
import torch
folds = {"n": 0}
guard = HeldOutGuard(min_steps=10_000, ema_alpha=0.5) # never fires in this test
def counting_eval() -> float:
folds["n"] += 1
return 0.40
obj = ComposerReplicationTrainer.__new__(ComposerReplicationTrainer)
obj.alpha_sdpo = 0.0
obj.beta_replay = 0.0
obj.heldout_guard = guard
obj.heldout_eval_fn = counting_eval
obj.strict_killswitch = True
obj.state = _State(global_step=0)
obj.args = _Args(logging_steps=10)
obj.control = _Control()
obj._metrics = {"train": {"reward": [0.40]}}
obj.logged = []
obj.log = obj.logged.append # type: ignore[assignment]
# Stub the GRPO parent loss (the real `super()._compute_loss` would need a
# full TRL trainer) and the SDPO / replay channels to zero. We patch the
# PARENT class's _compute_loss so `super()._compute_loss(...)` resolves to it.
parent = ComposerReplicationTrainer.__bases__[0]
monkeypatch.setattr(
parent, "_compute_loss",
lambda self, model, inputs: torch.tensor(1.0),
raising=False,
)
monkeypatch.setattr(
ComposerReplicationTrainer, "_compute_sdpo_loss",
lambda self, model, inputs: torch.tensor(0.0),
raising=True,
)
monkeypatch.setattr(
ComposerReplicationTrainer, "_compute_trace_replay_loss",
lambda self, model, inputs: torch.tensor(0.0),
raising=True,
)
for step in range(0, 35):
obj.state.global_step = step
obj._metrics["train"]["reward"].append(0.40 + 0.001 * step)
total = obj._compute_loss(model=object(), inputs={})
assert float(total.detach()) == pytest.approx(1.0)
# steps 0, 10, 20, 30 are the only cadence hits => exactly 4 guard folds.
assert folds["n"] == 4, f"expected 4 cadence folds, got {folds['n']}"
|