"""Tests for EKSExecutor — the Kubernetes Indexed-Job-backed executor. These tests exercise the executor's contract WITHOUT a live cluster and WITHOUT the `kubernetes` client actually being installed. They: * inject a fake `kubernetes` module into ``sys.modules`` so the executor's lazy ``from kubernetes import client`` / ``...client.exceptions`` calls resolve to recording stand-in V1* model classes (this is the k8s analogue of the modal test's ``_MockFunctionCall``), and * pass mock ``batch_api`` / ``core_api`` via dependency injection (the constructor's ``batch_api=`` / ``core_api=`` args), so no config loading or cluster contact happens. For real-cluster integration testing you would gate behind cluster availability (e.g. ``config.load_kube_config()`` succeeding), exactly like ``test_modal_spawn_executor.py`` gates on ``_is_modal_installed()``. Run: ``.venv/bin/python -m pytest -q`` """ from __future__ import annotations import sys import types import pytest from composer_replication.diloco.serverless import EKSExecutor, ReplicaHandle from composer_replication.diloco.serverless.eks import _expand_indexes # --------------------------------------------------------------------- # Fake `kubernetes` module — recording V1* model stand-ins + ApiException # --------------------------------------------------------------------- class _Rec: """Generic recording model: stores all ctor kwargs as attributes. Stands in for the kubernetes client's ``V1*`` model classes (V1Job, V1JobSpec, V1Container, V1EnvVar, ...). Every attr the executor sets is inspectable by tests. Mirrors how the modal mock records ``.spawn`` args. """ def __init__(self, **kwargs): # Default the common optional model fields to None so attribute # access in assertions never raises AttributeError. for k, v in kwargs.items(): setattr(self, k, v) def __getattr__(self, name): # only called when attr is genuinely absent return None class _ApiException(Exception): # noqa: N818 — mirrors kubernetes.client.exceptions.ApiException name """Stand-in for kubernetes.client.exceptions.ApiException.""" def __init__(self, status=None, reason=None, body=None): super().__init__(f"ApiException(status={status})") self.status = status self.reason = reason self.body = body # The set of V1* names the executor constructs. Each maps to _Rec. _V1_NAMES = [ "V1Job", "V1JobSpec", "V1ObjectMeta", "V1PodTemplateSpec", "V1PodSpec", "V1Container", "V1EnvVar", "V1EnvVarSource", "V1ObjectFieldSelector", "V1ResourceRequirements", "V1Toleration", "V1DeleteOptions", ] @pytest.fixture def fake_kubernetes(monkeypatch): """Install a fake `kubernetes` package into sys.modules for the test. Provides: - kubernetes.client. -> recording _Rec classes - kubernetes.client.exceptions.ApiException - kubernetes.client.BatchV1Api / CoreV1Api (unused — apis are injected) - kubernetes.config.load_incluster_config / load_kube_config / ConfigException """ kubernetes = types.ModuleType("kubernetes") client = types.ModuleType("kubernetes.client") exceptions = types.ModuleType("kubernetes.client.exceptions") config = types.ModuleType("kubernetes.config") for name in _V1_NAMES: setattr(client, name, _Rec) # Default api classes (only hit if NOT injected — we always inject). client.BatchV1Api = lambda *a, **k: pytest.fail("BatchV1Api should be injected") client.CoreV1Api = lambda *a, **k: pytest.fail("CoreV1Api should be injected") exceptions.ApiException = _ApiException client.exceptions = exceptions class _ConfigException(Exception): # noqa: N818 — mirrors kubernetes.config.ConfigException name pass config.ConfigException = _ConfigException config.load_incluster_config = lambda *a, **k: (_ for _ in ()).throw( _ConfigException("not in cluster") ) config.load_kube_config = lambda *a, **k: None kubernetes.client = client kubernetes.config = config monkeypatch.setitem(sys.modules, "kubernetes", kubernetes) monkeypatch.setitem(sys.modules, "kubernetes.client", client) monkeypatch.setitem(sys.modules, "kubernetes.client.exceptions", exceptions) monkeypatch.setitem(sys.modules, "kubernetes.config", config) return kubernetes # --------------------------------------------------------------------- # Mock BatchV1Api / CoreV1Api (the _MockBatchV1 the task asks for) # --------------------------------------------------------------------- class _MockBatchV1Api: """Records create/read-status/delete calls; returns a settable status.""" def __init__(self): self.created_jobs: list[tuple[str, object]] = [] self.delete_calls: list[dict] = [] # status object returned by read_namespaced_job_status().status self.status_obj = _Rec( active=None, succeeded=None, failed=None, completed_indexes=None, failed_indexes=None, conditions=None, ) # Optional: raise this ApiException on read (e.g. 404 -> cancelled) self.read_raises: Exception | None = None def create_namespaced_job(self, namespace, body): self.created_jobs.append((namespace, body)) return body def read_namespaced_job_status(self, name, namespace): if self.read_raises is not None: raise self.read_raises return _Rec(status=self.status_obj) def delete_namespaced_job(self, name, namespace, body=None): self.delete_calls.append( { "name": name, "namespace": namespace, "propagation_policy": getattr(body, "propagation_policy", None), "grace_period_seconds": getattr(body, "grace_period_seconds", None), } ) return _Rec(status="Success") class _MockCoreV1Api: """Canned list_namespaced_pod + read_namespaced_pod_log.""" def __init__(self, pods=None, logs="line1\nline2\n"): self._pods = pods if pods is not None else [] self._logs = logs self.log_calls: list[dict] = [] self.list_calls: list[dict] = [] self.log_raises: Exception | None = None def list_namespaced_pod(self, namespace, label_selector=None): self.list_calls.append({"namespace": namespace, "label_selector": label_selector}) return _Rec(items=list(self._pods)) def read_namespaced_pod_log(self, name, namespace, container=None, tail_lines=None): self.log_calls.append( { "name": name, "namespace": namespace, "container": container, "tail_lines": tail_lines, } ) if self.log_raises is not None: raise self.log_raises return self._logs def _make_pod(name, rank): """Build a fake pod with the completion-index annotation set.""" return _Rec( metadata=_Rec( name=name, annotations={"batch.kubernetes.io/job-completion-index": str(rank)}, labels={"job-name": name.rsplit("-", 2)[0]}, ), status=_Rec(phase="Running"), ) def _make_executor(fake_kubernetes, *, batch=None, core=None, **kwargs): batch = batch or _MockBatchV1Api() core = core or _MockCoreV1Api() ex = EKSExecutor( image="myrepo/composer-replica:latest", batch_api=batch, core_api=core, **kwargs, ) # Speed up collect() loops in tests. ex._collect_poll_interval = lambda: 0.0 return ex, batch, core # --------------------------------------------------------------------- # _expand_indexes — the run-length-range parser # --------------------------------------------------------------------- def test_expand_indexes_singletons_and_ranges(): assert _expand_indexes("1,3-5,7") == {1, 3, 4, 5, 7} assert _expand_indexes("0") == {0} assert _expand_indexes("0-3") == {0, 1, 2, 3} assert _expand_indexes("") == set() assert _expand_indexes(None) == set() # Reversed range is tolerated. assert _expand_indexes("5-3") == {3, 4, 5} # Whitespace / junk tolerated. assert _expand_indexes(" 2 , 4-6 ") == {2, 4, 5, 6} # --------------------------------------------------------------------- # Construction / preconditions # --------------------------------------------------------------------- def test_missing_kubernetes_raises_runtime_error_when_no_api_injected(): """With kubernetes absent AND no injected api, ctor must raise clearly. The import-guard path can ONLY be exercised when `kubernetes` is genuinely not importable in this interpreter. When it IS installed (e.g. via the `[eks]`/`[serverless]` extra in CI), the lazy import succeeds and the ctor legitimately does not raise — so skip rather than assert a false precondition. """ import importlib.util if importlib.util.find_spec("kubernetes") is not None: pytest.skip("kubernetes is importable in this interpreter; the absent-path cannot be exercised") with pytest.raises(RuntimeError, match="kubernetes"): EKSExecutor(image="x") def test_construction_with_injected_apis_does_not_need_kubernetes(): """When both apis are injected, ctor must not require the kubernetes import.""" batch = _MockBatchV1Api() core = _MockCoreV1Api() ex = EKSExecutor(image="img", batch_api=batch, core_api=core) assert ex.backend_name == "eks" assert ex.supports_inter_replica_network is False assert ex.image == "img" # --------------------------------------------------------------------- # launch_replicas — N handles, indexed-job spec correctness # --------------------------------------------------------------------- def test_launch_returns_n_rank_ordered_handles(fake_kubernetes): ex, batch, _ = _make_executor(fake_kubernetes) handles = ex.launch_replicas( n_replicas=4, entrypoint="ignored", entrypoint_args={"rendezvous_uri": "s3://b/run42/", "world_size": 4}, ) assert len(handles) == 4 for i, h in enumerate(handles): assert isinstance(h, ReplicaHandle) assert h.rank == i assert h.backend_name == "eks" assert h.metadata["rank"] == i # ALL handles share the same job_name / namespace (gang). assert h.metadata["job_name"] == handles[0].metadata["job_name"] assert h.metadata["namespace"] == "default" # Exactly ONE job was created (single Indexed Job topology). assert len(batch.created_jobs) == 1 def test_launch_creates_indexed_job_spec(fake_kubernetes): ex, batch, _ = _make_executor(fake_kubernetes) ex.launch_replicas( n_replicas=3, entrypoint="ignored", entrypoint_args={"rendezvous_uri": "s3://b/r/", "world_size": 3}, ) ns, job = batch.created_jobs[0] assert ns == "default" assert job.api_version == "batch/v1" assert job.kind == "Job" spec = job.spec assert spec.completions == 3 assert spec.parallelism == 3 assert spec.completion_mode == "Indexed" assert spec.backoff_limit == 0 assert spec.ttl_seconds_after_finished == 3600 # active_deadline_seconds == timeout (default 3600 here). assert spec.active_deadline_seconds == 3600 # restart_policy Never (required for Indexed jobs). assert spec.template.spec.restart_policy == "Never" def test_launch_rank_env_uses_downward_api_field_ref(fake_kubernetes): ex, batch, _ = _make_executor(fake_kubernetes) ex.launch_replicas( n_replicas=2, entrypoint="ignored", entrypoint_args={"rendezvous_uri": "s3://b/r/", "world_size": 2}, ) _, job = batch.created_jobs[0] env = job.spec.template.spec.containers[0].env by_name = {e.name: e for e in env} # REPLICA_RANK from the downward-API annotation (NOT a literal value). rr = by_name["REPLICA_RANK"] assert rr.value is None field_ref = rr.value_from.field_ref assert ( field_ref.field_path == "metadata.annotations['batch.kubernetes.io/job-completion-index']" ) # WORLD_SIZE is a literal string. assert by_name["WORLD_SIZE"].value == "2" # rendezvous_uri passed through as an upper-cased literal env var. assert by_name["RENDEZVOUS_URI"].value == "s3://b/r/" def test_launch_strips_rank_env_kwarg(fake_kubernetes): """`rank_env` is the LocalProcessExecutor convention — must not become env.""" ex, batch, _ = _make_executor(fake_kubernetes) ex.launch_replicas( n_replicas=1, entrypoint="ignored", entrypoint_args={"rank_env": "REPLICA_RANK", "rendezvous_uri": "s3://x/"}, ) _, job = batch.created_jobs[0] env_names = {e.name for e in job.spec.template.spec.containers[0].env} assert "RANK_ENV" not in env_names assert "RENDEZVOUS_URI" in env_names def test_launch_gpu_limit_is_string(fake_kubernetes): ex, batch, _ = _make_executor(fake_kubernetes) ex.launch_replicas( n_replicas=2, entrypoint="ignored", entrypoint_args={"rendezvous_uri": "s3://x/"}, gpu="A100", ) _, job = batch.created_jobs[0] container = job.spec.template.spec.containers[0] limits = container.resources.limits assert limits["nvidia.com/gpu"] == "1" # MUST be a string, not an int. assert isinstance(limits["nvidia.com/gpu"], str) # GPU node selector merged in. node_selector = job.spec.template.spec.node_selector assert node_selector["node.kubernetes.io/instance-type"] == "p4d.24xlarge" # GPU NoSchedule toleration auto-added. tols = job.spec.template.spec.tolerations assert any( t.key == "nvidia.com/gpu" and t.effect == "NoSchedule" for t in tols ) def test_launch_cpu_only_omits_gpu_limit(fake_kubernetes): ex, batch, _ = _make_executor(fake_kubernetes) ex.launch_replicas( n_replicas=2, entrypoint="ignored", entrypoint_args={"rendezvous_uri": "s3://x/"}, gpu=None, ) _, job = batch.created_jobs[0] limits = job.spec.template.spec.containers[0].resources.limits # No GPU -> no nvidia.com/gpu key at all (limits is None or empty). assert not limits or "nvidia.com/gpu" not in (limits or {}) def test_launch_passes_service_account_and_runtime_class(fake_kubernetes): ex, batch, _ = _make_executor( fake_kubernetes, service_account_name="diloco-irsa-sa", runtime_class_name="gvisor", ) ex.launch_replicas( n_replicas=1, entrypoint="ignored", entrypoint_args={"rendezvous_uri": "s3://x/"}, ) _, job = batch.created_jobs[0] pod_spec = job.spec.template.spec assert pod_spec.service_account_name == "diloco-irsa-sa" assert pod_spec.runtime_class_name == "gvisor" def test_launch_timeout_becomes_active_deadline(fake_kubernetes): ex, batch, _ = _make_executor(fake_kubernetes) ex.launch_replicas( n_replicas=1, entrypoint="ignored", entrypoint_args={"rendezvous_uri": "s3://x/"}, timeout=7200, ) _, job = batch.created_jobs[0] assert job.spec.active_deadline_seconds == 7200 def test_launch_uses_default_entrypoint_command(fake_kubernetes): ex, batch, _ = _make_executor(fake_kubernetes) ex.launch_replicas( n_replicas=1, entrypoint="ignored", entrypoint_args={"rendezvous_uri": "s3://x/"} ) _, job = batch.created_jobs[0] cmd = job.spec.template.spec.containers[0].command assert cmd == [ "python", "-m", "composer_replication.diloco.serverless.replica_entrypoint", ] def test_launch_rejects_zero_or_negative(fake_kubernetes): ex, _, _ = _make_executor(fake_kubernetes) with pytest.raises(ValueError, match="n_replicas"): ex.launch_replicas(n_replicas=0, entrypoint="x", entrypoint_args={}) with pytest.raises(ValueError, match="n_replicas"): ex.launch_replicas(n_replicas=-1, entrypoint="x", entrypoint_args={}) # --------------------------------------------------------------------- # poll — state mapping from completed/failed indexes + active count # --------------------------------------------------------------------- def _launch_two(fake_kubernetes, batch=None, core=None): ex, batch, core = _make_executor(fake_kubernetes, batch=batch, core=core) handles = ex.launch_replicas( n_replicas=4, entrypoint="x", entrypoint_args={"rendezvous_uri": "s3://x/"} ) return ex, batch, core, handles def test_poll_pending_when_nothing_active(fake_kubernetes): ex, batch, _, handles = _launch_two(fake_kubernetes) batch.status_obj = _Rec(active=0, completed_indexes=None, failed_indexes=None) assert ex.poll(handles[0]) == "pending" def test_poll_running_when_active(fake_kubernetes): ex, batch, _, handles = _launch_two(fake_kubernetes) batch.status_obj = _Rec(active=4, completed_indexes=None, failed_indexes=None) assert ex.poll(handles[2]) == "running" def test_poll_succeeded_when_rank_in_completed_indexes(fake_kubernetes): ex, batch, _, handles = _launch_two(fake_kubernetes) # completed_indexes "0,2-3" -> ranks {0,2,3} succeeded; rank 1 still running. batch.status_obj = _Rec( active=1, completed_indexes="0,2-3", failed_indexes=None ) assert ex.poll(handles[0]) == "succeeded" assert ex.poll(handles[2]) == "succeeded" assert ex.poll(handles[3]) == "succeeded" assert ex.poll(handles[1]) == "running" def test_poll_failed_when_rank_in_failed_indexes(fake_kubernetes): ex, batch, _, handles = _launch_two(fake_kubernetes) batch.status_obj = _Rec( active=0, completed_indexes="0", failed_indexes="1,3" ) assert ex.poll(handles[1]) == "failed" assert ex.poll(handles[3]) == "failed" assert ex.poll(handles[0]) == "succeeded" def test_poll_failed_on_whole_job_failed_condition(fake_kubernetes): """DeadlineExceeded etc.: a Failed condition with no per-index info -> failed.""" ex, batch, _, handles = _launch_two(fake_kubernetes) batch.status_obj = _Rec( active=0, completed_indexes=None, failed_indexes=None, conditions=[_Rec(type="Failed", status="True", reason="DeadlineExceeded")], ) assert ex.poll(handles[0]) == "failed" def test_poll_cancelled_on_404(fake_kubernetes): ex, batch, _, handles = _launch_two(fake_kubernetes) batch.read_raises = _ApiException(status=404) assert ex.poll(handles[0]) == "cancelled" def test_poll_reraises_non_404_api_exception(fake_kubernetes): ex, batch, _, handles = _launch_two(fake_kubernetes) batch.read_raises = _ApiException(status=500) with pytest.raises(_ApiException): ex.poll(handles[0]) # --------------------------------------------------------------------- # cancel — Background propagation on the shared job, idempotent # --------------------------------------------------------------------- def test_cancel_uses_background_propagation_on_shared_job(fake_kubernetes): ex, batch, _, handles = _launch_two(fake_kubernetes) ex.cancel(handles[2]) assert len(batch.delete_calls) == 1 call = batch.delete_calls[0] assert call["propagation_policy"] == "Background" assert call["grace_period_seconds"] == 0 # Cancelling ANY rank deletes the WHOLE shared job (gang semantics). assert call["name"] == handles[0].metadata["job_name"] assert call["namespace"] == "default" def test_cancel_swallows_404(fake_kubernetes): ex, batch, _, handles = _launch_two(fake_kubernetes) def _raise_404(name, namespace, body=None): raise _ApiException(status=404) batch.delete_namespaced_job = _raise_404 # Must NOT raise (already deleted == success per the Protocol). ex.cancel(handles[0]) def test_cancel_unknown_handle_is_noop(fake_kubernetes): ex, batch, _, _ = _launch_two(fake_kubernetes) fake = ReplicaHandle(rank=99, backend_name="eks", metadata={}) ex.cancel(fake) # no job_name in metadata -> no-op, no delete call assert len(batch.delete_calls) == 0 # --------------------------------------------------------------------- # stream_logs — find pod by completion-index annotation # --------------------------------------------------------------------- def test_stream_logs_reads_pod_for_rank(fake_kubernetes): pods = [ _make_pod("diloco-abcd1234-0-xyz", 0), _make_pod("diloco-abcd1234-1-xyz", 1), ] core = _MockCoreV1Api(pods=pods, logs="hello from rank 1\n") ex, _, core2, handles = _launch_two(fake_kubernetes, core=core) out = ex.stream_logs(handles[1], n_lines=50) assert out == "hello from rank 1\n" # Read the right pod, container 'replica', tail_lines honored. last = core.log_calls[-1] assert last["name"] == "diloco-abcd1234-1-xyz" assert last["container"] == "replica" assert last["tail_lines"] == 50 def test_stream_logs_placeholder_when_pod_missing(fake_kubernetes): core = _MockCoreV1Api(pods=[]) # no pods yet ex, _, _, handles = _launch_two(fake_kubernetes, core=core) out = ex.stream_logs(handles[0]) assert "rank 0" in out assert "not started" in out or "no logs" in out def test_stream_logs_placeholder_on_400(fake_kubernetes): pods = [_make_pod("diloco-abcd1234-0-xyz", 0)] core = _MockCoreV1Api(pods=pods) core.log_raises = _ApiException(status=400) # pod not started yet ex, _, _, handles = _launch_two(fake_kubernetes, core=core) out = ex.stream_logs(handles[0]) assert "rank 0" in out # --------------------------------------------------------------------- # collect — per-rank result dicts in handles order # --------------------------------------------------------------------- def test_collect_returns_terminal_results_in_order(fake_kubernetes): ex, batch, _, handles = _launch_two(fake_kubernetes) # All four ranks done: 0-2 succeeded, 3 failed. batch.status_obj = _Rec( active=0, completed_indexes="0-2", failed_indexes="3" ) results = ex.collect(handles, timeout=5) assert len(results) == 4 for i, r in enumerate(results): assert r["rank"] == i assert r["job_name"] == handles[0].metadata["job_name"] assert results[0]["status"] == "succeeded" and results[0]["exit_code"] == 0 assert results[1]["status"] == "succeeded" assert results[2]["status"] == "succeeded" assert results[3]["status"] == "failed" and results[3]["exit_code"] == 1 assert results[3]["error"] is not None def test_collect_returns_non_terminal_state_at_deadline(fake_kubernetes): ex, batch, _, handles = _launch_two(fake_kubernetes) # Never finishes: active stays > 0. batch.status_obj = _Rec(active=4, completed_indexes=None, failed_indexes=None) results = ex.collect(handles, timeout=0) # immediate deadline assert len(results) == 4 for r in results: assert r["status"] in ("running", "pending") assert r["exit_code"] is None