Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """SageMakerExecutor — production boto3-backed serverless executor. | |
| This is a fully-working cloud adapter (the sibling of `ModalSpawnExecutor`, | |
| not the loud-failing `modal.py` / `hf_jobs.py` skeletons). It implements the | |
| `ServerlessExecutor` Protocol against Amazon SageMaker Training Jobs via the | |
| boto3 low-level `sagemaker` client. | |
| Design choices | |
| -------------- | |
| 1. **N independent single-instance jobs, NOT one multi-instance job.** | |
| SageMaker's *native* distributed training (``ResourceConfig.InstanceCount > 1``) | |
| groups instances into ONE job with an in-cluster NCCL/MPI fabric wired via | |
| ``/opt/ml/input/config/resourceconfig.json``. That is the WRONG model for | |
| DiLoCo replicas — it would couple replicas through SageMaker's intra-job | |
| network and break the "each replica is an independent DiLoCo worker that | |
| syncs only through S3" design. So ``launch_replicas`` submits N **separate** | |
| training jobs, each with ``ResourceConfig.InstanceCount == 1``, tagged with | |
| ``REPLICA_RANK=i`` / ``WORLD_SIZE=N`` via the ``Environment`` map. This | |
| mirrors ``ModalSpawnExecutor`` spawning N independent Modal calls. | |
| 2. **Same S3 ``ObjectStoreAllReduce`` rendezvous — DiLoCo math untouched.** | |
| Cross-replica communication is EXCLUSIVELY the object-store rendezvous; the | |
| executor passes ``rendezvous_uri`` (an ``s3://...`` URI) through to | |
| ``replica_entrypoint.py`` unchanged. ``allreduce.py`` / ``MockManager`` / | |
| ``make_diloco_outer_loop`` / the trainer all stay byte-for-byte identical. | |
| 3. **Stateless after launch; rank via ``Environment``.** Handle metadata is the | |
| ``training_job_name`` (plus submit timestamp). ``replica_entrypoint.py`` | |
| already reads ``REPLICA_RANK`` from ``os.environ``, so the cleanest channel | |
| is the ``Environment`` map (string->string, max 100 entries, value <= 512 | |
| chars). The container command is baked into the image entrypoint and the | |
| rendezvous args are passed via ``AlgorithmSpecification.ContainerArguments``. | |
| 4. **``supports_inter_replica_network = False``.** Separate single-instance | |
| training jobs have no mutual network path by design — they rendezvous only | |
| through S3. (SageMaker's algo-N container fabric and | |
| ``EnableInterContainerTrafficEncryption`` only exist WITHIN a single | |
| multi-instance job, which this design deliberately does not use.) | |
| Load-bearing gotcha — ``EnableNetworkIsolation`` MUST stay ``False`` | |
| -------------------------------------------------------------------- | |
| When ``EnableNetworkIsolation=True`` the training *container* has no outbound | |
| network access. SageMaker's host-side processes still stage input channels and | |
| ship CloudWatch logs, but the container itself cannot make S3 GET/PUT calls. | |
| ``ObjectStoreAllReduce`` needs live S3 PUT+GET every outer round, so network | |
| isolation would silently dead-lock the allreduce poll loop until its timeout. | |
| This executor pins ``EnableNetworkIsolation=False`` (the API default) and never | |
| exposes it as a knob. The rendezvous bucket access must instead be granted on | |
| the execution ``RoleArn`` — the SageMaker analog of EKS IRSA. | |
| HyperPod <-> EKS 1:1 control-plane mapping (recommended hybrid) | |
| --------------------------------------------------------------- | |
| Per the SageMaker docs: *"The high-level architecture of Amazon EKS support in | |
| HyperPod involves a 1-to-1 mapping between an EKS cluster (control plane) and a | |
| HyperPod cluster (worker nodes) within a VPC."* | |
| (https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-hyperpod-eks.html) | |
| Consequence for this repo's hybrid: "use HyperPod for the inner GRPO trainer" | |
| does NOT mean leaving EKS — it means attaching a HyperPod-managed | |
| (auto-recovering, deep-health-checked, PyTorch-job auto-resume) node-group to | |
| the SAME EKS control plane that runs the outer loop. A future ``EKSExecutor`` | |
| (kubernetes client, Indexed Jobs) therefore targets both plain Karpenter GPU | |
| nodes AND HyperPod nodes transparently. ``SageMakerExecutor`` (ephemeral | |
| Training Jobs via boto3) is the SEPARATE bursty-fallback inner-loop path for | |
| when you don't want a persistent cluster: Training Jobs suit periodic / | |
| smaller-model / pay-per-use runs; HyperPod suits continuous / large-model / | |
| persistent runs. Both share the IDENTICAL S3 rendezvous, so a run can move | |
| between them with zero trainer / loss / DiLoCo changes. | |
| References | |
| ---------- | |
| - create_training_job: https://docs.aws.amazon.com/boto3/latest/reference/services/sagemaker/client/create_training_job.html | |
| - describe_training_job: https://docs.aws.amazon.com/boto3/latest/reference/services/sagemaker/client/describe_training_job.html | |
| - stop_training_job: https://docs.aws.amazon.com/boto3/latest/reference/services/sagemaker/client/stop_training_job.html | |
| - network isolation: https://repost.aws/knowledge-center/sagemaker-access-network-isolation | |
| - HyperPod-EKS: https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-hyperpod-eks.html | |
| - ADR-005 (executor protocol design) | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import time | |
| import uuid | |
| from collections.abc import Callable, Mapping | |
| from typing import Any | |
| from composer_replication.diloco.serverless.executor import ( | |
| ReplicaHandle, | |
| ) | |
| # SageMaker TrainingJobStatus -> Protocol status vocabulary. | |
| # describe_training_job's TrainingJobStatus is EXACTLY one of: | |
| # 'InProgress' | 'Completed' | 'Failed' | 'Stopping' | 'Stopped'. | |
| # We map Stopping -> 'running' (transient; still terminating, so collect() | |
| # keeps waiting) and Stopped -> 'cancelled'. | |
| _STATUS_MAP = { | |
| "InProgress": "running", | |
| "Completed": "succeeded", | |
| "Failed": "failed", | |
| "Stopping": "running", | |
| "Stopped": "cancelled", | |
| } | |
| # SecondaryStatus values that mean "queued / not yet executing user code" — | |
| # used to refine an InProgress job into the Protocol's 'pending'. | |
| _PENDING_SECONDARY = frozenset( | |
| {"Starting", "Pending", "LaunchingMLInstances", "PreparingTrainingStack"} | |
| ) | |
| # Abstract Protocol GPU strings -> SageMaker instance types. | |
| _GPU_INSTANCE_MAP = { | |
| "A100": "ml.p4d.24xlarge", | |
| "H100": "ml.p5.48xlarge", | |
| "H200": "ml.p5e.48xlarge", | |
| "B200": "ml.p6-b200.48xlarge", | |
| "L40S": "ml.g6e.12xlarge", | |
| "A10G": "ml.g5.2xlarge", | |
| "L4": "ml.g6.2xlarge", | |
| } | |
| _CLOUDWATCH_LOG_GROUP = "/aws/sagemaker/TrainingJobs" | |
| class SageMakerExecutor: | |
| """Run replicas as N independent SageMaker Training Jobs. | |
| Implements the `ServerlessExecutor` Protocol against the boto3 | |
| ``sagemaker`` client. Each replica is one single-instance training job; | |
| cross-replica communication happens only through the shared S3 | |
| ``ObjectStoreAllReduce`` rendezvous. | |
| Args: | |
| role_arn: IAM execution role SageMaker assumes for the job. Must grant | |
| S3 access to the rendezvous + output buckets (the boto3 analog of | |
| EKS IRSA). The caller's credentials need ``iam:PassRole`` on it. | |
| image_uri: ECR image URI for the training container. The image must | |
| bake an entrypoint that runs | |
| ``python -m composer_replication.diloco.serverless.replica_entrypoint`` | |
| (this executor also passes ``ContainerEntrypoint`` explicitly so a | |
| generic image works too). | |
| output_s3_path: ``s3://...`` prefix for ``OutputDataConfig.S3OutputPath`` | |
| (model artifacts / failure output). | |
| instance_type: default SageMaker instance type when ``gpu`` is not | |
| mapped (e.g. ``"ml.g5.2xlarge"``). ``gpu=None`` at launch falls | |
| back to ``cpu_instance_type``. | |
| cpu_instance_type: instance type used when ``gpu`` is ``None`` (CPU | |
| smoke tests). Default ``"ml.m5.xlarge"``. | |
| volume_size_gb: ``ResourceConfig.VolumeSizeInGB`` per job. | |
| run_id: prefix for generated training-job names. Defaults to a short | |
| random token so names are unique per region+account. | |
| region: AWS region for the lazily-constructed boto3 clients. ``None`` | |
| uses the ambient boto3 default-region resolution. | |
| sagemaker_client: inject a pre-built ``boto3.client('sagemaker')`` (or a | |
| mock) instead of constructing one. Used by tests. | |
| logs_client: inject a pre-built ``boto3.client('logs')`` (or a mock). | |
| Raises: | |
| RuntimeError: if boto3 is not installed and no client was injected. | |
| """ | |
| backend_name = "sagemaker" | |
| # Separate single-instance jobs have no mutual network path — S3 only. | |
| supports_inter_replica_network = False | |
| def __init__( | |
| self, | |
| *, | |
| role_arn: str, | |
| image_uri: str, | |
| output_s3_path: str, | |
| instance_type: str = "ml.g5.2xlarge", | |
| cpu_instance_type: str = "ml.m5.xlarge", | |
| volume_size_gb: int = 100, | |
| run_id: str | None = None, | |
| region: str | None = None, | |
| sagemaker_client: Any = None, | |
| logs_client: Any = None, | |
| ) -> None: | |
| self.role_arn = role_arn | |
| self.image_uri = image_uri | |
| self.output_s3_path = output_s3_path | |
| self.instance_type = instance_type | |
| self.cpu_instance_type = cpu_instance_type | |
| self.volume_size_gb = volume_size_gb | |
| self.run_id = run_id or f"diloco-{uuid.uuid4().hex[:8]}" | |
| self._region = region | |
| # Lazy boto3 — only constructed if the caller didn't inject a client. | |
| # This keeps `import composer_replication.diloco.serverless` free of a | |
| # hard boto3 dependency (boto3 lives in the optional [aws] extra), and | |
| # lets tests inject a _MockSMClient with zero AWS calls. | |
| if sagemaker_client is None: | |
| sagemaker_client = self._make_boto3_client("sagemaker") | |
| self._client = sagemaker_client | |
| self._logs_client = logs_client # built lazily on first stream_logs() | |
| # rank -> {"job_name": str, "result": dict | None} | |
| self._handles: dict[int, dict[str, Any]] = {} | |
| # ----------------------------------------------------------------- | |
| # boto3 plumbing (lazy) | |
| # ----------------------------------------------------------------- | |
| def _make_boto3_client(self, service: str) -> Any: | |
| try: | |
| import boto3 | |
| except ImportError as e: | |
| raise RuntimeError( | |
| "SageMakerExecutor requires boto3. Install with " | |
| "`pip install -e .[aws]` (or `pip install boto3`). " | |
| f"Got: {e!r}" | |
| ) from e | |
| if self._region is not None: | |
| return boto3.client(service, region_name=self._region) | |
| return boto3.client(service) | |
| def _map_gpu(self, gpu: str | None) -> str: | |
| """Translate the Protocol's abstract gpu string to an instance type. | |
| ``gpu=None`` -> ``cpu_instance_type`` (smoke tests). Unrecognized gpu | |
| strings fall back to ``instance_type`` (so a caller can pass a literal | |
| SageMaker instance type and it's honoured if not in the map). | |
| """ | |
| if gpu is None: | |
| return self.cpu_instance_type | |
| if gpu in _GPU_INSTANCE_MAP: | |
| return _GPU_INSTANCE_MAP[gpu] | |
| # Caller may have passed a literal "ml.*" instance type. | |
| if gpu.startswith("ml."): | |
| return gpu | |
| return self.instance_type | |
| def _job_name(self, rank: int) -> str: | |
| """Build a unique, regex-safe training-job name (<= 63 chars). | |
| Pattern required by the API: ``[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}``. | |
| """ | |
| name = f"{self.run_id}-r{rank:04d}-{int(time.time())}" | |
| return name[:63] | |
| # ----------------------------------------------------------------- | |
| # ServerlessExecutor Protocol | |
| # ----------------------------------------------------------------- | |
| def launch_replicas( | |
| self, | |
| n_replicas: int, | |
| entrypoint: str | Callable[..., Any], | |
| entrypoint_args: Mapping[str, Any], | |
| *, | |
| gpu: str | None = None, | |
| timeout: int = 3600, | |
| ) -> list[ReplicaHandle]: | |
| """Submit N independent single-instance SageMaker Training Jobs. | |
| Args: | |
| n_replicas: number of replicas (= number of training jobs). | |
| entrypoint: ignored — the container command is baked into the | |
| image / passed as ``ContainerEntrypoint``. Kept for Protocol | |
| compatibility. | |
| entrypoint_args: must contain ``rendezvous_uri`` (``s3://...``) and | |
| ``trainer_module``. Optional: ``trainer_fn`` (default | |
| ``"train"``), ``trainer_kwargs`` (dict, JSON-encoded into the | |
| container args). The conventional ``rank_env`` key (from | |
| ``LocalProcessExecutor``) is ignored — rank goes through the | |
| ``Environment`` map instead. | |
| gpu: abstract GPU spec mapped to an instance type via ``_map_gpu``. | |
| ``None`` -> CPU instance. | |
| timeout: ``StoppingCondition.MaxRuntimeInSeconds`` per job. | |
| Returns: | |
| ``list[ReplicaHandle]`` of length ``n_replicas`` in rank order | |
| (``handles[i].rank == i``). | |
| """ | |
| del entrypoint # container command is baked / passed explicitly | |
| if n_replicas < 1: | |
| raise ValueError(f"n_replicas must be >= 1, got {n_replicas}") | |
| rendezvous_uri = entrypoint_args.get("rendezvous_uri") | |
| if not rendezvous_uri: | |
| raise ValueError( | |
| "entrypoint_args must include 'rendezvous_uri' (the s3:// " | |
| "ObjectStoreAllReduce rendezvous prefix)." | |
| ) | |
| trainer_module = entrypoint_args.get("trainer_module") | |
| if not trainer_module: | |
| raise ValueError( | |
| "entrypoint_args must include 'trainer_module' (importable " | |
| "module path of the user's train function)." | |
| ) | |
| trainer_fn = entrypoint_args.get("trainer_fn", "train") | |
| trainer_kwargs = entrypoint_args.get("trainer_kwargs", {}) | |
| instance_type = self._map_gpu(gpu) | |
| # Container args: each element is a SINGLE token (StackOverflow | |
| # 77994925 — `['--world-size', '4']` NOT `['--world-size 4']`). | |
| container_args = [ | |
| "--rendezvous", str(rendezvous_uri), | |
| "--world-size", str(n_replicas), | |
| "--trainer-module", str(trainer_module), | |
| "--trainer-fn", str(trainer_fn), | |
| "--trainer-kwargs-json", json.dumps(trainer_kwargs), | |
| ] | |
| handles: list[ReplicaHandle] = [] | |
| for rank in range(n_replicas): | |
| job_name = self._job_name(rank) | |
| request = { | |
| "TrainingJobName": job_name, | |
| "AlgorithmSpecification": { | |
| "TrainingImage": self.image_uri, | |
| "TrainingInputMode": "File", | |
| "ContainerEntrypoint": [ | |
| "python", "-m", | |
| "composer_replication.diloco.serverless.replica_entrypoint", | |
| ], | |
| "ContainerArguments": container_args, | |
| }, | |
| "RoleArn": self.role_arn, | |
| # InputDataConfig intentionally omitted — the replica pulls | |
| # data via its own code / the S3 rendezvous, not SM channels. | |
| "OutputDataConfig": {"S3OutputPath": self.output_s3_path}, | |
| "ResourceConfig": { | |
| "InstanceType": instance_type, | |
| "InstanceCount": 1, | |
| "VolumeSizeInGB": self.volume_size_gb, | |
| }, | |
| "StoppingCondition": {"MaxRuntimeInSeconds": int(timeout)}, | |
| # REPLICA_RANK / WORLD_SIZE injected as container env vars; | |
| # replica_entrypoint.py reads os.environ['REPLICA_RANK']. | |
| "Environment": { | |
| "REPLICA_RANK": str(rank), | |
| "WORLD_SIZE": str(n_replicas), | |
| "RENDEZVOUS_URI": str(rendezvous_uri), | |
| }, | |
| # MUST stay False — True severs the container's S3 access and | |
| # dead-locks the allreduce poll loop. See module docstring. | |
| "EnableNetworkIsolation": False, | |
| } | |
| try: | |
| self._client.create_training_job(**request) | |
| except Exception as e: | |
| # Best-effort stop of already-launched siblings, then raise. | |
| for prior in handles: | |
| try: | |
| self.cancel(prior) | |
| except Exception: | |
| pass | |
| raise RuntimeError( | |
| f"SageMakerExecutor.launch_replicas failed at rank={rank} " | |
| f"of {n_replicas} (already-launched siblings stopped). " | |
| f"Underlying error: {e!r}" | |
| ) from e | |
| handle = ReplicaHandle( | |
| rank=rank, | |
| backend_name=self.backend_name, | |
| metadata={ | |
| "training_job_name": job_name, | |
| "submit_ts": time.time(), | |
| }, | |
| ) | |
| self._handles[rank] = {"job_name": job_name, "result": None} | |
| handles.append(handle) | |
| return handles | |
| def poll(self, handle: ReplicaHandle) -> str: | |
| """Poll a training job's status. | |
| Returns one of: ``"pending"`` | ``"running"`` | ``"succeeded"`` | | |
| ``"failed"`` | ``"cancelled"``. | |
| Maps ``describe_training_job``'s ``TrainingJobStatus`` via | |
| ``_STATUS_MAP``; refines ``InProgress`` to ``"pending"`` while the job | |
| is still queued (``SecondaryStatus`` in ``_PENDING_SECONDARY``). A | |
| vanished job (``ResourceNotFound``) is treated as ``"cancelled"``. | |
| """ | |
| meta = self._handles.get(handle.rank) | |
| if meta is None: | |
| return "cancelled" | |
| if meta["result"] is not None: | |
| return meta["result"]["status"] | |
| job_name = meta["job_name"] | |
| try: | |
| resp = self._client.describe_training_job(TrainingJobName=job_name) | |
| except Exception as e: | |
| if self._is_resource_not_found(e): | |
| return "cancelled" | |
| raise | |
| sm_status = resp.get("TrainingJobStatus", "InProgress") | |
| mapped = _STATUS_MAP.get(sm_status, "running") | |
| if sm_status == "InProgress": | |
| if resp.get("SecondaryStatus") in _PENDING_SECONDARY: | |
| return "pending" | |
| return "running" | |
| # Terminal — cache a result dict so collect()/repeat-poll are cheap. | |
| meta["result"] = self._terminal_result(handle.rank, sm_status, resp) | |
| return mapped | |
| def stream_logs(self, handle: ReplicaHandle, *, n_lines: int = 200) -> str: | |
| """Read recent CloudWatch logs for this replica's training job. | |
| SageMaker writes container stdout/stderr to the | |
| ``/aws/sagemaker/TrainingJobs`` log group, stream | |
| ``<job-name>/algo-<n>-<epoch>``. We discover the exact stream name by | |
| prefix then read the tail. Falls back to a CloudWatch console pointer | |
| on any error (mirrors ModalSpawnExecutor's dashboard-URL fallback). | |
| """ | |
| meta = self._handles.get(handle.rank) | |
| if meta is None: | |
| return f"<replica {handle.rank}: no metadata>" | |
| job_name = meta["job_name"] | |
| try: | |
| logs = self._logs() | |
| prefix = f"{job_name}/" | |
| streams = logs.describe_log_streams( | |
| logGroupName=_CLOUDWATCH_LOG_GROUP, | |
| logStreamNamePrefix=prefix, | |
| orderBy="LastEventTime", | |
| descending=True, | |
| limit=1, | |
| ) | |
| stream_list = streams.get("logStreams", []) | |
| if not stream_list: | |
| return ( | |
| f"[rank {handle.rank}] job={job_name}: no CloudWatch log " | |
| f"stream yet (job pending / not started)." | |
| ) | |
| stream_name = stream_list[0]["logStreamName"] | |
| events = logs.get_log_events( | |
| logGroupName=_CLOUDWATCH_LOG_GROUP, | |
| logStreamName=stream_name, | |
| limit=n_lines, | |
| startFromHead=False, | |
| ) | |
| lines = [e.get("message", "") for e in events.get("events", [])] | |
| body = "\n".join(lines) if lines else "<no log events>" | |
| return f"[rank {handle.rank}] job={job_name} stream={stream_name}\n{body}" | |
| except Exception as e: | |
| region = self._region or "<region>" | |
| url = ( | |
| f"https://{region}.console.aws.amazon.com/cloudwatch/home" | |
| f"?region={region}#logsV2:log-groups/log-group/" | |
| f"$252Faws$252Fsagemaker$252FTrainingJobs" | |
| ) | |
| return ( | |
| f"[rank {handle.rank}] job={job_name}: log fetch failed " | |
| f"({type(e).__name__}: {e!r}).\n CloudWatch console: {url}" | |
| ) | |
| def cancel(self, handle: ReplicaHandle) -> None: | |
| """Best-effort stop of a training job. | |
| Calls ``stop_training_job`` (SIGTERM + 120s grace), swallowing | |
| ``ResourceNotFound`` and "already terminal" ``ValidationException`` so | |
| the contract — "no exception if already terminated" — holds. | |
| """ | |
| meta = self._handles.get(handle.rank) | |
| if meta is None: | |
| return | |
| try: | |
| self._client.stop_training_job(TrainingJobName=meta["job_name"]) | |
| except Exception as e: | |
| # R5: swallow ONLY already-terminated signals — a vanished job | |
| # (ResourceNotFound) or an already-Completed/Stopped job (boto3 | |
| # raises ValidationException for "cannot stop a job in status X"). | |
| # A genuinely unexpected error (AccessDenied, throttling that | |
| # outlived retries, malformed request) must propagate rather than | |
| # masquerade as a successful cancel. | |
| if self._is_resource_not_found(e) or self._is_already_terminal(e): | |
| return | |
| raise | |
| def collect( | |
| self, | |
| handles: list[ReplicaHandle], | |
| *, | |
| timeout: int | None = None, | |
| ) -> list[dict[str, Any]]: | |
| """Block until all replicas finish; return per-replica result dicts. | |
| Polls ``describe_training_job`` per handle until the job reaches a | |
| terminal status (``Completed`` / ``Failed`` / ``Stopped``) or the | |
| shared deadline elapses. Returns results aligned to the input handle | |
| order (Protocol contract; mirrors ``LocalProcessExecutor.collect``). | |
| Each result dict has at least | |
| ``{"rank", "status", "exit_code", "error"}``. | |
| """ | |
| deadline = time.time() + (timeout if timeout is not None else 86400) | |
| poll_interval = 30.0 | |
| results: list[dict[str, Any]] = [] | |
| for h in handles: | |
| meta = self._handles.get(h.rank) | |
| if meta is None: | |
| results.append({ | |
| "rank": h.rank, | |
| "status": "cancelled", | |
| "exit_code": None, | |
| "error": "handle has no metadata (cancelled or unknown)", | |
| "result": None, | |
| "training_job_name": h.metadata.get("training_job_name"), | |
| }) | |
| continue | |
| # Already cached by an earlier poll()/collect(). | |
| if meta["result"] is not None: | |
| results.append(meta["result"]) | |
| continue | |
| job_name = meta["job_name"] | |
| result_dict: dict[str, Any] | None = None | |
| while True: | |
| try: | |
| resp = self._client.describe_training_job( | |
| TrainingJobName=job_name | |
| ) | |
| except Exception as e: | |
| if self._is_resource_not_found(e): | |
| result_dict = { | |
| "rank": h.rank, | |
| "status": "cancelled", | |
| "exit_code": None, | |
| "error": "training job not found (deleted?)", | |
| "result": None, | |
| "training_job_name": job_name, | |
| } | |
| break | |
| raise | |
| sm_status = resp.get("TrainingJobStatus", "InProgress") | |
| if sm_status in ("Completed", "Failed", "Stopped"): | |
| result_dict = self._terminal_result(h.rank, sm_status, resp) | |
| break | |
| if time.time() >= deadline: | |
| result_dict = { | |
| "rank": h.rank, | |
| "status": "running", | |
| "exit_code": None, | |
| "error": "timeout before terminal", | |
| "result": None, | |
| "training_job_name": job_name, | |
| } | |
| break | |
| # Sleep, but never overrun the deadline. | |
| time.sleep(min(poll_interval, max(0.0, deadline - time.time()))) | |
| # Cache only terminal results (not the timeout 'running' sentinel, | |
| # so a later collect() can re-check the job). | |
| if result_dict["status"] in ("succeeded", "failed", "cancelled"): | |
| meta["result"] = result_dict | |
| results.append(result_dict) | |
| return results | |
| # ----------------------------------------------------------------- | |
| # Helpers | |
| # ----------------------------------------------------------------- | |
| def _logs(self) -> Any: | |
| """Lazily build the CloudWatch Logs client (separate from sagemaker).""" | |
| if self._logs_client is None: | |
| self._logs_client = self._make_boto3_client("logs") | |
| return self._logs_client | |
| def _terminal_result( | |
| rank: int, sm_status: str, resp: Mapping[str, Any] | |
| ) -> dict[str, Any]: | |
| """Build a result dict from a terminal describe_training_job response.""" | |
| mapped = _STATUS_MAP.get(sm_status, "failed") | |
| if sm_status == "Completed": | |
| exit_code: int | None = 0 | |
| error = None | |
| elif sm_status == "Stopped": | |
| exit_code = None | |
| error = resp.get("FailureReason") | |
| else: # Failed | |
| exit_code = 1 | |
| error = resp.get("FailureReason") or "training job failed" | |
| artifacts = resp.get("ModelArtifacts", {}) or {} | |
| return { | |
| "rank": rank, | |
| "status": mapped, | |
| "exit_code": exit_code, | |
| "error": error, | |
| "result": artifacts.get("S3ModelArtifacts"), | |
| "training_job_name": resp.get("TrainingJobName"), | |
| } | |
| def _is_resource_not_found(self, exc: Exception) -> bool: | |
| """True if ``exc`` is the boto3 ResourceNotFound for the sagemaker client. | |
| Handles both the typed client exception | |
| (``client.exceptions.ResourceNotFound``) and a generic botocore | |
| ``ClientError`` whose error code is ``ResourceNotFound`` / | |
| ``ValidationException`` naming a missing job — robust across whether a | |
| real boto3 client or a mock is in use. | |
| """ | |
| rnf = getattr(getattr(self._client, "exceptions", None), | |
| "ResourceNotFound", None) | |
| if rnf is not None and isinstance(exc, rnf): | |
| return True | |
| # Generic botocore ClientError fallback. | |
| resp = getattr(exc, "response", None) | |
| if isinstance(resp, Mapping): | |
| code = resp.get("Error", {}).get("Code", "") | |
| if code in ("ResourceNotFound", "ValidationException"): | |
| return True | |
| return False | |
| def _is_already_terminal(self, exc: Exception) -> bool: | |
| """True if ``exc`` is the boto3 "cannot stop a job in status X" error. | |
| ``stop_training_job`` raises a ``ValidationException`` when the job is | |
| already Completed/Failed/Stopped — that is an idempotent no-op for | |
| cancel(), distinct from a genuinely unexpected error. Matched on the | |
| ClientError code + message text (robust to a mock raising a plain | |
| Exception whose message carries the phrase). | |
| """ | |
| resp = getattr(exc, "response", None) | |
| if isinstance(resp, Mapping): | |
| err = resp.get("Error", {}) | |
| if err.get("Code") == "ValidationException": | |
| return True | |
| msg = str(exc).lower() | |
| return ( | |
| "cannot be stopped" in msg | |
| or "already" in msg and ("stopped" in msg or "complete" in msg or "terminal" in msg) | |
| ) | |
| __all__ = ["SageMakerExecutor"] | |