File size: 15,202 Bytes
bd0c358
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
"""R1 — HeldOutGuard wired into ComposerReplicationTrainer (the #2 safeguard).

These tests close the "tripwire exists but never fires" gap: the run-level
collapse kill-switch (``composer_replication.safety.HeldOutGuard``) must actually
be folded by the trainer at its logging cadence, fed the in-loop GRPO reward and
an INJECTED held-out eval, and halt the run on a fired verdict.

Acceptance gates:
  1. BACKWARD-COMPAT: no ``heldout_guard`` => ``_maybe_update_killswitch`` is a
     pure no-op (never touches the eval fn, never logs) — identical behavior.
  2. THE WIRING: a fake ``heldout_eval_fn`` that DECLINES while the in-loop
     reward RISES drives the guard to fire; strict mode raises
     ``CollapseStopError`` and the verdict carries the reward-hacking signature.
  3. Soft stop: ``strict_killswitch=False`` => no raise, the HF loop is asked to
     stop (``control.should_training_stop``).
  4. Healthy run (held-out tracks reward) never fires.
  5. KL-to-init is read from TRL's logged metric and reaches the guard.
  6. Constructor contract: ``heldout_guard`` without ``heldout_eval_fn`` raises.
  7. ``_compute_loss`` only folds the guard at the ``logging_steps`` cadence
     (not every micro-step), mirroring the loss-component logging.

CPU-only, no model download, no full GRPOTrainer init (stub instance via
``__new__`` + manual attribute wiring, the pattern used by the SDPO tests).
"""
from __future__ import annotations

import pytest

from composer_replication.safety import CollapseStopError, HeldOutGuard
from composer_replication.trainer.composer_trainer import ComposerReplicationTrainer

# ---------------------------------------------------------------------------
# Stubs — mirror _make_sdpo_trainer in test_sdpo_alignment_indices.py: build the
# trainer via __new__ so we never run GRPOTrainer.__init__ (no TRL setup, no
# model download), then wire only the attributes the kill-switch path reads.
# ---------------------------------------------------------------------------

class _State:
    def __init__(self, global_step: int = 0) -> None:
        self.global_step = global_step


class _Args:
    def __init__(self, logging_steps: int = 1) -> None:
        self.logging_steps = logging_steps


class _Control:
    def __init__(self) -> None:
        self.should_training_stop = False


def _make_killswitch_trainer(
    guard: HeldOutGuard | None,
    eval_fn,
    *,
    strict: bool = True,
    reward: float | None = 0.40,
    kl: float | None = None,
):
    """A ComposerReplicationTrainer stub exposing only what the kill-switch reads.

    ``reward`` / ``kl`` seed TRL's per-step metric series (the trainer reads the
    tail of ``self._metrics["train"][name]``). Pass reward=None to simulate "no
    reward aggregated yet".
    """
    obj = ComposerReplicationTrainer.__new__(ComposerReplicationTrainer)
    obj.heldout_guard = guard
    obj.heldout_eval_fn = eval_fn
    obj.strict_killswitch = strict
    obj.state = _State(global_step=0)
    obj.args = _Args(logging_steps=1)
    obj.control = _Control()
    train_metrics: dict[str, list] = {}
    if reward is not None:
        train_metrics["reward"] = [reward]
    if kl is not None:
        train_metrics["kl"] = [kl]
    obj._metrics = {"train": train_metrics}
    obj.logged: list[dict] = []
    # capture self.log(...) instead of routing through HF Trainer.log
    obj.log = obj.logged.append  # type: ignore[assignment]
    return obj


def _set_step_reward(obj, step: int, reward: float, kl: float | None = None) -> None:
    obj.state.global_step = step
    obj._metrics["train"].setdefault("reward", []).append(reward)
    if kl is not None:
        obj._metrics["train"].setdefault("kl", []).append(kl)


# ---------------------------------------------------------------------------
# Gate 1 — backward compatibility: absent guard is a pure no-op
# ---------------------------------------------------------------------------

def test_absent_guard_is_noop():
    """No heldout_guard => the kill-switch path does nothing: the eval fn is
    never called, nothing is logged, no exception. This is the backward-compat
    guarantee (no kwarg => identical behavior)."""
    calls = {"n": 0}

    def eval_fn() -> float:
        calls["n"] += 1
        return 0.0

    # Pass eval_fn but NO guard — the helper must never reach the eval fn.
    obj = _make_killswitch_trainer(guard=None, eval_fn=eval_fn)
    for step in range(50):
        _set_step_reward(obj, step, reward=0.40 + 0.05 * step)
        obj._maybe_update_killswitch()  # must be a no-op

    assert calls["n"] == 0, "held-out eval fn was called even though no guard set"
    assert obj.logged == [], "kill-switch logged even though no guard configured"
    assert obj.control.should_training_stop is False


def test_constructor_defaults_leave_killswitch_off():
    """The constructor defaults: a trainer built without the kwargs has
    heldout_guard=None / strict_killswitch defaulting on but inert."""
    obj = ComposerReplicationTrainer.__new__(ComposerReplicationTrainer)
    # Simulate the default-kwarg assignment __init__ performs.
    obj.heldout_guard = None
    obj.heldout_eval_fn = None
    obj.strict_killswitch = True
    obj._maybe_update_killswitch()  # no state needed: returns immediately on None


