Baladithya Balamurugan
Wave 3: close the HIGH review findings (kill-switch wiring, HeldoutSplit, EKS entrypoint bug)
bd0c358
Raw
History Blame Contribute Delete
5.2 kB
"""Replica entrypoint — what each serverless replica runs.
This is the script invoked by `LocalProcessExecutor`, `ModalExecutor`,
`HFJobsExecutor`, etc. It learns its rank from the `REPLICA_RANK` env
var, sets up `ObjectStoreAllReduce` against the shared rendezvous URI,
wraps it in a `MockManager`, and hands it off to the user's training
function.
Usage from an executor:
>>> executor.launch_replicas(
... n_replicas=4,
... entrypoint="composer_replication.diloco.serverless.replica_entrypoint",
... entrypoint_args={
... "rendezvous_uri": "/tmp/run42/",
... "world_size": 4,
... "trainer_module": "my_project.trainer",
... "trainer_fn": "train",
... "trainer_kwargs": {"model_name": "Qwen/Qwen2.5-0.5B"},
... },
... )
The entrypoint expects:
- `REPLICA_RANK` env var set to the rank (0..world_size-1)
- `rendezvous_uri`: fsspec URI for object-store rendezvous
- `world_size`: total replicas
- `trainer_module`, `trainer_fn`: importable path to the user's train fn
- `trainer_kwargs`: dict passed to the user's train fn, plus an injected
`manager` kwarg containing the `MockManager`
"""
from __future__ import annotations
import importlib
import os
from typing import Any
def main(
rendezvous_uri: str,
world_size: int,
trainer_module: str,
trainer_fn: str = "train",
trainer_kwargs: dict[str, Any] | None = None,
) -> Any:
"""Entrypoint executed inside each replica.
Args:
rendezvous_uri: fsspec URI (or local path) for the rendezvous
world_size: total replicas
trainer_module: importable Python module containing the user's
train function
trainer_fn: name of the function to call (default "train")
trainer_kwargs: kwargs passed to the train function
Returns:
Whatever the train function returns.
"""
from composer_replication.diloco.serverless.allreduce import (
MockManager,
ObjectStoreAllReduce,
)
rank_str = os.environ.get("REPLICA_RANK")
if rank_str is None:
raise RuntimeError(
"REPLICA_RANK env var not set. The serverless executor "
"should set this for each replica."
)
rank = int(rank_str)
if not (0 <= rank < world_size):
raise ValueError(f"REPLICA_RANK={rank} not in [0, {world_size})")
store = ObjectStoreAllReduce(
uri=rendezvous_uri,
rank=rank,
world_size=world_size,
)
manager = MockManager(store)
mod = importlib.import_module(trainer_module)
fn = getattr(mod, trainer_fn)
kwargs = dict(trainer_kwargs or {})
kwargs["manager"] = manager # injected
kwargs["rank"] = rank
kwargs["world_size"] = world_size
return fn(**kwargs)
if __name__ == "__main__":
import argparse
import json
# Dual input contract (both backends supported):
# * argv — SageMakerExecutor / LocalProcessExecutor pass the run config as
# `--rendezvous/--world-size/--trainer-module` ContainerArguments.
# * env — EKSExecutor (and any backend that prefers a pure-env contract,
# since k8s Indexed Jobs already inject REPLICA_RANK via the downward API)
# pass the SAME values as RENDEZVOUS_URI / WORLD_SIZE / TRAINER_MODULE
# env vars. The argv flags are therefore NOT `required=True`: when absent
# we fall back to the env vars, and only error if NEITHER source supplies
# a mandatory field. This is the R3 fix — previously the argparse block
# hard-required argv, so an EKS pod (env-only) crashed at arg-parsing.
parser = argparse.ArgumentParser()
parser.add_argument("--rendezvous", default=None)
parser.add_argument("--world-size", type=int, default=None)
parser.add_argument("--trainer-module", default=None)
parser.add_argument("--trainer-fn", default=None)
parser.add_argument("--trainer-kwargs-json", default=None)
args = parser.parse_args()
def _resolve(arg_val, env_key, *, required, cast=lambda x: x):
if arg_val is not None:
return arg_val
env_val = os.environ.get(env_key)
if env_val is not None:
return cast(env_val)
if required:
raise SystemExit(
f"replica_entrypoint: missing '{env_key}' — supply it via the "
f"argv flag or the {env_key} environment variable "
f"(EKSExecutor uses env; SageMaker/Local use argv)."
)
return None
rendezvous = _resolve(args.rendezvous, "RENDEZVOUS_URI", required=True)
world_size = _resolve(args.world_size, "WORLD_SIZE", required=True, cast=int)
trainer_module = _resolve(args.trainer_module, "TRAINER_MODULE", required=True)
trainer_fn = _resolve(args.trainer_fn, "TRAINER_FN", required=False) or "train"
kwargs_json = _resolve(
args.trainer_kwargs_json, "TRAINER_KWARGS_JSON", required=False
) or "{}"
main(
rendezvous_uri=rendezvous,
world_size=world_size,
trainer_module=trainer_module,
trainer_fn=trainer_fn,
trainer_kwargs=json.loads(kwargs_json),
)