Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """data_collator.py — ComposerDataCollator: raw trace → trainer-ready batch. | |
| Pipeline: | |
| 1. Take a frozen agentic trace + N-teacher DPO pairs (from spike 002 + 003). | |
| 2. Tokenize each turn of the trace. | |
| 3. Detect error sites (turns where a tool call failed) using a configurable predicate. | |
| 4. At each error site, build ctx_teacher = ctx_student with hint inserted at the error-turn boundary. | |
| 5. Pad/align ctx_student and ctx_teacher so SDPO logits compare position-by-position. | |
| 6. Construct sdpo_loss_mask = 1 at post-hint tokens of the error turn, 0 elsewhere. | |
| 7. Tokenize DPO chosen/rejected pairs, build response masks, leave ref_logprobs as a precompute step. | |
| The output dict is what `ComposerReplicationTrainer._compute_loss` expects in its | |
| `inputs` argument. See `trl_path/composer_trainer.py` for the consumer side. | |
| Architectural note (verified via spike 005 test_opsd_loss.py): generalized_jsd_loss | |
| requires student_logits and teacher_logits to have the SAME (B, T, V) shape — that's | |
| why we pad/align here rather than inside the loss function. The post-hint section of | |
| ctx_teacher must have token-by-token alignment with the same section of ctx_student. | |
| """ | |
| from __future__ import annotations | |
| from collections.abc import Callable, Sequence | |
| from dataclasses import dataclass, field | |
| from typing import Any, TypedDict | |
| import torch | |
| # --------------------------------------------------------------------------- | |
| # Types | |
| # --------------------------------------------------------------------------- | |
| class TraceTurn(TypedDict, total=False): | |
| """One turn of an agentic trace.""" | |
| role: str # "user" | "assistant" | "tool" | |
| content: str # text or tool result | |
| tool_call: dict | None # parsed tool call, if assistant-issued | |
| tool_error: str | None # error_kind from the env, e.g. "tool_not_found" | |
| error_meta: dict # extra info for hint generator (available_tools, etc.) | |
| class TraceExample(TypedDict, total=False): | |
| """One training example: a (trace, optional DPO pairs) tuple.""" | |
| trace_id: str | |
| turns: list[TraceTurn] | |
| final_reward: float # RLVR scalar (test-pass etc.) at trajectory end | |
| dpo_pairs: list[dict] | None # from teacher_replay.extract_dpo_pairs | |
| # --------------------------------------------------------------------------- | |
| # Tokenizer protocol — duck-typed against HF AutoTokenizer | |
| # --------------------------------------------------------------------------- | |
| class TokenizerLike: | |
| """Minimal protocol the collator needs from a tokenizer. | |
| Compatible with HuggingFace `AutoTokenizer` instances (the typical case), | |
| but also satisfiable by simpler stubs for unit-testing. | |
| """ | |
| pad_token_id: int | |
| def __call__(self, text: str | list[str], **kwargs: Any) -> dict[str, list]: # pragma: no cover | |
| ... | |
| def apply_chat_template( # pragma: no cover | |
| self, messages: list[dict], **kwargs: Any | |
| ) -> str | list[int]: | |
| ... | |
| # --------------------------------------------------------------------------- | |
| # Configuration | |
| # --------------------------------------------------------------------------- | |
| class CollatorConfig: | |
| """Tunables for ComposerDataCollator.""" | |
| max_seq_len: int = 4096 | |
| max_dpo_seq_len: int = 2048 | |
| pad_token_id: int = 0 | |
| ignore_index: int = -100 # standard HF "ignore in loss" sentinel | |
| # SDPO behavior | |
| enable_sdpo: bool = True | |
| hint_generator: Callable[[str, dict], str | None] | None = None | |
| """Callable error_kind, error_meta -> hint_text (or None to skip).""" | |
| # Trace-replay DPO behavior | |
| enable_replay_dpo: bool = True | |
| # Reward shaping | |
| rlvr_reward_key: str = "final_reward" | |
| # --------------------------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------------------------- | |
| def _is_error_turn(turn: TraceTurn) -> bool: | |
| """Predicate: is this turn an error site that should trigger SDPO?""" | |
| return turn.get("tool_error") is not None | |
| def _build_chat_messages(turns: Sequence[TraceTurn]) -> list[dict]: | |
| """Convert TraceTurns to OpenAI-style chat messages for tokenizer.apply_chat_template.""" | |
| return [ | |
| {"role": t["role"], "content": t["content"]} | |
| for t in turns if t.get("content") | |
| ] | |
| def _pad_or_truncate(seq: list[int], target_len: int, pad_id: int) -> list[int]: | |
| """Right-pad with pad_id, or right-truncate to target_len.""" | |
| if len(seq) >= target_len: | |
| return seq[:target_len] | |
| return seq + [pad_id] * (target_len - len(seq)) | |
| def _mask_to_padded_indices( | |
| mask: torch.Tensor, # (B, T) where nonzero/True == valid position | |
| pad_sentinel: int = -1, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| """Convert a (B,T) bool/0-1 mask → (B,K_max) index tensor + (B,K_max) validity mask. | |
| Each row's K valid positions are written left-aligned into ``idx``; the | |
| ragged tail (rows with fewer than K_max positions) is padded with | |
| ``pad_sentinel`` (default -1). ``valid`` is True exactly where ``idx`` | |
| holds a real position. | |
| ADR-011: the SDPO loss gathers post-hint response logits via these indices, | |
| then masks the sentinel/padding positions so they contribute 0. K_max=0 | |
| (no valid positions anywhere) returns (B,0) tensors. | |
| """ | |
| B, T = mask.shape | |
| bool_mask = mask != 0 | |
| counts = bool_mask.sum(dim=1).long() # (B,) — K per row | |
| K_max = int(counts.max().item()) if counts.numel() else 0 | |
| if K_max == 0: | |
| return ( | |
| torch.full((B, 0), pad_sentinel, dtype=torch.long, device=mask.device), | |
| torch.zeros(B, 0, dtype=torch.bool, device=mask.device), | |
| ) | |
| idx = torch.full((B, K_max), pad_sentinel, dtype=torch.long, device=mask.device) | |
| valid = torch.zeros(B, K_max, dtype=torch.bool, device=mask.device) | |
| # torch.nonzero on a 2D bool tensor yields (total_K, 2): (batch_idx, pos_idx), | |
| # row-major so positions are already in per-row, ascending order. | |
| nz = torch.nonzero(bool_mask, as_tuple=False) # (total_K, 2) | |
| pos_idx = nz[:, 1] | |
| offsets = torch.zeros(B + 1, dtype=torch.long, device=mask.device) | |
| offsets[1:] = counts.cumsum(dim=0) | |
| for b in range(B): | |
| start, end = int(offsets[b].item()), int(offsets[b + 1].item()) | |
| k = end - start | |
| if k > 0: | |
| idx[b, :k] = pos_idx[start:end] | |
| valid[b, :k] = True | |
| return idx, valid | |
| # --------------------------------------------------------------------------- | |
| # The collator | |
| # --------------------------------------------------------------------------- | |
| class ComposerDataCollator: | |
| """Build trainer-ready batches from raw traces + optional DPO pairs. | |
| Usage: | |
| collator = ComposerDataCollator(tokenizer=tok, config=CollatorConfig()) | |
| batch = collator([trace_example_0, trace_example_1, ...]) | |
| # batch is a dict[str, torch.Tensor] ready for ComposerReplicationTrainer | |
| The dict contains: | |
| # Channel 1 (GRPO/RLVR — handled by the parent GRPOTrainer) | |
| - input_ids: (B, T_max) | |
| - attention_mask: (B, T_max) | |
| - response_mask: (B, T_max) | |
| - rewards: (B,) | |
| # Channel 2 (SDPO hint-distill) — present when any example has error turns | |
| - ctx_teacher_input_ids: (B, T_max) | |
| - sdpo_loss_mask: (B, T_max), 1 at post-hint error-turn tokens | |
| # Channel 3 (trace-replay DPO) — present when any example has dpo_pairs | |
| - dpo_chosen_input_ids: (B', T_dpo) | |
| - dpo_chosen_response_mask: (B', T_dpo) | |
| - dpo_rejected_input_ids: (B', T_dpo) | |
| - dpo_rejected_response_mask: (B', T_dpo) | |
| # ref_logprobs are NOT computed here — the trainer's reference-policy | |
| # forward pass at training time produces them. | |
| """ | |
| tokenizer: TokenizerLike | |
| config: CollatorConfig = field(default_factory=CollatorConfig) | |
| def __call__(self, batch: Sequence[TraceExample]) -> dict[str, torch.Tensor]: | |
| out: dict[str, torch.Tensor] = {} | |
| # --- Channel 1: GRPO core fields --- | |
| out.update(self._build_grpo_fields(batch)) | |
| # --- Channel 2: SDPO hint-distill fields --- | |
| if self.config.enable_sdpo: | |
| sdpo = self._build_sdpo_fields(batch) | |
| if sdpo is not None: | |
| out.update(sdpo) | |
| # Reconcile student vs teacher shapes for compose_loss's | |
| # `student_logits.shape == teacher_logits.shape` gate. | |
| # | |
| # CRITICAL: hint injection adds tokens IN THE MIDDLE of | |
| # the teacher sequence (before the recovery turn). The | |
| # recovery turn lives at teacher positions | |
| # [hint_end .. hint_end + len(recovery)] but at student | |
| # positions [recovery_start .. recovery_start + len(recovery)] | |
| # where recovery_start < hint_end. Right-padding student | |
| # to teacher length WOULD ALIAS PAD TOKENS to the | |
| # sdpo_loss_mask region — gives a degenerate ~ln(2) | |
| # JSD signal that LOOKS healthy but is meaningless | |
| # (Gemini W19 R1 BLOCKER). | |
| # | |
| # Correct alignment requires walking turns in lock-step, | |
| # padding student WHERE the teacher has hint tokens so | |
| # post-hint positions land at the same indices in both. | |
| # That reshape lives in `_build_aligned_student_for_sdpo`. | |
| aligned = self._build_aligned_student_for_sdpo( | |
| batch, teacher_len=out["ctx_teacher_input_ids"].shape[1] | |
| ) | |
| if aligned is not None: | |
| out["input_ids"] = aligned["input_ids"] | |
| out["attention_mask"] = aligned["attention_mask"] | |
| out["response_mask"] = aligned["response_mask"] | |
| # --- ADR-011: emit SDPO alignment indices --- | |
| # The loss (strict mode, default) requires explicit per-token | |
| # alignment indices into each sequence so the JSD compares | |
| # corresponding post-hint response tokens. Derive them from the | |
| # already-aligned masks: teacher positions from sdpo_loss_mask==1, | |
| # student positions from response_mask==1. Both masks are placed | |
| # on content tokens by _build_chat_aligned_mask, and the | |
| # placeholder-system-message trick makes them land at the SAME | |
| # logical token, so at valid positions s_idx == t_idx. | |
| if "sdpo_loss_mask" in out and "response_mask" in out: | |
| t_mask = out["sdpo_loss_mask"] == 1 | |
| s_mask = out["response_mask"] == 1 | |
| t_idx, t_valid = _mask_to_padded_indices(t_mask) | |
| s_idx, s_valid = _mask_to_padded_indices(s_mask) | |
| out["student_response_idx"] = s_idx | |
| out["teacher_response_idx"] = t_idx | |
| out["student_response_valid"] = s_valid | |
| out["teacher_response_valid"] = t_valid | |
| # --- Channel 3: trace-replay DPO fields --- | |
| if self.config.enable_replay_dpo: | |
| dpo = self._build_dpo_fields(batch) | |
| if dpo is not None: | |
| out.update(dpo) | |
| return out | |
| # ---------------------------------------------------------------------- | |
| # Channel 1: standard GRPO inputs | |
| # ---------------------------------------------------------------------- | |
| def _build_grpo_fields(self, batch: Sequence[TraceExample]) -> dict[str, torch.Tensor]: | |
| input_ids_list: list[list[int]] = [] | |
| response_masks_list: list[list[int]] = [] | |
| rewards: list[float] = [] | |
| for ex in batch: | |
| ids, resp_mask = self._tokenize_trace(ex["turns"]) | |
| input_ids_list.append(ids) | |
| response_masks_list.append(resp_mask) | |
| rewards.append(float(ex.get(self.config.rlvr_reward_key, 0.0))) | |
| max_len = min(self.config.max_seq_len, max(len(s) for s in input_ids_list)) | |
| input_ids = torch.tensor( | |
| [_pad_or_truncate(s, max_len, self.config.pad_token_id) for s in input_ids_list], | |
| dtype=torch.long, | |
| ) | |
| response_mask = torch.tensor( | |
| [_pad_or_truncate(m, max_len, 0) for m in response_masks_list], | |
| dtype=torch.long, | |
| ) | |
| attention_mask = (input_ids != self.config.pad_token_id).long() | |
| return { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask, | |
| "response_mask": response_mask, | |
| "rewards": torch.tensor(rewards, dtype=torch.float), | |
| } | |
| # ---------------------------------------------------------------------- | |
| # Channel 2: SDPO hint-distill inputs | |
| # ---------------------------------------------------------------------- | |
| def _build_sdpo_fields( | |
| self, batch: Sequence[TraceExample] | |
| ) -> dict[str, torch.Tensor] | None: | |
| """Build ctx_teacher + sdpo_loss_mask, aligned to ctx_student length.""" | |
| if self.config.hint_generator is None: | |
| return None # nothing to do without a hint generator | |
| ctx_teacher_list: list[list[int]] = [] | |
| sdpo_mask_list: list[list[int]] = [] | |
| any_error_sites = False | |
| for ex in batch: | |
| ctx_teacher_ids, sdpo_mask, has_errors = self._build_hint_injected_trace(ex["turns"]) | |
| ctx_teacher_list.append(ctx_teacher_ids) | |
| sdpo_mask_list.append(sdpo_mask) | |
| any_error_sites = any_error_sites or has_errors | |
| if not any_error_sites: | |
| return None # batch has no error sites — SDPO is a no-op for this step | |
| max_len = min(self.config.max_seq_len, max(len(s) for s in ctx_teacher_list)) | |
| ctx_teacher = torch.tensor( | |
| [_pad_or_truncate(s, max_len, self.config.pad_token_id) for s in ctx_teacher_list], | |
| dtype=torch.long, | |
| ) | |
| sdpo_mask = torch.tensor( | |
| [_pad_or_truncate(m, max_len, self.config.ignore_index) for m in sdpo_mask_list], | |
| dtype=torch.long, | |
| ) | |
| return { | |
| "ctx_teacher_input_ids": ctx_teacher, | |
| "sdpo_loss_mask": sdpo_mask, | |
| } | |
| def _build_hint_injected_trace( | |
| self, turns: Sequence[TraceTurn] | |
| ) -> tuple[list[int], list[int], bool]: | |
| """Walk the trace; at each error-turn boundary, inject a hint and mark | |
| the post-hint tokens as in-loss. | |
| Returns: | |
| (ctx_teacher_ids, sdpo_loss_mask, any_error_sites) | |
| """ | |
| if self.config.hint_generator is None: | |
| # Caller responsibility — short-circuited by the dispatch. | |
| empty: list[int] = [] | |
| return empty, empty, False | |
| teacher_messages: list[dict] = [] | |
| teacher_loss_segments: list[tuple[bool, str]] = [] # (is_loss_segment, text) | |
| any_errors = False | |
| for turn in turns: | |
| if _is_error_turn(turn): | |
| hint_text = self.config.hint_generator( | |
| turn.get("tool_error", "unknown"), | |
| turn.get("error_meta", {}), | |
| ) | |
| # Only treat this as an SDPO error site when BOTH a hint was | |
| # produced AND the recovery turn has content to distill against. | |
| # Real Claude Code traces frequently have empty recovery content | |
| # — e.g. when strip_thinking=True nukes a recovery turn that was | |
| # pure [THINKING] reasoning (observed ~67% of real error sites). | |
| # Injecting a hint with no recovery content produces an | |
| # all-ignore_index mask: a zero-signal SDPO row that wastes a | |
| # forward pass and silently dilutes the channel. Skip it; the | |
| # turn then falls through to the (also-skipped) empty passthrough. | |
| if hint_text and turn.get("content"): | |
| any_errors = True | |
| recovery_content = turn.get("content") or "" | |
| # Inject hint as a system-style addendum BEFORE the assistant's response | |
| teacher_messages.append({"role": "system", "content": hint_text}) | |
| teacher_loss_segments.append((False, hint_text)) | |
| teacher_messages.append({ | |
| "role": turn.get("role", "assistant"), | |
| "content": recovery_content, | |
| }) | |
| teacher_loss_segments.append((True, recovery_content)) # post-hint tokens = loss | |
| continue | |
| # Non-error turn (or hint generator returned None / empty recovery) — passthrough | |
| content = turn.get("content") | |
| if content: | |
| teacher_messages.append({ | |
| "role": turn.get("role", "assistant"), | |
| "content": content, | |
| }) | |
| teacher_loss_segments.append((False, content)) | |
| # Tokenize the full teacher conversation | |
| teacher_ids = self._tokenize_messages(teacher_messages) | |
| # Build the per-token loss mask ALIGNED to the chat-template tokenization | |
| # (Wave 20 fix). The old path tokenized each segment's raw text in | |
| # isolation and concatenated; that ignored the scaffolding tokens | |
| # (<|im_start|>{role}\n ... <|im_end|>\n, BOS, etc.) that | |
| # apply_chat_template inserts, so mask positions drifted left of the | |
| # real content tokens — the residual ~33% misalignment documented in | |
| # the Wave 19 production audit. `_build_chat_aligned_mask` derives the | |
| # mask from per-message apply_chat_template deltas instead, so loss | |
| # bits land exactly on content tokens regardless of template markers. | |
| sdpo_mask = self._build_chat_aligned_mask( | |
| teacher_messages, teacher_loss_segments, teacher_ids | |
| ) | |
| # Truncate mask to teacher_ids length if tokenization round-tripped slightly differently | |
| sdpo_mask = sdpo_mask[: len(teacher_ids)] | |
| if len(sdpo_mask) < len(teacher_ids): | |
| sdpo_mask = sdpo_mask + [self.config.ignore_index] * (len(teacher_ids) - len(sdpo_mask)) | |
| return teacher_ids, sdpo_mask, any_errors | |
| def _build_aligned_student_for_sdpo( | |
| self, | |
| batch: Sequence[TraceExample], | |
| teacher_len: int, | |
| ) -> dict[str, torch.Tensor] | None: | |
| """Build student input_ids that align position-by-position with the | |
| hint-injected teacher sequence. | |
| For SDPO the gate `student_logits.shape == teacher_logits.shape` | |
| must pass AND the sdpo_loss_mask positions (built relative to the | |
| teacher) must point to the SAME content tokens in the student. | |
| Strategy: build student MESSAGES that mirror the teacher messages | |
| EXCEPT the hint system-message is replaced with a placeholder | |
| system-message whose `content` tokenizes to the same length as | |
| the hint. Both sides go through `apply_chat_template`, so the | |
| chat-template markers (<|im_start|>system\\n, <|im_end|>\\n, etc.) | |
| are added identically. The recovery-turn tokens then land at the | |
| same indices in both tensors and `sdpo_loss_mask` selects | |
| identical content positions. | |
| Returns None if no error sites exist. | |
| """ | |
| if self.config.hint_generator is None: | |
| return None | |
| student_ids_list: list[list[int]] = [] | |
| response_mask_list: list[list[int]] = [] | |
| any_errors = False | |
| for ex in batch: | |
| ids, resp_mask, has_errors = self._build_aligned_student_one(ex["turns"]) | |
| student_ids_list.append(ids) | |
| response_mask_list.append(resp_mask) | |
| any_errors = any_errors or has_errors | |
| if not any_errors: | |
| return None | |
| max_len = teacher_len # match teacher exactly | |
| pad_id = self.config.pad_token_id | |
| input_ids = torch.tensor( | |
| [_pad_or_truncate(s, max_len, pad_id) for s in student_ids_list], | |
| dtype=torch.long, | |
| ) | |
| response_mask = torch.tensor( | |
| [_pad_or_truncate(m, max_len, 0) for m in response_mask_list], | |
| dtype=torch.long, | |
| ) | |
| attention_mask = (input_ids != pad_id).long() | |
| return { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask, | |
| "response_mask": response_mask, | |
| } | |
| def _make_placeholder_for_hint_length(self, hint_text: str) -> str: | |
| """Build a placeholder string whose tokenization length matches hint_text's. | |
| We start with a short repeating filler ('. ') and grow it until the | |
| tokenized length matches or exceeds the hint's. If we overshoot, | |
| we trim. This is necessarily approximate at the character-to-token | |
| boundary; we accept ±1 token tolerance and pad/truncate the final | |
| student tensor to match teacher length. | |
| """ | |
| target_len = len(self._tokenize_text(hint_text)) | |
| if target_len == 0: | |
| return "" | |
| # Use a content-free placeholder that tokenizes predictably. | |
| placeholder = ". " * target_len | |
| ph_len = len(self._tokenize_text(placeholder)) | |
| # Trim or extend via binary-search-ish refinement (at most 6 iters). | |
| for _ in range(6): | |
| if ph_len == target_len: | |
| break | |
| if ph_len > target_len: | |
| # Trim char-by-char | |
| while placeholder and ph_len > target_len: | |
| placeholder = placeholder[:-1] | |
| ph_len = len(self._tokenize_text(placeholder)) | |
| else: | |
| placeholder = placeholder + ". " | |
| ph_len = len(self._tokenize_text(placeholder)) | |
| return placeholder | |
| def _build_aligned_student_one( | |
| self, turns: Sequence[TraceTurn] | |
| ) -> tuple[list[int], list[int], bool]: | |
| """Walk one trace's turns, building a STUDENT messages list that | |
| mirrors the TEACHER messages list except hint system-messages are | |
| replaced with placeholder system-messages of the same token length. | |
| Returns (student_ids, response_mask, any_error_sites). | |
| """ | |
| if self.config.hint_generator is None: | |
| return [], [], False | |
| student_messages: list[dict] = [] | |
| # Track per-message (is_response_segment, text_for_response_mask) | |
| # We build response_mask via segment tokenization, same pattern as | |
| # teacher's _build_segment_mask, so the lengths match. | |
| student_loss_segments: list[tuple[bool, str]] = [] | |
| any_errors = False | |
| for turn in turns: | |
| if _is_error_turn(turn): | |
| hint_text = self.config.hint_generator( | |
| turn.get("tool_error", "unknown"), | |
| turn.get("error_meta", {}), | |
| ) | |
| # MUST mirror the teacher path's condition exactly (hint AND | |
| # recovery content) or the student/teacher message lists diverge | |
| # and the SDPO shape-match gate breaks. Empty-recovery error | |
| # turns are skipped on both sides — see _build_hint_injected_trace. | |
| if hint_text and turn.get("content"): | |
| any_errors = True | |
| recovery_content = turn.get("content") or "" | |
| placeholder = self._make_placeholder_for_hint_length(hint_text) | |
| # Student gets a placeholder system-msg at the SAME slot | |
| # the teacher gets the hint system-msg. | |
| student_messages.append({"role": "system", "content": placeholder}) | |
| student_loss_segments.append((False, placeholder)) | |
| student_messages.append({ | |
| "role": turn.get("role", "assistant"), | |
| "content": recovery_content, | |
| }) | |
| is_assistant = turn.get("role") == "assistant" | |
| student_loss_segments.append((is_assistant, recovery_content)) | |
| continue | |
| content = turn.get("content") | |
| if content: | |
| student_messages.append({ | |
| "role": turn.get("role", "assistant"), | |
| "content": content, | |
| }) | |
| is_assistant = turn.get("role") == "assistant" | |
| student_loss_segments.append((is_assistant, content)) | |
| # Tokenize the full student conversation via apply_chat_template | |
| # (mirrors teacher's path so chat-template markers are identical). | |
| student_ids = self._tokenize_messages(student_messages) | |
| # Build response mask ALIGNED to the chat-template tokenization (Wave 20 | |
| # fix — same drift bug as the teacher sdpo_mask path). We derive the | |
| # mask from per-message apply_chat_template deltas so 1-bits land on | |
| # the assistant content tokens exactly, not shifted by the template | |
| # scaffolding. `_build_chat_aligned_mask` emits 1 for loss segments and | |
| # ignore_index for the rest; we remap ignore_index -> 0 because the | |
| # response_mask convention here is 1=in-response, 0=not. | |
| raw_mask = self._build_chat_aligned_mask( | |
| student_messages, student_loss_segments, student_ids | |
| ) | |
| resp_mask = [1 if v == 1 else 0 for v in raw_mask] | |
| # Pad/truncate response_mask to student_ids length (same as teacher path). | |
| resp_mask = resp_mask[: len(student_ids)] | |
| if len(resp_mask) < len(student_ids): | |
| resp_mask = resp_mask + [0] * (len(student_ids) - len(resp_mask)) | |
| return student_ids, resp_mask, any_errors | |
| def _build_segment_mask( | |
| self, segments: Sequence[tuple[bool, str]] | |
| ) -> list[int]: | |
| """For each (is_loss, text) segment, tokenize and emit per-token mask values. | |
| Loss-active tokens get 1; non-loss tokens get -100 (ignore_index). | |
| NOTE (Wave 20): this naive per-segment concatenation IGNORES the | |
| chat-template scaffolding that `apply_chat_template` inserts around | |
| each message, so the resulting mask drifts out of alignment with a | |
| sequence produced via `_tokenize_messages`. It is retained only for | |
| the degenerate fallback inside `_build_chat_aligned_mask` and for | |
| callers that build sequences via raw segment concatenation (no chat | |
| template). The SDPO/response-mask paths now use | |
| `_build_chat_aligned_mask` instead. | |
| """ | |
| out: list[int] = [] | |
| for is_loss, text in segments: | |
| seg_ids = self._tokenize_text(text) | |
| mask_value = 1 if is_loss else self.config.ignore_index | |
| out.extend([mask_value] * len(seg_ids)) | |
| return out | |
| def _find_subseq(haystack: list[int], needle: list[int], start: int = 0) -> int: | |
| """Return the index where ``needle`` first occurs in ``haystack`` at or | |
| after ``start``, or -1 if absent. Linear scan (spans are short).""" | |
| if not needle: | |
| return start | |
| n, m = len(haystack), len(needle) | |
| for i in range(start, n - m + 1): | |
| if haystack[i:i + m] == needle: | |
| return i | |
| return -1 | |
| def _build_chat_aligned_mask( | |
| self, | |
| messages: Sequence[dict], | |
| segments: Sequence[tuple[bool, str]], | |
| full_ids: list[int], | |
| ) -> list[int]: | |
| """Build a per-token loss mask aligned to a chat-template tokenization. | |
| The caller builds ``messages`` and ``segments`` in lockstep — element | |
| ``k`` of each describes the same logical chunk, where ``segments[k] = | |
| (is_loss, content_text)`` and ``messages[k] = {role, content}``. | |
| We need a mask over ``full_ids = apply_chat_template(messages)`` whose | |
| 1-bits sit exactly on the content tokens of loss segments. The hard | |
| part is that ``apply_chat_template`` inserts role/BOS/EOS scaffolding | |
| between and around messages, so the naive ``_build_segment_mask`` | |
| (which tokenizes each content string in isolation and concatenates) | |
| drifts: its k-th block of mask bits lands at the wrong offset because | |
| all the preceding scaffolding tokens are unaccounted for. | |
| Algorithm — per-message prefix deltas: | |
| prev_len = len(apply_chat_template(messages[:k])) | |
| cur_len = len(apply_chat_template(messages[:k+1])) | |
| # message k occupies full_ids[prev_len : cur_len] (content + its | |
| # own scaffolding). Locate the content token run inside that span | |
| # by subsequence match against the content's standalone | |
| # tokenization, mark THOSE positions with the segment value and | |
| # leave the scaffolding as ignore_index. | |
| Falls back gracefully: | |
| * If the tokenizer has no usable chat template (stub / no template), | |
| ``_tokenize_messages`` returns a plain concatenation and the prefix | |
| deltas equal the raw content token counts — so the content | |
| subsequence match is trivially the whole span and the result | |
| matches ``_build_segment_mask`` exactly (stub tests stay green). | |
| * If a content run can't be located inside its span (rare tokenizer | |
| merges across the content/scaffolding boundary), we mark the whole | |
| message span with the segment value when it is a loss segment, so | |
| we never silently drop SDPO signal — we over-include by at most a | |
| couple scaffolding tokens rather than misalign. | |
| """ | |
| mask = [self.config.ignore_index] * len(full_ids) | |
| prev_len = 0 | |
| search_from = 0 | |
| for k, msg in enumerate(messages): | |
| prefix_ids = self._tokenize_messages(list(messages[: k + 1])) | |
| cur_len = len(prefix_ids) | |
| span_start, span_end = prev_len, cur_len | |
| prev_len = cur_len | |
| if span_end <= span_start: | |
| continue | |
| is_loss = segments[k][0] if k < len(segments) else False | |
| content = segments[k][1] if k < len(segments) else msg.get("content", "") | |
| if not is_loss: | |
| search_from = span_end | |
| continue | |
| # Loss segment: mark only the content tokens within the span. | |
| content_ids = self._tokenize_text(content) | |
| # Search for the content run inside this message's span. Anchor the | |
| # search at span_start so we don't match content from a later msg. | |
| idx = self._find_subseq(full_ids[:span_end], content_ids, start=max(span_start, search_from)) | |
| if idx != -1 and idx >= span_start: | |
| for p in range(idx, min(idx + len(content_ids), span_end)): | |
| mask[p] = 1 | |
| search_from = idx + len(content_ids) | |
| else: | |
| # Fallback: couldn't locate the content run (tokenizer merged | |
| # the content/scaffolding boundary). Mark the whole span as | |
| # loss rather than drop the SDPO signal entirely. Over-includes | |
| # at most the message's own scaffolding tokens. | |
| for p in range(span_start, span_end): | |
| mask[p] = 1 | |
| search_from = span_end | |
| return mask | |
| # ---------------------------------------------------------------------- | |
| # Channel 3: trace-replay DPO inputs | |
| # ---------------------------------------------------------------------- | |
| def _build_dpo_fields( | |
| self, batch: Sequence[TraceExample] | |
| ) -> dict[str, torch.Tensor] | None: | |
| """Tokenize chosen/rejected pairs from teacher disagreement. | |
| DPO accounting requires: | |
| - chosen_input_ids = prompt + chosen_response | |
| - rejected_input_ids = prompt + rejected_response | |
| - response_masks indicating which tokens are response (loss-bearing) vs prompt (no loss) | |
| """ | |
| all_chosen: list[list[int]] = [] | |
| all_rejected: list[list[int]] = [] | |
| all_chosen_resp_mask: list[list[int]] = [] | |
| all_rejected_resp_mask: list[list[int]] = [] | |
| for ex in batch: | |
| for pair in ex.get("dpo_pairs") or []: | |
| prompt_msgs = pair.get("state_messages", []) | |
| prompt_ids = self._tokenize_messages(prompt_msgs) | |
| chosen_ids = self._tokenize_text(pair["chosen"]) | |
| rejected_ids = self._tokenize_text(pair["rejected"]) | |
| chosen_full = prompt_ids + chosen_ids | |
| rejected_full = prompt_ids + rejected_ids | |
| # response_mask is 0 over prompt, 1 over response | |
| chosen_mask = [0] * len(prompt_ids) + [1] * len(chosen_ids) | |
| rejected_mask = [0] * len(prompt_ids) + [1] * len(rejected_ids) | |
| all_chosen.append(chosen_full) | |
| all_rejected.append(rejected_full) | |
| all_chosen_resp_mask.append(chosen_mask) | |
| all_rejected_resp_mask.append(rejected_mask) | |
| if not all_chosen: | |
| return None # no DPO pairs in this batch | |
| cap = self.config.max_dpo_seq_len | |
| max_len = min(cap, max(len(s) for s in (*all_chosen, *all_rejected))) | |
| return { | |
| "dpo_chosen_input_ids": torch.tensor( | |
| [_pad_or_truncate(s, max_len, self.config.pad_token_id) for s in all_chosen], | |
| dtype=torch.long, | |
| ), | |
| "dpo_chosen_response_mask": torch.tensor( | |
| [_pad_or_truncate(m, max_len, 0) for m in all_chosen_resp_mask], | |
| dtype=torch.long, | |
| ), | |
| "dpo_rejected_input_ids": torch.tensor( | |
| [_pad_or_truncate(s, max_len, self.config.pad_token_id) for s in all_rejected], | |
| dtype=torch.long, | |
| ), | |
| "dpo_rejected_response_mask": torch.tensor( | |
| [_pad_or_truncate(m, max_len, 0) for m in all_rejected_resp_mask], | |
| dtype=torch.long, | |
| ), | |
| } | |
| # ---------------------------------------------------------------------- | |
| # Tokenization helpers | |
| # ---------------------------------------------------------------------- | |
| def _tokenize_trace(self, turns: Sequence[TraceTurn]) -> tuple[list[int], list[int]]: | |
| """Tokenize an entire trace; return (ids, response_mask). | |
| response_mask = 1 over assistant turns (those are the loss-bearing tokens | |
| for GRPO), 0 over user/tool turns (prompt context). | |
| """ | |
| all_ids: list[int] = [] | |
| resp_mask: list[int] = [] | |
| for turn in turns: | |
| if not turn.get("content"): | |
| continue | |
| ids = self._tokenize_text(turn["content"]) | |
| mask_value = 1 if turn.get("role") == "assistant" else 0 | |
| all_ids.extend(ids) | |
| resp_mask.extend([mask_value] * len(ids)) | |
| return all_ids, resp_mask | |
| def _tokenize_text(self, text: str) -> list[int]: | |
| """Tokenize plain text via the tokenizer's __call__.""" | |
| result = self.tokenizer(text, add_special_tokens=False) | |
| ids = result["input_ids"] | |
| if hasattr(ids, "tolist"): | |
| ids = ids.tolist() | |
| # HF tokenizers often return list[list[int]] when batch-shaped; flatten if so | |
| if ids and isinstance(ids[0], list): | |
| ids = ids[0] | |
| return list(ids) | |
| def _tokenize_messages(self, messages: Sequence[dict]) -> list[int]: | |
| """Tokenize a chat-formatted list of messages. | |
| Tries apply_chat_template first; falls back to concatenated content if not available. | |
| NOTE: HF tokenizers' `apply_chat_template(tokenize=True)` is not | |
| consistently typed across families. Some return `list[int]`, others | |
| a `BatchEncoding` (a dict-like with `input_ids` key) — Qwen2.5 | |
| returns the latter. Handle both shapes here. | |
| """ | |
| if not messages: | |
| return [] | |
| try: | |
| raw = self.tokenizer.apply_chat_template( | |
| list(messages), tokenize=True, add_generation_prompt=False | |
| ) | |
| except (AttributeError, NotImplementedError, TypeError): | |
| # Stub tokenizer or no chat template defined — fall back to concatenated content | |
| text = "\n".join(m.get("content", "") for m in messages) | |
| return self._tokenize_text(text) | |
| # BatchEncoding (Qwen2.5 etc.): extract input_ids and unwrap if batched. | |
| if hasattr(raw, "keys") and "input_ids" in raw: | |
| ids = raw["input_ids"] | |
| else: | |
| ids = raw | |
| if hasattr(ids, "tolist"): | |
| ids = ids.tolist() | |
| # If we got list[list[int]] (batch shape), unwrap the single example. | |
| if ids and isinstance(ids[0], list): | |
| ids = ids[0] | |
| return list(ids) | |
| __all__ = [ | |
| "ComposerDataCollator", | |
| "CollatorConfig", | |
| "TraceTurn", | |
| "TraceExample", | |
| "TokenizerLike", | |
| ] | |