File size: 28,549 Bytes
7a55e1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d9dbbc
 
 
 
 
 
 
 
 
 
7a55e1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d9dbbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a55e1e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
"""SageMakerExecutor — production boto3-backed serverless executor.

This is a fully-working cloud adapter (the sibling of `ModalSpawnExecutor`,
not the loud-failing `modal.py` / `hf_jobs.py` skeletons). It implements the
`ServerlessExecutor` Protocol against Amazon SageMaker Training Jobs via the
boto3 low-level `sagemaker` client.

Design choices
--------------

1. **N independent single-instance jobs, NOT one multi-instance job.**
   SageMaker's *native* distributed training (``ResourceConfig.InstanceCount > 1``)
   groups instances into ONE job with an in-cluster NCCL/MPI fabric wired via
   ``/opt/ml/input/config/resourceconfig.json``. That is the WRONG model for
   DiLoCo replicas — it would couple replicas through SageMaker's intra-job
   network and break the "each replica is an independent DiLoCo worker that
   syncs only through S3" design. So ``launch_replicas`` submits N **separate**
   training jobs, each with ``ResourceConfig.InstanceCount == 1``, tagged with
   ``REPLICA_RANK=i`` / ``WORLD_SIZE=N`` via the ``Environment`` map. This
   mirrors ``ModalSpawnExecutor`` spawning N independent Modal calls.

2. **Same S3 ``ObjectStoreAllReduce`` rendezvous — DiLoCo math untouched.**
   Cross-replica communication is EXCLUSIVELY the object-store rendezvous; the
   executor passes ``rendezvous_uri`` (an ``s3://...`` URI) through to
   ``replica_entrypoint.py`` unchanged. ``allreduce.py`` / ``MockManager`` /
   ``make_diloco_outer_loop`` / the trainer all stay byte-for-byte identical.

3. **Stateless after launch; rank via ``Environment``.** Handle metadata is the
   ``training_job_name`` (plus submit timestamp). ``replica_entrypoint.py``
   already reads ``REPLICA_RANK`` from ``os.environ``, so the cleanest channel
   is the ``Environment`` map (string->string, max 100 entries, value <= 512
   chars). The container command is baked into the image entrypoint and the
   rendezvous args are passed via ``AlgorithmSpecification.ContainerArguments``.

4. **``supports_inter_replica_network = False``.** Separate single-instance
   training jobs have no mutual network path by design — they rendezvous only
   through S3. (SageMaker's algo-N container fabric and
   ``EnableInterContainerTrafficEncryption`` only exist WITHIN a single
   multi-instance job, which this design deliberately does not use.)

Load-bearing gotcha — ``EnableNetworkIsolation`` MUST stay ``False``
--------------------------------------------------------------------
When ``EnableNetworkIsolation=True`` the training *container* has no outbound
network access. SageMaker's host-side processes still stage input channels and
ship CloudWatch logs, but the container itself cannot make S3 GET/PUT calls.
``ObjectStoreAllReduce`` needs live S3 PUT+GET every outer round, so network
isolation would silently dead-lock the allreduce poll loop until its timeout.
This executor pins ``EnableNetworkIsolation=False`` (the API default) and never
exposes it as a knob. The rendezvous bucket access must instead be granted on
the execution ``RoleArn`` — the SageMaker analog of EKS IRSA.

HyperPod <-> EKS 1:1 control-plane mapping (recommended hybrid)
---------------------------------------------------------------
Per the SageMaker docs: *"The high-level architecture of Amazon EKS support in
HyperPod involves a 1-to-1 mapping between an EKS cluster (control plane) and a
HyperPod cluster (worker nodes) within a VPC."*
(https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-hyperpod-eks.html)

Consequence for this repo's hybrid: "use HyperPod for the inner GRPO trainer"
does NOT mean leaving EKS — it means attaching a HyperPod-managed
(auto-recovering, deep-health-checked, PyTorch-job auto-resume) node-group to
the SAME EKS control plane that runs the outer loop. A future ``EKSExecutor``
(kubernetes client, Indexed Jobs) therefore targets both plain Karpenter GPU
nodes AND HyperPod nodes transparently. ``SageMakerExecutor`` (ephemeral
Training Jobs via boto3) is the SEPARATE bursty-fallback inner-loop path for
when you don't want a persistent cluster: Training Jobs suit periodic /
smaller-model / pay-per-use runs; HyperPod suits continuous / large-model /
persistent runs. Both share the IDENTICAL S3 rendezvous, so a run can move
between them with zero trainer / loss / DiLoCo changes.

References
----------
- create_training_job: https://docs.aws.amazon.com/boto3/latest/reference/services/sagemaker/client/create_training_job.html
- describe_training_job: https://docs.aws.amazon.com/boto3/latest/reference/services/sagemaker/client/describe_training_job.html
- stop_training_job: https://docs.aws.amazon.com/boto3/latest/reference/services/sagemaker/client/stop_training_job.html
- network isolation: https://repost.aws/knowledge-center/sagemaker-access-network-isolation
- HyperPod-EKS: https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-hyperpod-eks.html
- ADR-005 (executor protocol design)
"""
from __future__ import annotations

