Codeseys's picture
Wave 21b: skip zero-signal SDPO on empty-recovery error turns + real-trace validation
d61036a
Raw
History Blame Contribute Delete
8.59 kB
"""Validate the full ingestion -> adapter -> collator -> SDPO data path against
REAL Claude Code session logs, and report the live SDPO alignment ratio.
Why this exists
---------------
The synthetic fixture in `spikes/007-real-trace-ingestion/fixtures/` proves the
pipeline works on hand-built data. This script proves it on REAL traces — long
tool outputs, multi-block content, thinking blocks, genuinely weird tool errors
— which is where the Wave 19 chat-template drift bug (residual ~33%
misalignment) actually bit. Wave 21's `_build_chat_aligned_mask` fix is verified
here at the population level.
What it measures
----------------
* ingestion yield (states emitted, error sites detected)
* structural vs string-only error flagging (the Wave 21 TOOL_ERROR_TAG fix —
structural should dominate; string-only should be ~0)
* SDPO alignment ratio: fraction of in-loss `sdpo_loss_mask` positions where
student token id == teacher token id. ~100% means the mask lands exactly on
content tokens; <95% means chat-template drift has regressed.
Usage
-----
python examples/validate_real_trace_alignment/run.py \
[--projects-dir ~/.claude/projects] \
[--max-sessions 8] [--model Qwen/Qwen2.5-0.5B-Instruct]
Requires a real chat-template tokenizer (transformers + a cached/instruct model)
and at least one local Claude Code session containing `is_error: true`. Exits 0
on PASS (>=95% alignment), 1 on FAIL, 2 if no error-bearing sessions were found.
"""
from __future__ import annotations
import argparse
import os
import sys
import traceback
from pathlib import Path
def _discover_error_sessions(projects_dir: Path, limit: int) -> list[Path]:
"""Find session JSONLs that contain at least one is_error:true tool_result,
skipping subagent (`agent-*`) files. Returns up to `limit`, smallest first
(faster to process, still representative)."""
hits: list[tuple[int, Path]] = []
for p in projects_dir.rglob("*.jsonl"):
if p.name.startswith("agent-"):
continue
try:
text = p.read_text(encoding="utf-8", errors="ignore")
except OSError:
continue
if '"is_error":true' in text or '"is_error": true' in text:
hits.append((p.stat().st_size, p))
hits.sort(key=lambda t: t[0])
return [p for _, p in hits[:limit]]
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--projects-dir", default=str(Path.home() / ".claude" / "projects"))
ap.add_argument("--max-sessions", type=int, default=8)
ap.add_argument("--model", default="Qwen/Qwen2.5-0.5B-Instruct")
ap.add_argument("--pass-threshold", type=float, default=0.95)
ap.add_argument(
"--strip-thinking",
action="store_true",
help="Strip [THINKING] blocks. DEFAULT IS FALSE for SDPO: on real "
"Claude Code traces the error-recovery turn is frequently pure "
"thinking, so stripping it empties ~67%% of error sites and the SDPO "
"channel sees no signal. Keep thinking for hint-distillation.",
)
args = ap.parse_args()
os.environ.setdefault("HF_HUB_OFFLINE", "1")
os.environ.setdefault("TRANSFORMERS_OFFLINE", "1")
from transformers import AutoTokenizer
from composer_replication.ingestion import ClaudeCodeIngester
from composer_replication.ingestion.trace_examples import (
TOOL_ERROR_TAG,
claude_states_to_trace_examples,
)
from composer_replication.trainer.data_collator import (
CollatorConfig,
ComposerDataCollator,
)
projects_dir = Path(args.projects_dir).expanduser()
if not projects_dir.exists():
print(f"projects dir not found: {projects_dir}")
return 2
sessions = _discover_error_sessions(projects_dir, args.max_sessions)
if not sessions:
print(f"no error-bearing sessions under {projects_dir}")
return 2
tok = AutoTokenizer.from_pretrained(args.model)
if not getattr(tok, "chat_template", None):
print(f"{args.model} has no chat template; pick an -Instruct model")
return 2
def hint_gen(kind, _meta):
return f"Recover from the {kind}: re-check the path/args before retrying."
cfg = CollatorConfig(hint_generator=hint_gen, enable_replay_dpo=False, max_seq_len=8192)
collator = ComposerDataCollator(tokenizer=tok, config=cfg)
tot_states = tot_err_sites = 0
tot_aligned = tot_inloss = 0
n_struct = n_string_only = 0
n_empty_recovery = n_nonempty_recovery = 0
sessions_with_sdpo = 0
crashes: list[tuple[str, str]] = []
for path in sessions:
label = path.name[:18]
try:
ing = ClaudeCodeIngester(skip_sidechain=True, strip_thinking=args.strip_thinking)
states = list(ing.ingest(path))
for s in states:
for m in s["messages"]:
if m.get("role") != "user":
continue
if m.get("tool_error") is True:
n_struct += 1
elif isinstance(m.get("content"), str) and TOOL_ERROR_TAG in m["content"]:
n_string_only += 1
examples = claude_states_to_trace_examples(states)
# Count empty vs non-empty recovery content among detected error turns.
for ex in examples:
for t in ex["turns"]:
if t.get("tool_error"):
if (t.get("content") or "").strip():
n_nonempty_recovery += 1
else:
n_empty_recovery += 1
err_examples = [
ex for ex in examples if any(t.get("tool_error") for t in ex["turns"])
]
tot_states += len(states)
tot_err_sites += sum(
sum(1 for t in ex["turns"] if t.get("tool_error")) for ex in examples
)
if err_examples:
batch = collator(err_examples[:4])
if "sdpo_loss_mask" in batch:
sessions_with_sdpo += 1
s_in = batch["input_ids"]
t_in = batch["ctx_teacher_input_ids"]
m_in = batch["sdpo_loss_mask"]
for row in range(s_in.shape[0]):
il = m_in[row] == 1
if int(il.sum()) == 0:
continue
tot_aligned += int((s_in[row][il] == t_in[row][il]).sum().item())
tot_inloss += int(il.sum().item())
print(f" OK {label}: {len(states):4d} states, {len(err_examples):3d} err-examples")
except Exception as e: # noqa: BLE001 — report-and-continue is the point
crashes.append((path.name, repr(e)))
print(f" CRASH {label}: {e!r}")
traceback.print_exc()
print("\n" + "=" * 64)
print("REAL-TRACE PIPELINE VALIDATION")
print("=" * 64)
print(f" sessions processed: {len(sessions) - len(crashes)}/{len(sessions)}")
print(f" total states emitted: {tot_states}")
print(f" total error sites: {tot_err_sites}")
print(f" structural-flagged users: {n_struct}")
print(f" string-tag-only users: {n_string_only} (Wave 21: should be ~0)")
_tot_recovery = n_empty_recovery + n_nonempty_recovery
if _tot_recovery:
pct_empty = 100 * n_empty_recovery / _tot_recovery
print(
f" empty-recovery sites: {n_empty_recovery}/{_tot_recovery} "
f"({pct_empty:.0f}%) — these fire NO SDPO signal"
)
if args.strip_thinking and pct_empty > 30:
print(
" ⚠ high empty-recovery rate with --strip-thinking: the recovery "
"turns are pure [THINKING]. Re-run WITHOUT --strip-thinking to "
"recover SDPO signal on these sites."
)
print(f" sessions firing SDPO: {sessions_with_sdpo}")
if not tot_inloss:
print(" no in-loss positions measured — cannot assess alignment")
return 2
ratio = tot_aligned / tot_inloss
print(f" SDPO alignment (REAL): {tot_aligned}/{tot_inloss} = {100 * ratio:.1f}%")
ok = ratio >= args.pass_threshold and not crashes
print(f" RESULT: {'PASS ✅' if ok else 'FAIL ❌'} (threshold {100*args.pass_threshold:.0f}%)")
if crashes:
print(f" {len(crashes)} crash(es): {[c[0] for c in crashes]}")
return 0 if ok else 1
if __name__ == "__main__":
sys.exit(main())