Baladithya Balamurugan
Wave 3 cleanup: close deferred-LOW review items R5/R6/R11
7d9dbbc
Raw
History Blame Contribute Delete
28.4 kB
"""EKSExecutor — production Amazon EKS / Kubernetes-backed serverless executor.
This is the v0-finished k8s sibling of `ModalSpawnExecutor`. It implements
the `ServerlessExecutor` Protocol against the Kubernetes ``BatchV1Api`` using
the **single Indexed Job** topology recommended for gang-scheduled DiLoCo
replicas.
Topology (the load-bearing design choice)
------------------------------------------
There are two ways to map N replicas onto k8s:
(A) ONE Indexed Job — ``completions=N, parallelism=N,
completionMode='Indexed'``. The control plane assigns each pod a
``JOB_COMPLETION_INDEX`` 0..N-1 which IS the rank, all pods share one
rendezvous URI, scheduling is atomic, and a single delete cancels the
whole cohort.
(B) N separate non-indexed Jobs, one per rank.
`EKSExecutor` uses **(A)** because it is the better fit for DiLoCo: rank
assignment is free, scheduling is gang-atomic, and one delete tears down the
cohort — which matches ``ObjectStoreAllReduce``'s all-or-nothing barrier. The
reconciliation with the per-replica ``ReplicaHandle`` contract: ``launch_replicas``
creates ONE Indexed Job but still returns N ``ReplicaHandle`` objects
(``handles[i].rank == i``) whose ``metadata`` stores the SHARED
``job_name``/``namespace`` plus that rank.
This is materially different from ``ModalSpawnExecutor`` where each handle is
an independent ``FunctionCall``:
* ``poll(handle)`` reads the single Job status and checks whether
``handle.rank`` is in the run-length-compressed ``completed_indexes`` /
``failed_indexes`` strings.
* ``cancel(handle)`` on ANY handle deletes the WHOLE Job (intentional gang
semantics — cancelling one rank tears down the whole replica cohort).
Rank plumbing
-------------
The repo's ``replica_entrypoint`` reads ``REPLICA_RANK``. We bridge the k8s
completion-index to that env var via the downward API rather than relying on
the auto-injected ``JOB_COMPLETION_INDEX``::
V1EnvVar(
name="REPLICA_RANK",
value_from=V1EnvVarSource(field_ref=V1ObjectFieldSelector(
field_path="metadata.annotations['batch.kubernetes.io/job-completion-index']")),
)
so the unchanged entrypoint's ``REPLICA_RANK`` read just works. ``WORLD_SIZE``
is set as a literal env var.
S3 rendezvous via IRSA / Pod Identity
-------------------------------------
``EKSExecutor`` accepts ``service_account_name`` and references it on the
PodSpec. The EKS Pod Identity / IRSA mutating webhook then injects
``AWS_ROLE_ARN`` + ``AWS_WEB_IDENTITY_TOKEN_FILE`` (and a projected token
volume) into the pod, so ``boto3``/``s3fs``/``fsspec`` pick up credentials via
the web-identity provider with ZERO code change inside the replica — the
``s3://`` rendezvous works out of the box. ``EKSExecutor`` only REFERENCES a
pre-annotated ServiceAccount; it never creates IAM/OIDC resources.
Sandboxing (advanced, optional)
-------------------------------
``runtime_class_name`` references a pre-existing cluster-scoped ``RuntimeClass``
(``runsc`` for gVisor, ``kata`` for Kata). It defaults to ``None``.
.. warning::
Combining ``gpu`` with ``runtime_class_name`` is advanced and unverified.
gVisor (runsc) needs ``nvproxy`` enabled and only supports a fixed allowlist
of NVIDIA driver families; Kata runs a microVM that caps CPU/mem and needs
GPU passthrough (PCIe/IOMMU + NVIDIA Kata Manager + CDI). Do not silently
combine the two without operator validation. ``EKSExecutor`` cannot create
the RuntimeClass — it only references one that already exists.
References
----------
- k8s Indexed Jobs: https://kubernetes.io/docs/tasks/job/indexed-parallel-processing-static/
- kubernetes-client/python job_crud example + V1JobSpec / V1JobStatus docs
- EKS IRSA: https://docs.aws.amazon.com/eks/latest/userguide/iam-roles-for-service-accounts.html
- ADR-005 (executor protocol design)
"""
from __future__ import annotations
import time
import uuid
from collections.abc import Callable, Mapping
from typing import Any
from composer_replication.diloco.serverless.executor import (
ReplicaHandle,
)
# Logical GPU spec ("A100"/"H100") -> (gpu_count_string, node_selector merge).
# The Protocol's `gpu` arg is a logical name; map it to a concrete EKS node
# class + GPU count rather than passing the opaque string straight through.
_GPU_SPEC_TABLE: dict[str, tuple[str, dict[str, str]]] = {
"A100": ("1", {"node.kubernetes.io/instance-type": "p4d.24xlarge"}),
"H100": ("1", {"node.kubernetes.io/instance-type": "p5.48xlarge"}),
"A10G": ("1", {"node.kubernetes.io/instance-type": "g5.xlarge"}),
"T4": ("1", {"node.kubernetes.io/instance-type": "g4dn.xlarge"}),
}
def _expand_indexes(spec: str | None) -> set[int]:
"""Expand a run-length-compressed completion-index string to a set.
The k8s ``V1JobStatus.completed_indexes`` / ``failed_indexes`` fields are
strings like ``"1,3-5,7"`` (comma-separated singletons and ``a-b`` ranges).
``_expand_indexes("1,3-5,7") == {1, 3, 4, 5, 7}``. Empty/None -> empty set.
"""
out: set[int] = set()
if not spec:
return out
for token in spec.split(","):
token = token.strip()
if not token:
continue
if "-" in token:
lo_s, _, hi_s = token.partition("-")
try:
lo, hi = int(lo_s), int(hi_s)
except ValueError:
continue
if hi < lo:
lo, hi = hi, lo
out.update(range(lo, hi + 1))
else:
try:
out.add(int(token))
except ValueError:
continue
return out
class EKSExecutor:
"""Run N DiLoCo replicas as a single Kubernetes Indexed Job on EKS.
Implements the `ServerlessExecutor` Protocol. ``launch_replicas`` creates
ONE Indexed Job (``completions == parallelism == n_replicas``,
``completionMode='Indexed'``) and returns N ``ReplicaHandle`` objects that
all share the same ``job_name``/``namespace`` (gang semantics).
Args:
image: container image that has ``composer_replication`` installed and
runs the replica entrypoint.
namespace: k8s namespace for the Job. Default ``"default"``.
service_account_name: ServiceAccount to attach to the PodSpec for IRSA /
EKS Pod Identity S3 access. ``EKSExecutor`` references it; it does
NOT create it or any IAM/OIDC resources.
node_selector: extra node selector merged into the GPU node selector.
tolerations: PodSpec tolerations. If GPU is requested and the caller did
not supply tolerations, the standard ``nvidia.com/gpu`` NoSchedule
toleration is added automatically.
runtime_class_name: optional pre-existing RuntimeClass (e.g. ``"gvisor"``
/ ``"kata"``). Default ``None``. See the module-level warning before
combining with ``gpu``.
command: container command. Defaults to the repo replica entrypoint
module ``["python", "-m",
"composer_replication.diloco.serverless.replica_entrypoint"]``.
cpu_request / memory_request: PodSpec resource requests.
ttl_seconds_after_finished: auto-delete the finished Job (and its pods,
cascadingly) after this many seconds. Default 3600.
backoff_limit: Job retry budget. Default 0 (fail-fast — RL gangs usually
do NOT want the k8s default of 6 retries).
gpu_resource_key: the GPU resource key. Default ``"nvidia.com/gpu"``.
run_id: optional run id baked into the generated Job name.
batch_api / core_api: dependency-injected ``BatchV1Api`` / ``CoreV1Api``
instances. When ``None`` (the default), they are built lazily on
first use via in-cluster or kube-config loading. Tests inject mocks.
Raises:
RuntimeError: if the ``kubernetes`` client is not installed AND no api
was injected (the import is needed to construct V1 model objects).
"""
backend_name = "eks"
# Pods are network-isolated by default; rendezvous is S3 (ObjectStoreAllReduce).
supports_inter_replica_network = False
def __init__(
self,
image: str,
*,
namespace: str = "default",
service_account_name: str | None = None,
node_selector: dict[str, str] | None = None,
tolerations: list[Any] | None = None,
runtime_class_name: str | None = None,
command: list[str] | None = None,
cpu_request: str = "4",
memory_request: str = "16Gi",
ttl_seconds_after_finished: int = 3600,
backoff_limit: int = 0,
gpu_resource_key: str = "nvidia.com/gpu",
run_id: str | None = None,
batch_api: Any = None,
core_api: Any = None,
) -> None:
# `kubernetes` is only strictly required when we have to BUILD V1 model
# objects ourselves (launch_replicas) or load cluster config (when no
# api is injected). We surface a clear error here only if we definitely
# need it and it is absent — i.e. when no api was injected. When apis
# ARE injected (tests, or callers that pre-built clients), we tolerate a
# missing top-level package and lazy-import `client` per call.
if batch_api is None or core_api is None:
try:
import kubernetes # noqa: F401
except ImportError as e:
raise RuntimeError(
'EKSExecutor requires the kubernetes client: '
'pip install "kubernetes>=29" (or '
"`pip install -e .[serverless]`). Got: " + repr(e)
) from e
self.image = image
self.namespace = namespace
self.service_account_name = service_account_name
self.node_selector = dict(node_selector) if node_selector else None
self.tolerations = list(tolerations) if tolerations else None
self.runtime_class_name = runtime_class_name
self.command = command or [
"python",
"-m",
"composer_replication.diloco.serverless.replica_entrypoint",
]
self.cpu_request = cpu_request
self.memory_request = memory_request
self.ttl_seconds_after_finished = ttl_seconds_after_finished
self.backoff_limit = backoff_limit
self.gpu_resource_key = gpu_resource_key
self.run_id = run_id or "diloco"
self._batch_api = batch_api
self._core_api = core_api
# rank -> {"job_name", "namespace", "result"}; lets poll/collect cache.
self._handles: dict[int, dict[str, Any]] = {}
# -----------------------------------------------------------------
# Lazy client construction (config loading only when not injected)
# -----------------------------------------------------------------
def _load_config(self) -> None:
"""Load k8s config once: in-cluster first, then ~/.kube/config."""
from kubernetes import config
try:
config.load_incluster_config()
except config.ConfigException:
config.load_kube_config()
def _batch(self) -> Any:
if self._batch_api is None:
from kubernetes import client
self._load_config()
self._batch_api = client.BatchV1Api()
return self._batch_api
def _core(self) -> Any:
if self._core_api is None:
from kubernetes import client
self._load_config()
self._core_api = client.CoreV1Api()
return self._core_api
# -----------------------------------------------------------------
# Job-spec construction
# -----------------------------------------------------------------
def _build_env(
self, world_size: int, entrypoint_args: Mapping[str, Any]
) -> list[Any]:
"""Build the container env list, including the downward-API rank var."""
from kubernetes import client
env: list[Any] = [
# REPLICA_RANK from the per-pod completion-index annotation via the
# downward API — bridges k8s indexing to the repo entrypoint's
# REPLICA_RANK read with no entrypoint change.
client.V1EnvVar(
name="REPLICA_RANK",
value_from=client.V1EnvVarSource(
field_ref=client.V1ObjectFieldSelector(
field_path=(
"metadata.annotations["
"'batch.kubernetes.io/job-completion-index']"
)
)
),
),
client.V1EnvVar(name="WORLD_SIZE", value=str(world_size)),
]
# rendezvous_uri (and any other scalar kwargs) passed as literal env so
# the entrypoint / user code can read them. `rank_env` is the
# LocalProcessExecutor convention — drop it (same as ModalSpawnExecutor).
for key, value in entrypoint_args.items():
if key == "rank_env":
continue
if isinstance(value, (str, int, float, bool)):
env.append(
client.V1EnvVar(name=key.upper(), value=str(value))
)
return env
def _build_resources(self, gpu: str | None) -> tuple[Any, dict[str, str], list[Any]]:
"""Build V1ResourceRequirements + (node_selector, tolerations) for GPU.
Returns (resources, node_selector, tolerations). The GPU count is
ALWAYS a STRING ('1', not int 1) — the OpenAPI type for the limits map
is dict[str, str] and an int can serialize wrong or raise.
"""
from kubernetes import client
requests = {"cpu": self.cpu_request, "memory": self.memory_request}
limits: dict[str, str] = {}
node_selector: dict[str, str] = dict(self.node_selector or {})
tolerations: list[Any] = list(self.tolerations or [])
if gpu is not None:
gpu_count, gpu_node_selector = _GPU_SPEC_TABLE.get(
gpu, ("1", {})
)
# STRING, always.
limits[self.gpu_resource_key] = str(gpu_count)
# Merge the mapped node selector under any caller-supplied one
# (caller wins on key conflicts).
for k, v in gpu_node_selector.items():
node_selector.setdefault(k, v)
# Auto-add the GPU NoSchedule toleration unless the caller overrode
# tolerations explicitly.
if not self.tolerations:
tolerations.append(
client.V1Toleration(
key=self.gpu_resource_key,
operator="Exists",
effect="NoSchedule",
)
)
resources = client.V1ResourceRequirements(
requests=requests,
limits=limits or None,
)
return resources, node_selector, tolerations
def _build_job(
self,
*,
job_name: str,
n_replicas: int,
gpu: str | None,
timeout: int,
entrypoint_args: Mapping[str, Any],
) -> Any:
"""Assemble the full V1Job (Indexed) bottom-up."""
from kubernetes import client
env = self._build_env(n_replicas, entrypoint_args)
resources, node_selector, tolerations = self._build_resources(gpu)
container = client.V1Container(
name="replica",
image=self.image,
command=list(self.command),
env=env,
resources=resources,
)
pod_spec = client.V1PodSpec(
restart_policy="Never", # required for Indexed jobs / fail-fast RL
containers=[container],
service_account_name=self.service_account_name,
node_selector=node_selector or None,
tolerations=tolerations or None,
runtime_class_name=self.runtime_class_name,
)
labels = {"app": "composer-diloco", "job-name": job_name}
pod_template = client.V1PodTemplateSpec(
metadata=client.V1ObjectMeta(labels=labels),
spec=pod_spec,
)
job_spec = client.V1JobSpec(
template=pod_template,
completions=n_replicas,
parallelism=n_replicas,
completion_mode="Indexed",
backoff_limit=self.backoff_limit,
ttl_seconds_after_finished=self.ttl_seconds_after_finished,
active_deadline_seconds=timeout,
)
return client.V1Job(
api_version="batch/v1",
kind="Job",
metadata=client.V1ObjectMeta(name=job_name, labels=labels),
spec=job_spec,
)
# -----------------------------------------------------------------
# ServerlessExecutor Protocol
# -----------------------------------------------------------------
def launch_replicas(
self,
n_replicas: int,
entrypoint: str | Callable[..., Any],
entrypoint_args: Mapping[str, Any],
*,
gpu: str | None = None,
timeout: int = 3600,
) -> list[ReplicaHandle]:
"""Create ONE Indexed Job of N pods and return N rank-ordered handles.
``entrypoint`` is ignored when it names a Callable (k8s runs a container
command, not an in-process callable); the container command is fixed at
construction (``command`` ctor arg). The repo entrypoint module is the
default. ``entrypoint_args`` scalar kwargs are passed as upper-cased env
vars so ``replica_entrypoint`` / user code can read them. ``gpu`` maps to
a ``nvidia.com/gpu`` limit + node selector; ``timeout`` becomes the Job's
``active_deadline_seconds`` hard wall-clock kill.
"""
del entrypoint # k8s runs a container command, not an in-process fn
if n_replicas < 1:
raise ValueError(f"n_replicas must be >= 1, got {n_replicas}")
job_name = f"{self.run_id}-{uuid.uuid4().hex[:8]}"
job = self._build_job(
job_name=job_name,
n_replicas=n_replicas,
gpu=gpu,
timeout=timeout,
entrypoint_args=entrypoint_args,
)
self._batch().create_namespaced_job(namespace=self.namespace, body=job)
handles: list[ReplicaHandle] = []
for rank in range(n_replicas):
handles.append(
ReplicaHandle(
rank=rank,
backend_name=self.backend_name,
metadata={
"job_name": job_name,
"namespace": self.namespace,
"rank": rank,
},
)
)
self._handles[rank] = {
"job_name": job_name,
"namespace": self.namespace,
"result": None,
}
return handles
def poll(self, handle: ReplicaHandle) -> str:
"""Poll this rank's status off the shared Indexed Job.
Reads ``read_namespaced_job_status`` once, then maps the whole-job
status to this rank: ``rank in completed_indexes`` -> ``succeeded``;
``rank in failed_indexes`` -> ``failed``; ``active > 0`` -> ``running``;
else ``pending``. A 404 (Job deleted/cancelled) -> ``cancelled``.
Returns one of: ``pending`` | ``running`` | ``succeeded`` | ``failed`` |
``cancelled``.
"""
from kubernetes.client.exceptions import ApiException
job_name = handle.metadata["job_name"]
namespace = handle.metadata["namespace"]
rank = handle.metadata.get("rank", handle.rank)
try:
status = self._batch().read_namespaced_job_status(
name=job_name, namespace=namespace
).status
except ApiException as e:
if getattr(e, "status", None) == 404:
return "cancelled"
raise
completed = _expand_indexes(getattr(status, "completed_indexes", None))
if rank in completed:
return "succeeded"
failed = _expand_indexes(getattr(status, "failed_indexes", None))
if rank in failed:
return "failed"
# Whole-job terminal Failed (e.g. DeadlineExceeded / backoff) with no
# per-index attribution -> treat this rank as failed.
for cond in (getattr(status, "conditions", None) or []):
if (
getattr(cond, "type", None) == "Failed"
and getattr(cond, "status", None) == "True"
):
return "failed"
active = getattr(status, "active", None) or 0
if active > 0:
return "running"
return "pending"
def stream_logs(self, handle: ReplicaHandle, *, n_lines: int = 200) -> str:
"""Read recent logs for this rank's pod.
Finds the pod whose ``batch.kubernetes.io/job-completion-index``
annotation (or label) equals the rank, then reads its log tail. Returns
a placeholder string (rather than raising) when the pod has not started
or the Job is gone — mirrors ``LocalProcessExecutor``.
"""
from kubernetes.client.exceptions import ApiException
job_name = handle.metadata["job_name"]
namespace = handle.metadata["namespace"]
rank = handle.metadata.get("rank", handle.rank)
idx_key = "batch.kubernetes.io/job-completion-index"
try:
pods = self._core().list_namespaced_pod(
namespace=namespace, label_selector=f"job-name={job_name}"
)
except ApiException:
return f"<rank {rank}: job not found / no pods yet>"
pod_name = None
for pod in getattr(pods, "items", None) or []:
meta = getattr(pod, "metadata", None)
annotations = getattr(meta, "annotations", None) or {}
labels = getattr(meta, "labels", None) or {}
if annotations.get(idx_key) == str(rank) or labels.get(idx_key) == str(rank):
pod_name = getattr(meta, "name", None)
break
if pod_name is None:
# Fall back to the deterministic name prefix on k8s >= 1.28.
prefix = f"{job_name}-{rank}-"
for pod in getattr(pods, "items", None) or []:
name = getattr(getattr(pod, "metadata", None), "name", "") or ""
if name.startswith(prefix):
pod_name = name
break
if pod_name is None:
return f"<rank {rank}: pod not started / no logs yet>"
try:
return self._core().read_namespaced_pod_log(
name=pod_name,
namespace=namespace,
container="replica",
tail_lines=n_lines,
)
except ApiException as e:
if getattr(e, "status", None) in (400, 404):
return f"<rank {rank}: pod not started / no logs yet>"
raise
def cancel(self, handle: ReplicaHandle) -> None:
"""Delete the WHOLE shared Indexed Job (gang teardown).
Because ``EKSExecutor`` uses one shared Indexed Job, cancelling ANY rank
tears down the entire replica cohort — intentional gang semantics for
the DiLoCo all-reduce barrier (a single straggler being cancelled should
not leave the rest spinning and burning GPU).
Uses ``propagation_policy='Background'`` so the pods are cascadingly
deleted (the k8s default ORPHANS pods, which would keep burning GPU —
the exact failure mode for RL). Idempotent: a 404 (already deleted) is
swallowed, and an unknown handle never raises, honoring the Protocol's
"no exception if already terminated" contract.
"""
from kubernetes import client
from kubernetes.client.exceptions import ApiException
job_name = handle.metadata.get("job_name")
namespace = handle.metadata.get("namespace", self.namespace)
if not job_name:
return # unknown handle — no-op
try:
self._batch().delete_namespaced_job(
name=job_name,
namespace=namespace,
body=client.V1DeleteOptions(
propagation_policy="Background",
grace_period_seconds=0,
),
)
except ApiException as e:
# R5: swallow ONLY already-terminated signals (404 Not Found, 409
# Conflict on a job mid-deletion). A genuinely unexpected API error
# (403 forbidden, 500, malformed request) must NOT be reported as a
# successful cancel — re-raise so a real teardown failure (leaking
# GPU-burning pods) is visible rather than silently swallowed.
if getattr(e, "status", None) in (404, 409):
return # already deleted / mid-deletion — idempotent no-op
raise
def collect(
self,
handles: list[ReplicaHandle],
*,
timeout: int | None = None,
) -> list[dict[str, Any]]:
"""Poll until every rank reaches a terminal state or the deadline.
Sleeps between polls (Job status is eventually consistent — do not
hammer the API server). Returns per-rank result dicts in handles order::
{"rank", "status", "exit_code", "error", "job_name"}
``exit_code`` is 0 for succeeded, 1 for failed, ``None`` for
running/pending/cancelled — matching the Protocol's documented shape.
"""
deadline = time.time() + (timeout if timeout is not None else 86400)
poll_interval = float(self._collect_poll_interval())
terminal = {"succeeded", "failed", "cancelled"}
results_by_rank: dict[int, dict[str, Any]] = {}
pending = list(handles)
while pending and time.time() < deadline:
still_pending: list[ReplicaHandle] = []
for h in pending:
state = self.poll(h)
if state in terminal:
results_by_rank[h.rank] = self._result_dict(h, state)
else:
still_pending.append(h)
pending = still_pending
if not pending:
break
remaining = deadline - time.time()
if remaining <= 0:
break
time.sleep(min(poll_interval, max(0.0, remaining)))
# Any rank still non-terminal at the deadline -> report its last state.
for h in pending:
state = self.poll(h)
results_by_rank[h.rank] = self._result_dict(h, state)
return [results_by_rank[h.rank] for h in handles]
# -----------------------------------------------------------------
# Internals
# -----------------------------------------------------------------
def _collect_poll_interval(self) -> float:
"""Seconds between collect() polls. Overridable in tests."""
return 5.0
@staticmethod
def _result_dict(handle: ReplicaHandle, state: str) -> dict[str, Any]:
exit_code = {"succeeded": 0, "failed": 1}.get(state, None)
error = None
if state == "failed":
error = f"rank {handle.rank} reported failed by Job status"
elif state == "cancelled":
error = f"rank {handle.rank} Job no longer exists (cancelled)"
elif state in ("running", "pending"):
error = f"rank {handle.rank} not terminal at deadline (state={state})"
return {
"rank": handle.rank,
"status": state,
"exit_code": exit_code,
"error": error,
"job_name": handle.metadata.get("job_name"),
# R6: cross-backend uniformity with Local/Modal/SageMaker collect()
# shapes. EKS replicas write their real output to the S3 rendezvous
# (ObjectStoreAllReduce), not back through the k8s API, so the Job
# status carries no in-band payload — the value is the rendezvous
# URI when known (callers read the artifact from S3), else None.
"result": handle.metadata.get("rendezvous_uri"),
}
__all__ = ["EKSExecutor"]