import json
import time
import uuid
from collections.abc import Callable, Mapping
from typing import Any

from composer_replication.diloco.serverless.executor import (
    ReplicaHandle,
)

# SageMaker TrainingJobStatus -> Protocol status vocabulary.
# describe_training_job's TrainingJobStatus is EXACTLY one of:
#   'InProgress' | 'Completed' | 'Failed' | 'Stopping' | 'Stopped'.
# We map Stopping -> 'running' (transient; still terminating, so collect()
# keeps waiting) and Stopped -> 'cancelled'.
_STATUS_MAP = {
    "InProgress": "running",
    "Completed": "succeeded",
    "Failed": "failed",
    "Stopping": "running",
    "Stopped": "cancelled",
}

# SecondaryStatus values that mean "queued / not yet executing user code" —
# used to refine an InProgress job into the Protocol's 'pending'.
_PENDING_SECONDARY = frozenset(
    {"Starting", "Pending", "LaunchingMLInstances", "PreparingTrainingStack"}
)

# Abstract Protocol GPU strings -> SageMaker instance types.
_GPU_INSTANCE_MAP = {
    "A100": "ml.p4d.24xlarge",
    "H100": "ml.p5.48xlarge",
    "H200": "ml.p5e.48xlarge",
    "B200": "ml.p6-b200.48xlarge",
    "L40S": "ml.g6e.12xlarge",
    "A10G": "ml.g5.2xlarge",
    "L4": "ml.g6.2xlarge",
}

_CLOUDWATCH_LOG_GROUP = "/aws/sagemaker/TrainingJobs"


