Baladithya Balamurugan
Wave 2: 4 new modules (kill-switch, EKS/SageMaker executors, DockerSandbox) + B4/B7 completion
7a55e1e
Raw
History Blame Contribute Delete
23.2 kB
"""Tests for EKSExecutor — the Kubernetes Indexed-Job-backed executor.
These tests exercise the executor's contract WITHOUT a live cluster and
WITHOUT the `kubernetes` client actually being installed. They:
* inject a fake `kubernetes` module into ``sys.modules`` so the executor's
lazy ``from kubernetes import client`` / ``...client.exceptions`` calls
resolve to recording stand-in V1* model classes (this is the k8s analogue
of the modal test's ``_MockFunctionCall``), and
* pass mock ``batch_api`` / ``core_api`` via dependency injection (the
constructor's ``batch_api=`` / ``core_api=`` args), so no config loading or
cluster contact happens.
For real-cluster integration testing you would gate behind cluster
availability (e.g. ``config.load_kube_config()`` succeeding), exactly like
``test_modal_spawn_executor.py`` gates on ``_is_modal_installed()``.
Run: ``.venv/bin/python -m pytest <thisfile> -q``
"""
from __future__ import annotations
import sys
import types
import pytest
from composer_replication.diloco.serverless import EKSExecutor, ReplicaHandle
from composer_replication.diloco.serverless.eks import _expand_indexes
# ---------------------------------------------------------------------
# Fake `kubernetes` module — recording V1* model stand-ins + ApiException
# ---------------------------------------------------------------------
class _Rec:
"""Generic recording model: stores all ctor kwargs as attributes.
Stands in for the kubernetes client's ``V1*`` model classes (V1Job,
V1JobSpec, V1Container, V1EnvVar, ...). Every attr the executor sets is
inspectable by tests. Mirrors how the modal mock records ``.spawn`` args.
"""
def __init__(self, **kwargs):
# Default the common optional model fields to None so attribute
# access in assertions never raises AttributeError.
for k, v in kwargs.items():
setattr(self, k, v)
def __getattr__(self, name): # only called when attr is genuinely absent
return None
class _ApiException(Exception): # noqa: N818 — mirrors kubernetes.client.exceptions.ApiException name
"""Stand-in for kubernetes.client.exceptions.ApiException."""
def __init__(self, status=None, reason=None, body=None):
super().__init__(f"ApiException(status={status})")
self.status = status
self.reason = reason
self.body = body
# The set of V1* names the executor constructs. Each maps to _Rec.
_V1_NAMES = [
"V1Job",
"V1JobSpec",
"V1ObjectMeta",
"V1PodTemplateSpec",
"V1PodSpec",
"V1Container",
"V1EnvVar",
"V1EnvVarSource",
"V1ObjectFieldSelector",
"V1ResourceRequirements",
"V1Toleration",
"V1DeleteOptions",
]
@pytest.fixture
def fake_kubernetes(monkeypatch):
"""Install a fake `kubernetes` package into sys.modules for the test.
Provides:
- kubernetes.client.<V1*> -> recording _Rec classes
- kubernetes.client.exceptions.ApiException
- kubernetes.client.BatchV1Api / CoreV1Api (unused — apis are injected)
- kubernetes.config.load_incluster_config / load_kube_config / ConfigException
"""
kubernetes = types.ModuleType("kubernetes")
client = types.ModuleType("kubernetes.client")
exceptions = types.ModuleType("kubernetes.client.exceptions")
config = types.ModuleType("kubernetes.config")
for name in _V1_NAMES:
setattr(client, name, _Rec)
# Default api classes (only hit if NOT injected — we always inject).
client.BatchV1Api = lambda *a, **k: pytest.fail("BatchV1Api should be injected")
client.CoreV1Api = lambda *a, **k: pytest.fail("CoreV1Api should be injected")
exceptions.ApiException = _ApiException
client.exceptions = exceptions
class _ConfigException(Exception): # noqa: N818 — mirrors kubernetes.config.ConfigException name
pass
config.ConfigException = _ConfigException
config.load_incluster_config = lambda *a, **k: (_ for _ in ()).throw(
_ConfigException("not in cluster")
)
config.load_kube_config = lambda *a, **k: None
kubernetes.client = client
kubernetes.config = config
monkeypatch.setitem(sys.modules, "kubernetes", kubernetes)
monkeypatch.setitem(sys.modules, "kubernetes.client", client)
monkeypatch.setitem(sys.modules, "kubernetes.client.exceptions", exceptions)
monkeypatch.setitem(sys.modules, "kubernetes.config", config)
return kubernetes
# ---------------------------------------------------------------------
# Mock BatchV1Api / CoreV1Api (the _MockBatchV1 the task asks for)
# ---------------------------------------------------------------------
class _MockBatchV1Api:
"""Records create/read-status/delete calls; returns a settable status."""
def __init__(self):
self.created_jobs: list[tuple[str, object]] = []
self.delete_calls: list[dict] = []
# status object returned by read_namespaced_job_status().status
self.status_obj = _Rec(
active=None,
succeeded=None,
failed=None,
completed_indexes=None,
failed_indexes=None,
conditions=None,
)
# Optional: raise this ApiException on read (e.g. 404 -> cancelled)
self.read_raises: Exception | None = None
def create_namespaced_job(self, namespace, body):
self.created_jobs.append((namespace, body))
return body
def read_namespaced_job_status(self, name, namespace):
if self.read_raises is not None:
raise self.read_raises
return _Rec(status=self.status_obj)
def delete_namespaced_job(self, name, namespace, body=None):
self.delete_calls.append(
{
"name": name,
"namespace": namespace,
"propagation_policy": getattr(body, "propagation_policy", None),
"grace_period_seconds": getattr(body, "grace_period_seconds", None),
}
)
return _Rec(status="Success")
class _MockCoreV1Api:
"""Canned list_namespaced_pod + read_namespaced_pod_log."""
def __init__(self, pods=None, logs="line1\nline2\n"):
self._pods = pods if pods is not None else []
self._logs = logs
self.log_calls: list[dict] = []
self.list_calls: list[dict] = []
self.log_raises: Exception | None = None
def list_namespaced_pod(self, namespace, label_selector=None):
self.list_calls.append({"namespace": namespace, "label_selector": label_selector})
return _Rec(items=list(self._pods))
def read_namespaced_pod_log(self, name, namespace, container=None, tail_lines=None):
self.log_calls.append(
{
"name": name,
"namespace": namespace,
"container": container,
"tail_lines": tail_lines,
}
)
if self.log_raises is not None:
raise self.log_raises
return self._logs
def _make_pod(name, rank):
"""Build a fake pod with the completion-index annotation set."""
return _Rec(
metadata=_Rec(
name=name,
annotations={"batch.kubernetes.io/job-completion-index": str(rank)},
labels={"job-name": name.rsplit("-", 2)[0]},
),
status=_Rec(phase="Running"),
)
def _make_executor(fake_kubernetes, *, batch=None, core=None, **kwargs):
batch = batch or _MockBatchV1Api()
core = core or _MockCoreV1Api()
ex = EKSExecutor(
image="myrepo/composer-replica:latest",
batch_api=batch,
core_api=core,
**kwargs,
)
# Speed up collect() loops in tests.
ex._collect_poll_interval = lambda: 0.0
return ex, batch, core
# ---------------------------------------------------------------------
# _expand_indexes — the run-length-range parser
# ---------------------------------------------------------------------
def test_expand_indexes_singletons_and_ranges():
assert _expand_indexes("1,3-5,7") == {1, 3, 4, 5, 7}
assert _expand_indexes("0") == {0}
assert _expand_indexes("0-3") == {0, 1, 2, 3}
assert _expand_indexes("") == set()
assert _expand_indexes(None) == set()
# Reversed range is tolerated.
assert _expand_indexes("5-3") == {3, 4, 5}
# Whitespace / junk tolerated.
assert _expand_indexes(" 2 , 4-6 ") == {2, 4, 5, 6}
# ---------------------------------------------------------------------
# Construction / preconditions
# ---------------------------------------------------------------------
def test_missing_kubernetes_raises_runtime_error_when_no_api_injected():
"""With kubernetes absent AND no injected api, ctor must raise clearly.
The import-guard path can ONLY be exercised when `kubernetes` is genuinely
not importable in this interpreter. When it IS installed (e.g. via the
`[eks]`/`[serverless]` extra in CI), the lazy import succeeds and the ctor
legitimately does not raise — so skip rather than assert a false precondition.
"""
import importlib.util
if importlib.util.find_spec("kubernetes") is not None:
pytest.skip("kubernetes is importable in this interpreter; the absent-path cannot be exercised")
with pytest.raises(RuntimeError, match="kubernetes"):
EKSExecutor(image="x")
def test_construction_with_injected_apis_does_not_need_kubernetes():
"""When both apis are injected, ctor must not require the kubernetes import."""
batch = _MockBatchV1Api()
core = _MockCoreV1Api()
ex = EKSExecutor(image="img", batch_api=batch, core_api=core)
assert ex.backend_name == "eks"
assert ex.supports_inter_replica_network is False
assert ex.image == "img"
# ---------------------------------------------------------------------
# launch_replicas — N handles, indexed-job spec correctness
# ---------------------------------------------------------------------
def test_launch_returns_n_rank_ordered_handles(fake_kubernetes):
ex, batch, _ = _make_executor(fake_kubernetes)
handles = ex.launch_replicas(
n_replicas=4,
entrypoint="ignored",
entrypoint_args={"rendezvous_uri": "s3://b/run42/", "world_size": 4},
)
assert len(handles) == 4
for i, h in enumerate(handles):
assert isinstance(h, ReplicaHandle)
assert h.rank == i
assert h.backend_name == "eks"
assert h.metadata["rank"] == i
# ALL handles share the same job_name / namespace (gang).
assert h.metadata["job_name"] == handles[0].metadata["job_name"]
assert h.metadata["namespace"] == "default"
# Exactly ONE job was created (single Indexed Job topology).
assert len(batch.created_jobs) == 1
def test_launch_creates_indexed_job_spec(fake_kubernetes):
ex, batch, _ = _make_executor(fake_kubernetes)
ex.launch_replicas(
n_replicas=3,
entrypoint="ignored",
entrypoint_args={"rendezvous_uri": "s3://b/r/", "world_size": 3},
)
ns, job = batch.created_jobs[0]
assert ns == "default"
assert job.api_version == "batch/v1"
assert job.kind == "Job"
spec = job.spec
assert spec.completions == 3
assert spec.parallelism == 3
assert spec.completion_mode == "Indexed"
assert spec.backoff_limit == 0
assert spec.ttl_seconds_after_finished == 3600
# active_deadline_seconds == timeout (default 3600 here).
assert spec.active_deadline_seconds == 3600
# restart_policy Never (required for Indexed jobs).
assert spec.template.spec.restart_policy == "Never"
def test_launch_rank_env_uses_downward_api_field_ref(fake_kubernetes):
ex, batch, _ = _make_executor(fake_kubernetes)
ex.launch_replicas(
n_replicas=2,
entrypoint="ignored",
entrypoint_args={"rendezvous_uri": "s3://b/r/", "world_size": 2},
)
_, job = batch.created_jobs[0]
env = job.spec.template.spec.containers[0].env
by_name = {e.name: e for e in env}
# REPLICA_RANK from the downward-API annotation (NOT a literal value).
rr = by_name["REPLICA_RANK"]
assert rr.value is None
field_ref = rr.value_from.field_ref
assert (
field_ref.field_path
== "metadata.annotations['batch.kubernetes.io/job-completion-index']"
)
# WORLD_SIZE is a literal string.
assert by_name["WORLD_SIZE"].value == "2"
# rendezvous_uri passed through as an upper-cased literal env var.
assert by_name["RENDEZVOUS_URI"].value == "s3://b/r/"
def test_launch_strips_rank_env_kwarg(fake_kubernetes):
"""`rank_env` is the LocalProcessExecutor convention — must not become env."""
ex, batch, _ = _make_executor(fake_kubernetes)
ex.launch_replicas(
n_replicas=1,
entrypoint="ignored",
entrypoint_args={"rank_env": "REPLICA_RANK", "rendezvous_uri": "s3://x/"},
)
_, job = batch.created_jobs[0]
env_names = {e.name for e in job.spec.template.spec.containers[0].env}
assert "RANK_ENV" not in env_names
assert "RENDEZVOUS_URI" in env_names
def test_launch_gpu_limit_is_string(fake_kubernetes):
ex, batch, _ = _make_executor(fake_kubernetes)
ex.launch_replicas(
n_replicas=2,
entrypoint="ignored",
entrypoint_args={"rendezvous_uri": "s3://x/"},
gpu="A100",
)
_, job = batch.created_jobs[0]
container = job.spec.template.spec.containers[0]
limits = container.resources.limits
assert limits["nvidia.com/gpu"] == "1"
# MUST be a string, not an int.
assert isinstance(limits["nvidia.com/gpu"], str)
# GPU node selector merged in.
node_selector = job.spec.template.spec.node_selector
assert node_selector["node.kubernetes.io/instance-type"] == "p4d.24xlarge"
# GPU NoSchedule toleration auto-added.
tols = job.spec.template.spec.tolerations
assert any(
t.key == "nvidia.com/gpu" and t.effect == "NoSchedule" for t in tols
)
def test_launch_cpu_only_omits_gpu_limit(fake_kubernetes):
ex, batch, _ = _make_executor(fake_kubernetes)
ex.launch_replicas(
n_replicas=2,
entrypoint="ignored",
entrypoint_args={"rendezvous_uri": "s3://x/"},
gpu=None,
)
_, job = batch.created_jobs[0]
limits = job.spec.template.spec.containers[0].resources.limits
# No GPU -> no nvidia.com/gpu key at all (limits is None or empty).
assert not limits or "nvidia.com/gpu" not in (limits or {})
def test_launch_passes_service_account_and_runtime_class(fake_kubernetes):
ex, batch, _ = _make_executor(
fake_kubernetes,
service_account_name="diloco-irsa-sa",
runtime_class_name="gvisor",
)
ex.launch_replicas(
n_replicas=1,
entrypoint="ignored",
entrypoint_args={"rendezvous_uri": "s3://x/"},
)
_, job = batch.created_jobs[0]
pod_spec = job.spec.template.spec
assert pod_spec.service_account_name == "diloco-irsa-sa"
assert pod_spec.runtime_class_name == "gvisor"
def test_launch_timeout_becomes_active_deadline(fake_kubernetes):
ex, batch, _ = _make_executor(fake_kubernetes)
ex.launch_replicas(
n_replicas=1,
entrypoint="ignored",
entrypoint_args={"rendezvous_uri": "s3://x/"},
timeout=7200,
)
_, job = batch.created_jobs[0]
assert job.spec.active_deadline_seconds == 7200
def test_launch_uses_default_entrypoint_command(fake_kubernetes):
ex, batch, _ = _make_executor(fake_kubernetes)
ex.launch_replicas(
n_replicas=1, entrypoint="ignored", entrypoint_args={"rendezvous_uri": "s3://x/"}
)
_, job = batch.created_jobs[0]
cmd = job.spec.template.spec.containers[0].command
assert cmd == [
"python",
"-m",
"composer_replication.diloco.serverless.replica_entrypoint",
]
def test_launch_rejects_zero_or_negative(fake_kubernetes):
ex, _, _ = _make_executor(fake_kubernetes)
with pytest.raises(ValueError, match="n_replicas"):
ex.launch_replicas(n_replicas=0, entrypoint="x", entrypoint_args={})
with pytest.raises(ValueError, match="n_replicas"):
ex.launch_replicas(n_replicas=-1, entrypoint="x", entrypoint_args={})
# ---------------------------------------------------------------------
# poll — state mapping from completed/failed indexes + active count
# ---------------------------------------------------------------------
def _launch_two(fake_kubernetes, batch=None, core=None):
ex, batch, core = _make_executor(fake_kubernetes, batch=batch, core=core)
handles = ex.launch_replicas(
n_replicas=4, entrypoint="x", entrypoint_args={"rendezvous_uri": "s3://x/"}
)
return ex, batch, core, handles
def test_poll_pending_when_nothing_active(fake_kubernetes):
ex, batch, _, handles = _launch_two(fake_kubernetes)
batch.status_obj = _Rec(active=0, completed_indexes=None, failed_indexes=None)
assert ex.poll(handles[0]) == "pending"
def test_poll_running_when_active(fake_kubernetes):
ex, batch, _, handles = _launch_two(fake_kubernetes)
batch.status_obj = _Rec(active=4, completed_indexes=None, failed_indexes=None)
assert ex.poll(handles[2]) == "running"
def test_poll_succeeded_when_rank_in_completed_indexes(fake_kubernetes):
ex, batch, _, handles = _launch_two(fake_kubernetes)
# completed_indexes "0,2-3" -> ranks {0,2,3} succeeded; rank 1 still running.
batch.status_obj = _Rec(
active=1, completed_indexes="0,2-3", failed_indexes=None
)
assert ex.poll(handles[0]) == "succeeded"
assert ex.poll(handles[2]) == "succeeded"
assert ex.poll(handles[3]) == "succeeded"
assert ex.poll(handles[1]) == "running"
def test_poll_failed_when_rank_in_failed_indexes(fake_kubernetes):
ex, batch, _, handles = _launch_two(fake_kubernetes)
batch.status_obj = _Rec(
active=0, completed_indexes="0", failed_indexes="1,3"
)
assert ex.poll(handles[1]) == "failed"
assert ex.poll(handles[3]) == "failed"
assert ex.poll(handles[0]) == "succeeded"
def test_poll_failed_on_whole_job_failed_condition(fake_kubernetes):
"""DeadlineExceeded etc.: a Failed condition with no per-index info -> failed."""
ex, batch, _, handles = _launch_two(fake_kubernetes)
batch.status_obj = _Rec(
active=0,
completed_indexes=None,
failed_indexes=None,
conditions=[_Rec(type="Failed", status="True", reason="DeadlineExceeded")],
)
assert ex.poll(handles[0]) == "failed"
def test_poll_cancelled_on_404(fake_kubernetes):
ex, batch, _, handles = _launch_two(fake_kubernetes)
batch.read_raises = _ApiException(status=404)
assert ex.poll(handles[0]) == "cancelled"
def test_poll_reraises_non_404_api_exception(fake_kubernetes):
ex, batch, _, handles = _launch_two(fake_kubernetes)
batch.read_raises = _ApiException(status=500)
with pytest.raises(_ApiException):
ex.poll(handles[0])
# ---------------------------------------------------------------------
# cancel — Background propagation on the shared job, idempotent
# ---------------------------------------------------------------------
def test_cancel_uses_background_propagation_on_shared_job(fake_kubernetes):
ex, batch, _, handles = _launch_two(fake_kubernetes)
ex.cancel(handles[2])
assert len(batch.delete_calls) == 1
call = batch.delete_calls[0]
assert call["propagation_policy"] == "Background"
assert call["grace_period_seconds"] == 0
# Cancelling ANY rank deletes the WHOLE shared job (gang semantics).
assert call["name"] == handles[0].metadata["job_name"]
assert call["namespace"] == "default"
def test_cancel_swallows_404(fake_kubernetes):
ex, batch, _, handles = _launch_two(fake_kubernetes)
def _raise_404(name, namespace, body=None):
raise _ApiException(status=404)
batch.delete_namespaced_job = _raise_404
# Must NOT raise (already deleted == success per the Protocol).
ex.cancel(handles[0])
def test_cancel_unknown_handle_is_noop(fake_kubernetes):
ex, batch, _, _ = _launch_two(fake_kubernetes)
fake = ReplicaHandle(rank=99, backend_name="eks", metadata={})
ex.cancel(fake) # no job_name in metadata -> no-op, no delete call
assert len(batch.delete_calls) == 0
# ---------------------------------------------------------------------
# stream_logs — find pod by completion-index annotation
# ---------------------------------------------------------------------
def test_stream_logs_reads_pod_for_rank(fake_kubernetes):
pods = [
_make_pod("diloco-abcd1234-0-xyz", 0),
_make_pod("diloco-abcd1234-1-xyz", 1),
]
core = _MockCoreV1Api(pods=pods, logs="hello from rank 1\n")
ex, _, core2, handles = _launch_two(fake_kubernetes, core=core)
out = ex.stream_logs(handles[1], n_lines=50)
assert out == "hello from rank 1\n"
# Read the right pod, container 'replica', tail_lines honored.
last = core.log_calls[-1]
assert last["name"] == "diloco-abcd1234-1-xyz"
assert last["container"] == "replica"
assert last["tail_lines"] == 50
def test_stream_logs_placeholder_when_pod_missing(fake_kubernetes):
core = _MockCoreV1Api(pods=[]) # no pods yet
ex, _, _, handles = _launch_two(fake_kubernetes, core=core)
out = ex.stream_logs(handles[0])
assert "rank 0" in out
assert "not started" in out or "no logs" in out
def test_stream_logs_placeholder_on_400(fake_kubernetes):
pods = [_make_pod("diloco-abcd1234-0-xyz", 0)]
core = _MockCoreV1Api(pods=pods)
core.log_raises = _ApiException(status=400) # pod not started yet
ex, _, _, handles = _launch_two(fake_kubernetes, core=core)
out = ex.stream_logs(handles[0])
assert "rank 0" in out
# ---------------------------------------------------------------------
# collect — per-rank result dicts in handles order
# ---------------------------------------------------------------------
def test_collect_returns_terminal_results_in_order(fake_kubernetes):
ex, batch, _, handles = _launch_two(fake_kubernetes)
# All four ranks done: 0-2 succeeded, 3 failed.
batch.status_obj = _Rec(
active=0, completed_indexes="0-2", failed_indexes="3"
)
results = ex.collect(handles, timeout=5)
assert len(results) == 4
for i, r in enumerate(results):
assert r["rank"] == i
assert r["job_name"] == handles[0].metadata["job_name"]
assert results[0]["status"] == "succeeded" and results[0]["exit_code"] == 0
assert results[1]["status"] == "succeeded"
assert results[2]["status"] == "succeeded"
assert results[3]["status"] == "failed" and results[3]["exit_code"] == 1
assert results[3]["error"] is not None
def test_collect_returns_non_terminal_state_at_deadline(fake_kubernetes):
ex, batch, _, handles = _launch_two(fake_kubernetes)
# Never finishes: active stays > 0.
batch.status_obj = _Rec(active=4, completed_indexes=None, failed_indexes=None)
results = ex.collect(handles, timeout=0) # immediate deadline
assert len(results) == 4
for r in results:
assert r["status"] in ("running", "pending")
assert r["exit_code"] is None