File size: 28,397 Bytes
7a55e1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d9dbbc
 
 
 
 
 
 
 
7a55e1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d9dbbc
 
 
 
 
 
7a55e1e
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
"""EKSExecutor — production Amazon EKS / Kubernetes-backed serverless executor.

This is the v0-finished k8s sibling of `ModalSpawnExecutor`. It implements
the `ServerlessExecutor` Protocol against the Kubernetes ``BatchV1Api`` using
the **single Indexed Job** topology recommended for gang-scheduled DiLoCo
replicas.

Topology (the load-bearing design choice)
------------------------------------------
There are two ways to map N replicas onto k8s:

  (A) ONE Indexed Job — ``completions=N, parallelism=N,
      completionMode='Indexed'``. The control plane assigns each pod a
      ``JOB_COMPLETION_INDEX`` 0..N-1 which IS the rank, all pods share one
      rendezvous URI, scheduling is atomic, and a single delete cancels the
      whole cohort.
  (B) N separate non-indexed Jobs, one per rank.

`EKSExecutor` uses **(A)** because it is the better fit for DiLoCo: rank
assignment is free, scheduling is gang-atomic, and one delete tears down the
cohort — which matches ``ObjectStoreAllReduce``'s all-or-nothing barrier. The
reconciliation with the per-replica ``ReplicaHandle`` contract: ``launch_replicas``
creates ONE Indexed Job but still returns N ``ReplicaHandle`` objects
(``handles[i].rank == i``) whose ``metadata`` stores the SHARED
``job_name``/``namespace`` plus that rank.

This is materially different from ``ModalSpawnExecutor`` where each handle is
an independent ``FunctionCall``:

  * ``poll(handle)`` reads the single Job status and checks whether
    ``handle.rank`` is in the run-length-compressed ``completed_indexes`` /
    ``failed_indexes`` strings.
  * ``cancel(handle)`` on ANY handle deletes the WHOLE Job (intentional gang
    semantics — cancelling one rank tears down the whole replica cohort).

Rank plumbing
-------------
The repo's ``replica_entrypoint`` reads ``REPLICA_RANK``. We bridge the k8s
completion-index to that env var via the downward API rather than relying on
the auto-injected ``JOB_COMPLETION_INDEX``::

    V1EnvVar(
        name="REPLICA_RANK",
        value_from=V1EnvVarSource(field_ref=V1ObjectFieldSelector(
            field_path="metadata.annotations['batch.kubernetes.io/job-completion-index']")),
    )

so the unchanged entrypoint's ``REPLICA_RANK`` read just works. ``WORLD_SIZE``
is set as a literal env var.

S3 rendezvous via IRSA / Pod Identity
-------------------------------------
``EKSExecutor`` accepts ``service_account_name`` and references it on the
PodSpec. The EKS Pod Identity / IRSA mutating webhook then injects
``AWS_ROLE_ARN`` + ``AWS_WEB_IDENTITY_TOKEN_FILE`` (and a projected token
volume) into the pod, so ``boto3``/``s3fs``/``fsspec`` pick up credentials via
the web-identity provider with ZERO code change inside the replica — the
``s3://`` rendezvous works out of the box. ``EKSExecutor`` only REFERENCES a
pre-annotated ServiceAccount; it never creates IAM/OIDC resources.

Sandboxing (advanced, optional)
-------------------------------
``runtime_class_name`` references a pre-existing cluster-scoped ``RuntimeClass``
(``runsc`` for gVisor, ``kata`` for Kata). It defaults to ``None``.

.. warning::
   Combining ``gpu`` with ``runtime_class_name`` is advanced and unverified.
   gVisor (runsc) needs ``nvproxy`` enabled and only supports a fixed allowlist
   of NVIDIA driver families; Kata runs a microVM that caps CPU/mem and needs
   GPU passthrough (PCIe/IOMMU + NVIDIA Kata Manager + CDI). Do not silently
   combine the two without operator validation. ``EKSExecutor`` cannot create
   the RuntimeClass — it only references one that already exists.

References
----------
- k8s Indexed Jobs: https://kubernetes.io/docs/tasks/job/indexed-parallel-processing-static/
- kubernetes-client/python job_crud example + V1JobSpec / V1JobStatus docs
- EKS IRSA: https://docs.aws.amazon.com/eks/latest/userguide/iam-roles-for-service-accounts.html
- ADR-005 (executor protocol design)
"""
from __future__ import annotations

