"""Wave 20 — chat-template alignment regression guard for the PACKAGE collator. `composer_replication.trainer.data_collator.ComposerDataCollator` builds the SDPO `sdpo_loss_mask` (and the aligned-student `response_mask`) so that in-loss positions sit exactly on content tokens. The hard part is that `apply_chat_template` inserts role/BOS/EOS scaffolding around each message; the old `_build_segment_mask` tokenized each content string in isolation and concatenated, so the mask drifted left of the real content tokens. The Wave 19 production audit measured this drift at ~67% aligned. Wave 20's `_build_chat_aligned_mask` derives the mask from per-message `apply_chat_template` prefix deltas instead, restoring ~100% alignment. These tests use a REAL chat-template tokenizer (the stub used by spikes/005 cannot expose the drift — its `apply_chat_template` adds no scaffolding). They skip cleanly when transformers / the model cache is absent. """ from __future__ import annotations import pytest from composer_replication.trainer.data_collator import ( CollatorConfig, ComposerDataCollator, ) def _load_real_chat_tokenizer(): """Return a real tokenizer with a chat template, or None to skip.""" try: import os os.environ.setdefault("HF_HUB_OFFLINE", "1") os.environ.setdefault("TRANSFORMERS_OFFLINE", "1") from transformers import AutoTokenizer except Exception: return None for model in ("Qwen/Qwen2.5-0.5B-Instruct", "Qwen/Qwen2.5-1.5B-Instruct"): try: t = AutoTokenizer.from_pretrained(model) if getattr(t, "chat_template", None): return t except Exception: continue return None _REAL_TOK = _load_real_chat_tokenizer() _SKIP_REASON = "real chat-template tokenizer not available (offline / not cached)" @pytest.fixture def real_chat_tok(): if _REAL_TOK is None: pytest.skip(_SKIP_REASON) return _REAL_TOK @pytest.fixture def multiturn_error_trace(): """Multi-turn trace with an error site after several turns, so the chat-template scaffolding drift compounds (what exposed the old 33%).""" return { "trace_id": "real-align-1", "turns": [ {"role": "user", "content": "Read /etc/app/config.yaml and summarize it."}, {"role": "assistant", "content": '[TOOL_USE] name=Read input={"path":"/etc/app/config.yaml"}'}, {"role": "user", "content": "[TOOL_RESULT (ERROR)] (id=t1)\nError: no such file or directory"}, { "role": "assistant", "content": "The file does not exist there. Let me search for it instead.", "tool_error": "file_not_found", "error_meta": {"source_role": "user"}, }, {"role": "user", "content": "[TOOL_RESULT] (id=t2)\nFound /opt/app/config.yaml"}, {"role": "assistant", "content": "Found it at /opt/app/config.yaml. Reading now."}, ], "final_reward": 0.0, } def _hint_gen(kind, _meta): return f"The path was wrong (kind: {kind}). Search with Glob before reading." def test_real_chat_template_sdpo_mask_fully_aligned(real_chat_tok, multiturn_error_trace): """THE Wave 20 guarantee: with a REAL chat template, every in-loss sdpo_loss_mask position must have student==teacher token id. Before the fix this drifted to ~67% because the mask was built from per-segment tokenization that ignored apply_chat_template scaffolding.""" cfg = CollatorConfig(hint_generator=_hint_gen, enable_replay_dpo=False) collator = ComposerDataCollator(tokenizer=real_chat_tok, config=cfg) batch = collator([multiturn_error_trace]) assert "sdpo_loss_mask" in batch, "SDPO channel did not fire on the error trace" s_in = batch["input_ids"] t_in = batch["ctx_teacher_input_ids"] m_in = batch["sdpo_loss_mask"] assert s_in.shape == t_in.shape == m_in.shape n_aligned = n_total = 0 for row in range(s_in.shape[0]): in_loss = m_in[row] == 1 if int(in_loss.sum()) == 0: continue s_at = s_in[row][in_loss] t_at = t_in[row][in_loss] n_aligned += int((s_at == t_at).sum().item()) n_total += int(in_loss.sum().item()) assert n_total > 0, "No in-loss positions — SDPO mask is empty" ratio = n_aligned / n_total assert ratio >= 0.95, ( f"SDPO mask alignment is only {100 * ratio:.1f}% ({n_aligned}/{n_total}); " f"the chat-template drift fix has regressed. Expected ~100%." ) def test_real_chat_template_in_loss_tokens_are_content_not_scaffolding( real_chat_tok, multiturn_error_trace ): """The in-loss teacher tokens must decode to the recovery turn's CONTENT, not chat-template markers (<|im_start|>, role strings, etc.).""" cfg = CollatorConfig(hint_generator=_hint_gen, enable_replay_dpo=False) collator = ComposerDataCollator(tokenizer=real_chat_tok, config=cfg) batch = collator([multiturn_error_trace]) t_in = batch["ctx_teacher_input_ids"][0] m_in = batch["sdpo_loss_mask"][0] in_loss = m_in == 1 decoded = real_chat_tok.decode(t_in[in_loss].tolist()) assert "does not exist" in decoded, ( f"In-loss tokens don't contain the recovery content; got: {decoded!r}" ) for marker in ("<|im_start|>", "<|im_end|>", "<|endoftext|>"): assert marker not in decoded, ( f"Chat-template marker {marker!r} leaked into the in-loss span: {decoded!r}" ) def test_real_chat_template_student_teacher_shapes_match(real_chat_tok, multiturn_error_trace): """The SDPO gate requires student_logits.shape == teacher_logits.shape; verify the aligned-student path produces matching sequence lengths.""" cfg = CollatorConfig(hint_generator=_hint_gen, enable_replay_dpo=False) collator = ComposerDataCollator(tokenizer=real_chat_tok, config=cfg) batch = collator([multiturn_error_trace]) assert batch["input_ids"].shape == batch["ctx_teacher_input_ids"].shape # ---------------------------------------------------------------------------- # Empty-recovery guard (Wave 21 — discovered on real Claude Code traces) # ---------------------------------------------------------------------------- # # ~67% of real error sites have EMPTY recovery content: when strip_thinking=True # the recovery turn (which was pure [THINKING] reasoning) becomes empty. Injecting # an SDPO hint with no recovery content yields an all-ignore_index mask — a # zero-signal row that wastes a forward pass and dilutes the channel. The collator # must treat empty-recovery error turns as non-error sites. These use a stub # tokenizer (pure logic, no model needed) so they always run. class _StubTok: """Word-level deterministic tokenizer; apply_chat_template space-joins.""" pad_token_id = 0 def __init__(self) -> None: self._v: dict[str, int] = {"": 0, "": 1, "": 2} def _id(self, w: str) -> int: if w not in self._v: self._v[w] = len(self._v) return self._v[w] def __call__(self, text, **_k): return {"input_ids": [self._id(w) for w in text.split()] if text else []} def apply_chat_template(self, messages, tokenize=True, **_k): # noqa: ARG002 return [self._id(w) for w in " ".join(m.get("content", "") for m in messages).split()] def _hint_for_tnf(kind, _meta): return "HINT use a real tool" if kind == "tool_not_found" else None def test_empty_recovery_does_not_fire_sdpo(): """An error turn with empty recovery content must NOT emit an SDPO mask.""" tok = _StubTok() trace = { "trace_id": "empty-recovery", "turns": [ {"role": "user", "content": "do the thing"}, {"role": "assistant", "content": "", "tool_error": "tool_not_found", "error_meta": {}}, {"role": "user", "content": "tool not found"}, ], "final_reward": 0.0, } cfg = CollatorConfig(hint_generator=_hint_for_tnf) collator = ComposerDataCollator(tokenizer=tok, config=cfg) batch = collator([trace]) assert "sdpo_loss_mask" not in batch, ( "Empty-recovery error turn fired a zero-signal SDPO mask; it must be skipped." ) def test_mixed_recovery_fires_on_nonempty_only(): """A trace mixing empty + non-empty recovery turns fires SDPO from the non-empty one and has loss-active positions.""" tok = _StubTok() trace = { "trace_id": "mixed-recovery", "turns": [ {"role": "user", "content": "first task"}, {"role": "assistant", "content": "", "tool_error": "tool_not_found", "error_meta": {}}, {"role": "user", "content": "tool not found"}, {"role": "assistant", "content": "let me use a real tool instead", "tool_error": "tool_not_found", "error_meta": {}}, ], "final_reward": 0.0, } cfg = CollatorConfig(hint_generator=_hint_for_tnf) collator = ComposerDataCollator(tokenizer=tok, config=cfg) batch = collator([trace]) assert "sdpo_loss_mask" in batch assert int((batch["sdpo_loss_mask"] == 1).sum()) > 0 def test_empty_recovery_keeps_student_teacher_shapes_matched(): """Even with a skipped empty-recovery turn, when SDPO DOES fire elsewhere the student/teacher shapes must still match (lockstep skip on both sides).""" tok = _StubTok() trace = { "trace_id": "mixed-shape", "turns": [ {"role": "user", "content": "task"}, {"role": "assistant", "content": "", "tool_error": "tool_not_found", "error_meta": {}}, {"role": "user", "content": "tool not found"}, {"role": "assistant", "content": "recover now with a real tool", "tool_error": "tool_not_found", "error_meta": {}}, ], "final_reward": 0.0, } cfg = CollatorConfig(hint_generator=_hint_for_tnf) collator = ComposerDataCollator(tokenizer=tok, config=cfg) batch = collator([trace]) assert batch["input_ids"].shape == batch["ctx_teacher_input_ids"].shape