"""SageMakerExecutor — production boto3-backed serverless executor. This is a fully-working cloud adapter (the sibling of `ModalSpawnExecutor`, not the loud-failing `modal.py` / `hf_jobs.py` skeletons). It implements the `ServerlessExecutor` Protocol against Amazon SageMaker Training Jobs via the boto3 low-level `sagemaker` client. Design choices -------------- 1. **N independent single-instance jobs, NOT one multi-instance job.** SageMaker's *native* distributed training (``ResourceConfig.InstanceCount > 1``) groups instances into ONE job with an in-cluster NCCL/MPI fabric wired via ``/opt/ml/input/config/resourceconfig.json``. That is the WRONG model for DiLoCo replicas — it would couple replicas through SageMaker's intra-job network and break the "each replica is an independent DiLoCo worker that syncs only through S3" design. So ``launch_replicas`` submits N **separate** training jobs, each with ``ResourceConfig.InstanceCount == 1``, tagged with ``REPLICA_RANK=i`` / ``WORLD_SIZE=N`` via the ``Environment`` map. This mirrors ``ModalSpawnExecutor`` spawning N independent Modal calls. 2. **Same S3 ``ObjectStoreAllReduce`` rendezvous — DiLoCo math untouched.** Cross-replica communication is EXCLUSIVELY the object-store rendezvous; the executor passes ``rendezvous_uri`` (an ``s3://...`` URI) through to ``replica_entrypoint.py`` unchanged. ``allreduce.py`` / ``MockManager`` / ``make_diloco_outer_loop`` / the trainer all stay byte-for-byte identical. 3. **Stateless after launch; rank via ``Environment``.** Handle metadata is the ``training_job_name`` (plus submit timestamp). ``replica_entrypoint.py`` already reads ``REPLICA_RANK`` from ``os.environ``, so the cleanest channel is the ``Environment`` map (string->string, max 100 entries, value <= 512 chars). The container command is baked into the image entrypoint and the rendezvous args are passed via ``AlgorithmSpecification.ContainerArguments``. 4. **``supports_inter_replica_network = False``.** Separate single-instance training jobs have no mutual network path by design — they rendezvous only through S3. (SageMaker's algo-N container fabric and ``EnableInterContainerTrafficEncryption`` only exist WITHIN a single multi-instance job, which this design deliberately does not use.) Load-bearing gotcha — ``EnableNetworkIsolation`` MUST stay ``False`` -------------------------------------------------------------------- When ``EnableNetworkIsolation=True`` the training *container* has no outbound network access. SageMaker's host-side processes still stage input channels and ship CloudWatch logs, but the container itself cannot make S3 GET/PUT calls. ``ObjectStoreAllReduce`` needs live S3 PUT+GET every outer round, so network isolation would silently dead-lock the allreduce poll loop until its timeout. This executor pins ``EnableNetworkIsolation=False`` (the API default) and never exposes it as a knob. The rendezvous bucket access must instead be granted on the execution ``RoleArn`` — the SageMaker analog of EKS IRSA. HyperPod <-> EKS 1:1 control-plane mapping (recommended hybrid) --------------------------------------------------------------- Per the SageMaker docs: *"The high-level architecture of Amazon EKS support in HyperPod involves a 1-to-1 mapping between an EKS cluster (control plane) and a HyperPod cluster (worker nodes) within a VPC."* (https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-hyperpod-eks.html) Consequence for this repo's hybrid: "use HyperPod for the inner GRPO trainer" does NOT mean leaving EKS — it means attaching a HyperPod-managed (auto-recovering, deep-health-checked, PyTorch-job auto-resume) node-group to the SAME EKS control plane that runs the outer loop. A future ``EKSExecutor`` (kubernetes client, Indexed Jobs) therefore targets both plain Karpenter GPU nodes AND HyperPod nodes transparently. ``SageMakerExecutor`` (ephemeral Training Jobs via boto3) is the SEPARATE bursty-fallback inner-loop path for when you don't want a persistent cluster: Training Jobs suit periodic / smaller-model / pay-per-use runs; HyperPod suits continuous / large-model / persistent runs. Both share the IDENTICAL S3 rendezvous, so a run can move between them with zero trainer / loss / DiLoCo changes. References ---------- - create_training_job: https://docs.aws.amazon.com/boto3/latest/reference/services/sagemaker/client/create_training_job.html - describe_training_job: https://docs.aws.amazon.com/boto3/latest/reference/services/sagemaker/client/describe_training_job.html - stop_training_job: https://docs.aws.amazon.com/boto3/latest/reference/services/sagemaker/client/stop_training_job.html - network isolation: https://repost.aws/knowledge-center/sagemaker-access-network-isolation - HyperPod-EKS: https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-hyperpod-eks.html - ADR-005 (executor protocol design) """ from __future__ import annotations import json import time import uuid from collections.abc import Callable, Mapping from typing import Any from composer_replication.diloco.serverless.executor import ( ReplicaHandle, ) # SageMaker TrainingJobStatus -> Protocol status vocabulary. # describe_training_job's TrainingJobStatus is EXACTLY one of: # 'InProgress' | 'Completed' | 'Failed' | 'Stopping' | 'Stopped'. # We map Stopping -> 'running' (transient; still terminating, so collect() # keeps waiting) and Stopped -> 'cancelled'. _STATUS_MAP = { "InProgress": "running", "Completed": "succeeded", "Failed": "failed", "Stopping": "running", "Stopped": "cancelled", } # SecondaryStatus values that mean "queued / not yet executing user code" — # used to refine an InProgress job into the Protocol's 'pending'. _PENDING_SECONDARY = frozenset( {"Starting", "Pending", "LaunchingMLInstances", "PreparingTrainingStack"} ) # Abstract Protocol GPU strings -> SageMaker instance types. _GPU_INSTANCE_MAP = { "A100": "ml.p4d.24xlarge", "H100": "ml.p5.48xlarge", "H200": "ml.p5e.48xlarge", "B200": "ml.p6-b200.48xlarge", "L40S": "ml.g6e.12xlarge", "A10G": "ml.g5.2xlarge", "L4": "ml.g6.2xlarge", } _CLOUDWATCH_LOG_GROUP = "/aws/sagemaker/TrainingJobs" class SageMakerExecutor: """Run replicas as N independent SageMaker Training Jobs. Implements the `ServerlessExecutor` Protocol against the boto3 ``sagemaker`` client. Each replica is one single-instance training job; cross-replica communication happens only through the shared S3 ``ObjectStoreAllReduce`` rendezvous. Args: role_arn: IAM execution role SageMaker assumes for the job. Must grant S3 access to the rendezvous + output buckets (the boto3 analog of EKS IRSA). The caller's credentials need ``iam:PassRole`` on it. image_uri: ECR image URI for the training container. The image must bake an entrypoint that runs ``python -m composer_replication.diloco.serverless.replica_entrypoint`` (this executor also passes ``ContainerEntrypoint`` explicitly so a generic image works too). output_s3_path: ``s3://...`` prefix for ``OutputDataConfig.S3OutputPath`` (model artifacts / failure output). instance_type: default SageMaker instance type when ``gpu`` is not mapped (e.g. ``"ml.g5.2xlarge"``). ``gpu=None`` at launch falls back to ``cpu_instance_type``. cpu_instance_type: instance type used when ``gpu`` is ``None`` (CPU smoke tests). Default ``"ml.m5.xlarge"``. volume_size_gb: ``ResourceConfig.VolumeSizeInGB`` per job. run_id: prefix for generated training-job names. Defaults to a short random token so names are unique per region+account. region: AWS region for the lazily-constructed boto3 clients. ``None`` uses the ambient boto3 default-region resolution. sagemaker_client: inject a pre-built ``boto3.client('sagemaker')`` (or a mock) instead of constructing one. Used by tests. logs_client: inject a pre-built ``boto3.client('logs')`` (or a mock). Raises: RuntimeError: if boto3 is not installed and no client was injected. """ backend_name = "sagemaker" # Separate single-instance jobs have no mutual network path — S3 only. supports_inter_replica_network = False def __init__( self, *, role_arn: str, image_uri: str, output_s3_path: str, instance_type: str = "ml.g5.2xlarge", cpu_instance_type: str = "ml.m5.xlarge", volume_size_gb: int = 100, run_id: str | None = None, region: str | None = None, sagemaker_client: Any = None, logs_client: Any = None, ) -> None: self.role_arn = role_arn self.image_uri = image_uri self.output_s3_path = output_s3_path self.instance_type = instance_type self.cpu_instance_type = cpu_instance_type self.volume_size_gb = volume_size_gb self.run_id = run_id or f"diloco-{uuid.uuid4().hex[:8]}" self._region = region # Lazy boto3 — only constructed if the caller didn't inject a client. # This keeps `import composer_replication.diloco.serverless` free of a # hard boto3 dependency (boto3 lives in the optional [aws] extra), and # lets tests inject a _MockSMClient with zero AWS calls. if sagemaker_client is None: sagemaker_client = self._make_boto3_client("sagemaker") self._client = sagemaker_client self._logs_client = logs_client # built lazily on first stream_logs() # rank -> {"job_name": str, "result": dict | None} self._handles: dict[int, dict[str, Any]] = {} # ----------------------------------------------------------------- # boto3 plumbing (lazy) # ----------------------------------------------------------------- def _make_boto3_client(self, service: str) -> Any: try: import boto3 except ImportError as e: raise RuntimeError( "SageMakerExecutor requires boto3. Install with " "`pip install -e .[aws]` (or `pip install boto3`). " f"Got: {e!r}" ) from e if self._region is not None: return boto3.client(service, region_name=self._region) return boto3.client(service) def _map_gpu(self, gpu: str | None) -> str: """Translate the Protocol's abstract gpu string to an instance type. ``gpu=None`` -> ``cpu_instance_type`` (smoke tests). Unrecognized gpu strings fall back to ``instance_type`` (so a caller can pass a literal SageMaker instance type and it's honoured if not in the map). """ if gpu is None: return self.cpu_instance_type if gpu in _GPU_INSTANCE_MAP: return _GPU_INSTANCE_MAP[gpu] # Caller may have passed a literal "ml.*" instance type. if gpu.startswith("ml."): return gpu return self.instance_type def _job_name(self, rank: int) -> str: """Build a unique, regex-safe training-job name (<= 63 chars). Pattern required by the API: ``[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}``. """ name = f"{self.run_id}-r{rank:04d}-{int(time.time())}" return name[:63] # ----------------------------------------------------------------- # ServerlessExecutor Protocol # ----------------------------------------------------------------- def launch_replicas( self, n_replicas: int, entrypoint: str | Callable[..., Any], entrypoint_args: Mapping[str, Any], *, gpu: str | None = None, timeout: int = 3600, ) -> list[ReplicaHandle]: """Submit N independent single-instance SageMaker Training Jobs. Args: n_replicas: number of replicas (= number of training jobs). entrypoint: ignored — the container command is baked into the image / passed as ``ContainerEntrypoint``. Kept for Protocol compatibility. entrypoint_args: must contain ``rendezvous_uri`` (``s3://...``) and ``trainer_module``. Optional: ``trainer_fn`` (default ``"train"``), ``trainer_kwargs`` (dict, JSON-encoded into the container args). The conventional ``rank_env`` key (from ``LocalProcessExecutor``) is ignored — rank goes through the ``Environment`` map instead. gpu: abstract GPU spec mapped to an instance type via ``_map_gpu``. ``None`` -> CPU instance. timeout: ``StoppingCondition.MaxRuntimeInSeconds`` per job. Returns: ``list[ReplicaHandle]`` of length ``n_replicas`` in rank order (``handles[i].rank == i``). """ del entrypoint # container command is baked / passed explicitly if n_replicas < 1: raise ValueError(f"n_replicas must be >= 1, got {n_replicas}") rendezvous_uri = entrypoint_args.get("rendezvous_uri") if not rendezvous_uri: raise ValueError( "entrypoint_args must include 'rendezvous_uri' (the s3:// " "ObjectStoreAllReduce rendezvous prefix)." ) trainer_module = entrypoint_args.get("trainer_module") if not trainer_module: raise ValueError( "entrypoint_args must include 'trainer_module' (importable " "module path of the user's train function)." ) trainer_fn = entrypoint_args.get("trainer_fn", "train") trainer_kwargs = entrypoint_args.get("trainer_kwargs", {}) instance_type = self._map_gpu(gpu) # Container args: each element is a SINGLE token (StackOverflow # 77994925 — `['--world-size', '4']` NOT `['--world-size 4']`). container_args = [ "--rendezvous", str(rendezvous_uri), "--world-size", str(n_replicas), "--trainer-module", str(trainer_module), "--trainer-fn", str(trainer_fn), "--trainer-kwargs-json", json.dumps(trainer_kwargs), ] handles: list[ReplicaHandle] = [] for rank in range(n_replicas): job_name = self._job_name(rank) request = { "TrainingJobName": job_name, "AlgorithmSpecification": { "TrainingImage": self.image_uri, "TrainingInputMode": "File", "ContainerEntrypoint": [ "python", "-m", "composer_replication.diloco.serverless.replica_entrypoint", ], "ContainerArguments": container_args, }, "RoleArn": self.role_arn, # InputDataConfig intentionally omitted — the replica pulls # data via its own code / the S3 rendezvous, not SM channels. "OutputDataConfig": {"S3OutputPath": self.output_s3_path}, "ResourceConfig": { "InstanceType": instance_type, "InstanceCount": 1, "VolumeSizeInGB": self.volume_size_gb, }, "StoppingCondition": {"MaxRuntimeInSeconds": int(timeout)}, # REPLICA_RANK / WORLD_SIZE injected as container env vars; # replica_entrypoint.py reads os.environ['REPLICA_RANK']. "Environment": { "REPLICA_RANK": str(rank), "WORLD_SIZE": str(n_replicas), "RENDEZVOUS_URI": str(rendezvous_uri), }, # MUST stay False — True severs the container's S3 access and # dead-locks the allreduce poll loop. See module docstring. "EnableNetworkIsolation": False, } try: self._client.create_training_job(**request) except Exception as e: # Best-effort stop of already-launched siblings, then raise. for prior in handles: try: self.cancel(prior) except Exception: pass raise RuntimeError( f"SageMakerExecutor.launch_replicas failed at rank={rank} " f"of {n_replicas} (already-launched siblings stopped). " f"Underlying error: {e!r}" ) from e handle = ReplicaHandle( rank=rank, backend_name=self.backend_name, metadata={ "training_job_name": job_name, "submit_ts": time.time(), }, ) self._handles[rank] = {"job_name": job_name, "result": None} handles.append(handle) return handles def poll(self, handle: ReplicaHandle) -> str: """Poll a training job's status. Returns one of: ``"pending"`` | ``"running"`` | ``"succeeded"`` | ``"failed"`` | ``"cancelled"``. Maps ``describe_training_job``'s ``TrainingJobStatus`` via ``_STATUS_MAP``; refines ``InProgress`` to ``"pending"`` while the job is still queued (``SecondaryStatus`` in ``_PENDING_SECONDARY``). A vanished job (``ResourceNotFound``) is treated as ``"cancelled"``. """ meta = self._handles.get(handle.rank) if meta is None: return "cancelled" if meta["result"] is not None: return meta["result"]["status"] job_name = meta["job_name"] try: resp = self._client.describe_training_job(TrainingJobName=job_name) except Exception as e: if self._is_resource_not_found(e): return "cancelled" raise sm_status = resp.get("TrainingJobStatus", "InProgress") mapped = _STATUS_MAP.get(sm_status, "running") if sm_status == "InProgress": if resp.get("SecondaryStatus") in _PENDING_SECONDARY: return "pending" return "running" # Terminal — cache a result dict so collect()/repeat-poll are cheap. meta["result"] = self._terminal_result(handle.rank, sm_status, resp) return mapped def stream_logs(self, handle: ReplicaHandle, *, n_lines: int = 200) -> str: """Read recent CloudWatch logs for this replica's training job. SageMaker writes container stdout/stderr to the ``/aws/sagemaker/TrainingJobs`` log group, stream ``/algo--``. We discover the exact stream name by prefix then read the tail. Falls back to a CloudWatch console pointer on any error (mirrors ModalSpawnExecutor's dashboard-URL fallback). """ meta = self._handles.get(handle.rank) if meta is None: return f"" job_name = meta["job_name"] try: logs = self._logs() prefix = f"{job_name}/" streams = logs.describe_log_streams( logGroupName=_CLOUDWATCH_LOG_GROUP, logStreamNamePrefix=prefix, orderBy="LastEventTime", descending=True, limit=1, ) stream_list = streams.get("logStreams", []) if not stream_list: return ( f"[rank {handle.rank}] job={job_name}: no CloudWatch log " f"stream yet (job pending / not started)." ) stream_name = stream_list[0]["logStreamName"] events = logs.get_log_events( logGroupName=_CLOUDWATCH_LOG_GROUP, logStreamName=stream_name, limit=n_lines, startFromHead=False, ) lines = [e.get("message", "") for e in events.get("events", [])] body = "\n".join(lines) if lines else "" return f"[rank {handle.rank}] job={job_name} stream={stream_name}\n{body}" except Exception as e: region = self._region or "" url = ( f"https://{region}.console.aws.amazon.com/cloudwatch/home" f"?region={region}#logsV2:log-groups/log-group/" f"$252Faws$252Fsagemaker$252FTrainingJobs" ) return ( f"[rank {handle.rank}] job={job_name}: log fetch failed " f"({type(e).__name__}: {e!r}).\n CloudWatch console: {url}" ) def cancel(self, handle: ReplicaHandle) -> None: """Best-effort stop of a training job. Calls ``stop_training_job`` (SIGTERM + 120s grace), swallowing ``ResourceNotFound`` and "already terminal" ``ValidationException`` so the contract — "no exception if already terminated" — holds. """ meta = self._handles.get(handle.rank) if meta is None: return try: self._client.stop_training_job(TrainingJobName=meta["job_name"]) except Exception as e: # R5: swallow ONLY already-terminated signals — a vanished job # (ResourceNotFound) or an already-Completed/Stopped job (boto3 # raises ValidationException for "cannot stop a job in status X"). # A genuinely unexpected error (AccessDenied, throttling that # outlived retries, malformed request) must propagate rather than # masquerade as a successful cancel. if self._is_resource_not_found(e) or self._is_already_terminal(e): return raise def collect( self, handles: list[ReplicaHandle], *, timeout: int | None = None, ) -> list[dict[str, Any]]: """Block until all replicas finish; return per-replica result dicts. Polls ``describe_training_job`` per handle until the job reaches a terminal status (``Completed`` / ``Failed`` / ``Stopped``) or the shared deadline elapses. Returns results aligned to the input handle order (Protocol contract; mirrors ``LocalProcessExecutor.collect``). Each result dict has at least ``{"rank", "status", "exit_code", "error"}``. """ deadline = time.time() + (timeout if timeout is not None else 86400) poll_interval = 30.0 results: list[dict[str, Any]] = [] for h in handles: meta = self._handles.get(h.rank) if meta is None: results.append({ "rank": h.rank, "status": "cancelled", "exit_code": None, "error": "handle has no metadata (cancelled or unknown)", "result": None, "training_job_name": h.metadata.get("training_job_name"), }) continue # Already cached by an earlier poll()/collect(). if meta["result"] is not None: results.append(meta["result"]) continue job_name = meta["job_name"] result_dict: dict[str, Any] | None = None while True: try: resp = self._client.describe_training_job( TrainingJobName=job_name ) except Exception as e: if self._is_resource_not_found(e): result_dict = { "rank": h.rank, "status": "cancelled", "exit_code": None, "error": "training job not found (deleted?)", "result": None, "training_job_name": job_name, } break raise sm_status = resp.get("TrainingJobStatus", "InProgress") if sm_status in ("Completed", "Failed", "Stopped"): result_dict = self._terminal_result(h.rank, sm_status, resp) break if time.time() >= deadline: result_dict = { "rank": h.rank, "status": "running", "exit_code": None, "error": "timeout before terminal", "result": None, "training_job_name": job_name, } break # Sleep, but never overrun the deadline. time.sleep(min(poll_interval, max(0.0, deadline - time.time()))) # Cache only terminal results (not the timeout 'running' sentinel, # so a later collect() can re-check the job). if result_dict["status"] in ("succeeded", "failed", "cancelled"): meta["result"] = result_dict results.append(result_dict) return results # ----------------------------------------------------------------- # Helpers # ----------------------------------------------------------------- def _logs(self) -> Any: """Lazily build the CloudWatch Logs client (separate from sagemaker).""" if self._logs_client is None: self._logs_client = self._make_boto3_client("logs") return self._logs_client @staticmethod def _terminal_result( rank: int, sm_status: str, resp: Mapping[str, Any] ) -> dict[str, Any]: """Build a result dict from a terminal describe_training_job response.""" mapped = _STATUS_MAP.get(sm_status, "failed") if sm_status == "Completed": exit_code: int | None = 0 error = None elif sm_status == "Stopped": exit_code = None error = resp.get("FailureReason") else: # Failed exit_code = 1 error = resp.get("FailureReason") or "training job failed" artifacts = resp.get("ModelArtifacts", {}) or {} return { "rank": rank, "status": mapped, "exit_code": exit_code, "error": error, "result": artifacts.get("S3ModelArtifacts"), "training_job_name": resp.get("TrainingJobName"), } def _is_resource_not_found(self, exc: Exception) -> bool: """True if ``exc`` is the boto3 ResourceNotFound for the sagemaker client. Handles both the typed client exception (``client.exceptions.ResourceNotFound``) and a generic botocore ``ClientError`` whose error code is ``ResourceNotFound`` / ``ValidationException`` naming a missing job — robust across whether a real boto3 client or a mock is in use. """ rnf = getattr(getattr(self._client, "exceptions", None), "ResourceNotFound", None) if rnf is not None and isinstance(exc, rnf): return True # Generic botocore ClientError fallback. resp = getattr(exc, "response", None) if isinstance(resp, Mapping): code = resp.get("Error", {}).get("Code", "") if code in ("ResourceNotFound", "ValidationException"): return True return False def _is_already_terminal(self, exc: Exception) -> bool: """True if ``exc`` is the boto3 "cannot stop a job in status X" error. ``stop_training_job`` raises a ``ValidationException`` when the job is already Completed/Failed/Stopped — that is an idempotent no-op for cancel(), distinct from a genuinely unexpected error. Matched on the ClientError code + message text (robust to a mock raising a plain Exception whose message carries the phrase). """ resp = getattr(exc, "response", None) if isinstance(resp, Mapping): err = resp.get("Error", {}) if err.get("Code") == "ValidationException": return True msg = str(exc).lower() return ( "cannot be stopped" in msg or "already" in msg and ("stopped" in msg or "complete" in msg or "terminal" in msg) ) __all__ = ["SageMakerExecutor"]