class SageMakerExecutor:
    """Run replicas as N independent SageMaker Training Jobs.

    Implements the `ServerlessExecutor` Protocol against the boto3
    ``sagemaker`` client. Each replica is one single-instance training job;
    cross-replica communication happens only through the shared S3
    ``ObjectStoreAllReduce`` rendezvous.

    Args:
        role_arn: IAM execution role SageMaker assumes for the job. Must grant
            S3 access to the rendezvous + output buckets (the boto3 analog of
            EKS IRSA). The caller's credentials need ``iam:PassRole`` on it.
        image_uri: ECR image URI for the training container. The image must
            bake an entrypoint that runs
            ``python -m composer_replication.diloco.serverless.replica_entrypoint``
            (this executor also passes ``ContainerEntrypoint`` explicitly so a
            generic image works too).
        output_s3_path: ``s3://...`` prefix for ``OutputDataConfig.S3OutputPath``
            (model artifacts / failure output).
        instance_type: default SageMaker instance type when ``gpu`` is not
            mapped (e.g. ``"ml.g5.2xlarge"``). ``gpu=None`` at launch falls
            back to ``cpu_instance_type``.
        cpu_instance_type: instance type used when ``gpu`` is ``None`` (CPU
            smoke tests). Default ``"ml.m5.xlarge"``.
        volume_size_gb: ``ResourceConfig.VolumeSizeInGB`` per job.
        run_id: prefix for generated training-job names. Defaults to a short
            random token so names are unique per region+account.
        region: AWS region for the lazily-constructed boto3 clients. ``None``
            uses the ambient boto3 default-region resolution.
        sagemaker_client: inject a pre-built ``boto3.client('sagemaker')`` (or a
            mock) instead of constructing one. Used by tests.
        logs_client: inject a pre-built ``boto3.client('logs')`` (or a mock).

    Raises:
        RuntimeError: if boto3 is not installed and no client was injected.
    """

    backend_name = "sagemaker"
    # Separate single-instance jobs have no mutual network path — S3 only.
    supports_inter_replica_network = False

    def __init__(
        self,
        *,
        role_arn: str,
        image_uri: str,
        output_s3_path: str,
        instance_type: str = "ml.g5.2xlarge",
        cpu_instance_type: str = "ml.m5.xlarge",
        volume_size_gb: int = 100,
        run_id: str | None = None,
        region: str | None = None,
        sagemaker_client: Any = None,
        logs_client: Any = None,
    ) -> None:
        self.role_arn = role_arn
        self.image_uri = image_uri
        self.output_s3_path = output_s3_path
        self.instance_type = instance_type
        self.cpu_instance_type = cpu_instance_type
        self.volume_size_gb = volume_size_gb
        self.run_id = run_id or f"diloco-{uuid.uuid4().hex[:8]}"
        self._region = region

        # Lazy boto3 — only constructed if the caller didn't inject a client.
        # This keeps `import composer_replication.diloco.serverless` free of a
        # hard boto3 dependency (boto3 lives in the optional [aws] extra), and
        # lets tests inject a _MockSMClient with zero AWS calls.
        if sagemaker_client is None:
            sagemaker_client = self._make_boto3_client("sagemaker")
        self._client = sagemaker_client
        self._logs_client = logs_client  # built lazily on first stream_logs()

        # rank -> {"job_name": str, "result": dict | None}
        self._handles: dict[int, dict[str, Any]] = {}

    # -----------------------------------------------------------------
    # boto3 plumbing (lazy)
    # -----------------------------------------------------------------

    def _make_boto3_client(self, service: str) -> Any:
        try:
            import boto3
        except ImportError as e:
            raise RuntimeError(
                "SageMakerExecutor requires boto3. Install with "
                "`pip install -e .[aws]` (or `pip install boto3`). "
                f"Got: {e!r}"
            ) from e
        if self._region is not None:
            return boto3.client(service, region_name=self._region)
        return boto3.client(service)

    def _map_gpu(self, gpu: str | None) -> str:
        """Translate the Protocol's abstract gpu string to an instance type.

        ``gpu=None`` -> ``cpu_instance_type`` (smoke tests). Unrecognized gpu
        strings fall back to ``instance_type`` (so a caller can pass a literal
        SageMaker instance type and it's honoured if not in the map).
        """
        if gpu is None:
            return self.cpu_instance_type
        if gpu in _GPU_INSTANCE_MAP:
            return _GPU_INSTANCE_MAP[gpu]
        # Caller may have passed a literal "ml.*" instance type.
        if gpu.startswith("ml."):
            return gpu
        return self.instance_type

    def _job_name(self, rank: int) -> str:
        """Build a unique, regex-safe training-job name (<= 63 chars).

        Pattern required by the API: ``[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}``.
        """
        name = f"{self.run_id}-r{rank:04d}-{int(time.time())}"
        return name[:63]

    # -----------------------------------------------------------------
    # ServerlessExecutor Protocol
    # -----------------------------------------------------------------

    def launch_replicas(
        self,
        n_replicas: int,
        entrypoint: str | Callable[..., Any],
        entrypoint_args: Mapping[str, Any],
        *,
        gpu: str | None = None,
        timeout: int = 3600,
    ) -> list[ReplicaHandle]:
        """Submit N independent single-instance SageMaker Training Jobs.

        Args:
            n_replicas: number of replicas (= number of training jobs).
            entrypoint: ignored — the container command is baked into the
                image / passed as ``ContainerEntrypoint``. Kept for Protocol
                compatibility.
            entrypoint_args: must contain ``rendezvous_uri`` (``s3://...``) and
                ``trainer_module``. Optional: ``trainer_fn`` (default
                ``"train"``), ``trainer_kwargs`` (dict, JSON-encoded into the
                container args). The conventional ``rank_env`` key (from
                ``LocalProcessExecutor``) is ignored — rank goes through the
                ``Environment`` map instead.
            gpu: abstract GPU spec mapped to an instance type via ``_map_gpu``.
                ``None`` -> CPU instance.
            timeout: ``StoppingCondition.MaxRuntimeInSeconds`` per job.

        Returns:
            ``list[ReplicaHandle]`` of length ``n_replicas`` in rank order
            (``handles[i].rank == i``).
        """
        del entrypoint  # container command is baked / passed explicitly

        if n_replicas < 1:
            raise ValueError(f"n_replicas must be >= 1, got {n_replicas}")

        rendezvous_uri = entrypoint_args.get("rendezvous_uri")
        if not rendezvous_uri:
            raise ValueError(
                "entrypoint_args must include 'rendezvous_uri' (the s3:// "
                "ObjectStoreAllReduce rendezvous prefix)."
            )
        trainer_module = entrypoint_args.get("trainer_module")
        if not trainer_module:
            raise ValueError(
                "entrypoint_args must include 'trainer_module' (importable "
                "module path of the user's train function)."
            )
        trainer_fn = entrypoint_args.get("trainer_fn", "train")
        trainer_kwargs = entrypoint_args.get("trainer_kwargs", {})

        instance_type = self._map_gpu(gpu)

        # Container args: each element is a SINGLE token (StackOverflow
        # 77994925 — `['--world-size', '4']` NOT `['--world-size 4']`).
        container_args = [
            "--rendezvous", str(rendezvous_uri),
            "--world-size", str(n_replicas),
            "--trainer-module", str(trainer_module),
            "--trainer-fn", str(trainer_fn),
            "--trainer-kwargs-json", json.dumps(trainer_kwargs),
        ]

        handles: list[ReplicaHandle] = []
        for rank in range(n_replicas):
            job_name = self._job_name(rank)
            request = {
                "TrainingJobName": job_name,
                "AlgorithmSpecification": {
                    "TrainingImage": self.image_uri,
                    "TrainingInputMode": "File",
                    "ContainerEntrypoint": [
                        "python", "-m",
                        "composer_replication.diloco.serverless.replica_entrypoint",
                    ],
                    "ContainerArguments": container_args,
                },
                "RoleArn": self.role_arn,
                # InputDataConfig intentionally omitted — the replica pulls
                # data via its own code / the S3 rendezvous, not SM channels.
                "OutputDataConfig": {"S3OutputPath": self.output_s3_path},
                "ResourceConfig": {
                    "InstanceType": instance_type,
                    "InstanceCount": 1,
                    "VolumeSizeInGB": self.volume_size_gb,
                },
                "StoppingCondition": {"MaxRuntimeInSeconds": int(timeout)},
                # REPLICA_RANK / WORLD_SIZE injected as container env vars;
                # replica_entrypoint.py reads os.environ['REPLICA_RANK'].
                "Environment": {
                    "REPLICA_RANK": str(rank),
                    "WORLD_SIZE": str(n_replicas),
                    "RENDEZVOUS_URI": str(rendezvous_uri),
                },
                # MUST stay False — True severs the container's S3 access and
                # dead-locks the allreduce poll loop. See module docstring.
                "EnableNetworkIsolation": False,
            }
            try:
                self._client.create_training_job(**request)
            except Exception as e:
                # Best-effort stop of already-launched siblings, then raise.
                for prior in handles:
                    try:
                        self.cancel(prior)
                    except Exception:
                        pass
                raise RuntimeError(
                    f"SageMakerExecutor.launch_replicas failed at rank={rank} "
                    f"of {n_replicas} (already-launched siblings stopped). "
                    f"Underlying error: {e!r}"
                ) from e

            handle = ReplicaHandle(
                rank=rank,
                backend_name=self.backend_name,
                metadata={
                    "training_job_name": job_name,
                    "submit_ts": time.time(),
                },
            )
            self._handles[rank] = {"job_name": job_name, "result": None}
            handles.append(handle)

        return handles

    def poll(self, handle: ReplicaHandle) -> str:
        """Poll a training job's status.

        Returns one of: ``"pending"`` | ``"running"`` | ``"succeeded"`` |
        ``"failed"`` | ``"cancelled"``.

        Maps ``describe_training_job``'s ``TrainingJobStatus`` via
        ``_STATUS_MAP``; refines ``InProgress`` to ``"pending"`` while the job
        is still queued (``SecondaryStatus`` in ``_PENDING_SECONDARY``). A
        vanished job (``ResourceNotFound``) is treated as ``"cancelled"``.
        """
        meta = self._handles.get(handle.rank)
        if meta is None:
            return "cancelled"
        if meta["result"] is not None:
            return meta["result"]["status"]

        job_name = meta["job_name"]
        try:
            resp = self._client.describe_training_job(TrainingJobName=job_name)
        except Exception as e:
            if self._is_resource_not_found(e):
                return "cancelled"
            raise

        sm_status = resp.get("TrainingJobStatus", "InProgress")
        mapped = _STATUS_MAP.get(sm_status, "running")

        if sm_status == "InProgress":
            if resp.get("SecondaryStatus") in _PENDING_SECONDARY:
                return "pending"
            return "running"

        # Terminal — cache a result dict so collect()/repeat-poll are cheap.
        meta["result"] = self._terminal_result(handle.rank, sm_status, resp)
        return mapped

    def stream_logs(self, handle: ReplicaHandle, *, n_lines: int = 200) -> str:
        """Read recent CloudWatch logs for this replica's training job.

        SageMaker writes container stdout/stderr to the
        ``/aws/sagemaker/TrainingJobs`` log group, stream
        ``<job-name>/algo-<n>-<epoch>``. We discover the exact stream name by
        prefix then read the tail. Falls back to a CloudWatch console pointer
        on any error (mirrors ModalSpawnExecutor's dashboard-URL fallback).
        """
        meta = self._handles.get(handle.rank)
        if meta is None:
            return f"<replica {handle.rank}: no metadata>"
        job_name = meta["job_name"]

        try:
            logs = self._logs()
            prefix = f"{job_name}/"
            streams = logs.describe_log_streams(
                logGroupName=_CLOUDWATCH_LOG_GROUP,
                logStreamNamePrefix=prefix,
                orderBy="LastEventTime",
                descending=True,
                limit=1,
            )
            stream_list = streams.get("logStreams", [])
            if not stream_list:
                return (
                    f"[rank {handle.rank}] job={job_name}: no CloudWatch log "
                    f"stream yet (job pending / not started)."
                )
            stream_name = stream_list[0]["logStreamName"]
            events = logs.get_log_events(
                logGroupName=_CLOUDWATCH_LOG_GROUP,
                logStreamName=stream_name,
                limit=n_lines,
                startFromHead=False,
            )
            lines = [e.get("message", "") for e in events.get("events", [])]
            body = "\n".join(lines) if lines else "<no log events>"
            return f"[rank {handle.rank}] job={job_name} stream={stream_name}\n{body}"
        except Exception as e:
            region = self._region or "<region>"
            url = (
                f"https://{region}.console.aws.amazon.com/cloudwatch/home"
                f"?region={region}#logsV2:log-groups/log-group/"
                f"$252Faws$252Fsagemaker$252FTrainingJobs"
            )
            return (
                f"[rank {handle.rank}] job={job_name}: log fetch failed "
                f"({type(e).__name__}: {e!r}).\n  CloudWatch console: {url}"
            )

    def cancel(self, handle: ReplicaHandle) -> None:
        """Best-effort stop of a training job.

        Calls ``stop_training_job`` (SIGTERM + 120s grace), swallowing
        ``ResourceNotFound`` and "already terminal" ``ValidationException`` so
        the contract — "no exception if already terminated" — holds.
        """
        meta = self._handles.get(handle.rank)
        if meta is None:
            return
        try:
            self._client.stop_training_job(TrainingJobName=meta["job_name"])
        except Exception as e:
            # R5: swallow ONLY already-terminated signals — a vanished job
            # (ResourceNotFound) or an already-Completed/Stopped job (boto3
            # raises ValidationException for "cannot stop a job in status X").
            # A genuinely unexpected error (AccessDenied, throttling that
            # outlived retries, malformed request) must propagate rather than
            # masquerade as a successful cancel.
            if self._is_resource_not_found(e) or self._is_already_terminal(e):
                return
            raise

    def collect(
        self,
        handles: list[ReplicaHandle],
        *,
        timeout: int | None = None,
    ) -> list[dict[str, Any]]:
        """Block until all replicas finish; return per-replica result dicts.

        Polls ``describe_training_job`` per handle until the job reaches a
        terminal status (``Completed`` / ``Failed`` / ``Stopped``) or the
        shared deadline elapses. Returns results aligned to the input handle
        order (Protocol contract; mirrors ``LocalProcessExecutor.collect``).

        Each result dict has at least
        ``{"rank", "status", "exit_code", "error"}``.
        """
        deadline = time.time() + (timeout if timeout is not None else 86400)
        poll_interval = 30.0
        results: list[dict[str, Any]] = []

        for h in handles:
            meta = self._handles.get(h.rank)
            if meta is None:
                results.append({
                    "rank": h.rank,
                    "status": "cancelled",
                    "exit_code": None,
                    "error": "handle has no metadata (cancelled or unknown)",
                    "result": None,
                    "training_job_name": h.metadata.get("training_job_name"),
                })
                continue

            # Already cached by an earlier poll()/collect().
            if meta["result"] is not None:
                results.append(meta["result"])
                continue

            job_name = meta["job_name"]
            result_dict: dict[str, Any] | None = None
            while True:
                try:
                    resp = self._client.describe_training_job(
                        TrainingJobName=job_name
                    )
                except Exception as e:
                    if self._is_resource_not_found(e):
                        result_dict = {
                            "rank": h.rank,
                            "status": "cancelled",
                            "exit_code": None,
                            "error": "training job not found (deleted?)",
                            "result": None,
                            "training_job_name": job_name,
                        }
                        break
                    raise

                sm_status = resp.get("TrainingJobStatus", "InProgress")
                if sm_status in ("Completed", "Failed", "Stopped"):
                    result_dict = self._terminal_result(h.rank, sm_status, resp)
                    break

                if time.time() >= deadline:
                    result_dict = {
                        "rank": h.rank,
                        "status": "running",
                        "exit_code": None,
                        "error": "timeout before terminal",
                        "result": None,
                        "training_job_name": job_name,
                    }
                    break

                # Sleep, but never overrun the deadline.
                time.sleep(min(poll_interval, max(0.0, deadline - time.time())))

            # Cache only terminal results (not the timeout 'running' sentinel,
            # so a later collect() can re-check the job).
            if result_dict["status"] in ("succeeded", "failed", "cancelled"):
                meta["result"] = result_dict
            results.append(result_dict)

        return results

    # -----------------------------------------------------------------
    # Helpers
    # -----------------------------------------------------------------

    def _logs(self) -> Any:
        """Lazily build the CloudWatch Logs client (separate from sagemaker)."""
        if self._logs_client is None:
            self._logs_client = self._make_boto3_client("logs")
        return self._logs_client

    @staticmethod
    def _terminal_result(
        rank: int, sm_status: str, resp: Mapping[str, Any]
    ) -> dict[str, Any]:
        """Build a result dict from a terminal describe_training_job response."""
        mapped = _STATUS_MAP.get(sm_status, "failed")
        if sm_status == "Completed":
            exit_code: int | None = 0
            error = None
        elif sm_status == "Stopped":
            exit_code = None
            error = resp.get("FailureReason")
        else:  # Failed
            exit_code = 1
            error = resp.get("FailureReason") or "training job failed"
        artifacts = resp.get("ModelArtifacts", {}) or {}
        return {
            "rank": rank,
            "status": mapped,
            "exit_code": exit_code,
            "error": error,
            "result": artifacts.get("S3ModelArtifacts"),
            "training_job_name": resp.get("TrainingJobName"),
        }

    def _is_resource_not_found(self, exc: Exception) -> bool:
        """True if ``exc`` is the boto3 ResourceNotFound for the sagemaker client.

        Handles both the typed client exception
        (``client.exceptions.ResourceNotFound``) and a generic botocore
        ``ClientError`` whose error code is ``ResourceNotFound`` /
        ``ValidationException`` naming a missing job — robust across whether a
        real boto3 client or a mock is in use.
        """
        rnf = getattr(getattr(self._client, "exceptions", None),
                      "ResourceNotFound", None)
        if rnf is not None and isinstance(exc, rnf):
            return True
        # Generic botocore ClientError fallback.
        resp = getattr(exc, "response", None)
        if isinstance(resp, Mapping):
            code = resp.get("Error", {}).get("Code", "")
            if code in ("ResourceNotFound", "ValidationException"):
                return True
        return False

    def _is_already_terminal(self, exc: Exception) -> bool:
        """True if ``exc`` is the boto3 "cannot stop a job in status X" error.

        ``stop_training_job`` raises a ``ValidationException`` when the job is
        already Completed/Failed/Stopped — that is an idempotent no-op for
        cancel(), distinct from a genuinely unexpected error. Matched on the
        ClientError code + message text (robust to a mock raising a plain
        Exception whose message carries the phrase).
        """
        resp = getattr(exc, "response", None)
        if isinstance(resp, Mapping):
            err = resp.get("Error", {})
            if err.get("Code") == "ValidationException":
                return True
        msg = str(exc).lower()
        return (
            "cannot be stopped" in msg
            or "already" in msg and ("stopped" in msg or "complete" in msg or "terminal" in msg)
        )


__all__ = ["SageMakerExecutor"]