Baladithya Balamurugan
Wave 3 cleanup: close deferred-LOW review items R5/R6/R11
7d9dbbc
Raw
History Blame Contribute Delete
9.74 kB
"""Tests for SageMakerExecutor (composer_replication.diloco.serverless.sagemaker).
The executor is exercised with an INJECTED mock boto3 sagemaker client (the
`sagemaker_client=` ctor arg), so these run on any host without boto3 or AWS
credentials — mirroring the _MockFunctionCall pattern in
test_modal_spawn_executor.py and the _MockBatchV1Api pattern in
test_eks_executor.py.
Closes the test-coverage gap left when the SageMakerExecutor was first written
without a test module (caught during Wave-2 integration, 2026-06-09).
"""
from __future__ import annotations
import importlib.util
import pytest
from composer_replication.diloco.serverless import SageMakerExecutor
from composer_replication.diloco.serverless.executor import ReplicaHandle
# ---------------------------------------------------------------------
# Mock boto3 sagemaker client
# ---------------------------------------------------------------------
class _MockSMClient:
"""Records create/stop calls and serves a scripted status per job name."""
def __init__(self):
self.created: list[dict] = []
self.stopped: list[str] = []
# job_name -> (TrainingJobStatus, SecondaryStatus)
self._status: dict[str, tuple[str, str]] = {}
self.raise_not_found_on: set[str] = set()
def create_training_job(self, **request):
self.created.append(request)
# default a newly-created job to InProgress/Starting (== pending)
self._status[request["TrainingJobName"]] = ("InProgress", "Starting")
return {"TrainingJobArn": f"arn:aws:sagemaker:::training-job/{request['TrainingJobName']}"}
def describe_training_job(self, TrainingJobName): # noqa: N803 (boto3 casing)
if TrainingJobName in self.raise_not_found_on:
raise _ResourceNotFoundError(f"job {TrainingJobName} not found")
status, secondary = self._status.get(TrainingJobName, ("InProgress", "Training"))
return {
"TrainingJobName": TrainingJobName,
"TrainingJobStatus": status,
"SecondaryStatus": secondary,
"TrainingJobArn": f"arn:aws:sagemaker:::training-job/{TrainingJobName}",
}
def stop_training_job(self, TrainingJobName): # noqa: N803
self.stopped.append(TrainingJobName)
# test helper
def set_status(self, job_name, status, secondary="Completed"):
self._status[job_name] = (status, secondary)
class _ResourceNotFoundError(Exception):
"""Stand-in for botocore ResourceNotFound (the executor matches on name/text)."""
def __init__(self, msg):
super().__init__(msg)
# botocore-style response shape some impls check
self.response = {"Error": {"Code": "ResourceNotFound", "Message": msg}}
def _make_executor(client=None):
return SageMakerExecutor(
image_uri="123.dkr.ecr.us-east-1.amazonaws.com/trainer:latest",
role_arn="arn:aws:iam::123:role/SMRole",
output_s3_path="s3://bucket/out/",
region="us-east-1",
sagemaker_client=client or _MockSMClient(),
)
_VALID_ARGS = {
"rendezvous_uri": "s3://bucket/rendezvous/run1/",
"trainer_module": "my_pkg.trainer",
}
# ---------------------------------------------------------------------
# Construction
# ---------------------------------------------------------------------
def test_backend_identity():
ex = _make_executor()
assert ex.backend_name == "sagemaker"
assert ex.supports_inter_replica_network is False
def test_missing_boto3_raises_when_no_client_injected():
"""The import-guard path only fires when boto3 is genuinely absent."""
if importlib.util.find_spec("boto3") is not None:
pytest.skip("boto3 importable; absent-path cannot be exercised")
with pytest.raises(RuntimeError, match="boto3"):
SageMakerExecutor(
image_uri="x", role_arn="r", output_s3_path="s3://b/o/",
)
def test_construction_with_injected_client_needs_no_boto3():
ex = _make_executor()
assert ex is not None
# ---------------------------------------------------------------------
# launch_replicas
# ---------------------------------------------------------------------
def test_launch_returns_rank_ordered_handles():
client = _MockSMClient()
ex = _make_executor(client)
handles = ex.launch_replicas(3, entrypoint="ignored", entrypoint_args=_VALID_ARGS)
assert len(handles) == 3
assert [h.rank for h in handles] == [0, 1, 2]
assert all(isinstance(h, ReplicaHandle) and h.backend_name == "sagemaker" for h in handles)
assert len(client.created) == 3
def test_launch_injects_rank_world_size_and_rendezvous_env():
client = _MockSMClient()
ex = _make_executor(client)
ex.launch_replicas(2, entrypoint="ignored", entrypoint_args=_VALID_ARGS)
for rank, req in enumerate(client.created):
env = req["Environment"]
assert env["REPLICA_RANK"] == str(rank)
assert env["WORLD_SIZE"] == "2"
assert env["RENDEZVOUS_URI"] == _VALID_ARGS["rendezvous_uri"]
# network isolation MUST stay False (else S3 rendezvous deadlocks)
assert req["EnableNetworkIsolation"] is False
assert req["OutputDataConfig"]["S3OutputPath"] == "s3://bucket/out/"
assert req["ResourceConfig"]["InstanceCount"] == 1
def test_launch_validates_n_replicas():
ex = _make_executor()
with pytest.raises(ValueError, match="n_replicas"):
ex.launch_replicas(0, entrypoint="x", entrypoint_args=_VALID_ARGS)
def test_launch_requires_rendezvous_and_trainer_module():
ex = _make_executor()
with pytest.raises(ValueError, match="rendezvous_uri"):
ex.launch_replicas(1, entrypoint="x", entrypoint_args={"trainer_module": "m"})
with pytest.raises(ValueError, match="trainer_module"):
ex.launch_replicas(1, entrypoint="x", entrypoint_args={"rendezvous_uri": "s3://b/r/"})
def test_launch_partial_failure_stops_siblings_and_raises():
class _FailingClient(_MockSMClient):
def create_training_job(self, **request):
if len(self.created) >= 2: # 3rd create fails
raise RuntimeError("ThrottlingException")
return super().create_training_job(**request)
client = _FailingClient()
ex = _make_executor(client)
with pytest.raises(RuntimeError, match="rank=2"):
ex.launch_replicas(3, entrypoint="x", entrypoint_args=_VALID_ARGS)
# the two already-launched siblings were best-effort stopped
assert len(client.stopped) == 2
# ---------------------------------------------------------------------
# poll status mapping
# ---------------------------------------------------------------------
def test_poll_status_mapping():
client = _MockSMClient()
ex = _make_executor(client)
handles = ex.launch_replicas(1, entrypoint="x", entrypoint_args=_VALID_ARGS)
h = handles[0]
job = client.created[0]["TrainingJobName"]
client.set_status(job, "InProgress", "Starting")
assert ex.poll(h) == "pending"
client.set_status(job, "InProgress", "Training")
assert ex.poll(h) == "running"
client.set_status(job, "Completed")
assert ex.poll(h) == "succeeded"
def test_poll_failed_and_stopped():
client = _MockSMClient()
ex = _make_executor(client)
h = ex.launch_replicas(1, entrypoint="x", entrypoint_args=_VALID_ARGS)[0]
job = client.created[0]["TrainingJobName"]
client.set_status(job, "Failed")
assert ex.poll(h) == "failed"
client2 = _MockSMClient()
ex2 = _make_executor(client2)
h2 = ex2.launch_replicas(1, entrypoint="x", entrypoint_args=_VALID_ARGS)[0]
job2 = client2.created[0]["TrainingJobName"]
client2.set_status(job2, "Stopped")
assert ex2.poll(h2) == "cancelled"
def test_poll_vanished_job_is_cancelled():
client = _MockSMClient()
ex = _make_executor(client)
h = ex.launch_replicas(1, entrypoint="x", entrypoint_args=_VALID_ARGS)[0]
client.raise_not_found_on.add(client.created[0]["TrainingJobName"])
assert ex.poll(h) == "cancelled"
def test_poll_unknown_handle_is_cancelled():
ex = _make_executor()
orphan = ReplicaHandle(rank=99, backend_name="sagemaker", metadata={})
assert ex.poll(orphan) == "cancelled"
# ---------------------------------------------------------------------
# cancel
# ---------------------------------------------------------------------
def test_cancel_calls_stop_training_job():
client = _MockSMClient()
ex = _make_executor(client)
h = ex.launch_replicas(1, entrypoint="x", entrypoint_args=_VALID_ARGS)[0]
ex.cancel(h)
assert client.stopped == [client.created[0]["TrainingJobName"]]
def test_cancel_swallows_errors():
class _RaisingStop(_MockSMClient):
def stop_training_job(self, TrainingJobName): # noqa: N803
raise _ResourceNotFoundError("already terminal")
client = _RaisingStop()
ex = _make_executor(client)
h = ex.launch_replicas(1, entrypoint="x", entrypoint_args=_VALID_ARGS)[0]
ex.cancel(h) # must not raise
# unknown handle must also be a no-op
ex.cancel(ReplicaHandle(rank=42, backend_name="sagemaker", metadata={}))
def test_cancel_reraises_unexpected_error():
"""R5: a genuinely unexpected error (not already-terminated) must propagate,
not be silently swallowed as a successful cancel."""
class _BoomClient(_MockSMClient):
def stop_training_job(self, TrainingJobName): # noqa: N803
raise RuntimeError("AccessDeniedException: not authorized")
client = _BoomClient()
ex = _make_executor(client)
h = ex.launch_replicas(1, entrypoint="x", entrypoint_args=_VALID_ARGS)[0]
with pytest.raises(RuntimeError, match="AccessDenied"):
ex.cancel(h)