"""Replica entrypoint — what each serverless replica runs. This is the script invoked by `LocalProcessExecutor`, `ModalExecutor`, `HFJobsExecutor`, etc. It learns its rank from the `REPLICA_RANK` env var, sets up `ObjectStoreAllReduce` against the shared rendezvous URI, wraps it in a `MockManager`, and hands it off to the user's training function. Usage from an executor: >>> executor.launch_replicas( ... n_replicas=4, ... entrypoint="composer_replication.diloco.serverless.replica_entrypoint", ... entrypoint_args={ ... "rendezvous_uri": "/tmp/run42/", ... "world_size": 4, ... "trainer_module": "my_project.trainer", ... "trainer_fn": "train", ... "trainer_kwargs": {"model_name": "Qwen/Qwen2.5-0.5B"}, ... }, ... ) The entrypoint expects: - `REPLICA_RANK` env var set to the rank (0..world_size-1) - `rendezvous_uri`: fsspec URI for object-store rendezvous - `world_size`: total replicas - `trainer_module`, `trainer_fn`: importable path to the user's train fn - `trainer_kwargs`: dict passed to the user's train fn, plus an injected `manager` kwarg containing the `MockManager` """ from __future__ import annotations import importlib import os from typing import Any def main( rendezvous_uri: str, world_size: int, trainer_module: str, trainer_fn: str = "train", trainer_kwargs: dict[str, Any] | None = None, ) -> Any: """Entrypoint executed inside each replica. Args: rendezvous_uri: fsspec URI (or local path) for the rendezvous world_size: total replicas trainer_module: importable Python module containing the user's train function trainer_fn: name of the function to call (default "train") trainer_kwargs: kwargs passed to the train function Returns: Whatever the train function returns. """ from composer_replication.diloco.serverless.allreduce import ( MockManager, ObjectStoreAllReduce, ) rank_str = os.environ.get("REPLICA_RANK") if rank_str is None: raise RuntimeError( "REPLICA_RANK env var not set. The serverless executor " "should set this for each replica." ) rank = int(rank_str) if not (0 <= rank < world_size): raise ValueError(f"REPLICA_RANK={rank} not in [0, {world_size})") store = ObjectStoreAllReduce( uri=rendezvous_uri, rank=rank, world_size=world_size, ) manager = MockManager(store) mod = importlib.import_module(trainer_module) fn = getattr(mod, trainer_fn) kwargs = dict(trainer_kwargs or {}) kwargs["manager"] = manager # injected kwargs["rank"] = rank kwargs["world_size"] = world_size return fn(**kwargs) if __name__ == "__main__": import argparse import json # Dual input contract (both backends supported): # * argv — SageMakerExecutor / LocalProcessExecutor pass the run config as # `--rendezvous/--world-size/--trainer-module` ContainerArguments. # * env — EKSExecutor (and any backend that prefers a pure-env contract, # since k8s Indexed Jobs already inject REPLICA_RANK via the downward API) # pass the SAME values as RENDEZVOUS_URI / WORLD_SIZE / TRAINER_MODULE # env vars. The argv flags are therefore NOT `required=True`: when absent # we fall back to the env vars, and only error if NEITHER source supplies # a mandatory field. This is the R3 fix — previously the argparse block # hard-required argv, so an EKS pod (env-only) crashed at arg-parsing. parser = argparse.ArgumentParser() parser.add_argument("--rendezvous", default=None) parser.add_argument("--world-size", type=int, default=None) parser.add_argument("--trainer-module", default=None) parser.add_argument("--trainer-fn", default=None) parser.add_argument("--trainer-kwargs-json", default=None) args = parser.parse_args() def _resolve(arg_val, env_key, *, required, cast=lambda x: x): if arg_val is not None: return arg_val env_val = os.environ.get(env_key) if env_val is not None: return cast(env_val) if required: raise SystemExit( f"replica_entrypoint: missing '{env_key}' — supply it via the " f"argv flag or the {env_key} environment variable " f"(EKSExecutor uses env; SageMaker/Local use argv)." ) return None rendezvous = _resolve(args.rendezvous, "RENDEZVOUS_URI", required=True) world_size = _resolve(args.world_size, "WORLD_SIZE", required=True, cast=int) trainer_module = _resolve(args.trainer_module, "TRAINER_MODULE", required=True) trainer_fn = _resolve(args.trainer_fn, "TRAINER_FN", required=False) or "train" kwargs_json = _resolve( args.trainer_kwargs_json, "TRAINER_KWARGS_JSON", required=False ) or "{}" main( rendezvous_uri=rendezvous, world_size=world_size, trainer_module=trainer_module, trainer_fn=trainer_fn, trainer_kwargs=json.loads(kwargs_json), )