import time
import uuid
from collections.abc import Callable, Mapping
from typing import Any

from composer_replication.diloco.serverless.executor import (
    ReplicaHandle,
)

# Logical GPU spec ("A100"/"H100") -> (gpu_count_string, node_selector merge).
# The Protocol's `gpu` arg is a logical name; map it to a concrete EKS node
# class + GPU count rather than passing the opaque string straight through.
_GPU_SPEC_TABLE: dict[str, tuple[str, dict[str, str]]] = {
    "A100": ("1", {"node.kubernetes.io/instance-type": "p4d.24xlarge"}),
    "H100": ("1", {"node.kubernetes.io/instance-type": "p5.48xlarge"}),
    "A10G": ("1", {"node.kubernetes.io/instance-type": "g5.xlarge"}),
    "T4": ("1", {"node.kubernetes.io/instance-type": "g4dn.xlarge"}),
}


def _expand_indexes(spec: str | None) -> set[int]:
    """Expand a run-length-compressed completion-index string to a set.

    The k8s ``V1JobStatus.completed_indexes`` / ``failed_indexes`` fields are
    strings like ``"1,3-5,7"`` (comma-separated singletons and ``a-b`` ranges).
    ``_expand_indexes("1,3-5,7") == {1, 3, 4, 5, 7}``. Empty/None -> empty set.
    """
    out: set[int] = set()
    if not spec:
        return out
    for token in spec.split(","):
        token = token.strip()
        if not token:
            continue
        if "-" in token:
            lo_s, _, hi_s = token.partition("-")
            try:
                lo, hi = int(lo_s), int(hi_s)
            except ValueError:
                continue
            if hi < lo:
                lo, hi = hi, lo
            out.update(range(lo, hi + 1))
        else:
            try:
                out.add(int(token))
            except ValueError:
                continue
    return out


