Codeseys's picture
feat(wave-a): close ADR-011 (SDPO alignment indices) + ADR-012 (review findings)
d02d724
Raw
History Blame Contribute Delete
37.1 kB
"""data_collator.py — ComposerDataCollator: raw trace → trainer-ready batch.
Pipeline:
1. Take a frozen agentic trace + N-teacher DPO pairs (from spike 002 + 003).
2. Tokenize each turn of the trace.
3. Detect error sites (turns where a tool call failed) using a configurable predicate.
4. At each error site, build ctx_teacher = ctx_student with hint inserted at the error-turn boundary.
5. Pad/align ctx_student and ctx_teacher so SDPO logits compare position-by-position.
6. Construct sdpo_loss_mask = 1 at post-hint tokens of the error turn, 0 elsewhere.
7. Tokenize DPO chosen/rejected pairs, build response masks, leave ref_logprobs as a precompute step.
The output dict is what `ComposerReplicationTrainer._compute_loss` expects in its
`inputs` argument. See `trl_path/composer_trainer.py` for the consumer side.
Architectural note (verified via spike 005 test_opsd_loss.py): generalized_jsd_loss
requires student_logits and teacher_logits to have the SAME (B, T, V) shape — that's
why we pad/align here rather than inside the loss function. The post-hint section of
ctx_teacher must have token-by-token alignment with the same section of ctx_student.
"""
from __future__ import annotations
from collections.abc import Callable, Sequence
from dataclasses import dataclass, field
from typing import Any, TypedDict
import torch
# ---------------------------------------------------------------------------
# Types
# ---------------------------------------------------------------------------
class TraceTurn(TypedDict, total=False):
"""One turn of an agentic trace."""
role: str # "user" | "assistant" | "tool"
content: str # text or tool result
tool_call: dict | None # parsed tool call, if assistant-issued
tool_error: str | None # error_kind from the env, e.g. "tool_not_found"
error_meta: dict # extra info for hint generator (available_tools, etc.)
class TraceExample(TypedDict, total=False):
"""One training example: a (trace, optional DPO pairs) tuple."""
trace_id: str
turns: list[TraceTurn]
final_reward: float # RLVR scalar (test-pass etc.) at trajectory end
dpo_pairs: list[dict] | None # from teacher_replay.extract_dpo_pairs
# ---------------------------------------------------------------------------
# Tokenizer protocol — duck-typed against HF AutoTokenizer
# ---------------------------------------------------------------------------
class TokenizerLike:
"""Minimal protocol the collator needs from a tokenizer.
Compatible with HuggingFace `AutoTokenizer` instances (the typical case),
but also satisfiable by simpler stubs for unit-testing.
"""
pad_token_id: int
def __call__(self, text: str | list[str], **kwargs: Any) -> dict[str, list]: # pragma: no cover
...
def apply_chat_template( # pragma: no cover
self, messages: list[dict], **kwargs: Any
) -> str | list[int]:
...
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
@dataclass
class CollatorConfig:
"""Tunables for ComposerDataCollator."""
max_seq_len: int = 4096
max_dpo_seq_len: int = 2048
pad_token_id: int = 0
ignore_index: int = -100 # standard HF "ignore in loss" sentinel
# SDPO behavior
enable_sdpo: bool = True
hint_generator: Callable[[str, dict], str | None] | None = None
"""Callable error_kind, error_meta -> hint_text (or None to skip)."""
# Trace-replay DPO behavior
enable_replay_dpo: bool = True
# Reward shaping
rlvr_reward_key: str = "final_reward"
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _is_error_turn(turn: TraceTurn) -> bool:
"""Predicate: is this turn an error site that should trigger SDPO?"""
return turn.get("tool_error") is not None
def _build_chat_messages(turns: Sequence[TraceTurn]) -> list[dict]:
"""Convert TraceTurns to OpenAI-style chat messages for tokenizer.apply_chat_template."""
return [
{"role": t["role"], "content": t["content"]}
for t in turns if t.get("content")
]
def _pad_or_truncate(seq: list[int], target_len: int, pad_id: int) -> list[int]:
"""Right-pad with pad_id, or right-truncate to target_len."""
if len(seq) >= target_len:
return seq[:target_len]
return seq + [pad_id] * (target_len - len(seq))
def _mask_to_padded_indices(
mask: torch.Tensor, # (B, T) where nonzero/True == valid position
pad_sentinel: int = -1,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Convert a (B,T) bool/0-1 mask → (B,K_max) index tensor + (B,K_max) validity mask.
Each row's K valid positions are written left-aligned into ``idx``; the
ragged tail (rows with fewer than K_max positions) is padded with
``pad_sentinel`` (default -1). ``valid`` is True exactly where ``idx``
holds a real position.
ADR-011: the SDPO loss gathers post-hint response logits via these indices,
then masks the sentinel/padding positions so they contribute 0. K_max=0
(no valid positions anywhere) returns (B,0) tensors.
"""
B, T = mask.shape
bool_mask = mask != 0
counts = bool_mask.sum(dim=1).long() # (B,) — K per row
K_max = int(counts.max().item()) if counts.numel() else 0
if K_max == 0:
return (
torch.full((B, 0), pad_sentinel, dtype=torch.long, device=mask.device),
torch.zeros(B, 0, dtype=torch.bool, device=mask.device),
)
idx = torch.full((B, K_max), pad_sentinel, dtype=torch.long, device=mask.device)
valid = torch.zeros(B, K_max, dtype=torch.bool, device=mask.device)
# torch.nonzero on a 2D bool tensor yields (total_K, 2): (batch_idx, pos_idx),
# row-major so positions are already in per-row, ascending order.
nz = torch.nonzero(bool_mask, as_tuple=False) # (total_K, 2)
pos_idx = nz[:, 1]
offsets = torch.zeros(B + 1, dtype=torch.long, device=mask.device)
offsets[1:] = counts.cumsum(dim=0)
for b in range(B):
start, end = int(offsets[b].item()), int(offsets[b + 1].item())
k = end - start
if k > 0:
idx[b, :k] = pos_idx[start:end]
valid[b, :k] = True
return idx, valid
# ---------------------------------------------------------------------------
# The collator
# ---------------------------------------------------------------------------
@dataclass
class ComposerDataCollator:
"""Build trainer-ready batches from raw traces + optional DPO pairs.
Usage:
collator = ComposerDataCollator(tokenizer=tok, config=CollatorConfig())
batch = collator([trace_example_0, trace_example_1, ...])
# batch is a dict[str, torch.Tensor] ready for ComposerReplicationTrainer
The dict contains:
# Channel 1 (GRPO/RLVR — handled by the parent GRPOTrainer)
- input_ids: (B, T_max)
- attention_mask: (B, T_max)
- response_mask: (B, T_max)
- rewards: (B,)
# Channel 2 (SDPO hint-distill) — present when any example has error turns
- ctx_teacher_input_ids: (B, T_max)
- sdpo_loss_mask: (B, T_max), 1 at post-hint error-turn tokens
# Channel 3 (trace-replay DPO) — present when any example has dpo_pairs
- dpo_chosen_input_ids: (B', T_dpo)
- dpo_chosen_response_mask: (B', T_dpo)
- dpo_rejected_input_ids: (B', T_dpo)
- dpo_rejected_response_mask: (B', T_dpo)
# ref_logprobs are NOT computed here — the trainer's reference-policy
# forward pass at training time produces them.
"""
tokenizer: TokenizerLike
config: CollatorConfig = field(default_factory=CollatorConfig)
def __call__(self, batch: Sequence[TraceExample]) -> dict[str, torch.Tensor]:
out: dict[str, torch.Tensor] = {}
# --- Channel 1: GRPO core fields ---
out.update(self._build_grpo_fields(batch))
# --- Channel 2: SDPO hint-distill fields ---
if self.config.enable_sdpo:
sdpo = self._build_sdpo_fields(batch)
if sdpo is not None:
out.update(sdpo)
# Reconcile student vs teacher shapes for compose_loss's
# `student_logits.shape == teacher_logits.shape` gate.
#
# CRITICAL: hint injection adds tokens IN THE MIDDLE of
# the teacher sequence (before the recovery turn). The
# recovery turn lives at teacher positions
# [hint_end .. hint_end + len(recovery)] but at student
# positions [recovery_start .. recovery_start + len(recovery)]
# where recovery_start < hint_end. Right-padding student
# to teacher length WOULD ALIAS PAD TOKENS to the
# sdpo_loss_mask region — gives a degenerate ~ln(2)
# JSD signal that LOOKS healthy but is meaningless
# (Gemini W19 R1 BLOCKER).
#
# Correct alignment requires walking turns in lock-step,
# padding student WHERE the teacher has hint tokens so
# post-hint positions land at the same indices in both.
# That reshape lives in `_build_aligned_student_for_sdpo`.
aligned = self._build_aligned_student_for_sdpo(
batch, teacher_len=out["ctx_teacher_input_ids"].shape[1]
)
if aligned is not None:
out["input_ids"] = aligned["input_ids"]
out["attention_mask"] = aligned["attention_mask"]
out["response_mask"] = aligned["response_mask"]
# --- ADR-011: emit SDPO alignment indices ---
# The loss (strict mode, default) requires explicit per-token
# alignment indices into each sequence so the JSD compares
# corresponding post-hint response tokens. Derive them from the
# already-aligned masks: teacher positions from sdpo_loss_mask==1,
# student positions from response_mask==1. Both masks are placed
# on content tokens by _build_chat_aligned_mask, and the
# placeholder-system-message trick makes them land at the SAME
# logical token, so at valid positions s_idx == t_idx.
if "sdpo_loss_mask" in out and "response_mask" in out:
t_mask = out["sdpo_loss_mask"] == 1
s_mask = out["response_mask"] == 1
t_idx, t_valid = _mask_to_padded_indices(t_mask)
s_idx, s_valid = _mask_to_padded_indices(s_mask)
out["student_response_idx"] = s_idx
out["teacher_response_idx"] = t_idx
out["student_response_valid"] = s_valid
out["teacher_response_valid"] = t_valid
# --- Channel 3: trace-replay DPO fields ---
if self.config.enable_replay_dpo:
dpo = self._build_dpo_fields(batch)
if dpo is not None:
out.update(dpo)
return out
# ----------------------------------------------------------------------
# Channel 1: standard GRPO inputs
# ----------------------------------------------------------------------
def _build_grpo_fields(self, batch: Sequence[TraceExample]) -> dict[str, torch.Tensor]:
input_ids_list: list[list[int]] = []
response_masks_list: list[list[int]] = []
rewards: list[float] = []
for ex in batch:
ids, resp_mask = self._tokenize_trace(ex["turns"])
input_ids_list.append(ids)
response_masks_list.append(resp_mask)
rewards.append(float(ex.get(self.config.rlvr_reward_key, 0.0)))
max_len = min(self.config.max_seq_len, max(len(s) for s in input_ids_list))
input_ids = torch.tensor(
[_pad_or_truncate(s, max_len, self.config.pad_token_id) for s in input_ids_list],
dtype=torch.long,
)
response_mask = torch.tensor(
[_pad_or_truncate(m, max_len, 0) for m in response_masks_list],
dtype=torch.long,
)
attention_mask = (input_ids != self.config.pad_token_id).long()
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"response_mask": response_mask,
"rewards": torch.tensor(rewards, dtype=torch.float),
}
# ----------------------------------------------------------------------
# Channel 2: SDPO hint-distill inputs
# ----------------------------------------------------------------------
def _build_sdpo_fields(
self, batch: Sequence[TraceExample]
) -> dict[str, torch.Tensor] | None:
"""Build ctx_teacher + sdpo_loss_mask, aligned to ctx_student length."""
if self.config.hint_generator is None:
return None # nothing to do without a hint generator
ctx_teacher_list: list[list[int]] = []
sdpo_mask_list: list[list[int]] = []
any_error_sites = False
for ex in batch:
ctx_teacher_ids, sdpo_mask, has_errors = self._build_hint_injected_trace(ex["turns"])
ctx_teacher_list.append(ctx_teacher_ids)
sdpo_mask_list.append(sdpo_mask)
any_error_sites = any_error_sites or has_errors
if not any_error_sites:
return None # batch has no error sites — SDPO is a no-op for this step
max_len = min(self.config.max_seq_len, max(len(s) for s in ctx_teacher_list))
ctx_teacher = torch.tensor(
[_pad_or_truncate(s, max_len, self.config.pad_token_id) for s in ctx_teacher_list],
dtype=torch.long,
)
sdpo_mask = torch.tensor(
[_pad_or_truncate(m, max_len, self.config.ignore_index) for m in sdpo_mask_list],
dtype=torch.long,
)
return {
"ctx_teacher_input_ids": ctx_teacher,
"sdpo_loss_mask": sdpo_mask,
}
def _build_hint_injected_trace(
self, turns: Sequence[TraceTurn]
) -> tuple[list[int], list[int], bool]:
"""Walk the trace; at each error-turn boundary, inject a hint and mark
the post-hint tokens as in-loss.
Returns:
(ctx_teacher_ids, sdpo_loss_mask, any_error_sites)
"""
if self.config.hint_generator is None:
# Caller responsibility — short-circuited by the dispatch.
empty: list[int] = []
return empty, empty, False
teacher_messages: list[dict] = []
teacher_loss_segments: list[tuple[bool, str]] = [] # (is_loss_segment, text)
any_errors = False
for turn in turns:
if _is_error_turn(turn):
hint_text = self.config.hint_generator(
turn.get("tool_error", "unknown"),
turn.get("error_meta", {}),
)
# Only treat this as an SDPO error site when BOTH a hint was
# produced AND the recovery turn has content to distill against.
# Real Claude Code traces frequently have empty recovery content
# — e.g. when strip_thinking=True nukes a recovery turn that was
# pure [THINKING] reasoning (observed ~67% of real error sites).
# Injecting a hint with no recovery content produces an
# all-ignore_index mask: a zero-signal SDPO row that wastes a
# forward pass and silently dilutes the channel. Skip it; the
# turn then falls through to the (also-skipped) empty passthrough.
if hint_text and turn.get("content"):
any_errors = True
recovery_content = turn.get("content") or ""
# Inject hint as a system-style addendum BEFORE the assistant's response
teacher_messages.append({"role": "system", "content": hint_text})
teacher_loss_segments.append((False, hint_text))
teacher_messages.append({
"role": turn.get("role", "assistant"),
"content": recovery_content,
})
teacher_loss_segments.append((True, recovery_content)) # post-hint tokens = loss
continue
# Non-error turn (or hint generator returned None / empty recovery) — passthrough
content = turn.get("content")
if content:
teacher_messages.append({
"role": turn.get("role", "assistant"),
"content": content,
})
teacher_loss_segments.append((False, content))
# Tokenize the full teacher conversation
teacher_ids = self._tokenize_messages(teacher_messages)
# Build the per-token loss mask ALIGNED to the chat-template tokenization
# (Wave 20 fix). The old path tokenized each segment's raw text in
# isolation and concatenated; that ignored the scaffolding tokens
# (<|im_start|>{role}\n ... <|im_end|>\n, BOS, etc.) that
# apply_chat_template inserts, so mask positions drifted left of the
# real content tokens — the residual ~33% misalignment documented in
# the Wave 19 production audit. `_build_chat_aligned_mask` derives the
# mask from per-message apply_chat_template deltas instead, so loss
# bits land exactly on content tokens regardless of template markers.
sdpo_mask = self._build_chat_aligned_mask(
teacher_messages, teacher_loss_segments, teacher_ids
)
# Truncate mask to teacher_ids length if tokenization round-tripped slightly differently
sdpo_mask = sdpo_mask[: len(teacher_ids)]
if len(sdpo_mask) < len(teacher_ids):
sdpo_mask = sdpo_mask + [self.config.ignore_index] * (len(teacher_ids) - len(sdpo_mask))
return teacher_ids, sdpo_mask, any_errors
def _build_aligned_student_for_sdpo(
self,
batch: Sequence[TraceExample],
teacher_len: int,
) -> dict[str, torch.Tensor] | None:
"""Build student input_ids that align position-by-position with the
hint-injected teacher sequence.
For SDPO the gate `student_logits.shape == teacher_logits.shape`
must pass AND the sdpo_loss_mask positions (built relative to the
teacher) must point to the SAME content tokens in the student.
Strategy: build student MESSAGES that mirror the teacher messages
EXCEPT the hint system-message is replaced with a placeholder
system-message whose `content` tokenizes to the same length as
the hint. Both sides go through `apply_chat_template`, so the
chat-template markers (<|im_start|>system\\n, <|im_end|>\\n, etc.)
are added identically. The recovery-turn tokens then land at the
same indices in both tensors and `sdpo_loss_mask` selects
identical content positions.
Returns None if no error sites exist.
"""
if self.config.hint_generator is None:
return None
student_ids_list: list[list[int]] = []
response_mask_list: list[list[int]] = []
any_errors = False
for ex in batch:
ids, resp_mask, has_errors = self._build_aligned_student_one(ex["turns"])
student_ids_list.append(ids)
response_mask_list.append(resp_mask)
any_errors = any_errors or has_errors
if not any_errors:
return None
max_len = teacher_len # match teacher exactly
pad_id = self.config.pad_token_id
input_ids = torch.tensor(
[_pad_or_truncate(s, max_len, pad_id) for s in student_ids_list],
dtype=torch.long,
)
response_mask = torch.tensor(
[_pad_or_truncate(m, max_len, 0) for m in response_mask_list],
dtype=torch.long,
)
attention_mask = (input_ids != pad_id).long()
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"response_mask": response_mask,
}
def _make_placeholder_for_hint_length(self, hint_text: str) -> str:
"""Build a placeholder string whose tokenization length matches hint_text's.
We start with a short repeating filler ('. ') and grow it until the
tokenized length matches or exceeds the hint's. If we overshoot,
we trim. This is necessarily approximate at the character-to-token
boundary; we accept ±1 token tolerance and pad/truncate the final
student tensor to match teacher length.
"""
target_len = len(self._tokenize_text(hint_text))
if target_len == 0:
return ""
# Use a content-free placeholder that tokenizes predictably.
placeholder = ". " * target_len
ph_len = len(self._tokenize_text(placeholder))
# Trim or extend via binary-search-ish refinement (at most 6 iters).
for _ in range(6):
if ph_len == target_len:
break
if ph_len > target_len:
# Trim char-by-char
while placeholder and ph_len > target_len:
placeholder = placeholder[:-1]
ph_len = len(self._tokenize_text(placeholder))
else:
placeholder = placeholder + ". "
ph_len = len(self._tokenize_text(placeholder))
return placeholder
def _build_aligned_student_one(
self, turns: Sequence[TraceTurn]
) -> tuple[list[int], list[int], bool]:
"""Walk one trace's turns, building a STUDENT messages list that
mirrors the TEACHER messages list except hint system-messages are
replaced with placeholder system-messages of the same token length.
Returns (student_ids, response_mask, any_error_sites).
"""
if self.config.hint_generator is None:
return [], [], False
student_messages: list[dict] = []
# Track per-message (is_response_segment, text_for_response_mask)
# We build response_mask via segment tokenization, same pattern as
# teacher's _build_segment_mask, so the lengths match.
student_loss_segments: list[tuple[bool, str]] = []
any_errors = False
for turn in turns:
if _is_error_turn(turn):
hint_text = self.config.hint_generator(
turn.get("tool_error", "unknown"),
turn.get("error_meta", {}),
)
# MUST mirror the teacher path's condition exactly (hint AND
# recovery content) or the student/teacher message lists diverge
# and the SDPO shape-match gate breaks. Empty-recovery error
# turns are skipped on both sides — see _build_hint_injected_trace.
if hint_text and turn.get("content"):
any_errors = True
recovery_content = turn.get("content") or ""
placeholder = self._make_placeholder_for_hint_length(hint_text)
# Student gets a placeholder system-msg at the SAME slot
# the teacher gets the hint system-msg.
student_messages.append({"role": "system", "content": placeholder})
student_loss_segments.append((False, placeholder))
student_messages.append({
"role": turn.get("role", "assistant"),
"content": recovery_content,
})
is_assistant = turn.get("role") == "assistant"
student_loss_segments.append((is_assistant, recovery_content))
continue
content = turn.get("content")
if content:
student_messages.append({
"role": turn.get("role", "assistant"),
"content": content,
})
is_assistant = turn.get("role") == "assistant"
student_loss_segments.append((is_assistant, content))
# Tokenize the full student conversation via apply_chat_template
# (mirrors teacher's path so chat-template markers are identical).
student_ids = self._tokenize_messages(student_messages)
# Build response mask ALIGNED to the chat-template tokenization (Wave 20
# fix — same drift bug as the teacher sdpo_mask path). We derive the
# mask from per-message apply_chat_template deltas so 1-bits land on
# the assistant content tokens exactly, not shifted by the template
# scaffolding. `_build_chat_aligned_mask` emits 1 for loss segments and
# ignore_index for the rest; we remap ignore_index -> 0 because the
# response_mask convention here is 1=in-response, 0=not.
raw_mask = self._build_chat_aligned_mask(
student_messages, student_loss_segments, student_ids
)
resp_mask = [1 if v == 1 else 0 for v in raw_mask]
# Pad/truncate response_mask to student_ids length (same as teacher path).
resp_mask = resp_mask[: len(student_ids)]
if len(resp_mask) < len(student_ids):
resp_mask = resp_mask + [0] * (len(student_ids) - len(resp_mask))
return student_ids, resp_mask, any_errors
def _build_segment_mask(
self, segments: Sequence[tuple[bool, str]]
) -> list[int]:
"""For each (is_loss, text) segment, tokenize and emit per-token mask values.
Loss-active tokens get 1; non-loss tokens get -100 (ignore_index).
NOTE (Wave 20): this naive per-segment concatenation IGNORES the
chat-template scaffolding that `apply_chat_template` inserts around
each message, so the resulting mask drifts out of alignment with a
sequence produced via `_tokenize_messages`. It is retained only for
the degenerate fallback inside `_build_chat_aligned_mask` and for
callers that build sequences via raw segment concatenation (no chat
template). The SDPO/response-mask paths now use
`_build_chat_aligned_mask` instead.
"""
out: list[int] = []
for is_loss, text in segments:
seg_ids = self._tokenize_text(text)
mask_value = 1 if is_loss else self.config.ignore_index
out.extend([mask_value] * len(seg_ids))
return out
@staticmethod
def _find_subseq(haystack: list[int], needle: list[int], start: int = 0) -> int:
"""Return the index where ``needle`` first occurs in ``haystack`` at or
after ``start``, or -1 if absent. Linear scan (spans are short)."""
if not needle:
return start
n, m = len(haystack), len(needle)
for i in range(start, n - m + 1):
if haystack[i:i + m] == needle:
return i
return -1
def _build_chat_aligned_mask(
self,
messages: Sequence[dict],
segments: Sequence[tuple[bool, str]],
full_ids: list[int],
) -> list[int]:
"""Build a per-token loss mask aligned to a chat-template tokenization.
The caller builds ``messages`` and ``segments`` in lockstep — element
``k`` of each describes the same logical chunk, where ``segments[k] =
(is_loss, content_text)`` and ``messages[k] = {role, content}``.
We need a mask over ``full_ids = apply_chat_template(messages)`` whose
1-bits sit exactly on the content tokens of loss segments. The hard
part is that ``apply_chat_template`` inserts role/BOS/EOS scaffolding
between and around messages, so the naive ``_build_segment_mask``
(which tokenizes each content string in isolation and concatenates)
drifts: its k-th block of mask bits lands at the wrong offset because
all the preceding scaffolding tokens are unaccounted for.
Algorithm — per-message prefix deltas:
prev_len = len(apply_chat_template(messages[:k]))
cur_len = len(apply_chat_template(messages[:k+1]))
# message k occupies full_ids[prev_len : cur_len] (content + its
# own scaffolding). Locate the content token run inside that span
# by subsequence match against the content's standalone
# tokenization, mark THOSE positions with the segment value and
# leave the scaffolding as ignore_index.
Falls back gracefully:
* If the tokenizer has no usable chat template (stub / no template),
``_tokenize_messages`` returns a plain concatenation and the prefix
deltas equal the raw content token counts — so the content
subsequence match is trivially the whole span and the result
matches ``_build_segment_mask`` exactly (stub tests stay green).
* If a content run can't be located inside its span (rare tokenizer
merges across the content/scaffolding boundary), we mark the whole
message span with the segment value when it is a loss segment, so
we never silently drop SDPO signal — we over-include by at most a
couple scaffolding tokens rather than misalign.
"""
mask = [self.config.ignore_index] * len(full_ids)
prev_len = 0
search_from = 0
for k, msg in enumerate(messages):
prefix_ids = self._tokenize_messages(list(messages[: k + 1]))
cur_len = len(prefix_ids)
span_start, span_end = prev_len, cur_len
prev_len = cur_len
if span_end <= span_start:
continue
is_loss = segments[k][0] if k < len(segments) else False
content = segments[k][1] if k < len(segments) else msg.get("content", "")
if not is_loss:
search_from = span_end
continue
# Loss segment: mark only the content tokens within the span.
content_ids = self._tokenize_text(content)
# Search for the content run inside this message's span. Anchor the
# search at span_start so we don't match content from a later msg.
idx = self._find_subseq(full_ids[:span_end], content_ids, start=max(span_start, search_from))
if idx != -1 and idx >= span_start:
for p in range(idx, min(idx + len(content_ids), span_end)):
mask[p] = 1
search_from = idx + len(content_ids)
else:
# Fallback: couldn't locate the content run (tokenizer merged
# the content/scaffolding boundary). Mark the whole span as
# loss rather than drop the SDPO signal entirely. Over-includes
# at most the message's own scaffolding tokens.
for p in range(span_start, span_end):
mask[p] = 1
search_from = span_end
return mask
# ----------------------------------------------------------------------
# Channel 3: trace-replay DPO inputs
# ----------------------------------------------------------------------
def _build_dpo_fields(
self, batch: Sequence[TraceExample]
) -> dict[str, torch.Tensor] | None:
"""Tokenize chosen/rejected pairs from teacher disagreement.
DPO accounting requires:
- chosen_input_ids = prompt + chosen_response
- rejected_input_ids = prompt + rejected_response
- response_masks indicating which tokens are response (loss-bearing) vs prompt (no loss)
"""
all_chosen: list[list[int]] = []
all_rejected: list[list[int]] = []
all_chosen_resp_mask: list[list[int]] = []
all_rejected_resp_mask: list[list[int]] = []
for ex in batch:
for pair in ex.get("dpo_pairs") or []:
prompt_msgs = pair.get("state_messages", [])
prompt_ids = self._tokenize_messages(prompt_msgs)
chosen_ids = self._tokenize_text(pair["chosen"])
rejected_ids = self._tokenize_text(pair["rejected"])
chosen_full = prompt_ids + chosen_ids
rejected_full = prompt_ids + rejected_ids
# response_mask is 0 over prompt, 1 over response
chosen_mask = [0] * len(prompt_ids) + [1] * len(chosen_ids)
rejected_mask = [0] * len(prompt_ids) + [1] * len(rejected_ids)
all_chosen.append(chosen_full)
all_rejected.append(rejected_full)
all_chosen_resp_mask.append(chosen_mask)
all_rejected_resp_mask.append(rejected_mask)
if not all_chosen:
return None # no DPO pairs in this batch
cap = self.config.max_dpo_seq_len
max_len = min(cap, max(len(s) for s in (*all_chosen, *all_rejected)))
return {
"dpo_chosen_input_ids": torch.tensor(
[_pad_or_truncate(s, max_len, self.config.pad_token_id) for s in all_chosen],
dtype=torch.long,
),
"dpo_chosen_response_mask": torch.tensor(
[_pad_or_truncate(m, max_len, 0) for m in all_chosen_resp_mask],
dtype=torch.long,
),
"dpo_rejected_input_ids": torch.tensor(
[_pad_or_truncate(s, max_len, self.config.pad_token_id) for s in all_rejected],
dtype=torch.long,
),
"dpo_rejected_response_mask": torch.tensor(
[_pad_or_truncate(m, max_len, 0) for m in all_rejected_resp_mask],
dtype=torch.long,
),
}
# ----------------------------------------------------------------------
# Tokenization helpers
# ----------------------------------------------------------------------
def _tokenize_trace(self, turns: Sequence[TraceTurn]) -> tuple[list[int], list[int]]:
"""Tokenize an entire trace; return (ids, response_mask).
response_mask = 1 over assistant turns (those are the loss-bearing tokens
for GRPO), 0 over user/tool turns (prompt context).
"""
all_ids: list[int] = []
resp_mask: list[int] = []
for turn in turns:
if not turn.get("content"):
continue
ids = self._tokenize_text(turn["content"])
mask_value = 1 if turn.get("role") == "assistant" else 0
all_ids.extend(ids)
resp_mask.extend([mask_value] * len(ids))
return all_ids, resp_mask
def _tokenize_text(self, text: str) -> list[int]:
"""Tokenize plain text via the tokenizer's __call__."""
result = self.tokenizer(text, add_special_tokens=False)
ids = result["input_ids"]
if hasattr(ids, "tolist"):
ids = ids.tolist()
# HF tokenizers often return list[list[int]] when batch-shaped; flatten if so
if ids and isinstance(ids[0], list):
ids = ids[0]
return list(ids)
def _tokenize_messages(self, messages: Sequence[dict]) -> list[int]:
"""Tokenize a chat-formatted list of messages.
Tries apply_chat_template first; falls back to concatenated content if not available.
NOTE: HF tokenizers' `apply_chat_template(tokenize=True)` is not
consistently typed across families. Some return `list[int]`, others
a `BatchEncoding` (a dict-like with `input_ids` key) — Qwen2.5
returns the latter. Handle both shapes here.
"""
if not messages:
return []
try:
raw = self.tokenizer.apply_chat_template(
list(messages), tokenize=True, add_generation_prompt=False
)
except (AttributeError, NotImplementedError, TypeError):
# Stub tokenizer or no chat template defined — fall back to concatenated content
text = "\n".join(m.get("content", "") for m in messages)
return self._tokenize_text(text)
# BatchEncoding (Qwen2.5 etc.): extract input_ids and unwrap if batched.
if hasattr(raw, "keys") and "input_ids" in raw:
ids = raw["input_ids"]
else:
ids = raw
if hasattr(ids, "tolist"):
ids = ids.tolist()
# If we got list[list[int]] (batch shape), unwrap the single example.
if ids and isinstance(ids[0], list):
ids = ids[0]
return list(ids)
__all__ = [
"ComposerDataCollator",
"CollatorConfig",
"TraceTurn",
"TraceExample",
"TokenizerLike",
]