File size: 8,585 Bytes
d61036a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
"""Validate the full ingestion -> adapter -> collator -> SDPO data path against
REAL Claude Code session logs, and report the live SDPO alignment ratio.

Why this exists
---------------
The synthetic fixture in `spikes/007-real-trace-ingestion/fixtures/` proves the
pipeline works on hand-built data. This script proves it on REAL traces — long
tool outputs, multi-block content, thinking blocks, genuinely weird tool errors
— which is where the Wave 19 chat-template drift bug (residual ~33%
misalignment) actually bit. Wave 21's `_build_chat_aligned_mask` fix is verified
here at the population level.

What it measures
----------------
  * ingestion yield (states emitted, error sites detected)
  * structural vs string-only error flagging (the Wave 21 TOOL_ERROR_TAG fix —
    structural should dominate; string-only should be ~0)
  * SDPO alignment ratio: fraction of in-loss `sdpo_loss_mask` positions where
    student token id == teacher token id. ~100% means the mask lands exactly on
    content tokens; <95% means chat-template drift has regressed.

Usage
-----
    python examples/validate_real_trace_alignment/run.py \
        [--projects-dir ~/.claude/projects] \
        [--max-sessions 8] [--model Qwen/Qwen2.5-0.5B-Instruct]

Requires a real chat-template tokenizer (transformers + a cached/instruct model)
and at least one local Claude Code session containing `is_error: true`. Exits 0
on PASS (>=95% alignment), 1 on FAIL, 2 if no error-bearing sessions were found.
"""
from __future__ import annotations

import argparse
import os
import sys
import traceback
from pathlib import Path


def _discover_error_sessions(projects_dir: Path, limit: int) -> list[Path]:
    """Find session JSONLs that contain at least one is_error:true tool_result,
    skipping subagent (`agent-*`) files. Returns up to `limit`, smallest first
    (faster to process, still representative)."""
    hits: list[tuple[int, Path]] = []
    for p in projects_dir.rglob("*.jsonl"):
        if p.name.startswith("agent-"):
            continue
        try:
            text = p.read_text(encoding="utf-8", errors="ignore")
        except OSError:
            continue
        if '"is_error":true' in text or '"is_error": true' in text:
            hits.append((p.stat().st_size, p))
    hits.sort(key=lambda t: t[0])
    return [p for _, p in hits[:limit]]


