Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
File size: 8,585 Bytes
d61036a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 | """Validate the full ingestion -> adapter -> collator -> SDPO data path against
REAL Claude Code session logs, and report the live SDPO alignment ratio.
Why this exists
---------------
The synthetic fixture in `spikes/007-real-trace-ingestion/fixtures/` proves the
pipeline works on hand-built data. This script proves it on REAL traces — long
tool outputs, multi-block content, thinking blocks, genuinely weird tool errors
— which is where the Wave 19 chat-template drift bug (residual ~33%
misalignment) actually bit. Wave 21's `_build_chat_aligned_mask` fix is verified
here at the population level.
What it measures
----------------
* ingestion yield (states emitted, error sites detected)
* structural vs string-only error flagging (the Wave 21 TOOL_ERROR_TAG fix —
structural should dominate; string-only should be ~0)
* SDPO alignment ratio: fraction of in-loss `sdpo_loss_mask` positions where
student token id == teacher token id. ~100% means the mask lands exactly on
content tokens; <95% means chat-template drift has regressed.
Usage
-----
python examples/validate_real_trace_alignment/run.py \
[--projects-dir ~/.claude/projects] \
[--max-sessions 8] [--model Qwen/Qwen2.5-0.5B-Instruct]
Requires a real chat-template tokenizer (transformers + a cached/instruct model)
and at least one local Claude Code session containing `is_error: true`. Exits 0
on PASS (>=95% alignment), 1 on FAIL, 2 if no error-bearing sessions were found.
"""
from __future__ import annotations
import argparse
import os
import sys
import traceback
from pathlib import Path
def _discover_error_sessions(projects_dir: Path, limit: int) -> list[Path]:
"""Find session JSONLs that contain at least one is_error:true tool_result,
skipping subagent (`agent-*`) files. Returns up to `limit`, smallest first
(faster to process, still representative)."""
hits: list[tuple[int, Path]] = []
for p in projects_dir.rglob("*.jsonl"):
if p.name.startswith("agent-"):
continue
try:
text = p.read_text(encoding="utf-8", errors="ignore")
except OSError:
continue
if '"is_error":true' in text or '"is_error": true' in text:
hits.append((p.stat().st_size, p))
hits.sort(key=lambda t: t[0])
return [p for _, p in hits[:limit]]
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--projects-dir", default=str(Path.home() / ".claude" / "projects"))
ap.add_argument("--max-sessions", type=int, default=8)
ap.add_argument("--model", default="Qwen/Qwen2.5-0.5B-Instruct")
ap.add_argument("--pass-threshold", type=float, default=0.95)
ap.add_argument(
"--strip-thinking",
action="store_true",
help="Strip [THINKING] blocks. DEFAULT IS FALSE for SDPO: on real "
"Claude Code traces the error-recovery turn is frequently pure "
"thinking, so stripping it empties ~67%% of error sites and the SDPO "
"channel sees no signal. Keep thinking for hint-distillation.",
)
args = ap.parse_args()
os.environ.setdefault("HF_HUB_OFFLINE", "1")
os.environ.setdefault("TRANSFORMERS_OFFLINE", "1")
from transformers import AutoTokenizer
from composer_replication.ingestion import ClaudeCodeIngester
from composer_replication.ingestion.trace_examples import (
TOOL_ERROR_TAG,
claude_states_to_trace_examples,
)
from composer_replication.trainer.data_collator import (
CollatorConfig,
ComposerDataCollator,
)
projects_dir = Path(args.projects_dir).expanduser()
if not projects_dir.exists():
print(f"projects dir not found: {projects_dir}")
return 2
sessions = _discover_error_sessions(projects_dir, args.max_sessions)
if not sessions:
print(f"no error-bearing sessions under {projects_dir}")
return 2
tok = AutoTokenizer.from_pretrained(args.model)
if not getattr(tok, "chat_template", None):
print(f"{args.model} has no chat template; pick an -Instruct model")
return 2
def hint_gen(kind, _meta):
return f"Recover from the {kind}: re-check the path/args before retrying."
cfg = CollatorConfig(hint_generator=hint_gen, enable_replay_dpo=False, max_seq_len=8192)
collator = ComposerDataCollator(tokenizer=tok, config=cfg)
tot_states = tot_err_sites = 0
tot_aligned = tot_inloss = 0
n_struct = n_string_only = 0
n_empty_recovery = n_nonempty_recovery = 0
sessions_with_sdpo = 0
crashes: list[tuple[str, str]] = []
for path in sessions:
label = path.name[:18]
try:
ing = ClaudeCodeIngester(skip_sidechain=True, strip_thinking=args.strip_thinking)
states = list(ing.ingest(path))
for s in states:
for m in s["messages"]:
if m.get("role") != "user":
continue
if m.get("tool_error") is True:
n_struct += 1
elif isinstance(m.get("content"), str) and TOOL_ERROR_TAG in m["content"]:
n_string_only += 1
examples = claude_states_to_trace_examples(states)
# Count empty vs non-empty recovery content among detected error turns.
for ex in examples:
for t in ex["turns"]:
if t.get("tool_error"):
if (t.get("content") or "").strip():
n_nonempty_recovery += 1
else:
n_empty_recovery += 1
err_examples = [
ex for ex in examples if any(t.get("tool_error") for t in ex["turns"])
]
tot_states += len(states)
tot_err_sites += sum(
sum(1 for t in ex["turns"] if t.get("tool_error")) for ex in examples
)
if err_examples:
batch = collator(err_examples[:4])
if "sdpo_loss_mask" in batch:
sessions_with_sdpo += 1
s_in = batch["input_ids"]
t_in = batch["ctx_teacher_input_ids"]
m_in = batch["sdpo_loss_mask"]
for row in range(s_in.shape[0]):
il = m_in[row] == 1
if int(il.sum()) == 0:
continue
tot_aligned += int((s_in[row][il] == t_in[row][il]).sum().item())
tot_inloss += int(il.sum().item())
print(f" OK {label}: {len(states):4d} states, {len(err_examples):3d} err-examples")
except Exception as e: # noqa: BLE001 — report-and-continue is the point
crashes.append((path.name, repr(e)))
print(f" CRASH {label}: {e!r}")
traceback.print_exc()
print("\n" + "=" * 64)
print("REAL-TRACE PIPELINE VALIDATION")
print("=" * 64)
print(f" sessions processed: {len(sessions) - len(crashes)}/{len(sessions)}")
print(f" total states emitted: {tot_states}")
print(f" total error sites: {tot_err_sites}")
print(f" structural-flagged users: {n_struct}")
print(f" string-tag-only users: {n_string_only} (Wave 21: should be ~0)")
_tot_recovery = n_empty_recovery + n_nonempty_recovery
if _tot_recovery:
pct_empty = 100 * n_empty_recovery / _tot_recovery
print(
f" empty-recovery sites: {n_empty_recovery}/{_tot_recovery} "
f"({pct_empty:.0f}%) — these fire NO SDPO signal"
)
if args.strip_thinking and pct_empty > 30:
print(
" ⚠ high empty-recovery rate with --strip-thinking: the recovery "
"turns are pure [THINKING]. Re-run WITHOUT --strip-thinking to "
"recover SDPO signal on these sites."
)
print(f" sessions firing SDPO: {sessions_with_sdpo}")
if not tot_inloss:
print(" no in-loss positions measured — cannot assess alignment")
return 2
ratio = tot_aligned / tot_inloss
print(f" SDPO alignment (REAL): {tot_aligned}/{tot_inloss} = {100 * ratio:.1f}%")
ok = ratio >= args.pass_threshold and not crashes
print(f" RESULT: {'PASS ✅' if ok else 'FAIL ❌'} (threshold {100*args.pass_threshold:.0f}%)")
if crashes:
print(f" {len(crashes)} crash(es): {[c[0] for c in crashes]}")
return 0 if ok else 1
if __name__ == "__main__":
sys.exit(main())
|