"""EKSExecutor — production Amazon EKS / Kubernetes-backed serverless executor. This is the v0-finished k8s sibling of `ModalSpawnExecutor`. It implements the `ServerlessExecutor` Protocol against the Kubernetes ``BatchV1Api`` using the **single Indexed Job** topology recommended for gang-scheduled DiLoCo replicas. Topology (the load-bearing design choice) ------------------------------------------ There are two ways to map N replicas onto k8s: (A) ONE Indexed Job — ``completions=N, parallelism=N, completionMode='Indexed'``. The control plane assigns each pod a ``JOB_COMPLETION_INDEX`` 0..N-1 which IS the rank, all pods share one rendezvous URI, scheduling is atomic, and a single delete cancels the whole cohort. (B) N separate non-indexed Jobs, one per rank. `EKSExecutor` uses **(A)** because it is the better fit for DiLoCo: rank assignment is free, scheduling is gang-atomic, and one delete tears down the cohort — which matches ``ObjectStoreAllReduce``'s all-or-nothing barrier. The reconciliation with the per-replica ``ReplicaHandle`` contract: ``launch_replicas`` creates ONE Indexed Job but still returns N ``ReplicaHandle`` objects (``handles[i].rank == i``) whose ``metadata`` stores the SHARED ``job_name``/``namespace`` plus that rank. This is materially different from ``ModalSpawnExecutor`` where each handle is an independent ``FunctionCall``: * ``poll(handle)`` reads the single Job status and checks whether ``handle.rank`` is in the run-length-compressed ``completed_indexes`` / ``failed_indexes`` strings. * ``cancel(handle)`` on ANY handle deletes the WHOLE Job (intentional gang semantics — cancelling one rank tears down the whole replica cohort). Rank plumbing ------------- The repo's ``replica_entrypoint`` reads ``REPLICA_RANK``. We bridge the k8s completion-index to that env var via the downward API rather than relying on the auto-injected ``JOB_COMPLETION_INDEX``:: V1EnvVar( name="REPLICA_RANK", value_from=V1EnvVarSource(field_ref=V1ObjectFieldSelector( field_path="metadata.annotations['batch.kubernetes.io/job-completion-index']")), ) so the unchanged entrypoint's ``REPLICA_RANK`` read just works. ``WORLD_SIZE`` is set as a literal env var. S3 rendezvous via IRSA / Pod Identity ------------------------------------- ``EKSExecutor`` accepts ``service_account_name`` and references it on the PodSpec. The EKS Pod Identity / IRSA mutating webhook then injects ``AWS_ROLE_ARN`` + ``AWS_WEB_IDENTITY_TOKEN_FILE`` (and a projected token volume) into the pod, so ``boto3``/``s3fs``/``fsspec`` pick up credentials via the web-identity provider with ZERO code change inside the replica — the ``s3://`` rendezvous works out of the box. ``EKSExecutor`` only REFERENCES a pre-annotated ServiceAccount; it never creates IAM/OIDC resources. Sandboxing (advanced, optional) ------------------------------- ``runtime_class_name`` references a pre-existing cluster-scoped ``RuntimeClass`` (``runsc`` for gVisor, ``kata`` for Kata). It defaults to ``None``. .. warning:: Combining ``gpu`` with ``runtime_class_name`` is advanced and unverified. gVisor (runsc) needs ``nvproxy`` enabled and only supports a fixed allowlist of NVIDIA driver families; Kata runs a microVM that caps CPU/mem and needs GPU passthrough (PCIe/IOMMU + NVIDIA Kata Manager + CDI). Do not silently combine the two without operator validation. ``EKSExecutor`` cannot create the RuntimeClass — it only references one that already exists. References ---------- - k8s Indexed Jobs: https://kubernetes.io/docs/tasks/job/indexed-parallel-processing-static/ - kubernetes-client/python job_crud example + V1JobSpec / V1JobStatus docs - EKS IRSA: https://docs.aws.amazon.com/eks/latest/userguide/iam-roles-for-service-accounts.html - ADR-005 (executor protocol design) """ from __future__ import annotations import time import uuid from collections.abc import Callable, Mapping from typing import Any from composer_replication.diloco.serverless.executor import ( ReplicaHandle, ) # Logical GPU spec ("A100"/"H100") -> (gpu_count_string, node_selector merge). # The Protocol's `gpu` arg is a logical name; map it to a concrete EKS node # class + GPU count rather than passing the opaque string straight through. _GPU_SPEC_TABLE: dict[str, tuple[str, dict[str, str]]] = { "A100": ("1", {"node.kubernetes.io/instance-type": "p4d.24xlarge"}), "H100": ("1", {"node.kubernetes.io/instance-type": "p5.48xlarge"}), "A10G": ("1", {"node.kubernetes.io/instance-type": "g5.xlarge"}), "T4": ("1", {"node.kubernetes.io/instance-type": "g4dn.xlarge"}), } def _expand_indexes(spec: str | None) -> set[int]: """Expand a run-length-compressed completion-index string to a set. The k8s ``V1JobStatus.completed_indexes`` / ``failed_indexes`` fields are strings like ``"1,3-5,7"`` (comma-separated singletons and ``a-b`` ranges). ``_expand_indexes("1,3-5,7") == {1, 3, 4, 5, 7}``. Empty/None -> empty set. """ out: set[int] = set() if not spec: return out for token in spec.split(","): token = token.strip() if not token: continue if "-" in token: lo_s, _, hi_s = token.partition("-") try: lo, hi = int(lo_s), int(hi_s) except ValueError: continue if hi < lo: lo, hi = hi, lo out.update(range(lo, hi + 1)) else: try: out.add(int(token)) except ValueError: continue return out class EKSExecutor: """Run N DiLoCo replicas as a single Kubernetes Indexed Job on EKS. Implements the `ServerlessExecutor` Protocol. ``launch_replicas`` creates ONE Indexed Job (``completions == parallelism == n_replicas``, ``completionMode='Indexed'``) and returns N ``ReplicaHandle`` objects that all share the same ``job_name``/``namespace`` (gang semantics). Args: image: container image that has ``composer_replication`` installed and runs the replica entrypoint. namespace: k8s namespace for the Job. Default ``"default"``. service_account_name: ServiceAccount to attach to the PodSpec for IRSA / EKS Pod Identity S3 access. ``EKSExecutor`` references it; it does NOT create it or any IAM/OIDC resources. node_selector: extra node selector merged into the GPU node selector. tolerations: PodSpec tolerations. If GPU is requested and the caller did not supply tolerations, the standard ``nvidia.com/gpu`` NoSchedule toleration is added automatically. runtime_class_name: optional pre-existing RuntimeClass (e.g. ``"gvisor"`` / ``"kata"``). Default ``None``. See the module-level warning before combining with ``gpu``. command: container command. Defaults to the repo replica entrypoint module ``["python", "-m", "composer_replication.diloco.serverless.replica_entrypoint"]``. cpu_request / memory_request: PodSpec resource requests. ttl_seconds_after_finished: auto-delete the finished Job (and its pods, cascadingly) after this many seconds. Default 3600. backoff_limit: Job retry budget. Default 0 (fail-fast — RL gangs usually do NOT want the k8s default of 6 retries). gpu_resource_key: the GPU resource key. Default ``"nvidia.com/gpu"``. run_id: optional run id baked into the generated Job name. batch_api / core_api: dependency-injected ``BatchV1Api`` / ``CoreV1Api`` instances. When ``None`` (the default), they are built lazily on first use via in-cluster or kube-config loading. Tests inject mocks. Raises: RuntimeError: if the ``kubernetes`` client is not installed AND no api was injected (the import is needed to construct V1 model objects). """ backend_name = "eks" # Pods are network-isolated by default; rendezvous is S3 (ObjectStoreAllReduce). supports_inter_replica_network = False def __init__( self, image: str, *, namespace: str = "default", service_account_name: str | None = None, node_selector: dict[str, str] | None = None, tolerations: list[Any] | None = None, runtime_class_name: str | None = None, command: list[str] | None = None, cpu_request: str = "4", memory_request: str = "16Gi", ttl_seconds_after_finished: int = 3600, backoff_limit: int = 0, gpu_resource_key: str = "nvidia.com/gpu", run_id: str | None = None, batch_api: Any = None, core_api: Any = None, ) -> None: # `kubernetes` is only strictly required when we have to BUILD V1 model # objects ourselves (launch_replicas) or load cluster config (when no # api is injected). We surface a clear error here only if we definitely # need it and it is absent — i.e. when no api was injected. When apis # ARE injected (tests, or callers that pre-built clients), we tolerate a # missing top-level package and lazy-import `client` per call. if batch_api is None or core_api is None: try: import kubernetes # noqa: F401 except ImportError as e: raise RuntimeError( 'EKSExecutor requires the kubernetes client: ' 'pip install "kubernetes>=29" (or ' "`pip install -e .[serverless]`). Got: " + repr(e) ) from e self.image = image self.namespace = namespace self.service_account_name = service_account_name self.node_selector = dict(node_selector) if node_selector else None self.tolerations = list(tolerations) if tolerations else None self.runtime_class_name = runtime_class_name self.command = command or [ "python", "-m", "composer_replication.diloco.serverless.replica_entrypoint", ] self.cpu_request = cpu_request self.memory_request = memory_request self.ttl_seconds_after_finished = ttl_seconds_after_finished self.backoff_limit = backoff_limit self.gpu_resource_key = gpu_resource_key self.run_id = run_id or "diloco" self._batch_api = batch_api self._core_api = core_api # rank -> {"job_name", "namespace", "result"}; lets poll/collect cache. self._handles: dict[int, dict[str, Any]] = {} # ----------------------------------------------------------------- # Lazy client construction (config loading only when not injected) # ----------------------------------------------------------------- def _load_config(self) -> None: """Load k8s config once: in-cluster first, then ~/.kube/config.""" from kubernetes import config try: config.load_incluster_config() except config.ConfigException: config.load_kube_config() def _batch(self) -> Any: if self._batch_api is None: from kubernetes import client self._load_config() self._batch_api = client.BatchV1Api() return self._batch_api def _core(self) -> Any: if self._core_api is None: from kubernetes import client self._load_config() self._core_api = client.CoreV1Api() return self._core_api # ----------------------------------------------------------------- # Job-spec construction # ----------------------------------------------------------------- def _build_env( self, world_size: int, entrypoint_args: Mapping[str, Any] ) -> list[Any]: """Build the container env list, including the downward-API rank var.""" from kubernetes import client env: list[Any] = [ # REPLICA_RANK from the per-pod completion-index annotation via the # downward API — bridges k8s indexing to the repo entrypoint's # REPLICA_RANK read with no entrypoint change. client.V1EnvVar( name="REPLICA_RANK", value_from=client.V1EnvVarSource( field_ref=client.V1ObjectFieldSelector( field_path=( "metadata.annotations[" "'batch.kubernetes.io/job-completion-index']" ) ) ), ), client.V1EnvVar(name="WORLD_SIZE", value=str(world_size)), ] # rendezvous_uri (and any other scalar kwargs) passed as literal env so # the entrypoint / user code can read them. `rank_env` is the # LocalProcessExecutor convention — drop it (same as ModalSpawnExecutor). for key, value in entrypoint_args.items(): if key == "rank_env": continue if isinstance(value, (str, int, float, bool)): env.append( client.V1EnvVar(name=key.upper(), value=str(value)) ) return env def _build_resources(self, gpu: str | None) -> tuple[Any, dict[str, str], list[Any]]: """Build V1ResourceRequirements + (node_selector, tolerations) for GPU. Returns (resources, node_selector, tolerations). The GPU count is ALWAYS a STRING ('1', not int 1) — the OpenAPI type for the limits map is dict[str, str] and an int can serialize wrong or raise. """ from kubernetes import client requests = {"cpu": self.cpu_request, "memory": self.memory_request} limits: dict[str, str] = {} node_selector: dict[str, str] = dict(self.node_selector or {}) tolerations: list[Any] = list(self.tolerations or []) if gpu is not None: gpu_count, gpu_node_selector = _GPU_SPEC_TABLE.get( gpu, ("1", {}) ) # STRING, always. limits[self.gpu_resource_key] = str(gpu_count) # Merge the mapped node selector under any caller-supplied one # (caller wins on key conflicts). for k, v in gpu_node_selector.items(): node_selector.setdefault(k, v) # Auto-add the GPU NoSchedule toleration unless the caller overrode # tolerations explicitly. if not self.tolerations: tolerations.append( client.V1Toleration( key=self.gpu_resource_key, operator="Exists", effect="NoSchedule", ) ) resources = client.V1ResourceRequirements( requests=requests, limits=limits or None, ) return resources, node_selector, tolerations def _build_job( self, *, job_name: str, n_replicas: int, gpu: str | None, timeout: int, entrypoint_args: Mapping[str, Any], ) -> Any: """Assemble the full V1Job (Indexed) bottom-up.""" from kubernetes import client env = self._build_env(n_replicas, entrypoint_args) resources, node_selector, tolerations = self._build_resources(gpu) container = client.V1Container( name="replica", image=self.image, command=list(self.command), env=env, resources=resources, ) pod_spec = client.V1PodSpec( restart_policy="Never", # required for Indexed jobs / fail-fast RL containers=[container], service_account_name=self.service_account_name, node_selector=node_selector or None, tolerations=tolerations or None, runtime_class_name=self.runtime_class_name, ) labels = {"app": "composer-diloco", "job-name": job_name} pod_template = client.V1PodTemplateSpec( metadata=client.V1ObjectMeta(labels=labels), spec=pod_spec, ) job_spec = client.V1JobSpec( template=pod_template, completions=n_replicas, parallelism=n_replicas, completion_mode="Indexed", backoff_limit=self.backoff_limit, ttl_seconds_after_finished=self.ttl_seconds_after_finished, active_deadline_seconds=timeout, ) return client.V1Job( api_version="batch/v1", kind="Job", metadata=client.V1ObjectMeta(name=job_name, labels=labels), spec=job_spec, ) # ----------------------------------------------------------------- # ServerlessExecutor Protocol # ----------------------------------------------------------------- def launch_replicas( self, n_replicas: int, entrypoint: str | Callable[..., Any], entrypoint_args: Mapping[str, Any], *, gpu: str | None = None, timeout: int = 3600, ) -> list[ReplicaHandle]: """Create ONE Indexed Job of N pods and return N rank-ordered handles. ``entrypoint`` is ignored when it names a Callable (k8s runs a container command, not an in-process callable); the container command is fixed at construction (``command`` ctor arg). The repo entrypoint module is the default. ``entrypoint_args`` scalar kwargs are passed as upper-cased env vars so ``replica_entrypoint`` / user code can read them. ``gpu`` maps to a ``nvidia.com/gpu`` limit + node selector; ``timeout`` becomes the Job's ``active_deadline_seconds`` hard wall-clock kill. """ del entrypoint # k8s runs a container command, not an in-process fn if n_replicas < 1: raise ValueError(f"n_replicas must be >= 1, got {n_replicas}") job_name = f"{self.run_id}-{uuid.uuid4().hex[:8]}" job = self._build_job( job_name=job_name, n_replicas=n_replicas, gpu=gpu, timeout=timeout, entrypoint_args=entrypoint_args, ) self._batch().create_namespaced_job(namespace=self.namespace, body=job) handles: list[ReplicaHandle] = [] for rank in range(n_replicas): handles.append( ReplicaHandle( rank=rank, backend_name=self.backend_name, metadata={ "job_name": job_name, "namespace": self.namespace, "rank": rank, }, ) ) self._handles[rank] = { "job_name": job_name, "namespace": self.namespace, "result": None, } return handles def poll(self, handle: ReplicaHandle) -> str: """Poll this rank's status off the shared Indexed Job. Reads ``read_namespaced_job_status`` once, then maps the whole-job status to this rank: ``rank in completed_indexes`` -> ``succeeded``; ``rank in failed_indexes`` -> ``failed``; ``active > 0`` -> ``running``; else ``pending``. A 404 (Job deleted/cancelled) -> ``cancelled``. Returns one of: ``pending`` | ``running`` | ``succeeded`` | ``failed`` | ``cancelled``. """ from kubernetes.client.exceptions import ApiException job_name = handle.metadata["job_name"] namespace = handle.metadata["namespace"] rank = handle.metadata.get("rank", handle.rank) try: status = self._batch().read_namespaced_job_status( name=job_name, namespace=namespace ).status except ApiException as e: if getattr(e, "status", None) == 404: return "cancelled" raise completed = _expand_indexes(getattr(status, "completed_indexes", None)) if rank in completed: return "succeeded" failed = _expand_indexes(getattr(status, "failed_indexes", None)) if rank in failed: return "failed" # Whole-job terminal Failed (e.g. DeadlineExceeded / backoff) with no # per-index attribution -> treat this rank as failed. for cond in (getattr(status, "conditions", None) or []): if ( getattr(cond, "type", None) == "Failed" and getattr(cond, "status", None) == "True" ): return "failed" active = getattr(status, "active", None) or 0 if active > 0: return "running" return "pending" def stream_logs(self, handle: ReplicaHandle, *, n_lines: int = 200) -> str: """Read recent logs for this rank's pod. Finds the pod whose ``batch.kubernetes.io/job-completion-index`` annotation (or label) equals the rank, then reads its log tail. Returns a placeholder string (rather than raising) when the pod has not started or the Job is gone — mirrors ``LocalProcessExecutor``. """ from kubernetes.client.exceptions import ApiException job_name = handle.metadata["job_name"] namespace = handle.metadata["namespace"] rank = handle.metadata.get("rank", handle.rank) idx_key = "batch.kubernetes.io/job-completion-index" try: pods = self._core().list_namespaced_pod( namespace=namespace, label_selector=f"job-name={job_name}" ) except ApiException: return f"" pod_name = None for pod in getattr(pods, "items", None) or []: meta = getattr(pod, "metadata", None) annotations = getattr(meta, "annotations", None) or {} labels = getattr(meta, "labels", None) or {} if annotations.get(idx_key) == str(rank) or labels.get(idx_key) == str(rank): pod_name = getattr(meta, "name", None) break if pod_name is None: # Fall back to the deterministic name prefix on k8s >= 1.28. prefix = f"{job_name}-{rank}-" for pod in getattr(pods, "items", None) or []: name = getattr(getattr(pod, "metadata", None), "name", "") or "" if name.startswith(prefix): pod_name = name break if pod_name is None: return f"" try: return self._core().read_namespaced_pod_log( name=pod_name, namespace=namespace, container="replica", tail_lines=n_lines, ) except ApiException as e: if getattr(e, "status", None) in (400, 404): return f"" raise def cancel(self, handle: ReplicaHandle) -> None: """Delete the WHOLE shared Indexed Job (gang teardown). Because ``EKSExecutor`` uses one shared Indexed Job, cancelling ANY rank tears down the entire replica cohort — intentional gang semantics for the DiLoCo all-reduce barrier (a single straggler being cancelled should not leave the rest spinning and burning GPU). Uses ``propagation_policy='Background'`` so the pods are cascadingly deleted (the k8s default ORPHANS pods, which would keep burning GPU — the exact failure mode for RL). Idempotent: a 404 (already deleted) is swallowed, and an unknown handle never raises, honoring the Protocol's "no exception if already terminated" contract. """ from kubernetes import client from kubernetes.client.exceptions import ApiException job_name = handle.metadata.get("job_name") namespace = handle.metadata.get("namespace", self.namespace) if not job_name: return # unknown handle — no-op try: self._batch().delete_namespaced_job( name=job_name, namespace=namespace, body=client.V1DeleteOptions( propagation_policy="Background", grace_period_seconds=0, ), ) except ApiException as e: # R5: swallow ONLY already-terminated signals (404 Not Found, 409 # Conflict on a job mid-deletion). A genuinely unexpected API error # (403 forbidden, 500, malformed request) must NOT be reported as a # successful cancel — re-raise so a real teardown failure (leaking # GPU-burning pods) is visible rather than silently swallowed. if getattr(e, "status", None) in (404, 409): return # already deleted / mid-deletion — idempotent no-op raise def collect( self, handles: list[ReplicaHandle], *, timeout: int | None = None, ) -> list[dict[str, Any]]: """Poll until every rank reaches a terminal state or the deadline. Sleeps between polls (Job status is eventually consistent — do not hammer the API server). Returns per-rank result dicts in handles order:: {"rank", "status", "exit_code", "error", "job_name"} ``exit_code`` is 0 for succeeded, 1 for failed, ``None`` for running/pending/cancelled — matching the Protocol's documented shape. """ deadline = time.time() + (timeout if timeout is not None else 86400) poll_interval = float(self._collect_poll_interval()) terminal = {"succeeded", "failed", "cancelled"} results_by_rank: dict[int, dict[str, Any]] = {} pending = list(handles) while pending and time.time() < deadline: still_pending: list[ReplicaHandle] = [] for h in pending: state = self.poll(h) if state in terminal: results_by_rank[h.rank] = self._result_dict(h, state) else: still_pending.append(h) pending = still_pending if not pending: break remaining = deadline - time.time() if remaining <= 0: break time.sleep(min(poll_interval, max(0.0, remaining))) # Any rank still non-terminal at the deadline -> report its last state. for h in pending: state = self.poll(h) results_by_rank[h.rank] = self._result_dict(h, state) return [results_by_rank[h.rank] for h in handles] # ----------------------------------------------------------------- # Internals # ----------------------------------------------------------------- def _collect_poll_interval(self) -> float: """Seconds between collect() polls. Overridable in tests.""" return 5.0 @staticmethod def _result_dict(handle: ReplicaHandle, state: str) -> dict[str, Any]: exit_code = {"succeeded": 0, "failed": 1}.get(state, None) error = None if state == "failed": error = f"rank {handle.rank} reported failed by Job status" elif state == "cancelled": error = f"rank {handle.rank} Job no longer exists (cancelled)" elif state in ("running", "pending"): error = f"rank {handle.rank} not terminal at deadline (state={state})" return { "rank": handle.rank, "status": state, "exit_code": exit_code, "error": error, "job_name": handle.metadata.get("job_name"), # R6: cross-backend uniformity with Local/Modal/SageMaker collect() # shapes. EKS replicas write their real output to the S3 rendezvous # (ObjectStoreAllReduce), not back through the k8s API, so the Job # status carries no in-band payload — the value is the rendezvous # URI when known (callers read the artifact from S3), else None. "result": handle.metadata.get("rendezvous_uri"), } __all__ = ["EKSExecutor"]