# ---------------------------------------------------------------------------
# Gate 2 — THE WIRING: declining held-out + rising reward => guard fires & raises
# ---------------------------------------------------------------------------

def test_guard_fires_and_raises_on_reward_hacking_signature():
    """Fake heldout_eval_fn DECLINES while the in-loop reward RISES — the
    canonical reward-hacking signature. The wired guard must fire and (strict
    mode) raise CollapseStopError with the reward-hacking reason."""
    # min_steps small so the test is fast; isolate the decline-streak path.
    guard = HeldOutGuard(
        min_steps=3, decline_patience=3, ema_alpha=0.5, max_proxy_real_gap=10.0
    )

    # held-out declines every call; reward (TRL metric) rises every step.
    heldout = {"v": 0.80}

    def declining_eval() -> float:
        heldout["v"] -= 0.05
        return heldout["v"]

    obj = _make_killswitch_trainer(guard, declining_eval, strict=True)

    raised = None
    for step in range(1, 30):
        # rising in-loop reward fed via the TRL metric tail
        _set_step_reward(obj, step, reward=0.30 + 0.03 * step)
        try:
            obj._maybe_update_killswitch()
        except CollapseStopError as exc:
            raised = exc
            break

    assert raised is not None, "guard never fired on the reward-hacking signature"
    assert guard.should_halt()
    assert raised.status.fire
    # The fired verdict must be the held-out-declines-while-reward-rises signature.
    assert "held-out" in raised.status.reason
    assert raised.status.proxy_real_gap > 0.0  # proxy gained while real lost
    # And the kill-switch logged the verdict before raising.
    assert any("killswitch/fire" in d for d in obj.logged)


def test_soft_stop_sets_control_instead_of_raising():
    """strict_killswitch=False => a fired verdict does NOT raise; it sets the HF
    loop's control.should_training_stop so training ends gracefully."""
    guard = HeldOutGuard(
        min_steps=3, decline_patience=3, ema_alpha=0.5, max_proxy_real_gap=10.0
    )
    heldout = {"v": 0.80}

    def declining_eval() -> float:
        heldout["v"] -= 0.05
        return heldout["v"]

    obj = _make_killswitch_trainer(guard, declining_eval, strict=False)

    for step in range(1, 30):
        _set_step_reward(obj, step, reward=0.30 + 0.03 * step)
        obj._maybe_update_killswitch()  # must NOT raise
        if obj.control.should_training_stop:
            break

    assert obj.control.should_training_stop is True, (
        "soft-stop guard fired but did not request training stop"
    )
    assert guard.should_halt()


# ---------------------------------------------------------------------------
# Gate 4 — healthy run never fires
# ---------------------------------------------------------------------------

def test_healthy_run_never_fires():
    """Held-out tracks the in-loop reward (both rise together), KL in band =>
    the wired guard never fires and training is never asked to stop."""
    guard = HeldOutGuard(
        min_steps=3, decline_patience=3, ema_alpha=0.5, kl_hard_stop=0.08,
        max_proxy_real_gap=10.0,
    )
    heldout = {"v": 0.28}

    def rising_eval() -> float:
        heldout["v"] += 0.01
        return heldout["v"]

    obj = _make_killswitch_trainer(guard, rising_eval, strict=True, kl=0.03)

    for step in range(1, 40):
        _set_step_reward(obj, step, reward=0.30 + 0.01 * step, kl=0.03)
        obj._maybe_update_killswitch()  # must never raise on a healthy run

    assert not guard.should_halt()
    assert obj.control.should_training_stop is False


# ---------------------------------------------------------------------------
# Gate 5 — KL-to-init from TRL's logged metric reaches the guard
# ---------------------------------------------------------------------------

def test_kl_to_init_is_forwarded_to_guard():
    """The KL the trainer reads from TRL's "kl" metric must reach the guard's
    kl_ema (proves kl_to_init wiring), and a KL breach fires via the KL path."""
    guard = HeldOutGuard(
        min_steps=3, decline_patience=100, ema_alpha=0.5, kl_hard_stop=0.08,
        max_proxy_real_gap=10.0,  # isolate the KL path
    )

    obj = _make_killswitch_trainer(guard, lambda: 0.40, strict=True, kl=0.04)

    # Warm-up with healthy KL; metrics flat so only the KL path can fire.
    for step in range(1, 5):
        _set_step_reward(obj, step, reward=0.40, kl=0.04)
        obj._maybe_update_killswitch()
    assert guard.last_status is not None and guard.last_status.kl_ema is not None, (
        "kl_to_init never reached the guard — KL wiring is broken"
    )

    # KL spikes above the hard stop; EMA climbs and crosses => fire via KL path.
    raised = None
    for step in range(5, 20):
        _set_step_reward(obj, step, reward=0.40, kl=0.30)
        try:
            obj._maybe_update_killswitch()
        except CollapseStopError as exc:
            raised = exc
            break
    assert raised is not None, "KL hard-stop never fired through the wired guard"
    assert "kl_to_init" in raised.status.reason


