"""Validate the full ingestion -> adapter -> collator -> SDPO data path against REAL Claude Code session logs, and report the live SDPO alignment ratio. Why this exists --------------- The synthetic fixture in `spikes/007-real-trace-ingestion/fixtures/` proves the pipeline works on hand-built data. This script proves it on REAL traces — long tool outputs, multi-block content, thinking blocks, genuinely weird tool errors — which is where the Wave 19 chat-template drift bug (residual ~33% misalignment) actually bit. Wave 21's `_build_chat_aligned_mask` fix is verified here at the population level. What it measures ---------------- * ingestion yield (states emitted, error sites detected) * structural vs string-only error flagging (the Wave 21 TOOL_ERROR_TAG fix — structural should dominate; string-only should be ~0) * SDPO alignment ratio: fraction of in-loss `sdpo_loss_mask` positions where student token id == teacher token id. ~100% means the mask lands exactly on content tokens; <95% means chat-template drift has regressed. Usage ----- python examples/validate_real_trace_alignment/run.py \ [--projects-dir ~/.claude/projects] \ [--max-sessions 8] [--model Qwen/Qwen2.5-0.5B-Instruct] Requires a real chat-template tokenizer (transformers + a cached/instruct model) and at least one local Claude Code session containing `is_error: true`. Exits 0 on PASS (>=95% alignment), 1 on FAIL, 2 if no error-bearing sessions were found. """ from __future__ import annotations import argparse import os import sys import traceback from pathlib import Path def _discover_error_sessions(projects_dir: Path, limit: int) -> list[Path]: """Find session JSONLs that contain at least one is_error:true tool_result, skipping subagent (`agent-*`) files. Returns up to `limit`, smallest first (faster to process, still representative).""" hits: list[tuple[int, Path]] = [] for p in projects_dir.rglob("*.jsonl"): if p.name.startswith("agent-"): continue try: text = p.read_text(encoding="utf-8", errors="ignore") except OSError: continue if '"is_error":true' in text or '"is_error": true' in text: hits.append((p.stat().st_size, p)) hits.sort(key=lambda t: t[0]) return [p for _, p in hits[:limit]] def main() -> int: ap = argparse.ArgumentParser() ap.add_argument("--projects-dir", default=str(Path.home() / ".claude" / "projects")) ap.add_argument("--max-sessions", type=int, default=8) ap.add_argument("--model", default="Qwen/Qwen2.5-0.5B-Instruct") ap.add_argument("--pass-threshold", type=float, default=0.95) ap.add_argument( "--strip-thinking", action="store_true", help="Strip [THINKING] blocks. DEFAULT IS FALSE for SDPO: on real " "Claude Code traces the error-recovery turn is frequently pure " "thinking, so stripping it empties ~67%% of error sites and the SDPO " "channel sees no signal. Keep thinking for hint-distillation.", ) args = ap.parse_args() os.environ.setdefault("HF_HUB_OFFLINE", "1") os.environ.setdefault("TRANSFORMERS_OFFLINE", "1") from transformers import AutoTokenizer from composer_replication.ingestion import ClaudeCodeIngester from composer_replication.ingestion.trace_examples import ( TOOL_ERROR_TAG, claude_states_to_trace_examples, ) from composer_replication.trainer.data_collator import ( CollatorConfig, ComposerDataCollator, ) projects_dir = Path(args.projects_dir).expanduser() if not projects_dir.exists(): print(f"projects dir not found: {projects_dir}") return 2 sessions = _discover_error_sessions(projects_dir, args.max_sessions) if not sessions: print(f"no error-bearing sessions under {projects_dir}") return 2 tok = AutoTokenizer.from_pretrained(args.model) if not getattr(tok, "chat_template", None): print(f"{args.model} has no chat template; pick an -Instruct model") return 2 def hint_gen(kind, _meta): return f"Recover from the {kind}: re-check the path/args before retrying." cfg = CollatorConfig(hint_generator=hint_gen, enable_replay_dpo=False, max_seq_len=8192) collator = ComposerDataCollator(tokenizer=tok, config=cfg) tot_states = tot_err_sites = 0 tot_aligned = tot_inloss = 0 n_struct = n_string_only = 0 n_empty_recovery = n_nonempty_recovery = 0 sessions_with_sdpo = 0 crashes: list[tuple[str, str]] = [] for path in sessions: label = path.name[:18] try: ing = ClaudeCodeIngester(skip_sidechain=True, strip_thinking=args.strip_thinking) states = list(ing.ingest(path)) for s in states: for m in s["messages"]: if m.get("role") != "user": continue if m.get("tool_error") is True: n_struct += 1 elif isinstance(m.get("content"), str) and TOOL_ERROR_TAG in m["content"]: n_string_only += 1 examples = claude_states_to_trace_examples(states) # Count empty vs non-empty recovery content among detected error turns. for ex in examples: for t in ex["turns"]: if t.get("tool_error"): if (t.get("content") or "").strip(): n_nonempty_recovery += 1 else: n_empty_recovery += 1 err_examples = [ ex for ex in examples if any(t.get("tool_error") for t in ex["turns"]) ] tot_states += len(states) tot_err_sites += sum( sum(1 for t in ex["turns"] if t.get("tool_error")) for ex in examples ) if err_examples: batch = collator(err_examples[:4]) if "sdpo_loss_mask" in batch: sessions_with_sdpo += 1 s_in = batch["input_ids"] t_in = batch["ctx_teacher_input_ids"] m_in = batch["sdpo_loss_mask"] for row in range(s_in.shape[0]): il = m_in[row] == 1 if int(il.sum()) == 0: continue tot_aligned += int((s_in[row][il] == t_in[row][il]).sum().item()) tot_inloss += int(il.sum().item()) print(f" OK {label}: {len(states):4d} states, {len(err_examples):3d} err-examples") except Exception as e: # noqa: BLE001 — report-and-continue is the point crashes.append((path.name, repr(e))) print(f" CRASH {label}: {e!r}") traceback.print_exc() print("\n" + "=" * 64) print("REAL-TRACE PIPELINE VALIDATION") print("=" * 64) print(f" sessions processed: {len(sessions) - len(crashes)}/{len(sessions)}") print(f" total states emitted: {tot_states}") print(f" total error sites: {tot_err_sites}") print(f" structural-flagged users: {n_struct}") print(f" string-tag-only users: {n_string_only} (Wave 21: should be ~0)") _tot_recovery = n_empty_recovery + n_nonempty_recovery if _tot_recovery: pct_empty = 100 * n_empty_recovery / _tot_recovery print( f" empty-recovery sites: {n_empty_recovery}/{_tot_recovery} " f"({pct_empty:.0f}%) — these fire NO SDPO signal" ) if args.strip_thinking and pct_empty > 30: print( " ⚠ high empty-recovery rate with --strip-thinking: the recovery " "turns are pure [THINKING]. Re-run WITHOUT --strip-thinking to " "recover SDPO signal on these sites." ) print(f" sessions firing SDPO: {sessions_with_sdpo}") if not tot_inloss: print(" no in-loss positions measured — cannot assess alignment") return 2 ratio = tot_aligned / tot_inloss print(f" SDPO alignment (REAL): {tot_aligned}/{tot_inloss} = {100 * ratio:.1f}%") ok = ratio >= args.pass_threshold and not crashes print(f" RESULT: {'PASS ✅' if ok else 'FAIL ❌'} (threshold {100*args.pass_threshold:.0f}%)") if crashes: print(f" {len(crashes)} crash(es): {[c[0] for c in crashes]}") return 0 if ok else 1 if __name__ == "__main__": sys.exit(main())