Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """ModalSpawnExecutor — production Modal-backed serverless executor. | |
| This is the v0-finished sibling of `ModalExecutor` (which remains a | |
| skeleton per Wave 18 contract). The skeleton class stays unchanged to | |
| preserve `test_skeleton_executors.py`'s pinned NotImplementedError | |
| contract; this class is the working alternative for users who want | |
| real Modal execution. | |
| Design choices vs the skeleton's docstring: | |
| 1. **User-provided `modal.Function` instead of internal app construction.** | |
| The skeleton showed a pattern where ModalExecutor builds its own | |
| `modal.App` and registers `run_replica` internally. That couples the | |
| executor to image/GPU/Volume choices the user actually wants to own. | |
| Instead, ModalSpawnExecutor takes a *pre-decorated* `modal.Function` | |
| from the caller — the user defines: | |
| @app.function(gpu="H100:4", image=my_image, volumes={"/vol": vol}, | |
| secrets=[modal.Secret.from_name("hf-token")], | |
| timeout=4*3600) | |
| def run_replica(rendezvous_uri: str, world_size: int, | |
| rank: int, **entrypoint_args): | |
| import os | |
| os.environ["REPLICA_RANK"] = str(rank) | |
| from composer_replication.diloco.serverless import ( | |
| MockManager, ObjectStoreAllReduce, | |
| ) | |
| store = ObjectStoreAllReduce(rendezvous_uri, rank=rank, | |
| world_size=world_size) | |
| manager = MockManager(store) | |
| # ... user's training loop with this manager ... | |
| then constructs: | |
| executor = ModalSpawnExecutor(modal_function=run_replica) | |
| handles = executor.launch_replicas( | |
| n_replicas=4, | |
| entrypoint=run_replica, # ignored — function is bound | |
| entrypoint_args={"rendezvous_uri": "/vol/diloco/run42", | |
| "world_size": 4}, | |
| ) | |
| 2. **Rank as explicit kwarg, not env-var indirection.** Modal Functions | |
| start with a clean env, so the rank-via-env pattern that | |
| LocalProcessExecutor uses is fragile here (Modal would need | |
| container-level env injection per call, which `modal.Secret.from_dict` | |
| does but adds a round-trip per spawn). We pass rank as a kwarg to | |
| `.spawn(rank=i)` so it's plumbed through Modal's call args directly. | |
| 3. **Handle metadata = `call_id`, no in-process state.** Unlike | |
| LocalProcessExecutor (which holds Process refs), this executor is | |
| stateless after launch — handles are reconstructed via | |
| `modal.FunctionCall.from_id(call_id)` for poll/cancel/collect. | |
| Lets the executor survive process restart mid-run. | |
| References: | |
| - modal-client 1.4.x docs on FunctionCall: https://modal.com/docs/reference/modal.FunctionCall | |
| - ADR-005 (executor protocol design) | |
| """ | |
| from __future__ import annotations | |
| import time | |
| from typing import Any, Callable, Mapping | |
| from composer_replication.diloco.serverless.executor import ( | |
| ReplicaHandle, | |
| ServerlessExecutor, | |
| ) | |
| class ModalSpawnExecutor: | |
| """Run replicas as parallel Modal Function spawns. | |
| Implements the `ServerlessExecutor` Protocol against Modal's | |
| `Function.spawn()` API. The user must provide a pre-decorated | |
| `modal.Function` (with `@app.function(...)` already applied) — see | |
| module docstring for the expected signature. | |
| Args: | |
| modal_function: a `modal.Function` registered against a `modal.App`. | |
| Must accept at minimum `rank: int` plus the kwargs in | |
| `entrypoint_args`. Image / GPU / Volume / Secret / timeout | |
| are pinned on the decorator and the executor won't override | |
| them. | |
| deploy: if True, calls `modal_function.app.deploy()` before | |
| spawning. Required when running outside a `modal run` context | |
| (e.g. from a regular Python script). Default False — assumes | |
| the user is inside a `modal run` block where the app is | |
| already live. | |
| Raises: | |
| RuntimeError: if `modal` client is not installed. | |
| TypeError: if `modal_function` is not a `modal.Function`. | |
| """ | |
| backend_name = "modal_spawn" | |
| supports_inter_replica_network = False # Modal containers are isolated by default | |
| def __init__( | |
| self, | |
| modal_function: Any, | |
| *, | |
| deploy: bool = False, | |
| ) -> None: | |
| try: | |
| import modal # noqa: F401 | |
| except ImportError as e: | |
| raise RuntimeError( | |
| "ModalSpawnExecutor requires the modal client. Install with " | |
| "`pip install modal` and configure with `modal token new`. " | |
| f"Got: {e!r}" | |
| ) | |
| # Duck-type check — modal.Function objects expose .spawn / .remote / | |
| # ._app, which the user-supplied function will have if they used the | |
| # @app.function(...) decorator. We avoid `isinstance(_, modal.Function)` | |
| # to stay tolerant of modal-client minor-version changes that may | |
| # restructure the class. | |
| if not (hasattr(modal_function, "spawn") and hasattr(modal_function, "remote")): | |
| raise TypeError( | |
| f"modal_function must be a modal.Function (decorated via " | |
| f"`@app.function(...)`). Got {type(modal_function)!r} which " | |
| f"has no `.spawn()` method. " | |
| f"See ModalSpawnExecutor docstring for expected signature." | |
| ) | |
| self.modal_function = modal_function | |
| self._deploy_requested = deploy | |
| self._deployed = False | |
| self._handles: dict[int, dict[str, Any]] = {} | |
| # ----------------------------------------------------------------- | |
| # Lifecycle | |
| # ----------------------------------------------------------------- | |
| def _maybe_deploy(self) -> None: | |
| if self._deploy_requested and not self._deployed: | |
| # `modal_function.app` exposes the underlying App. Calling | |
| # `.deploy()` registers it with Modal so spawn() works from | |
| # outside `modal run`. | |
| app = getattr(self.modal_function, "app", None) | |
| if app is None: | |
| raise RuntimeError( | |
| "modal_function.app is None — can't deploy. The function " | |
| "must have been decorated against a real modal.App." | |
| ) | |
| app.deploy() | |
| self._deployed = True | |
| # ----------------------------------------------------------------- | |
| # ServerlessExecutor Protocol | |
| # ----------------------------------------------------------------- | |
| def launch_replicas( | |
| self, | |
| n_replicas: int, | |
| entrypoint: str | Callable[..., Any], | |
| entrypoint_args: Mapping[str, Any], | |
| *, | |
| gpu: str | None = None, | |
| timeout: int = 3600, | |
| ) -> list[ReplicaHandle]: | |
| """Spawn N parallel Modal Function calls. | |
| Note: `entrypoint` is **ignored** — the actual entrypoint is the | |
| `modal_function` passed to `__init__`. This keeps the executor | |
| Protocol-compatible while preserving the user's image/GPU | |
| decoration. `gpu` and `timeout` are similarly ignored (pinned | |
| on the function decorator). | |
| """ | |
| del entrypoint, gpu, timeout # pinned on the decorated function | |
| if n_replicas < 1: | |
| raise ValueError(f"n_replicas must be >= 1, got {n_replicas}") | |
| self._maybe_deploy() | |
| # Strip rank_env if present — we use explicit `rank` kwarg. | |
| spawn_kwargs = {k: v for k, v in entrypoint_args.items() | |
| if k != "rank_env"} | |
| handles: list[ReplicaHandle] = [] | |
| for rank in range(n_replicas): | |
| try: | |
| fcall = self.modal_function.spawn(rank=rank, **spawn_kwargs) | |
| except Exception as e: | |
| # Best-effort cancel any already-launched siblings | |
| for prior in handles: | |
| try: | |
| self.cancel(prior) | |
| except Exception: | |
| pass | |
| raise RuntimeError( | |
| f"ModalSpawnExecutor.launch_replicas failed at rank={rank} " | |
| f"of {n_replicas} (already-launched siblings cancelled). " | |
| f"Underlying error: {e!r}" | |
| ) from e | |
| handle = ReplicaHandle( | |
| rank=rank, | |
| backend_name=self.backend_name, | |
| metadata={ | |
| "call_id": fcall.object_id, | |
| "spawn_ts": time.time(), | |
| }, | |
| ) | |
| self._handles[rank] = { | |
| "fcall": fcall, | |
| "result": None, | |
| } | |
| handles.append(handle) | |
| return handles | |
| def poll(self, handle: ReplicaHandle) -> str: | |
| """Poll a Modal call's status. | |
| Modal's FunctionCall doesn't expose a non-blocking status getter | |
| directly (the API is `.get(timeout=...)`), so we poll by trying | |
| `.get(timeout=0)` and treating Timeout/Pending as "running". | |
| Returns one of: "pending" | "running" | "succeeded" | "failed" | | |
| "cancelled". | |
| """ | |
| meta = self._handles.get(handle.rank) | |
| if meta is None: | |
| return "cancelled" | |
| # If we already collected this one, return cached result | |
| if meta["result"] is not None: | |
| return meta["result"]["status"] | |
| import modal | |
| from modal.exception import OutputExpiredError | |
| fcall = meta["fcall"] | |
| # Re-hydrate to get fresh state | |
| try: | |
| # `.get(timeout=0)` returns immediately if done; raises TimeoutError otherwise. | |
| result_value = fcall.get(timeout=0) | |
| meta["result"] = { | |
| "rank": handle.rank, | |
| "status": "succeeded", | |
| "exit_code": 0, | |
| "error": None, | |
| "result": result_value, | |
| "call_id": handle.metadata.get("call_id"), | |
| } | |
| return "succeeded" | |
| except TimeoutError: | |
| return "running" | |
| except OutputExpiredError as e: | |
| meta["result"] = { | |
| "rank": handle.rank, | |
| "status": "failed", | |
| "exit_code": 1, | |
| "error": f"OutputExpiredError: {e!r}", | |
| "result": None, | |
| "call_id": handle.metadata.get("call_id"), | |
| } | |
| return "failed" | |
| except Exception as e: | |
| # User-code exception bubbles up here as the original exception class | |
| meta["result"] = { | |
| "rank": handle.rank, | |
| "status": "failed", | |
| "exit_code": 1, | |
| "error": f"{type(e).__name__}: {e!r}", | |
| "result": None, | |
| "call_id": handle.metadata.get("call_id"), | |
| } | |
| return "failed" | |
| def stream_logs(self, handle: ReplicaHandle, *, n_lines: int = 200) -> str: | |
| """Read recent Modal logs for this call. | |
| Modal exposes per-FunctionCall logs via the dashboard URL. The | |
| client API doesn't expose log-streaming directly in 1.4.x, so we | |
| return a pointer to the dashboard URL plus any captured error | |
| from poll(). | |
| """ | |
| meta = self._handles.get(handle.rank) | |
| if meta is None: | |
| return f"<replica {handle.rank}: no metadata>" | |
| call_id = handle.metadata.get("call_id", "<unknown>") | |
| try: | |
| dashboard_url = meta["fcall"].get_dashboard_url() | |
| except Exception: | |
| dashboard_url = ( | |
| f"https://modal.com/apps/<workspace>/<env>/calls/{call_id}" | |
| ) | |
| if meta.get("result"): | |
| err = meta["result"].get("error") or "<no error>" | |
| return ( | |
| f"[rank {handle.rank}] call_id={call_id}\n" | |
| f" Dashboard: {dashboard_url}\n" | |
| f" Result: {meta['result']['status']}\n" | |
| f" Error: {err[-2000:] if err else '<none>'}" | |
| ) | |
| return ( | |
| f"[rank {handle.rank}] call_id={call_id} (still running)\n" | |
| f" Dashboard: {dashboard_url}\n" | |
| f" Logs not streamable via client API in modal-client 1.4.x; " | |
| f"use the dashboard URL or `modal app logs <app-id>`." | |
| ) | |
| def cancel(self, handle: ReplicaHandle) -> None: | |
| """Best-effort cancel of a Modal call.""" | |
| meta = self._handles.get(handle.rank) | |
| if meta is None: | |
| return | |
| try: | |
| meta["fcall"].cancel() | |
| except Exception: | |
| # Already terminated, network blip, etc. — best-effort. | |
| pass | |
| def collect( | |
| self, | |
| handles: list[ReplicaHandle], | |
| *, | |
| timeout: int | None = None, | |
| ) -> list[dict[str, Any]]: | |
| """Block until all replicas finish; return per-replica result dicts. | |
| Modal's `.get(timeout=...)` blocks until the call completes or | |
| the timeout elapses. We iterate handles, calling `.get()` with | |
| the remaining time budget, so the cumulative wall-clock is | |
| bounded by `timeout`. | |
| """ | |
| deadline = time.time() + (timeout if timeout is not None else 86400) | |
| results: list[dict[str, Any]] = [] | |
| for h in handles: | |
| meta = self._handles.get(h.rank) | |
| if meta is None: | |
| results.append({ | |
| "rank": h.rank, | |
| "status": "cancelled", | |
| "exit_code": None, | |
| "error": "handle has no metadata (cancelled or unknown)", | |
| "result": None, | |
| "call_id": h.metadata.get("call_id"), | |
| }) | |
| continue | |
| # Already collected by an earlier poll() | |
| if meta["result"] is not None: | |
| results.append(meta["result"]) | |
| continue | |
| remaining = max(0.0, deadline - time.time()) | |
| try: | |
| result_value = meta["fcall"].get(timeout=remaining) | |
| result_dict = { | |
| "rank": h.rank, | |
| "status": "succeeded", | |
| "exit_code": 0, | |
| "error": None, | |
| "result": result_value, | |
| "call_id": h.metadata.get("call_id"), | |
| } | |
| except TimeoutError as e: | |
| result_dict = { | |
| "rank": h.rank, | |
| "status": "running", | |
| "exit_code": None, | |
| "error": f"TimeoutError after deadline: {e!r}", | |
| "result": None, | |
| "call_id": h.metadata.get("call_id"), | |
| } | |
| except Exception as e: | |
| result_dict = { | |
| "rank": h.rank, | |
| "status": "failed", | |
| "exit_code": 1, | |
| "error": f"{type(e).__name__}: {e!r}", | |
| "result": None, | |
| "call_id": h.metadata.get("call_id"), | |
| } | |
| meta["result"] = result_dict | |
| results.append(result_dict) | |
| return results | |
| __all__ = ["ModalSpawnExecutor"] | |