def test_no_reward_metric_yet_skips_cleanly():
    """Before TRL has aggregated any reward (empty metric series), the helper
    skips the fold rather than feeding a fabricated 0.0 into the guard's EMA."""
    guard = HeldOutGuard(min_steps=3, ema_alpha=0.5)
    calls = {"n": 0}

    def eval_fn() -> float:
        calls["n"] += 1
        return 0.40

    obj = _make_killswitch_trainer(guard, eval_fn, reward=None)
    obj._maybe_update_killswitch()  # no reward series => skip

    assert calls["n"] == 0, "eval fn called despite no in-loop reward yet"
    assert guard.last_status is None, "guard advanced despite no reward metric"


# ---------------------------------------------------------------------------
# Gate 6 — constructor contract
# ---------------------------------------------------------------------------

def test_guard_without_eval_fn_raises_at_construction(monkeypatch):
    """A guard with no held-out eval is meaningless (the tripwire needs the
    held-out signal) => the REAL __init__ must reject it loudly. We stub the
    GRPOTrainer parent __init__ so the validation clause runs without a full
    TRL/model setup."""
    parent = ComposerReplicationTrainer.__bases__[0]
    monkeypatch.setattr(parent, "__init__", lambda self, *a, **k: None, raising=False)

    guard = HeldOutGuard(min_steps=3)
    with pytest.raises(ValueError, match="heldout_eval_fn"):
        ComposerReplicationTrainer(heldout_guard=guard)  # no heldout_eval_fn


def test_guard_with_eval_fn_constructs_and_stays_off_when_absent(monkeypatch):
    """The real __init__ wires the kill-switch attributes; with both provided it
    constructs cleanly, and with neither provided the guard stays None (the
    default = OFF backward-compat path)."""
    parent = ComposerReplicationTrainer.__bases__[0]
    monkeypatch.setattr(parent, "__init__", lambda self, *a, **k: None, raising=False)

    # Both provided => constructs, guard wired.
    guard = HeldOutGuard(min_steps=3)
    t = ComposerReplicationTrainer(heldout_guard=guard, heldout_eval_fn=lambda: 0.4)
    assert t.heldout_guard is guard
    assert t.strict_killswitch is True  # strict default

    # Neither provided => guard stays None (OFF).
    t2 = ComposerReplicationTrainer()
    assert t2.heldout_guard is None
    assert t2.heldout_eval_fn is None


# ---------------------------------------------------------------------------
# Gate 7 — _compute_loss only folds the guard at the logging cadence
# ---------------------------------------------------------------------------

def test_compute_loss_folds_guard_only_at_logging_cadence(monkeypatch):
    """Drive _compute_loss end-to-end (with the GRPO parent loss + SDPO/replay
    channels stubbed) and assert the guard is folded ONLY on logging-cadence
    steps — i.e. the kill-switch fold sits inside the same cadence gate as the
    loss-component logging, not on every micro-step."""
    import torch

    folds = {"n": 0}
    guard = HeldOutGuard(min_steps=10_000, ema_alpha=0.5)  # never fires in this test

    def counting_eval() -> float:
        folds["n"] += 1
        return 0.40

    obj = ComposerReplicationTrainer.__new__(ComposerReplicationTrainer)
    obj.alpha_sdpo = 0.0
    obj.beta_replay = 0.0
    obj.heldout_guard = guard
    obj.heldout_eval_fn = counting_eval
    obj.strict_killswitch = True
    obj.state = _State(global_step=0)
    obj.args = _Args(logging_steps=10)
    obj.control = _Control()
    obj._metrics = {"train": {"reward": [0.40]}}
    obj.logged = []
    obj.log = obj.logged.append  # type: ignore[assignment]

    # Stub the GRPO parent loss (the real `super()._compute_loss` would need a
    # full TRL trainer) and the SDPO / replay channels to zero. We patch the
    # PARENT class's _compute_loss so `super()._compute_loss(...)` resolves to it.
    parent = ComposerReplicationTrainer.__bases__[0]
    monkeypatch.setattr(
        parent, "_compute_loss",
        lambda self, model, inputs: torch.tensor(1.0),
        raising=False,
    )
    monkeypatch.setattr(
        ComposerReplicationTrainer, "_compute_sdpo_loss",
        lambda self, model, inputs: torch.tensor(0.0),
        raising=True,
    )
    monkeypatch.setattr(
        ComposerReplicationTrainer, "_compute_trace_replay_loss",
        lambda self, model, inputs: torch.tensor(0.0),
        raising=True,
    )

    for step in range(0, 35):
        obj.state.global_step = step
        obj._metrics["train"]["reward"].append(0.40 + 0.001 * step)
        total = obj._compute_loss(model=object(), inputs={})
        assert float(total.detach()) == pytest.approx(1.0)

    # steps 0, 10, 20, 30 are the only cadence hits => exactly 4 guard folds.
    assert folds["n"] == 4, f"expected 4 cadence folds, got {folds['n']}"