def main() -> int:
    ap = argparse.ArgumentParser()
    ap.add_argument("--projects-dir", default=str(Path.home() / ".claude" / "projects"))
    ap.add_argument("--max-sessions", type=int, default=8)
    ap.add_argument("--model", default="Qwen/Qwen2.5-0.5B-Instruct")
    ap.add_argument("--pass-threshold", type=float, default=0.95)
    ap.add_argument(
        "--strip-thinking",
        action="store_true",
        help="Strip [THINKING] blocks. DEFAULT IS FALSE for SDPO: on real "
        "Claude Code traces the error-recovery turn is frequently pure "
        "thinking, so stripping it empties ~67%% of error sites and the SDPO "
        "channel sees no signal. Keep thinking for hint-distillation.",
    )
    args = ap.parse_args()

    os.environ.setdefault("HF_HUB_OFFLINE", "1")
    os.environ.setdefault("TRANSFORMERS_OFFLINE", "1")

    from transformers import AutoTokenizer

    from composer_replication.ingestion import ClaudeCodeIngester
    from composer_replication.ingestion.trace_examples import (
        TOOL_ERROR_TAG,
        claude_states_to_trace_examples,
    )
    from composer_replication.trainer.data_collator import (
        CollatorConfig,
        ComposerDataCollator,
    )

    projects_dir = Path(args.projects_dir).expanduser()
    if not projects_dir.exists():
        print(f"projects dir not found: {projects_dir}")
        return 2

    sessions = _discover_error_sessions(projects_dir, args.max_sessions)
    if not sessions:
        print(f"no error-bearing sessions under {projects_dir}")
        return 2

    tok = AutoTokenizer.from_pretrained(args.model)
    if not getattr(tok, "chat_template", None):
        print(f"{args.model} has no chat template; pick an -Instruct model")
        return 2

    def hint_gen(kind, _meta):
        return f"Recover from the {kind}: re-check the path/args before retrying."

    cfg = CollatorConfig(hint_generator=hint_gen, enable_replay_dpo=False, max_seq_len=8192)
    collator = ComposerDataCollator(tokenizer=tok, config=cfg)

    tot_states = tot_err_sites = 0
    tot_aligned = tot_inloss = 0
    n_struct = n_string_only = 0
    n_empty_recovery = n_nonempty_recovery = 0
    sessions_with_sdpo = 0
    crashes: list[tuple[str, str]] = []

    for path in sessions:
        label = path.name[:18]
        try:
            ing = ClaudeCodeIngester(skip_sidechain=True, strip_thinking=args.strip_thinking)
            states = list(ing.ingest(path))
            for s in states:
                for m in s["messages"]:
                    if m.get("role") != "user":
                        continue
                    if m.get("tool_error") is True:
                        n_struct += 1
                    elif isinstance(m.get("content"), str) and TOOL_ERROR_TAG in m["content"]:
                        n_string_only += 1
            examples = claude_states_to_trace_examples(states)
            # Count empty vs non-empty recovery content among detected error turns.
            for ex in examples:
                for t in ex["turns"]:
                    if t.get("tool_error"):
                        if (t.get("content") or "").strip():
                            n_nonempty_recovery += 1
                        else:
                            n_empty_recovery += 1
            err_examples = [
                ex for ex in examples if any(t.get("tool_error") for t in ex["turns"])
            ]
            tot_states += len(states)
            tot_err_sites += sum(
                sum(1 for t in ex["turns"] if t.get("tool_error")) for ex in examples
            )

            if err_examples:
                batch = collator(err_examples[:4])
                if "sdpo_loss_mask" in batch:
                    sessions_with_sdpo += 1
                    s_in = batch["input_ids"]
                    t_in = batch["ctx_teacher_input_ids"]
                    m_in = batch["sdpo_loss_mask"]
                    for row in range(s_in.shape[0]):
                        il = m_in[row] == 1
                        if int(il.sum()) == 0:
                            continue
                        tot_aligned += int((s_in[row][il] == t_in[row][il]).sum().item())
                        tot_inloss += int(il.sum().item())
            print(f"  OK    {label}: {len(states):4d} states, {len(err_examples):3d} err-examples")
        except Exception as e:  # noqa: BLE001 — report-and-continue is the point
            crashes.append((path.name, repr(e)))
            print(f"  CRASH {label}: {e!r}")
            traceback.print_exc()

    print("\n" + "=" * 64)
    print("REAL-TRACE PIPELINE VALIDATION")
    print("=" * 64)
    print(f"  sessions processed:       {len(sessions) - len(crashes)}/{len(sessions)}")
    print(f"  total states emitted:     {tot_states}")
    print(f"  total error sites:        {tot_err_sites}")
    print(f"  structural-flagged users: {n_struct}")
    print(f"  string-tag-only users:    {n_string_only}  (Wave 21: should be ~0)")
    _tot_recovery = n_empty_recovery + n_nonempty_recovery
    if _tot_recovery:
        pct_empty = 100 * n_empty_recovery / _tot_recovery
        print(
            f"  empty-recovery sites:     {n_empty_recovery}/{_tot_recovery} "
            f"({pct_empty:.0f}%) — these fire NO SDPO signal"
        )
        if args.strip_thinking and pct_empty > 30:
            print(
                "    ⚠ high empty-recovery rate with --strip-thinking: the recovery "
                "turns are pure [THINKING]. Re-run WITHOUT --strip-thinking to "
                "recover SDPO signal on these sites."
            )
    print(f"  sessions firing SDPO:     {sessions_with_sdpo}")

    if not tot_inloss:
        print("  no in-loss positions measured — cannot assess alignment")
        return 2
    ratio = tot_aligned / tot_inloss
    print(f"  SDPO alignment (REAL):    {tot_aligned}/{tot_inloss} = {100 * ratio:.1f}%")
    ok = ratio >= args.pass_threshold and not crashes
    print(f"  RESULT: {'PASS ✅' if ok else 'FAIL ❌'} (threshold {100*args.pass_threshold:.0f}%)")
    if crashes:
        print(f"  {len(crashes)} crash(es): {[c[0] for c in crashes]}")
    return 0 if ok else 1


if __name__ == "__main__":
    sys.exit(main())