Codeseys's picture
Wave 21b: skip zero-signal SDPO on empty-recovery error turns + real-trace validation
d61036a
Raw
History Blame Contribute Delete
10.1 kB
"""Wave 20 — chat-template alignment regression guard for the PACKAGE collator.
`composer_replication.trainer.data_collator.ComposerDataCollator` builds the
SDPO `sdpo_loss_mask` (and the aligned-student `response_mask`) so that in-loss
positions sit exactly on content tokens. The hard part is that
`apply_chat_template` inserts role/BOS/EOS scaffolding around each message; the
old `_build_segment_mask` tokenized each content string in isolation and
concatenated, so the mask drifted left of the real content tokens. The Wave 19
production audit measured this drift at ~67% aligned. Wave 20's
`_build_chat_aligned_mask` derives the mask from per-message
`apply_chat_template` prefix deltas instead, restoring ~100% alignment.
These tests use a REAL chat-template tokenizer (the stub used by
spikes/005 cannot expose the drift — its `apply_chat_template` adds no
scaffolding). They skip cleanly when transformers / the model cache is absent.
"""
from __future__ import annotations
import pytest
from composer_replication.trainer.data_collator import (
CollatorConfig,
ComposerDataCollator,
)
def _load_real_chat_tokenizer():
"""Return a real tokenizer with a chat template, or None to skip."""
try:
import os
os.environ.setdefault("HF_HUB_OFFLINE", "1")
os.environ.setdefault("TRANSFORMERS_OFFLINE", "1")
from transformers import AutoTokenizer
except Exception:
return None
for model in ("Qwen/Qwen2.5-0.5B-Instruct", "Qwen/Qwen2.5-1.5B-Instruct"):
try:
t = AutoTokenizer.from_pretrained(model)
if getattr(t, "chat_template", None):
return t
except Exception:
continue
return None
_REAL_TOK = _load_real_chat_tokenizer()
_SKIP_REASON = "real chat-template tokenizer not available (offline / not cached)"
@pytest.fixture
def real_chat_tok():
if _REAL_TOK is None:
pytest.skip(_SKIP_REASON)
return _REAL_TOK
@pytest.fixture
def multiturn_error_trace():
"""Multi-turn trace with an error site after several turns, so the
chat-template scaffolding drift compounds (what exposed the old 33%)."""
return {
"trace_id": "real-align-1",
"turns": [
{"role": "user", "content": "Read /etc/app/config.yaml and summarize it."},
{"role": "assistant", "content": '[TOOL_USE] name=Read input={"path":"/etc/app/config.yaml"}'},
{"role": "user", "content": "[TOOL_RESULT (ERROR)] (id=t1)\nError: no such file or directory"},
{
"role": "assistant",
"content": "The file does not exist there. Let me search for it instead.",
"tool_error": "file_not_found",
"error_meta": {"source_role": "user"},
},
{"role": "user", "content": "[TOOL_RESULT] (id=t2)\nFound /opt/app/config.yaml"},
{"role": "assistant", "content": "Found it at /opt/app/config.yaml. Reading now."},
],
"final_reward": 0.0,
}
def _hint_gen(kind, _meta):
return f"The path was wrong (kind: {kind}). Search with Glob before reading."
def test_real_chat_template_sdpo_mask_fully_aligned(real_chat_tok, multiturn_error_trace):
"""THE Wave 20 guarantee: with a REAL chat template, every in-loss
sdpo_loss_mask position must have student==teacher token id. Before the
fix this drifted to ~67% because the mask was built from per-segment
tokenization that ignored apply_chat_template scaffolding."""
cfg = CollatorConfig(hint_generator=_hint_gen, enable_replay_dpo=False)
collator = ComposerDataCollator(tokenizer=real_chat_tok, config=cfg)
batch = collator([multiturn_error_trace])
assert "sdpo_loss_mask" in batch, "SDPO channel did not fire on the error trace"
s_in = batch["input_ids"]
t_in = batch["ctx_teacher_input_ids"]
m_in = batch["sdpo_loss_mask"]
assert s_in.shape == t_in.shape == m_in.shape
n_aligned = n_total = 0
for row in range(s_in.shape[0]):
in_loss = m_in[row] == 1
if int(in_loss.sum()) == 0:
continue
s_at = s_in[row][in_loss]
t_at = t_in[row][in_loss]
n_aligned += int((s_at == t_at).sum().item())
n_total += int(in_loss.sum().item())
assert n_total > 0, "No in-loss positions — SDPO mask is empty"
ratio = n_aligned / n_total
assert ratio >= 0.95, (
f"SDPO mask alignment is only {100 * ratio:.1f}% ({n_aligned}/{n_total}); "
f"the chat-template drift fix has regressed. Expected ~100%."
)
def test_real_chat_template_in_loss_tokens_are_content_not_scaffolding(
real_chat_tok, multiturn_error_trace
):
"""The in-loss teacher tokens must decode to the recovery turn's CONTENT,
not chat-template markers (<|im_start|>, role strings, etc.)."""
cfg = CollatorConfig(hint_generator=_hint_gen, enable_replay_dpo=False)
collator = ComposerDataCollator(tokenizer=real_chat_tok, config=cfg)
batch = collator([multiturn_error_trace])
t_in = batch["ctx_teacher_input_ids"][0]
m_in = batch["sdpo_loss_mask"][0]
in_loss = m_in == 1
decoded = real_chat_tok.decode(t_in[in_loss].tolist())
assert "does not exist" in decoded, (
f"In-loss tokens don't contain the recovery content; got: {decoded!r}"
)
for marker in ("<|im_start|>", "<|im_end|>", "<|endoftext|>"):
assert marker not in decoded, (
f"Chat-template marker {marker!r} leaked into the in-loss span: {decoded!r}"
)
def test_real_chat_template_student_teacher_shapes_match(real_chat_tok, multiturn_error_trace):
"""The SDPO gate requires student_logits.shape == teacher_logits.shape;
verify the aligned-student path produces matching sequence lengths."""
cfg = CollatorConfig(hint_generator=_hint_gen, enable_replay_dpo=False)
collator = ComposerDataCollator(tokenizer=real_chat_tok, config=cfg)
batch = collator([multiturn_error_trace])
assert batch["input_ids"].shape == batch["ctx_teacher_input_ids"].shape
# ----------------------------------------------------------------------------
# Empty-recovery guard (Wave 21 — discovered on real Claude Code traces)
# ----------------------------------------------------------------------------
#
# ~67% of real error sites have EMPTY recovery content: when strip_thinking=True
# the recovery turn (which was pure [THINKING] reasoning) becomes empty. Injecting
# an SDPO hint with no recovery content yields an all-ignore_index mask — a
# zero-signal row that wastes a forward pass and dilutes the channel. The collator
# must treat empty-recovery error turns as non-error sites. These use a stub
# tokenizer (pure logic, no model needed) so they always run.
class _StubTok:
"""Word-level deterministic tokenizer; apply_chat_template space-joins."""
pad_token_id = 0
def __init__(self) -> None:
self._v: dict[str, int] = {"<pad>": 0, "<bos>": 1, "<eos>": 2}
def _id(self, w: str) -> int:
if w not in self._v:
self._v[w] = len(self._v)
return self._v[w]
def __call__(self, text, **_k):
return {"input_ids": [self._id(w) for w in text.split()] if text else []}
def apply_chat_template(self, messages, tokenize=True, **_k): # noqa: ARG002
return [self._id(w) for w in " ".join(m.get("content", "") for m in messages).split()]
def _hint_for_tnf(kind, _meta):
return "HINT use a real tool" if kind == "tool_not_found" else None
def test_empty_recovery_does_not_fire_sdpo():
"""An error turn with empty recovery content must NOT emit an SDPO mask."""
tok = _StubTok()
trace = {
"trace_id": "empty-recovery",
"turns": [
{"role": "user", "content": "do the thing"},
{"role": "assistant", "content": "", "tool_error": "tool_not_found", "error_meta": {}},
{"role": "user", "content": "tool not found"},
],
"final_reward": 0.0,
}
cfg = CollatorConfig(hint_generator=_hint_for_tnf)
collator = ComposerDataCollator(tokenizer=tok, config=cfg)
batch = collator([trace])
assert "sdpo_loss_mask" not in batch, (
"Empty-recovery error turn fired a zero-signal SDPO mask; it must be skipped."
)
def test_mixed_recovery_fires_on_nonempty_only():
"""A trace mixing empty + non-empty recovery turns fires SDPO from the
non-empty one and has loss-active positions."""
tok = _StubTok()
trace = {
"trace_id": "mixed-recovery",
"turns": [
{"role": "user", "content": "first task"},
{"role": "assistant", "content": "", "tool_error": "tool_not_found", "error_meta": {}},
{"role": "user", "content": "tool not found"},
{"role": "assistant", "content": "let me use a real tool instead",
"tool_error": "tool_not_found", "error_meta": {}},
],
"final_reward": 0.0,
}
cfg = CollatorConfig(hint_generator=_hint_for_tnf)
collator = ComposerDataCollator(tokenizer=tok, config=cfg)
batch = collator([trace])
assert "sdpo_loss_mask" in batch
assert int((batch["sdpo_loss_mask"] == 1).sum()) > 0
def test_empty_recovery_keeps_student_teacher_shapes_matched():
"""Even with a skipped empty-recovery turn, when SDPO DOES fire elsewhere
the student/teacher shapes must still match (lockstep skip on both sides)."""
tok = _StubTok()
trace = {
"trace_id": "mixed-shape",
"turns": [
{"role": "user", "content": "task"},
{"role": "assistant", "content": "", "tool_error": "tool_not_found", "error_meta": {}},
{"role": "user", "content": "tool not found"},
{"role": "assistant", "content": "recover now with a real tool",
"tool_error": "tool_not_found", "error_meta": {}},
],
"final_reward": 0.0,
}
cfg = CollatorConfig(hint_generator=_hint_for_tnf)
collator = ComposerDataCollator(tokenizer=tok, config=cfg)
batch = collator([trace])
assert batch["input_ids"].shape == batch["ctx_teacher_input_ids"].shape