"""Tests for SageMakerExecutor (composer_replication.diloco.serverless.sagemaker). The executor is exercised with an INJECTED mock boto3 sagemaker client (the `sagemaker_client=` ctor arg), so these run on any host without boto3 or AWS credentials — mirroring the _MockFunctionCall pattern in test_modal_spawn_executor.py and the _MockBatchV1Api pattern in test_eks_executor.py. Closes the test-coverage gap left when the SageMakerExecutor was first written without a test module (caught during Wave-2 integration, 2026-06-09). """ from __future__ import annotations import importlib.util import pytest from composer_replication.diloco.serverless import SageMakerExecutor from composer_replication.diloco.serverless.executor import ReplicaHandle # --------------------------------------------------------------------- # Mock boto3 sagemaker client # --------------------------------------------------------------------- class _MockSMClient: """Records create/stop calls and serves a scripted status per job name.""" def __init__(self): self.created: list[dict] = [] self.stopped: list[str] = [] # job_name -> (TrainingJobStatus, SecondaryStatus) self._status: dict[str, tuple[str, str]] = {} self.raise_not_found_on: set[str] = set() def create_training_job(self, **request): self.created.append(request) # default a newly-created job to InProgress/Starting (== pending) self._status[request["TrainingJobName"]] = ("InProgress", "Starting") return {"TrainingJobArn": f"arn:aws:sagemaker:::training-job/{request['TrainingJobName']}"} def describe_training_job(self, TrainingJobName): # noqa: N803 (boto3 casing) if TrainingJobName in self.raise_not_found_on: raise _ResourceNotFoundError(f"job {TrainingJobName} not found") status, secondary = self._status.get(TrainingJobName, ("InProgress", "Training")) return { "TrainingJobName": TrainingJobName, "TrainingJobStatus": status, "SecondaryStatus": secondary, "TrainingJobArn": f"arn:aws:sagemaker:::training-job/{TrainingJobName}", } def stop_training_job(self, TrainingJobName): # noqa: N803 self.stopped.append(TrainingJobName) # test helper def set_status(self, job_name, status, secondary="Completed"): self._status[job_name] = (status, secondary) class _ResourceNotFoundError(Exception): """Stand-in for botocore ResourceNotFound (the executor matches on name/text).""" def __init__(self, msg): super().__init__(msg) # botocore-style response shape some impls check self.response = {"Error": {"Code": "ResourceNotFound", "Message": msg}} def _make_executor(client=None): return SageMakerExecutor( image_uri="123.dkr.ecr.us-east-1.amazonaws.com/trainer:latest", role_arn="arn:aws:iam::123:role/SMRole", output_s3_path="s3://bucket/out/", region="us-east-1", sagemaker_client=client or _MockSMClient(), ) _VALID_ARGS = { "rendezvous_uri": "s3://bucket/rendezvous/run1/", "trainer_module": "my_pkg.trainer", } # --------------------------------------------------------------------- # Construction # --------------------------------------------------------------------- def test_backend_identity(): ex = _make_executor() assert ex.backend_name == "sagemaker" assert ex.supports_inter_replica_network is False def test_missing_boto3_raises_when_no_client_injected(): """The import-guard path only fires when boto3 is genuinely absent.""" if importlib.util.find_spec("boto3") is not None: pytest.skip("boto3 importable; absent-path cannot be exercised") with pytest.raises(RuntimeError, match="boto3"): SageMakerExecutor( image_uri="x", role_arn="r", output_s3_path="s3://b/o/", ) def test_construction_with_injected_client_needs_no_boto3(): ex = _make_executor() assert ex is not None # --------------------------------------------------------------------- # launch_replicas # --------------------------------------------------------------------- def test_launch_returns_rank_ordered_handles(): client = _MockSMClient() ex = _make_executor(client) handles = ex.launch_replicas(3, entrypoint="ignored", entrypoint_args=_VALID_ARGS) assert len(handles) == 3 assert [h.rank for h in handles] == [0, 1, 2] assert all(isinstance(h, ReplicaHandle) and h.backend_name == "sagemaker" for h in handles) assert len(client.created) == 3 def test_launch_injects_rank_world_size_and_rendezvous_env(): client = _MockSMClient() ex = _make_executor(client) ex.launch_replicas(2, entrypoint="ignored", entrypoint_args=_VALID_ARGS) for rank, req in enumerate(client.created): env = req["Environment"] assert env["REPLICA_RANK"] == str(rank) assert env["WORLD_SIZE"] == "2" assert env["RENDEZVOUS_URI"] == _VALID_ARGS["rendezvous_uri"] # network isolation MUST stay False (else S3 rendezvous deadlocks) assert req["EnableNetworkIsolation"] is False assert req["OutputDataConfig"]["S3OutputPath"] == "s3://bucket/out/" assert req["ResourceConfig"]["InstanceCount"] == 1 def test_launch_validates_n_replicas(): ex = _make_executor() with pytest.raises(ValueError, match="n_replicas"): ex.launch_replicas(0, entrypoint="x", entrypoint_args=_VALID_ARGS) def test_launch_requires_rendezvous_and_trainer_module(): ex = _make_executor() with pytest.raises(ValueError, match="rendezvous_uri"): ex.launch_replicas(1, entrypoint="x", entrypoint_args={"trainer_module": "m"}) with pytest.raises(ValueError, match="trainer_module"): ex.launch_replicas(1, entrypoint="x", entrypoint_args={"rendezvous_uri": "s3://b/r/"}) def test_launch_partial_failure_stops_siblings_and_raises(): class _FailingClient(_MockSMClient): def create_training_job(self, **request): if len(self.created) >= 2: # 3rd create fails raise RuntimeError("ThrottlingException") return super().create_training_job(**request) client = _FailingClient() ex = _make_executor(client) with pytest.raises(RuntimeError, match="rank=2"): ex.launch_replicas(3, entrypoint="x", entrypoint_args=_VALID_ARGS) # the two already-launched siblings were best-effort stopped assert len(client.stopped) == 2 # --------------------------------------------------------------------- # poll status mapping # --------------------------------------------------------------------- def test_poll_status_mapping(): client = _MockSMClient() ex = _make_executor(client) handles = ex.launch_replicas(1, entrypoint="x", entrypoint_args=_VALID_ARGS) h = handles[0] job = client.created[0]["TrainingJobName"] client.set_status(job, "InProgress", "Starting") assert ex.poll(h) == "pending" client.set_status(job, "InProgress", "Training") assert ex.poll(h) == "running" client.set_status(job, "Completed") assert ex.poll(h) == "succeeded" def test_poll_failed_and_stopped(): client = _MockSMClient() ex = _make_executor(client) h = ex.launch_replicas(1, entrypoint="x", entrypoint_args=_VALID_ARGS)[0] job = client.created[0]["TrainingJobName"] client.set_status(job, "Failed") assert ex.poll(h) == "failed" client2 = _MockSMClient() ex2 = _make_executor(client2) h2 = ex2.launch_replicas(1, entrypoint="x", entrypoint_args=_VALID_ARGS)[0] job2 = client2.created[0]["TrainingJobName"] client2.set_status(job2, "Stopped") assert ex2.poll(h2) == "cancelled" def test_poll_vanished_job_is_cancelled(): client = _MockSMClient() ex = _make_executor(client) h = ex.launch_replicas(1, entrypoint="x", entrypoint_args=_VALID_ARGS)[0] client.raise_not_found_on.add(client.created[0]["TrainingJobName"]) assert ex.poll(h) == "cancelled" def test_poll_unknown_handle_is_cancelled(): ex = _make_executor() orphan = ReplicaHandle(rank=99, backend_name="sagemaker", metadata={}) assert ex.poll(orphan) == "cancelled" # --------------------------------------------------------------------- # cancel # --------------------------------------------------------------------- def test_cancel_calls_stop_training_job(): client = _MockSMClient() ex = _make_executor(client) h = ex.launch_replicas(1, entrypoint="x", entrypoint_args=_VALID_ARGS)[0] ex.cancel(h) assert client.stopped == [client.created[0]["TrainingJobName"]] def test_cancel_swallows_errors(): class _RaisingStop(_MockSMClient): def stop_training_job(self, TrainingJobName): # noqa: N803 raise _ResourceNotFoundError("already terminal") client = _RaisingStop() ex = _make_executor(client) h = ex.launch_replicas(1, entrypoint="x", entrypoint_args=_VALID_ARGS)[0] ex.cancel(h) # must not raise # unknown handle must also be a no-op ex.cancel(ReplicaHandle(rank=42, backend_name="sagemaker", metadata={})) def test_cancel_reraises_unexpected_error(): """R5: a genuinely unexpected error (not already-terminated) must propagate, not be silently swallowed as a successful cancel.""" class _BoomClient(_MockSMClient): def stop_training_job(self, TrainingJobName): # noqa: N803 raise RuntimeError("AccessDeniedException: not authorized") client = _BoomClient() ex = _make_executor(client) h = ex.launch_replicas(1, entrypoint="x", entrypoint_args=_VALID_ARGS)[0] with pytest.raises(RuntimeError, match="AccessDenied"): ex.cancel(h)