Codeseys's picture
Wave 20: ModalSpawnExecutor — finish the Modal-backed serverless executor
a384097
Raw
History Blame Contribute Delete
15.2 kB
"""ModalSpawnExecutor — production Modal-backed serverless executor.
This is the v0-finished sibling of `ModalExecutor` (which remains a
skeleton per Wave 18 contract). The skeleton class stays unchanged to
preserve `test_skeleton_executors.py`'s pinned NotImplementedError
contract; this class is the working alternative for users who want
real Modal execution.
Design choices vs the skeleton's docstring:
1. **User-provided `modal.Function` instead of internal app construction.**
The skeleton showed a pattern where ModalExecutor builds its own
`modal.App` and registers `run_replica` internally. That couples the
executor to image/GPU/Volume choices the user actually wants to own.
Instead, ModalSpawnExecutor takes a *pre-decorated* `modal.Function`
from the caller — the user defines:
@app.function(gpu="H100:4", image=my_image, volumes={"/vol": vol},
secrets=[modal.Secret.from_name("hf-token")],
timeout=4*3600)
def run_replica(rendezvous_uri: str, world_size: int,
rank: int, **entrypoint_args):
import os
os.environ["REPLICA_RANK"] = str(rank)
from composer_replication.diloco.serverless import (
MockManager, ObjectStoreAllReduce,
)
store = ObjectStoreAllReduce(rendezvous_uri, rank=rank,
world_size=world_size)
manager = MockManager(store)
# ... user's training loop with this manager ...
then constructs:
executor = ModalSpawnExecutor(modal_function=run_replica)
handles = executor.launch_replicas(
n_replicas=4,
entrypoint=run_replica, # ignored — function is bound
entrypoint_args={"rendezvous_uri": "/vol/diloco/run42",
"world_size": 4},
)
2. **Rank as explicit kwarg, not env-var indirection.** Modal Functions
start with a clean env, so the rank-via-env pattern that
LocalProcessExecutor uses is fragile here (Modal would need
container-level env injection per call, which `modal.Secret.from_dict`
does but adds a round-trip per spawn). We pass rank as a kwarg to
`.spawn(rank=i)` so it's plumbed through Modal's call args directly.
3. **Handle metadata = `call_id`, no in-process state.** Unlike
LocalProcessExecutor (which holds Process refs), this executor is
stateless after launch — handles are reconstructed via
`modal.FunctionCall.from_id(call_id)` for poll/cancel/collect.
Lets the executor survive process restart mid-run.
References:
- modal-client 1.4.x docs on FunctionCall: https://modal.com/docs/reference/modal.FunctionCall
- ADR-005 (executor protocol design)
"""
from __future__ import annotations
import time
from typing import Any, Callable, Mapping
from composer_replication.diloco.serverless.executor import (
ReplicaHandle,
ServerlessExecutor,
)
class ModalSpawnExecutor:
"""Run replicas as parallel Modal Function spawns.
Implements the `ServerlessExecutor` Protocol against Modal's
`Function.spawn()` API. The user must provide a pre-decorated
`modal.Function` (with `@app.function(...)` already applied) — see
module docstring for the expected signature.
Args:
modal_function: a `modal.Function` registered against a `modal.App`.
Must accept at minimum `rank: int` plus the kwargs in
`entrypoint_args`. Image / GPU / Volume / Secret / timeout
are pinned on the decorator and the executor won't override
them.
deploy: if True, calls `modal_function.app.deploy()` before
spawning. Required when running outside a `modal run` context
(e.g. from a regular Python script). Default False — assumes
the user is inside a `modal run` block where the app is
already live.
Raises:
RuntimeError: if `modal` client is not installed.
TypeError: if `modal_function` is not a `modal.Function`.
"""
backend_name = "modal_spawn"
supports_inter_replica_network = False # Modal containers are isolated by default
def __init__(
self,
modal_function: Any,
*,
deploy: bool = False,
) -> None:
try:
import modal # noqa: F401
except ImportError as e:
raise RuntimeError(
"ModalSpawnExecutor requires the modal client. Install with "
"`pip install modal` and configure with `modal token new`. "
f"Got: {e!r}"
)
# Duck-type check — modal.Function objects expose .spawn / .remote /
# ._app, which the user-supplied function will have if they used the
# @app.function(...) decorator. We avoid `isinstance(_, modal.Function)`
# to stay tolerant of modal-client minor-version changes that may
# restructure the class.
if not (hasattr(modal_function, "spawn") and hasattr(modal_function, "remote")):
raise TypeError(
f"modal_function must be a modal.Function (decorated via "
f"`@app.function(...)`). Got {type(modal_function)!r} which "
f"has no `.spawn()` method. "
f"See ModalSpawnExecutor docstring for expected signature."
)
self.modal_function = modal_function
self._deploy_requested = deploy
self._deployed = False
self._handles: dict[int, dict[str, Any]] = {}
# -----------------------------------------------------------------
# Lifecycle
# -----------------------------------------------------------------
def _maybe_deploy(self) -> None:
if self._deploy_requested and not self._deployed:
# `modal_function.app` exposes the underlying App. Calling
# `.deploy()` registers it with Modal so spawn() works from
# outside `modal run`.
app = getattr(self.modal_function, "app", None)
if app is None:
raise RuntimeError(
"modal_function.app is None — can't deploy. The function "
"must have been decorated against a real modal.App."
)
app.deploy()
self._deployed = True
# -----------------------------------------------------------------
# ServerlessExecutor Protocol
# -----------------------------------------------------------------
def launch_replicas(
self,
n_replicas: int,
entrypoint: str | Callable[..., Any],
entrypoint_args: Mapping[str, Any],
*,
gpu: str | None = None,
timeout: int = 3600,
) -> list[ReplicaHandle]:
"""Spawn N parallel Modal Function calls.
Note: `entrypoint` is **ignored** — the actual entrypoint is the
`modal_function` passed to `__init__`. This keeps the executor
Protocol-compatible while preserving the user's image/GPU
decoration. `gpu` and `timeout` are similarly ignored (pinned
on the function decorator).
"""
del entrypoint, gpu, timeout # pinned on the decorated function
if n_replicas < 1:
raise ValueError(f"n_replicas must be >= 1, got {n_replicas}")
self._maybe_deploy()
# Strip rank_env if present — we use explicit `rank` kwarg.
spawn_kwargs = {k: v for k, v in entrypoint_args.items()
if k != "rank_env"}
handles: list[ReplicaHandle] = []
for rank in range(n_replicas):
try:
fcall = self.modal_function.spawn(rank=rank, **spawn_kwargs)
except Exception as e:
# Best-effort cancel any already-launched siblings
for prior in handles:
try:
self.cancel(prior)
except Exception:
pass
raise RuntimeError(
f"ModalSpawnExecutor.launch_replicas failed at rank={rank} "
f"of {n_replicas} (already-launched siblings cancelled). "
f"Underlying error: {e!r}"
) from e
handle = ReplicaHandle(
rank=rank,
backend_name=self.backend_name,
metadata={
"call_id": fcall.object_id,
"spawn_ts": time.time(),
},
)
self._handles[rank] = {
"fcall": fcall,
"result": None,
}
handles.append(handle)
return handles
def poll(self, handle: ReplicaHandle) -> str:
"""Poll a Modal call's status.
Modal's FunctionCall doesn't expose a non-blocking status getter
directly (the API is `.get(timeout=...)`), so we poll by trying
`.get(timeout=0)` and treating Timeout/Pending as "running".
Returns one of: "pending" | "running" | "succeeded" | "failed" |
"cancelled".
"""
meta = self._handles.get(handle.rank)
if meta is None:
return "cancelled"
# If we already collected this one, return cached result
if meta["result"] is not None:
return meta["result"]["status"]
import modal
from modal.exception import OutputExpiredError
fcall = meta["fcall"]
# Re-hydrate to get fresh state
try:
# `.get(timeout=0)` returns immediately if done; raises TimeoutError otherwise.
result_value = fcall.get(timeout=0)
meta["result"] = {
"rank": handle.rank,
"status": "succeeded",
"exit_code": 0,
"error": None,
"result": result_value,
"call_id": handle.metadata.get("call_id"),
}
return "succeeded"
except TimeoutError:
return "running"
except OutputExpiredError as e:
meta["result"] = {
"rank": handle.rank,
"status": "failed",
"exit_code": 1,
"error": f"OutputExpiredError: {e!r}",
"result": None,
"call_id": handle.metadata.get("call_id"),
}
return "failed"
except Exception as e:
# User-code exception bubbles up here as the original exception class
meta["result"] = {
"rank": handle.rank,
"status": "failed",
"exit_code": 1,
"error": f"{type(e).__name__}: {e!r}",
"result": None,
"call_id": handle.metadata.get("call_id"),
}
return "failed"
def stream_logs(self, handle: ReplicaHandle, *, n_lines: int = 200) -> str:
"""Read recent Modal logs for this call.
Modal exposes per-FunctionCall logs via the dashboard URL. The
client API doesn't expose log-streaming directly in 1.4.x, so we
return a pointer to the dashboard URL plus any captured error
from poll().
"""
meta = self._handles.get(handle.rank)
if meta is None:
return f"<replica {handle.rank}: no metadata>"
call_id = handle.metadata.get("call_id", "<unknown>")
try:
dashboard_url = meta["fcall"].get_dashboard_url()
except Exception:
dashboard_url = (
f"https://modal.com/apps/<workspace>/<env>/calls/{call_id}"
)
if meta.get("result"):
err = meta["result"].get("error") or "<no error>"
return (
f"[rank {handle.rank}] call_id={call_id}\n"
f" Dashboard: {dashboard_url}\n"
f" Result: {meta['result']['status']}\n"
f" Error: {err[-2000:] if err else '<none>'}"
)
return (
f"[rank {handle.rank}] call_id={call_id} (still running)\n"
f" Dashboard: {dashboard_url}\n"
f" Logs not streamable via client API in modal-client 1.4.x; "
f"use the dashboard URL or `modal app logs <app-id>`."
)
def cancel(self, handle: ReplicaHandle) -> None:
"""Best-effort cancel of a Modal call."""
meta = self._handles.get(handle.rank)
if meta is None:
return
try:
meta["fcall"].cancel()
except Exception:
# Already terminated, network blip, etc. — best-effort.
pass
def collect(
self,
handles: list[ReplicaHandle],
*,
timeout: int | None = None,
) -> list[dict[str, Any]]:
"""Block until all replicas finish; return per-replica result dicts.
Modal's `.get(timeout=...)` blocks until the call completes or
the timeout elapses. We iterate handles, calling `.get()` with
the remaining time budget, so the cumulative wall-clock is
bounded by `timeout`.
"""
deadline = time.time() + (timeout if timeout is not None else 86400)
results: list[dict[str, Any]] = []
for h in handles:
meta = self._handles.get(h.rank)
if meta is None:
results.append({
"rank": h.rank,
"status": "cancelled",
"exit_code": None,
"error": "handle has no metadata (cancelled or unknown)",
"result": None,
"call_id": h.metadata.get("call_id"),
})
continue
# Already collected by an earlier poll()
if meta["result"] is not None:
results.append(meta["result"])
continue
remaining = max(0.0, deadline - time.time())
try:
result_value = meta["fcall"].get(timeout=remaining)
result_dict = {
"rank": h.rank,
"status": "succeeded",
"exit_code": 0,
"error": None,
"result": result_value,
"call_id": h.metadata.get("call_id"),
}
except TimeoutError as e:
result_dict = {
"rank": h.rank,
"status": "running",
"exit_code": None,
"error": f"TimeoutError after deadline: {e!r}",
"result": None,
"call_id": h.metadata.get("call_id"),
}
except Exception as e:
result_dict = {
"rank": h.rank,
"status": "failed",
"exit_code": 1,
"error": f"{type(e).__name__}: {e!r}",
"result": None,
"call_id": h.metadata.get("call_id"),
}
meta["result"] = result_dict
results.append(result_dict)
return results
__all__ = ["ModalSpawnExecutor"]