class EKSExecutor:
    """Run N DiLoCo replicas as a single Kubernetes Indexed Job on EKS.

    Implements the `ServerlessExecutor` Protocol. ``launch_replicas`` creates
    ONE Indexed Job (``completions == parallelism == n_replicas``,
    ``completionMode='Indexed'``) and returns N ``ReplicaHandle`` objects that
    all share the same ``job_name``/``namespace`` (gang semantics).

    Args:
        image: container image that has ``composer_replication`` installed and
            runs the replica entrypoint.
        namespace: k8s namespace for the Job. Default ``"default"``.
        service_account_name: ServiceAccount to attach to the PodSpec for IRSA /
            EKS Pod Identity S3 access. ``EKSExecutor`` references it; it does
            NOT create it or any IAM/OIDC resources.
        node_selector: extra node selector merged into the GPU node selector.
        tolerations: PodSpec tolerations. If GPU is requested and the caller did
            not supply tolerations, the standard ``nvidia.com/gpu`` NoSchedule
            toleration is added automatically.
        runtime_class_name: optional pre-existing RuntimeClass (e.g. ``"gvisor"``
            / ``"kata"``). Default ``None``. See the module-level warning before
            combining with ``gpu``.
        command: container command. Defaults to the repo replica entrypoint
            module ``["python", "-m",
            "composer_replication.diloco.serverless.replica_entrypoint"]``.
        cpu_request / memory_request: PodSpec resource requests.
        ttl_seconds_after_finished: auto-delete the finished Job (and its pods,
            cascadingly) after this many seconds. Default 3600.
        backoff_limit: Job retry budget. Default 0 (fail-fast — RL gangs usually
            do NOT want the k8s default of 6 retries).
        gpu_resource_key: the GPU resource key. Default ``"nvidia.com/gpu"``.
        run_id: optional run id baked into the generated Job name.
        batch_api / core_api: dependency-injected ``BatchV1Api`` / ``CoreV1Api``
            instances. When ``None`` (the default), they are built lazily on
            first use via in-cluster or kube-config loading. Tests inject mocks.

    Raises:
        RuntimeError: if the ``kubernetes`` client is not installed AND no api
            was injected (the import is needed to construct V1 model objects).
    """

    backend_name = "eks"
    # Pods are network-isolated by default; rendezvous is S3 (ObjectStoreAllReduce).
    supports_inter_replica_network = False

    def __init__(
        self,
        image: str,
        *,
        namespace: str = "default",
        service_account_name: str | None = None,
        node_selector: dict[str, str] | None = None,
        tolerations: list[Any] | None = None,
        runtime_class_name: str | None = None,
        command: list[str] | None = None,
        cpu_request: str = "4",
        memory_request: str = "16Gi",
        ttl_seconds_after_finished: int = 3600,
        backoff_limit: int = 0,
        gpu_resource_key: str = "nvidia.com/gpu",
        run_id: str | None = None,
        batch_api: Any = None,
        core_api: Any = None,
    ) -> None:
        # `kubernetes` is only strictly required when we have to BUILD V1 model
        # objects ourselves (launch_replicas) or load cluster config (when no
        # api is injected). We surface a clear error here only if we definitely
        # need it and it is absent — i.e. when no api was injected. When apis
        # ARE injected (tests, or callers that pre-built clients), we tolerate a
        # missing top-level package and lazy-import `client` per call.
        if batch_api is None or core_api is None:
            try:
                import kubernetes  # noqa: F401
            except ImportError as e:
                raise RuntimeError(
                    'EKSExecutor requires the kubernetes client: '
                    'pip install "kubernetes>=29" (or '
                    "`pip install -e .[serverless]`). Got: " + repr(e)
                ) from e

        self.image = image
        self.namespace = namespace
        self.service_account_name = service_account_name
        self.node_selector = dict(node_selector) if node_selector else None
        self.tolerations = list(tolerations) if tolerations else None
        self.runtime_class_name = runtime_class_name
        self.command = command or [
            "python",
            "-m",
            "composer_replication.diloco.serverless.replica_entrypoint",
        ]
        self.cpu_request = cpu_request
        self.memory_request = memory_request
        self.ttl_seconds_after_finished = ttl_seconds_after_finished
        self.backoff_limit = backoff_limit
        self.gpu_resource_key = gpu_resource_key
        self.run_id = run_id or "diloco"

        self._batch_api = batch_api
        self._core_api = core_api
        # rank -> {"job_name", "namespace", "result"}; lets poll/collect cache.
        self._handles: dict[int, dict[str, Any]] = {}

    # -----------------------------------------------------------------
    # Lazy client construction (config loading only when not injected)
    # -----------------------------------------------------------------

    def _load_config(self) -> None:
        """Load k8s config once: in-cluster first, then ~/.kube/config."""
        from kubernetes import config

        try:
            config.load_incluster_config()
        except config.ConfigException:
            config.load_kube_config()

    def _batch(self) -> Any:
        if self._batch_api is None:
            from kubernetes import client

            self._load_config()
            self._batch_api = client.BatchV1Api()
        return self._batch_api

    def _core(self) -> Any:
        if self._core_api is None:
            from kubernetes import client

            self._load_config()
            self._core_api = client.CoreV1Api()
        return self._core_api

    # -----------------------------------------------------------------
    # Job-spec construction
    # -----------------------------------------------------------------

    def _build_env(
        self, world_size: int, entrypoint_args: Mapping[str, Any]
    ) -> list[Any]:
        """Build the container env list, including the downward-API rank var."""
        from kubernetes import client

        env: list[Any] = [
            # REPLICA_RANK from the per-pod completion-index annotation via the
            # downward API — bridges k8s indexing to the repo entrypoint's
            # REPLICA_RANK read with no entrypoint change.
            client.V1EnvVar(
                name="REPLICA_RANK",
                value_from=client.V1EnvVarSource(
                    field_ref=client.V1ObjectFieldSelector(
                        field_path=(
                            "metadata.annotations["
                            "'batch.kubernetes.io/job-completion-index']"
                        )
                    )
                ),
            ),
            client.V1EnvVar(name="WORLD_SIZE", value=str(world_size)),
        ]
        # rendezvous_uri (and any other scalar kwargs) passed as literal env so
        # the entrypoint / user code can read them. `rank_env` is the
        # LocalProcessExecutor convention — drop it (same as ModalSpawnExecutor).
        for key, value in entrypoint_args.items():
            if key == "rank_env":
                continue
            if isinstance(value, (str, int, float, bool)):
                env.append(
                    client.V1EnvVar(name=key.upper(), value=str(value))
                )
        return env

    def _build_resources(self, gpu: str | None) -> tuple[Any, dict[str, str], list[Any]]:
        """Build V1ResourceRequirements + (node_selector, tolerations) for GPU.

        Returns (resources, node_selector, tolerations). The GPU count is
        ALWAYS a STRING ('1', not int 1) — the OpenAPI type for the limits map
        is dict[str, str] and an int can serialize wrong or raise.
        """
        from kubernetes import client

        requests = {"cpu": self.cpu_request, "memory": self.memory_request}
        limits: dict[str, str] = {}
        node_selector: dict[str, str] = dict(self.node_selector or {})
        tolerations: list[Any] = list(self.tolerations or [])

        if gpu is not None:
            gpu_count, gpu_node_selector = _GPU_SPEC_TABLE.get(
                gpu, ("1", {})
            )
            # STRING, always.
            limits[self.gpu_resource_key] = str(gpu_count)
            # Merge the mapped node selector under any caller-supplied one
            # (caller wins on key conflicts).
            for k, v in gpu_node_selector.items():
                node_selector.setdefault(k, v)
            # Auto-add the GPU NoSchedule toleration unless the caller overrode
            # tolerations explicitly.
            if not self.tolerations:
                tolerations.append(
                    client.V1Toleration(
                        key=self.gpu_resource_key,
                        operator="Exists",
                        effect="NoSchedule",
                    )
                )

        resources = client.V1ResourceRequirements(
            requests=requests,
            limits=limits or None,
        )
        return resources, node_selector, tolerations

    def _build_job(
        self,
        *,
        job_name: str,
        n_replicas: int,
        gpu: str | None,
        timeout: int,
        entrypoint_args: Mapping[str, Any],
    ) -> Any:
        """Assemble the full V1Job (Indexed) bottom-up."""
        from kubernetes import client

        env = self._build_env(n_replicas, entrypoint_args)
        resources, node_selector, tolerations = self._build_resources(gpu)

        container = client.V1Container(
            name="replica",
            image=self.image,
            command=list(self.command),
            env=env,
            resources=resources,
        )

        pod_spec = client.V1PodSpec(
            restart_policy="Never",  # required for Indexed jobs / fail-fast RL
            containers=[container],
            service_account_name=self.service_account_name,
            node_selector=node_selector or None,
            tolerations=tolerations or None,
            runtime_class_name=self.runtime_class_name,
        )

        labels = {"app": "composer-diloco", "job-name": job_name}
        pod_template = client.V1PodTemplateSpec(
            metadata=client.V1ObjectMeta(labels=labels),
            spec=pod_spec,
        )

        job_spec = client.V1JobSpec(
            template=pod_template,
            completions=n_replicas,
            parallelism=n_replicas,
            completion_mode="Indexed",
            backoff_limit=self.backoff_limit,
            ttl_seconds_after_finished=self.ttl_seconds_after_finished,
            active_deadline_seconds=timeout,
        )

        return client.V1Job(
            api_version="batch/v1",
            kind="Job",
            metadata=client.V1ObjectMeta(name=job_name, labels=labels),
            spec=job_spec,
        )

    # -----------------------------------------------------------------
    # ServerlessExecutor Protocol
    # -----------------------------------------------------------------

    def launch_replicas(
        self,
        n_replicas: int,
        entrypoint: str | Callable[..., Any],
        entrypoint_args: Mapping[str, Any],
        *,
        gpu: str | None = None,
        timeout: int = 3600,
    ) -> list[ReplicaHandle]:
        """Create ONE Indexed Job of N pods and return N rank-ordered handles.

        ``entrypoint`` is ignored when it names a Callable (k8s runs a container
        command, not an in-process callable); the container command is fixed at
        construction (``command`` ctor arg). The repo entrypoint module is the
        default. ``entrypoint_args`` scalar kwargs are passed as upper-cased env
        vars so ``replica_entrypoint`` / user code can read them. ``gpu`` maps to
        a ``nvidia.com/gpu`` limit + node selector; ``timeout`` becomes the Job's
        ``active_deadline_seconds`` hard wall-clock kill.
        """
        del entrypoint  # k8s runs a container command, not an in-process fn

        if n_replicas < 1:
            raise ValueError(f"n_replicas must be >= 1, got {n_replicas}")

        job_name = f"{self.run_id}-{uuid.uuid4().hex[:8]}"
        job = self._build_job(
            job_name=job_name,
            n_replicas=n_replicas,
            gpu=gpu,
            timeout=timeout,
            entrypoint_args=entrypoint_args,
        )

        self._batch().create_namespaced_job(namespace=self.namespace, body=job)

        handles: list[ReplicaHandle] = []
        for rank in range(n_replicas):
            handles.append(
                ReplicaHandle(
                    rank=rank,
                    backend_name=self.backend_name,
                    metadata={
                        "job_name": job_name,
                        "namespace": self.namespace,
                        "rank": rank,
                    },
                )
            )
            self._handles[rank] = {
                "job_name": job_name,
                "namespace": self.namespace,
                "result": None,
            }
        return handles

    def poll(self, handle: ReplicaHandle) -> str:
        """Poll this rank's status off the shared Indexed Job.

        Reads ``read_namespaced_job_status`` once, then maps the whole-job
        status to this rank: ``rank in completed_indexes`` -> ``succeeded``;
        ``rank in failed_indexes`` -> ``failed``; ``active > 0`` -> ``running``;
        else ``pending``. A 404 (Job deleted/cancelled) -> ``cancelled``.

        Returns one of: ``pending`` | ``running`` | ``succeeded`` | ``failed`` |
        ``cancelled``.
        """
        from kubernetes.client.exceptions import ApiException

        job_name = handle.metadata["job_name"]
        namespace = handle.metadata["namespace"]
        rank = handle.metadata.get("rank", handle.rank)

        try:
            status = self._batch().read_namespaced_job_status(
                name=job_name, namespace=namespace
            ).status
        except ApiException as e:
            if getattr(e, "status", None) == 404:
                return "cancelled"
            raise

        completed = _expand_indexes(getattr(status, "completed_indexes", None))
        if rank in completed:
            return "succeeded"

        failed = _expand_indexes(getattr(status, "failed_indexes", None))
        if rank in failed:
            return "failed"

        # Whole-job terminal Failed (e.g. DeadlineExceeded / backoff) with no
        # per-index attribution -> treat this rank as failed.
        for cond in (getattr(status, "conditions", None) or []):
            if (
                getattr(cond, "type", None) == "Failed"
                and getattr(cond, "status", None) == "True"
            ):
                return "failed"

        active = getattr(status, "active", None) or 0
        if active > 0:
            return "running"
        return "pending"

    def stream_logs(self, handle: ReplicaHandle, *, n_lines: int = 200) -> str:
        """Read recent logs for this rank's pod.

        Finds the pod whose ``batch.kubernetes.io/job-completion-index``
        annotation (or label) equals the rank, then reads its log tail. Returns
        a placeholder string (rather than raising) when the pod has not started
        or the Job is gone — mirrors ``LocalProcessExecutor``.
        """
        from kubernetes.client.exceptions import ApiException

        job_name = handle.metadata["job_name"]
        namespace = handle.metadata["namespace"]
        rank = handle.metadata.get("rank", handle.rank)
        idx_key = "batch.kubernetes.io/job-completion-index"

        try:
            pods = self._core().list_namespaced_pod(
                namespace=namespace, label_selector=f"job-name={job_name}"
            )
        except ApiException:
            return f"<rank {rank}: job not found / no pods yet>"

        pod_name = None
        for pod in getattr(pods, "items", None) or []:
            meta = getattr(pod, "metadata", None)
            annotations = getattr(meta, "annotations", None) or {}
            labels = getattr(meta, "labels", None) or {}
            if annotations.get(idx_key) == str(rank) or labels.get(idx_key) == str(rank):
                pod_name = getattr(meta, "name", None)
                break

        if pod_name is None:
            # Fall back to the deterministic name prefix on k8s >= 1.28.
            prefix = f"{job_name}-{rank}-"
            for pod in getattr(pods, "items", None) or []:
                name = getattr(getattr(pod, "metadata", None), "name", "") or ""
                if name.startswith(prefix):
                    pod_name = name
                    break

        if pod_name is None:
            return f"<rank {rank}: pod not started / no logs yet>"

        try:
            return self._core().read_namespaced_pod_log(
                name=pod_name,
                namespace=namespace,
                container="replica",
                tail_lines=n_lines,
            )
        except ApiException as e:
            if getattr(e, "status", None) in (400, 404):
                return f"<rank {rank}: pod not started / no logs yet>"
            raise

    def cancel(self, handle: ReplicaHandle) -> None:
        """Delete the WHOLE shared Indexed Job (gang teardown).

        Because ``EKSExecutor`` uses one shared Indexed Job, cancelling ANY rank
        tears down the entire replica cohort — intentional gang semantics for
        the DiLoCo all-reduce barrier (a single straggler being cancelled should
        not leave the rest spinning and burning GPU).

        Uses ``propagation_policy='Background'`` so the pods are cascadingly
        deleted (the k8s default ORPHANS pods, which would keep burning GPU —
        the exact failure mode for RL). Idempotent: a 404 (already deleted) is
        swallowed, and an unknown handle never raises, honoring the Protocol's
        "no exception if already terminated" contract.
        """
        from kubernetes import client
        from kubernetes.client.exceptions import ApiException

        job_name = handle.metadata.get("job_name")
        namespace = handle.metadata.get("namespace", self.namespace)
        if not job_name:
            return  # unknown handle — no-op

        try:
            self._batch().delete_namespaced_job(
                name=job_name,
                namespace=namespace,
                body=client.V1DeleteOptions(
                    propagation_policy="Background",
                    grace_period_seconds=0,
                ),
            )
        except ApiException as e:
            # R5: swallow ONLY already-terminated signals (404 Not Found, 409
            # Conflict on a job mid-deletion). A genuinely unexpected API error
            # (403 forbidden, 500, malformed request) must NOT be reported as a
            # successful cancel — re-raise so a real teardown failure (leaking
            # GPU-burning pods) is visible rather than silently swallowed.
            if getattr(e, "status", None) in (404, 409):
                return  # already deleted / mid-deletion — idempotent no-op
            raise

    def collect(
        self,
        handles: list[ReplicaHandle],
        *,
        timeout: int | None = None,
    ) -> list[dict[str, Any]]:
        """Poll until every rank reaches a terminal state or the deadline.

        Sleeps between polls (Job status is eventually consistent — do not
        hammer the API server). Returns per-rank result dicts in handles order::

            {"rank", "status", "exit_code", "error", "job_name"}

        ``exit_code`` is 0 for succeeded, 1 for failed, ``None`` for
        running/pending/cancelled — matching the Protocol's documented shape.
        """
        deadline = time.time() + (timeout if timeout is not None else 86400)
        poll_interval = float(self._collect_poll_interval())
        terminal = {"succeeded", "failed", "cancelled"}
        results_by_rank: dict[int, dict[str, Any]] = {}

        pending = list(handles)
        while pending and time.time() < deadline:
            still_pending: list[ReplicaHandle] = []
            for h in pending:
                state = self.poll(h)
                if state in terminal:
                    results_by_rank[h.rank] = self._result_dict(h, state)
                else:
                    still_pending.append(h)
            pending = still_pending
            if not pending:
                break
            remaining = deadline - time.time()
            if remaining <= 0:
                break
            time.sleep(min(poll_interval, max(0.0, remaining)))

        # Any rank still non-terminal at the deadline -> report its last state.
        for h in pending:
            state = self.poll(h)
            results_by_rank[h.rank] = self._result_dict(h, state)

        return [results_by_rank[h.rank] for h in handles]

    # -----------------------------------------------------------------
    # Internals
    # -----------------------------------------------------------------

    def _collect_poll_interval(self) -> float:
        """Seconds between collect() polls. Overridable in tests."""
        return 5.0

    @staticmethod
    def _result_dict(handle: ReplicaHandle, state: str) -> dict[str, Any]:
        exit_code = {"succeeded": 0, "failed": 1}.get(state, None)
        error = None
        if state == "failed":
            error = f"rank {handle.rank} reported failed by Job status"
        elif state == "cancelled":
            error = f"rank {handle.rank} Job no longer exists (cancelled)"
        elif state in ("running", "pending"):
            error = f"rank {handle.rank} not terminal at deadline (state={state})"
        return {
            "rank": handle.rank,
            "status": state,
            "exit_code": exit_code,
            "error": error,
            "job_name": handle.metadata.get("job_name"),
            # R6: cross-backend uniformity with Local/Modal/SageMaker collect()
            # shapes. EKS replicas write their real output to the S3 rendezvous
            # (ObjectStoreAllReduce), not back through the k8s API, so the Job
            # status carries no in-band payload — the value is the rendezvous
            # URI when known (callers read the artifact from S3), else None.
            "result": handle.metadata.get("rendezvous_uri"),
        }


__all__ = ["EKSExecutor"]