Baladithya Balamurugan
Wave 3 cleanup: close deferred-LOW review items R5/R6/R11
7d9dbbc
Raw
History Blame Contribute Delete
28.5 kB
"""SageMakerExecutor — production boto3-backed serverless executor.
This is a fully-working cloud adapter (the sibling of `ModalSpawnExecutor`,
not the loud-failing `modal.py` / `hf_jobs.py` skeletons). It implements the
`ServerlessExecutor` Protocol against Amazon SageMaker Training Jobs via the
boto3 low-level `sagemaker` client.
Design choices
--------------
1. **N independent single-instance jobs, NOT one multi-instance job.**
SageMaker's *native* distributed training (``ResourceConfig.InstanceCount > 1``)
groups instances into ONE job with an in-cluster NCCL/MPI fabric wired via
``/opt/ml/input/config/resourceconfig.json``. That is the WRONG model for
DiLoCo replicas — it would couple replicas through SageMaker's intra-job
network and break the "each replica is an independent DiLoCo worker that
syncs only through S3" design. So ``launch_replicas`` submits N **separate**
training jobs, each with ``ResourceConfig.InstanceCount == 1``, tagged with
``REPLICA_RANK=i`` / ``WORLD_SIZE=N`` via the ``Environment`` map. This
mirrors ``ModalSpawnExecutor`` spawning N independent Modal calls.
2. **Same S3 ``ObjectStoreAllReduce`` rendezvous — DiLoCo math untouched.**
Cross-replica communication is EXCLUSIVELY the object-store rendezvous; the
executor passes ``rendezvous_uri`` (an ``s3://...`` URI) through to
``replica_entrypoint.py`` unchanged. ``allreduce.py`` / ``MockManager`` /
``make_diloco_outer_loop`` / the trainer all stay byte-for-byte identical.
3. **Stateless after launch; rank via ``Environment``.** Handle metadata is the
``training_job_name`` (plus submit timestamp). ``replica_entrypoint.py``
already reads ``REPLICA_RANK`` from ``os.environ``, so the cleanest channel
is the ``Environment`` map (string->string, max 100 entries, value <= 512
chars). The container command is baked into the image entrypoint and the
rendezvous args are passed via ``AlgorithmSpecification.ContainerArguments``.
4. **``supports_inter_replica_network = False``.** Separate single-instance
training jobs have no mutual network path by design — they rendezvous only
through S3. (SageMaker's algo-N container fabric and
``EnableInterContainerTrafficEncryption`` only exist WITHIN a single
multi-instance job, which this design deliberately does not use.)
Load-bearing gotcha — ``EnableNetworkIsolation`` MUST stay ``False``
--------------------------------------------------------------------
When ``EnableNetworkIsolation=True`` the training *container* has no outbound
network access. SageMaker's host-side processes still stage input channels and
ship CloudWatch logs, but the container itself cannot make S3 GET/PUT calls.
``ObjectStoreAllReduce`` needs live S3 PUT+GET every outer round, so network
isolation would silently dead-lock the allreduce poll loop until its timeout.
This executor pins ``EnableNetworkIsolation=False`` (the API default) and never
exposes it as a knob. The rendezvous bucket access must instead be granted on
the execution ``RoleArn`` — the SageMaker analog of EKS IRSA.
HyperPod <-> EKS 1:1 control-plane mapping (recommended hybrid)
---------------------------------------------------------------
Per the SageMaker docs: *"The high-level architecture of Amazon EKS support in
HyperPod involves a 1-to-1 mapping between an EKS cluster (control plane) and a
HyperPod cluster (worker nodes) within a VPC."*
(https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-hyperpod-eks.html)
Consequence for this repo's hybrid: "use HyperPod for the inner GRPO trainer"
does NOT mean leaving EKS — it means attaching a HyperPod-managed
(auto-recovering, deep-health-checked, PyTorch-job auto-resume) node-group to
the SAME EKS control plane that runs the outer loop. A future ``EKSExecutor``
(kubernetes client, Indexed Jobs) therefore targets both plain Karpenter GPU
nodes AND HyperPod nodes transparently. ``SageMakerExecutor`` (ephemeral
Training Jobs via boto3) is the SEPARATE bursty-fallback inner-loop path for
when you don't want a persistent cluster: Training Jobs suit periodic /
smaller-model / pay-per-use runs; HyperPod suits continuous / large-model /
persistent runs. Both share the IDENTICAL S3 rendezvous, so a run can move
between them with zero trainer / loss / DiLoCo changes.
References
----------
- create_training_job: https://docs.aws.amazon.com/boto3/latest/reference/services/sagemaker/client/create_training_job.html
- describe_training_job: https://docs.aws.amazon.com/boto3/latest/reference/services/sagemaker/client/describe_training_job.html
- stop_training_job: https://docs.aws.amazon.com/boto3/latest/reference/services/sagemaker/client/stop_training_job.html
- network isolation: https://repost.aws/knowledge-center/sagemaker-access-network-isolation
- HyperPod-EKS: https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-hyperpod-eks.html
- ADR-005 (executor protocol design)
"""
from __future__ import annotations
import json
import time
import uuid
from collections.abc import Callable, Mapping
from typing import Any
from composer_replication.diloco.serverless.executor import (
ReplicaHandle,
)
# SageMaker TrainingJobStatus -> Protocol status vocabulary.
# describe_training_job's TrainingJobStatus is EXACTLY one of:
# 'InProgress' | 'Completed' | 'Failed' | 'Stopping' | 'Stopped'.
# We map Stopping -> 'running' (transient; still terminating, so collect()
# keeps waiting) and Stopped -> 'cancelled'.
_STATUS_MAP = {
"InProgress": "running",
"Completed": "succeeded",
"Failed": "failed",
"Stopping": "running",
"Stopped": "cancelled",
}
# SecondaryStatus values that mean "queued / not yet executing user code" —
# used to refine an InProgress job into the Protocol's 'pending'.
_PENDING_SECONDARY = frozenset(
{"Starting", "Pending", "LaunchingMLInstances", "PreparingTrainingStack"}
)
# Abstract Protocol GPU strings -> SageMaker instance types.
_GPU_INSTANCE_MAP = {
"A100": "ml.p4d.24xlarge",
"H100": "ml.p5.48xlarge",
"H200": "ml.p5e.48xlarge",
"B200": "ml.p6-b200.48xlarge",
"L40S": "ml.g6e.12xlarge",
"A10G": "ml.g5.2xlarge",
"L4": "ml.g6.2xlarge",
}
_CLOUDWATCH_LOG_GROUP = "/aws/sagemaker/TrainingJobs"
class SageMakerExecutor:
"""Run replicas as N independent SageMaker Training Jobs.
Implements the `ServerlessExecutor` Protocol against the boto3
``sagemaker`` client. Each replica is one single-instance training job;
cross-replica communication happens only through the shared S3
``ObjectStoreAllReduce`` rendezvous.
Args:
role_arn: IAM execution role SageMaker assumes for the job. Must grant
S3 access to the rendezvous + output buckets (the boto3 analog of
EKS IRSA). The caller's credentials need ``iam:PassRole`` on it.
image_uri: ECR image URI for the training container. The image must
bake an entrypoint that runs
``python -m composer_replication.diloco.serverless.replica_entrypoint``
(this executor also passes ``ContainerEntrypoint`` explicitly so a
generic image works too).
output_s3_path: ``s3://...`` prefix for ``OutputDataConfig.S3OutputPath``
(model artifacts / failure output).
instance_type: default SageMaker instance type when ``gpu`` is not
mapped (e.g. ``"ml.g5.2xlarge"``). ``gpu=None`` at launch falls
back to ``cpu_instance_type``.
cpu_instance_type: instance type used when ``gpu`` is ``None`` (CPU
smoke tests). Default ``"ml.m5.xlarge"``.
volume_size_gb: ``ResourceConfig.VolumeSizeInGB`` per job.
run_id: prefix for generated training-job names. Defaults to a short
random token so names are unique per region+account.
region: AWS region for the lazily-constructed boto3 clients. ``None``
uses the ambient boto3 default-region resolution.
sagemaker_client: inject a pre-built ``boto3.client('sagemaker')`` (or a
mock) instead of constructing one. Used by tests.
logs_client: inject a pre-built ``boto3.client('logs')`` (or a mock).
Raises:
RuntimeError: if boto3 is not installed and no client was injected.
"""
backend_name = "sagemaker"
# Separate single-instance jobs have no mutual network path — S3 only.
supports_inter_replica_network = False
def __init__(
self,
*,
role_arn: str,
image_uri: str,
output_s3_path: str,
instance_type: str = "ml.g5.2xlarge",
cpu_instance_type: str = "ml.m5.xlarge",
volume_size_gb: int = 100,
run_id: str | None = None,
region: str | None = None,
sagemaker_client: Any = None,
logs_client: Any = None,
) -> None:
self.role_arn = role_arn
self.image_uri = image_uri
self.output_s3_path = output_s3_path
self.instance_type = instance_type
self.cpu_instance_type = cpu_instance_type
self.volume_size_gb = volume_size_gb
self.run_id = run_id or f"diloco-{uuid.uuid4().hex[:8]}"
self._region = region
# Lazy boto3 — only constructed if the caller didn't inject a client.
# This keeps `import composer_replication.diloco.serverless` free of a
# hard boto3 dependency (boto3 lives in the optional [aws] extra), and
# lets tests inject a _MockSMClient with zero AWS calls.
if sagemaker_client is None:
sagemaker_client = self._make_boto3_client("sagemaker")
self._client = sagemaker_client
self._logs_client = logs_client # built lazily on first stream_logs()
# rank -> {"job_name": str, "result": dict | None}
self._handles: dict[int, dict[str, Any]] = {}
# -----------------------------------------------------------------
# boto3 plumbing (lazy)
# -----------------------------------------------------------------
def _make_boto3_client(self, service: str) -> Any:
try:
import boto3
except ImportError as e:
raise RuntimeError(
"SageMakerExecutor requires boto3. Install with "
"`pip install -e .[aws]` (or `pip install boto3`). "
f"Got: {e!r}"
) from e
if self._region is not None:
return boto3.client(service, region_name=self._region)
return boto3.client(service)
def _map_gpu(self, gpu: str | None) -> str:
"""Translate the Protocol's abstract gpu string to an instance type.
``gpu=None`` -> ``cpu_instance_type`` (smoke tests). Unrecognized gpu
strings fall back to ``instance_type`` (so a caller can pass a literal
SageMaker instance type and it's honoured if not in the map).
"""
if gpu is None:
return self.cpu_instance_type
if gpu in _GPU_INSTANCE_MAP:
return _GPU_INSTANCE_MAP[gpu]
# Caller may have passed a literal "ml.*" instance type.
if gpu.startswith("ml."):
return gpu
return self.instance_type
def _job_name(self, rank: int) -> str:
"""Build a unique, regex-safe training-job name (<= 63 chars).
Pattern required by the API: ``[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}``.
"""
name = f"{self.run_id}-r{rank:04d}-{int(time.time())}"
return name[:63]
# -----------------------------------------------------------------
# ServerlessExecutor Protocol
# -----------------------------------------------------------------
def launch_replicas(
self,
n_replicas: int,
entrypoint: str | Callable[..., Any],
entrypoint_args: Mapping[str, Any],
*,
gpu: str | None = None,
timeout: int = 3600,
) -> list[ReplicaHandle]:
"""Submit N independent single-instance SageMaker Training Jobs.
Args:
n_replicas: number of replicas (= number of training jobs).
entrypoint: ignored — the container command is baked into the
image / passed as ``ContainerEntrypoint``. Kept for Protocol
compatibility.
entrypoint_args: must contain ``rendezvous_uri`` (``s3://...``) and
``trainer_module``. Optional: ``trainer_fn`` (default
``"train"``), ``trainer_kwargs`` (dict, JSON-encoded into the
container args). The conventional ``rank_env`` key (from
``LocalProcessExecutor``) is ignored — rank goes through the
``Environment`` map instead.
gpu: abstract GPU spec mapped to an instance type via ``_map_gpu``.
``None`` -> CPU instance.
timeout: ``StoppingCondition.MaxRuntimeInSeconds`` per job.
Returns:
``list[ReplicaHandle]`` of length ``n_replicas`` in rank order
(``handles[i].rank == i``).
"""
del entrypoint # container command is baked / passed explicitly
if n_replicas < 1:
raise ValueError(f"n_replicas must be >= 1, got {n_replicas}")
rendezvous_uri = entrypoint_args.get("rendezvous_uri")
if not rendezvous_uri:
raise ValueError(
"entrypoint_args must include 'rendezvous_uri' (the s3:// "
"ObjectStoreAllReduce rendezvous prefix)."
)
trainer_module = entrypoint_args.get("trainer_module")
if not trainer_module:
raise ValueError(
"entrypoint_args must include 'trainer_module' (importable "
"module path of the user's train function)."
)
trainer_fn = entrypoint_args.get("trainer_fn", "train")
trainer_kwargs = entrypoint_args.get("trainer_kwargs", {})
instance_type = self._map_gpu(gpu)
# Container args: each element is a SINGLE token (StackOverflow
# 77994925 — `['--world-size', '4']` NOT `['--world-size 4']`).
container_args = [
"--rendezvous", str(rendezvous_uri),
"--world-size", str(n_replicas),
"--trainer-module", str(trainer_module),
"--trainer-fn", str(trainer_fn),
"--trainer-kwargs-json", json.dumps(trainer_kwargs),
]
handles: list[ReplicaHandle] = []
for rank in range(n_replicas):
job_name = self._job_name(rank)
request = {
"TrainingJobName": job_name,
"AlgorithmSpecification": {
"TrainingImage": self.image_uri,
"TrainingInputMode": "File",
"ContainerEntrypoint": [
"python", "-m",
"composer_replication.diloco.serverless.replica_entrypoint",
],
"ContainerArguments": container_args,
},
"RoleArn": self.role_arn,
# InputDataConfig intentionally omitted — the replica pulls
# data via its own code / the S3 rendezvous, not SM channels.
"OutputDataConfig": {"S3OutputPath": self.output_s3_path},
"ResourceConfig": {
"InstanceType": instance_type,
"InstanceCount": 1,
"VolumeSizeInGB": self.volume_size_gb,
},
"StoppingCondition": {"MaxRuntimeInSeconds": int(timeout)},
# REPLICA_RANK / WORLD_SIZE injected as container env vars;
# replica_entrypoint.py reads os.environ['REPLICA_RANK'].
"Environment": {
"REPLICA_RANK": str(rank),
"WORLD_SIZE": str(n_replicas),
"RENDEZVOUS_URI": str(rendezvous_uri),
},
# MUST stay False — True severs the container's S3 access and
# dead-locks the allreduce poll loop. See module docstring.
"EnableNetworkIsolation": False,
}
try:
self._client.create_training_job(**request)
except Exception as e:
# Best-effort stop of already-launched siblings, then raise.
for prior in handles:
try:
self.cancel(prior)
except Exception:
pass
raise RuntimeError(
f"SageMakerExecutor.launch_replicas failed at rank={rank} "
f"of {n_replicas} (already-launched siblings stopped). "
f"Underlying error: {e!r}"
) from e
handle = ReplicaHandle(
rank=rank,
backend_name=self.backend_name,
metadata={
"training_job_name": job_name,
"submit_ts": time.time(),
},
)
self._handles[rank] = {"job_name": job_name, "result": None}
handles.append(handle)
return handles
def poll(self, handle: ReplicaHandle) -> str:
"""Poll a training job's status.
Returns one of: ``"pending"`` | ``"running"`` | ``"succeeded"`` |
``"failed"`` | ``"cancelled"``.
Maps ``describe_training_job``'s ``TrainingJobStatus`` via
``_STATUS_MAP``; refines ``InProgress`` to ``"pending"`` while the job
is still queued (``SecondaryStatus`` in ``_PENDING_SECONDARY``). A
vanished job (``ResourceNotFound``) is treated as ``"cancelled"``.
"""
meta = self._handles.get(handle.rank)
if meta is None:
return "cancelled"
if meta["result"] is not None:
return meta["result"]["status"]
job_name = meta["job_name"]
try:
resp = self._client.describe_training_job(TrainingJobName=job_name)
except Exception as e:
if self._is_resource_not_found(e):
return "cancelled"
raise
sm_status = resp.get("TrainingJobStatus", "InProgress")
mapped = _STATUS_MAP.get(sm_status, "running")
if sm_status == "InProgress":
if resp.get("SecondaryStatus") in _PENDING_SECONDARY:
return "pending"
return "running"
# Terminal — cache a result dict so collect()/repeat-poll are cheap.
meta["result"] = self._terminal_result(handle.rank, sm_status, resp)
return mapped
def stream_logs(self, handle: ReplicaHandle, *, n_lines: int = 200) -> str:
"""Read recent CloudWatch logs for this replica's training job.
SageMaker writes container stdout/stderr to the
``/aws/sagemaker/TrainingJobs`` log group, stream
``<job-name>/algo-<n>-<epoch>``. We discover the exact stream name by
prefix then read the tail. Falls back to a CloudWatch console pointer
on any error (mirrors ModalSpawnExecutor's dashboard-URL fallback).
"""
meta = self._handles.get(handle.rank)
if meta is None:
return f"<replica {handle.rank}: no metadata>"
job_name = meta["job_name"]
try:
logs = self._logs()
prefix = f"{job_name}/"
streams = logs.describe_log_streams(
logGroupName=_CLOUDWATCH_LOG_GROUP,
logStreamNamePrefix=prefix,
orderBy="LastEventTime",
descending=True,
limit=1,
)
stream_list = streams.get("logStreams", [])
if not stream_list:
return (
f"[rank {handle.rank}] job={job_name}: no CloudWatch log "
f"stream yet (job pending / not started)."
)
stream_name = stream_list[0]["logStreamName"]
events = logs.get_log_events(
logGroupName=_CLOUDWATCH_LOG_GROUP,
logStreamName=stream_name,
limit=n_lines,
startFromHead=False,
)
lines = [e.get("message", "") for e in events.get("events", [])]
body = "\n".join(lines) if lines else "<no log events>"
return f"[rank {handle.rank}] job={job_name} stream={stream_name}\n{body}"
except Exception as e:
region = self._region or "<region>"
url = (
f"https://{region}.console.aws.amazon.com/cloudwatch/home"
f"?region={region}#logsV2:log-groups/log-group/"
f"$252Faws$252Fsagemaker$252FTrainingJobs"
)
return (
f"[rank {handle.rank}] job={job_name}: log fetch failed "
f"({type(e).__name__}: {e!r}).\n CloudWatch console: {url}"
)
def cancel(self, handle: ReplicaHandle) -> None:
"""Best-effort stop of a training job.
Calls ``stop_training_job`` (SIGTERM + 120s grace), swallowing
``ResourceNotFound`` and "already terminal" ``ValidationException`` so
the contract — "no exception if already terminated" — holds.
"""
meta = self._handles.get(handle.rank)
if meta is None:
return
try:
self._client.stop_training_job(TrainingJobName=meta["job_name"])
except Exception as e:
# R5: swallow ONLY already-terminated signals — a vanished job
# (ResourceNotFound) or an already-Completed/Stopped job (boto3
# raises ValidationException for "cannot stop a job in status X").
# A genuinely unexpected error (AccessDenied, throttling that
# outlived retries, malformed request) must propagate rather than
# masquerade as a successful cancel.
if self._is_resource_not_found(e) or self._is_already_terminal(e):
return
raise
def collect(
self,
handles: list[ReplicaHandle],
*,
timeout: int | None = None,
) -> list[dict[str, Any]]:
"""Block until all replicas finish; return per-replica result dicts.
Polls ``describe_training_job`` per handle until the job reaches a
terminal status (``Completed`` / ``Failed`` / ``Stopped``) or the
shared deadline elapses. Returns results aligned to the input handle
order (Protocol contract; mirrors ``LocalProcessExecutor.collect``).
Each result dict has at least
``{"rank", "status", "exit_code", "error"}``.
"""
deadline = time.time() + (timeout if timeout is not None else 86400)
poll_interval = 30.0
results: list[dict[str, Any]] = []
for h in handles:
meta = self._handles.get(h.rank)
if meta is None:
results.append({
"rank": h.rank,
"status": "cancelled",
"exit_code": None,
"error": "handle has no metadata (cancelled or unknown)",
"result": None,
"training_job_name": h.metadata.get("training_job_name"),
})
continue
# Already cached by an earlier poll()/collect().
if meta["result"] is not None:
results.append(meta["result"])
continue
job_name = meta["job_name"]
result_dict: dict[str, Any] | None = None
while True:
try:
resp = self._client.describe_training_job(
TrainingJobName=job_name
)
except Exception as e:
if self._is_resource_not_found(e):
result_dict = {
"rank": h.rank,
"status": "cancelled",
"exit_code": None,
"error": "training job not found (deleted?)",
"result": None,
"training_job_name": job_name,
}
break
raise
sm_status = resp.get("TrainingJobStatus", "InProgress")
if sm_status in ("Completed", "Failed", "Stopped"):
result_dict = self._terminal_result(h.rank, sm_status, resp)
break
if time.time() >= deadline:
result_dict = {
"rank": h.rank,
"status": "running",
"exit_code": None,
"error": "timeout before terminal",
"result": None,
"training_job_name": job_name,
}
break
# Sleep, but never overrun the deadline.
time.sleep(min(poll_interval, max(0.0, deadline - time.time())))
# Cache only terminal results (not the timeout 'running' sentinel,
# so a later collect() can re-check the job).
if result_dict["status"] in ("succeeded", "failed", "cancelled"):
meta["result"] = result_dict
results.append(result_dict)
return results
# -----------------------------------------------------------------
# Helpers
# -----------------------------------------------------------------
def _logs(self) -> Any:
"""Lazily build the CloudWatch Logs client (separate from sagemaker)."""
if self._logs_client is None:
self._logs_client = self._make_boto3_client("logs")
return self._logs_client
@staticmethod
def _terminal_result(
rank: int, sm_status: str, resp: Mapping[str, Any]
) -> dict[str, Any]:
"""Build a result dict from a terminal describe_training_job response."""
mapped = _STATUS_MAP.get(sm_status, "failed")
if sm_status == "Completed":
exit_code: int | None = 0
error = None
elif sm_status == "Stopped":
exit_code = None
error = resp.get("FailureReason")
else: # Failed
exit_code = 1
error = resp.get("FailureReason") or "training job failed"
artifacts = resp.get("ModelArtifacts", {}) or {}
return {
"rank": rank,
"status": mapped,
"exit_code": exit_code,
"error": error,
"result": artifacts.get("S3ModelArtifacts"),
"training_job_name": resp.get("TrainingJobName"),
}
def _is_resource_not_found(self, exc: Exception) -> bool:
"""True if ``exc`` is the boto3 ResourceNotFound for the sagemaker client.
Handles both the typed client exception
(``client.exceptions.ResourceNotFound``) and a generic botocore
``ClientError`` whose error code is ``ResourceNotFound`` /
``ValidationException`` naming a missing job — robust across whether a
real boto3 client or a mock is in use.
"""
rnf = getattr(getattr(self._client, "exceptions", None),
"ResourceNotFound", None)
if rnf is not None and isinstance(exc, rnf):
return True
# Generic botocore ClientError fallback.
resp = getattr(exc, "response", None)
if isinstance(resp, Mapping):
code = resp.get("Error", {}).get("Code", "")
if code in ("ResourceNotFound", "ValidationException"):
return True
return False
def _is_already_terminal(self, exc: Exception) -> bool:
"""True if ``exc`` is the boto3 "cannot stop a job in status X" error.
``stop_training_job`` raises a ``ValidationException`` when the job is
already Completed/Failed/Stopped — that is an idempotent no-op for
cancel(), distinct from a genuinely unexpected error. Matched on the
ClientError code + message text (robust to a mock raising a plain
Exception whose message carries the phrase).
"""
resp = getattr(exc, "response", None)
if isinstance(resp, Mapping):
err = resp.get("Error", {})
if err.get("Code") == "ValidationException":
return True
msg = str(exc).lower()
return (
"cannot be stopped" in msg
or "already" in msg and ("stopped" in msg or "complete" in msg or "terminal" in msg)
)
__all__ = ["SageMakerExecutor"]