"""ModalSpawnExecutor — production Modal-backed serverless executor. This is the v0-finished sibling of `ModalExecutor` (which remains a skeleton per Wave 18 contract). The skeleton class stays unchanged to preserve `test_skeleton_executors.py`'s pinned NotImplementedError contract; this class is the working alternative for users who want real Modal execution. Design choices vs the skeleton's docstring: 1. **User-provided `modal.Function` instead of internal app construction.** The skeleton showed a pattern where ModalExecutor builds its own `modal.App` and registers `run_replica` internally. That couples the executor to image/GPU/Volume choices the user actually wants to own. Instead, ModalSpawnExecutor takes a *pre-decorated* `modal.Function` from the caller — the user defines: @app.function(gpu="H100:4", image=my_image, volumes={"/vol": vol}, secrets=[modal.Secret.from_name("hf-token")], timeout=4*3600) def run_replica(rendezvous_uri: str, world_size: int, rank: int, **entrypoint_args): import os os.environ["REPLICA_RANK"] = str(rank) from composer_replication.diloco.serverless import ( MockManager, ObjectStoreAllReduce, ) store = ObjectStoreAllReduce(rendezvous_uri, rank=rank, world_size=world_size) manager = MockManager(store) # ... user's training loop with this manager ... then constructs: executor = ModalSpawnExecutor(modal_function=run_replica) handles = executor.launch_replicas( n_replicas=4, entrypoint=run_replica, # ignored — function is bound entrypoint_args={"rendezvous_uri": "/vol/diloco/run42", "world_size": 4}, ) 2. **Rank as explicit kwarg, not env-var indirection.** Modal Functions start with a clean env, so the rank-via-env pattern that LocalProcessExecutor uses is fragile here (Modal would need container-level env injection per call, which `modal.Secret.from_dict` does but adds a round-trip per spawn). We pass rank as a kwarg to `.spawn(rank=i)` so it's plumbed through Modal's call args directly. 3. **Handle metadata = `call_id`, no in-process state.** Unlike LocalProcessExecutor (which holds Process refs), this executor is stateless after launch — handles are reconstructed via `modal.FunctionCall.from_id(call_id)` for poll/cancel/collect. Lets the executor survive process restart mid-run. References: - modal-client 1.4.x docs on FunctionCall: https://modal.com/docs/reference/modal.FunctionCall - ADR-005 (executor protocol design) """ from __future__ import annotations import time from typing import Any, Callable, Mapping from composer_replication.diloco.serverless.executor import ( ReplicaHandle, ServerlessExecutor, ) class ModalSpawnExecutor: """Run replicas as parallel Modal Function spawns. Implements the `ServerlessExecutor` Protocol against Modal's `Function.spawn()` API. The user must provide a pre-decorated `modal.Function` (with `@app.function(...)` already applied) — see module docstring for the expected signature. Args: modal_function: a `modal.Function` registered against a `modal.App`. Must accept at minimum `rank: int` plus the kwargs in `entrypoint_args`. Image / GPU / Volume / Secret / timeout are pinned on the decorator and the executor won't override them. deploy: if True, calls `modal_function.app.deploy()` before spawning. Required when running outside a `modal run` context (e.g. from a regular Python script). Default False — assumes the user is inside a `modal run` block where the app is already live. Raises: RuntimeError: if `modal` client is not installed. TypeError: if `modal_function` is not a `modal.Function`. """ backend_name = "modal_spawn" supports_inter_replica_network = False # Modal containers are isolated by default def __init__( self, modal_function: Any, *, deploy: bool = False, ) -> None: try: import modal # noqa: F401 except ImportError as e: raise RuntimeError( "ModalSpawnExecutor requires the modal client. Install with " "`pip install modal` and configure with `modal token new`. " f"Got: {e!r}" ) # Duck-type check — modal.Function objects expose .spawn / .remote / # ._app, which the user-supplied function will have if they used the # @app.function(...) decorator. We avoid `isinstance(_, modal.Function)` # to stay tolerant of modal-client minor-version changes that may # restructure the class. if not (hasattr(modal_function, "spawn") and hasattr(modal_function, "remote")): raise TypeError( f"modal_function must be a modal.Function (decorated via " f"`@app.function(...)`). Got {type(modal_function)!r} which " f"has no `.spawn()` method. " f"See ModalSpawnExecutor docstring for expected signature." ) self.modal_function = modal_function self._deploy_requested = deploy self._deployed = False self._handles: dict[int, dict[str, Any]] = {} # ----------------------------------------------------------------- # Lifecycle # ----------------------------------------------------------------- def _maybe_deploy(self) -> None: if self._deploy_requested and not self._deployed: # `modal_function.app` exposes the underlying App. Calling # `.deploy()` registers it with Modal so spawn() works from # outside `modal run`. app = getattr(self.modal_function, "app", None) if app is None: raise RuntimeError( "modal_function.app is None — can't deploy. The function " "must have been decorated against a real modal.App." ) app.deploy() self._deployed = True # ----------------------------------------------------------------- # ServerlessExecutor Protocol # ----------------------------------------------------------------- def launch_replicas( self, n_replicas: int, entrypoint: str | Callable[..., Any], entrypoint_args: Mapping[str, Any], *, gpu: str | None = None, timeout: int = 3600, ) -> list[ReplicaHandle]: """Spawn N parallel Modal Function calls. Note: `entrypoint` is **ignored** — the actual entrypoint is the `modal_function` passed to `__init__`. This keeps the executor Protocol-compatible while preserving the user's image/GPU decoration. `gpu` and `timeout` are similarly ignored (pinned on the function decorator). """ del entrypoint, gpu, timeout # pinned on the decorated function if n_replicas < 1: raise ValueError(f"n_replicas must be >= 1, got {n_replicas}") self._maybe_deploy() # Strip rank_env if present — we use explicit `rank` kwarg. spawn_kwargs = {k: v for k, v in entrypoint_args.items() if k != "rank_env"} handles: list[ReplicaHandle] = [] for rank in range(n_replicas): try: fcall = self.modal_function.spawn(rank=rank, **spawn_kwargs) except Exception as e: # Best-effort cancel any already-launched siblings for prior in handles: try: self.cancel(prior) except Exception: pass raise RuntimeError( f"ModalSpawnExecutor.launch_replicas failed at rank={rank} " f"of {n_replicas} (already-launched siblings cancelled). " f"Underlying error: {e!r}" ) from e handle = ReplicaHandle( rank=rank, backend_name=self.backend_name, metadata={ "call_id": fcall.object_id, "spawn_ts": time.time(), }, ) self._handles[rank] = { "fcall": fcall, "result": None, } handles.append(handle) return handles def poll(self, handle: ReplicaHandle) -> str: """Poll a Modal call's status. Modal's FunctionCall doesn't expose a non-blocking status getter directly (the API is `.get(timeout=...)`), so we poll by trying `.get(timeout=0)` and treating Timeout/Pending as "running". Returns one of: "pending" | "running" | "succeeded" | "failed" | "cancelled". """ meta = self._handles.get(handle.rank) if meta is None: return "cancelled" # If we already collected this one, return cached result if meta["result"] is not None: return meta["result"]["status"] import modal from modal.exception import OutputExpiredError fcall = meta["fcall"] # Re-hydrate to get fresh state try: # `.get(timeout=0)` returns immediately if done; raises TimeoutError otherwise. result_value = fcall.get(timeout=0) meta["result"] = { "rank": handle.rank, "status": "succeeded", "exit_code": 0, "error": None, "result": result_value, "call_id": handle.metadata.get("call_id"), } return "succeeded" except TimeoutError: return "running" except OutputExpiredError as e: meta["result"] = { "rank": handle.rank, "status": "failed", "exit_code": 1, "error": f"OutputExpiredError: {e!r}", "result": None, "call_id": handle.metadata.get("call_id"), } return "failed" except Exception as e: # User-code exception bubbles up here as the original exception class meta["result"] = { "rank": handle.rank, "status": "failed", "exit_code": 1, "error": f"{type(e).__name__}: {e!r}", "result": None, "call_id": handle.metadata.get("call_id"), } return "failed" def stream_logs(self, handle: ReplicaHandle, *, n_lines: int = 200) -> str: """Read recent Modal logs for this call. Modal exposes per-FunctionCall logs via the dashboard URL. The client API doesn't expose log-streaming directly in 1.4.x, so we return a pointer to the dashboard URL plus any captured error from poll(). """ meta = self._handles.get(handle.rank) if meta is None: return f"" call_id = handle.metadata.get("call_id", "") try: dashboard_url = meta["fcall"].get_dashboard_url() except Exception: dashboard_url = ( f"https://modal.com/apps///calls/{call_id}" ) if meta.get("result"): err = meta["result"].get("error") or "" return ( f"[rank {handle.rank}] call_id={call_id}\n" f" Dashboard: {dashboard_url}\n" f" Result: {meta['result']['status']}\n" f" Error: {err[-2000:] if err else ''}" ) return ( f"[rank {handle.rank}] call_id={call_id} (still running)\n" f" Dashboard: {dashboard_url}\n" f" Logs not streamable via client API in modal-client 1.4.x; " f"use the dashboard URL or `modal app logs `." ) def cancel(self, handle: ReplicaHandle) -> None: """Best-effort cancel of a Modal call.""" meta = self._handles.get(handle.rank) if meta is None: return try: meta["fcall"].cancel() except Exception: # Already terminated, network blip, etc. — best-effort. pass def collect( self, handles: list[ReplicaHandle], *, timeout: int | None = None, ) -> list[dict[str, Any]]: """Block until all replicas finish; return per-replica result dicts. Modal's `.get(timeout=...)` blocks until the call completes or the timeout elapses. We iterate handles, calling `.get()` with the remaining time budget, so the cumulative wall-clock is bounded by `timeout`. """ deadline = time.time() + (timeout if timeout is not None else 86400) results: list[dict[str, Any]] = [] for h in handles: meta = self._handles.get(h.rank) if meta is None: results.append({ "rank": h.rank, "status": "cancelled", "exit_code": None, "error": "handle has no metadata (cancelled or unknown)", "result": None, "call_id": h.metadata.get("call_id"), }) continue # Already collected by an earlier poll() if meta["result"] is not None: results.append(meta["result"]) continue remaining = max(0.0, deadline - time.time()) try: result_value = meta["fcall"].get(timeout=remaining) result_dict = { "rank": h.rank, "status": "succeeded", "exit_code": 0, "error": None, "result": result_value, "call_id": h.metadata.get("call_id"), } except TimeoutError as e: result_dict = { "rank": h.rank, "status": "running", "exit_code": None, "error": f"TimeoutError after deadline: {e!r}", "result": None, "call_id": h.metadata.get("call_id"), } except Exception as e: result_dict = { "rank": h.rank, "status": "failed", "exit_code": 1, "error": f"{type(e).__name__}: {e!r}", "result": None, "call_id": h.metadata.get("call_id"), } meta["result"] = result_dict results.append(result_dict) return results __all__ = ["ModalSpawnExecutor"]