File size: 9,737 Bytes
7a55e1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d9dbbc
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
"""Tests for SageMakerExecutor (composer_replication.diloco.serverless.sagemaker).

The executor is exercised with an INJECTED mock boto3 sagemaker client (the
`sagemaker_client=` ctor arg), so these run on any host without boto3 or AWS
credentials — mirroring the _MockFunctionCall pattern in
test_modal_spawn_executor.py and the _MockBatchV1Api pattern in
test_eks_executor.py.

Closes the test-coverage gap left when the SageMakerExecutor was first written
without a test module (caught during Wave-2 integration, 2026-06-09).
"""
from __future__ import annotations

import importlib.util

import pytest

from composer_replication.diloco.serverless import SageMakerExecutor
from composer_replication.diloco.serverless.executor import ReplicaHandle

# ---------------------------------------------------------------------
# Mock boto3 sagemaker client
# ---------------------------------------------------------------------


class _MockSMClient:
    """Records create/stop calls and serves a scripted status per job name."""

    def __init__(self):
        self.created: list[dict] = []
        self.stopped: list[str] = []
        # job_name -> (TrainingJobStatus, SecondaryStatus)
        self._status: dict[str, tuple[str, str]] = {}
        self.raise_not_found_on: set[str] = set()

    def create_training_job(self, **request):
        self.created.append(request)
        # default a newly-created job to InProgress/Starting (== pending)
        self._status[request["TrainingJobName"]] = ("InProgress", "Starting")
        return {"TrainingJobArn": f"arn:aws:sagemaker:::training-job/{request['TrainingJobName']}"}

    def describe_training_job(self, TrainingJobName):  # noqa: N803 (boto3 casing)
        if TrainingJobName in self.raise_not_found_on:
            raise _ResourceNotFoundError(f"job {TrainingJobName} not found")
        status, secondary = self._status.get(TrainingJobName, ("InProgress", "Training"))
        return {
            "TrainingJobName": TrainingJobName,
            "TrainingJobStatus": status,
            "SecondaryStatus": secondary,
            "TrainingJobArn": f"arn:aws:sagemaker:::training-job/{TrainingJobName}",
        }

    def stop_training_job(self, TrainingJobName):  # noqa: N803
        self.stopped.append(TrainingJobName)

    # test helper
    def set_status(self, job_name, status, secondary="Completed"):
        self._status[job_name] = (status, secondary)


class _ResourceNotFoundError(Exception):
    """Stand-in for botocore ResourceNotFound (the executor matches on name/text)."""

    def __init__(self, msg):
        super().__init__(msg)
        # botocore-style response shape some impls check
        self.response = {"Error": {"Code": "ResourceNotFound", "Message": msg}}


def _make_executor(client=None):
    return SageMakerExecutor(
        image_uri="123.dkr.ecr.us-east-1.amazonaws.com/trainer:latest",
        role_arn="arn:aws:iam::123:role/SMRole",
        output_s3_path="s3://bucket/out/",
        region="us-east-1",
        sagemaker_client=client or _MockSMClient(),
    )


_VALID_ARGS = {
    "rendezvous_uri": "s3://bucket/rendezvous/run1/",
    "trainer_module": "my_pkg.trainer",
}


# ---------------------------------------------------------------------
# Construction
# ---------------------------------------------------------------------


def test_backend_identity():
    ex = _make_executor()
    assert ex.backend_name == "sagemaker"
    assert ex.supports_inter_replica_network is False


def test_missing_boto3_raises_when_no_client_injected():
    """The import-guard path only fires when boto3 is genuinely absent."""
    if importlib.util.find_spec("boto3") is not None:
        pytest.skip("boto3 importable; absent-path cannot be exercised")
    with pytest.raises(RuntimeError, match="boto3"):
        SageMakerExecutor(
            image_uri="x", role_arn="r", output_s3_path="s3://b/o/",
        )


def test_construction_with_injected_client_needs_no_boto3():
    ex = _make_executor()
    assert ex is not None


# ---------------------------------------------------------------------
# launch_replicas
# ---------------------------------------------------------------------


def test_launch_returns_rank_ordered_handles():
    client = _MockSMClient()
    ex = _make_executor(client)
    handles = ex.launch_replicas(3, entrypoint="ignored", entrypoint_args=_VALID_ARGS)
    assert len(handles) == 3
    assert [h.rank for h in handles] == [0, 1, 2]
    assert all(isinstance(h, ReplicaHandle) and h.backend_name == "sagemaker" for h in handles)
    assert len(client.created) == 3


def test_launch_injects_rank_world_size_and_rendezvous_env():
    client = _MockSMClient()
    ex = _make_executor(client)
    ex.launch_replicas(2, entrypoint="ignored", entrypoint_args=_VALID_ARGS)
    for rank, req in enumerate(client.created):
        env = req["Environment"]
        assert env["REPLICA_RANK"] == str(rank)
        assert env["WORLD_SIZE"] == "2"
        assert env["RENDEZVOUS_URI"] == _VALID_ARGS["rendezvous_uri"]
        # network isolation MUST stay False (else S3 rendezvous deadlocks)
        assert req["EnableNetworkIsolation"] is False
        assert req["OutputDataConfig"]["S3OutputPath"] == "s3://bucket/out/"
        assert req["ResourceConfig"]["InstanceCount"] == 1


def test_launch_validates_n_replicas():
    ex = _make_executor()
    with pytest.raises(ValueError, match="n_replicas"):
        ex.launch_replicas(0, entrypoint="x", entrypoint_args=_VALID_ARGS)


def test_launch_requires_rendezvous_and_trainer_module():
    ex = _make_executor()
    with pytest.raises(ValueError, match="rendezvous_uri"):
        ex.launch_replicas(1, entrypoint="x", entrypoint_args={"trainer_module": "m"})
    with pytest.raises(ValueError, match="trainer_module"):
        ex.launch_replicas(1, entrypoint="x", entrypoint_args={"rendezvous_uri": "s3://b/r/"})


def test_launch_partial_failure_stops_siblings_and_raises():
    class _FailingClient(_MockSMClient):
        def create_training_job(self, **request):
            if len(self.created) >= 2:  # 3rd create fails
                raise RuntimeError("ThrottlingException")
            return super().create_training_job(**request)

    client = _FailingClient()
    ex = _make_executor(client)
    with pytest.raises(RuntimeError, match="rank=2"):
        ex.launch_replicas(3, entrypoint="x", entrypoint_args=_VALID_ARGS)
    # the two already-launched siblings were best-effort stopped
    assert len(client.stopped) == 2


# ---------------------------------------------------------------------
# poll status mapping
# ---------------------------------------------------------------------


def test_poll_status_mapping():
    client = _MockSMClient()
    ex = _make_executor(client)
    handles = ex.launch_replicas(1, entrypoint="x", entrypoint_args=_VALID_ARGS)
    h = handles[0]
    job = client.created[0]["TrainingJobName"]

    client.set_status(job, "InProgress", "Starting")
    assert ex.poll(h) == "pending"
    client.set_status(job, "InProgress", "Training")
    assert ex.poll(h) == "running"
    client.set_status(job, "Completed")
    assert ex.poll(h) == "succeeded"


def test_poll_failed_and_stopped():
    client = _MockSMClient()
    ex = _make_executor(client)
    h = ex.launch_replicas(1, entrypoint="x", entrypoint_args=_VALID_ARGS)[0]
    job = client.created[0]["TrainingJobName"]
    client.set_status(job, "Failed")
    assert ex.poll(h) == "failed"

    client2 = _MockSMClient()
    ex2 = _make_executor(client2)
    h2 = ex2.launch_replicas(1, entrypoint="x", entrypoint_args=_VALID_ARGS)[0]
    job2 = client2.created[0]["TrainingJobName"]
    client2.set_status(job2, "Stopped")
    assert ex2.poll(h2) == "cancelled"


def test_poll_vanished_job_is_cancelled():
    client = _MockSMClient()
    ex = _make_executor(client)
    h = ex.launch_replicas(1, entrypoint="x", entrypoint_args=_VALID_ARGS)[0]
    client.raise_not_found_on.add(client.created[0]["TrainingJobName"])
    assert ex.poll(h) == "cancelled"


def test_poll_unknown_handle_is_cancelled():
    ex = _make_executor()
    orphan = ReplicaHandle(rank=99, backend_name="sagemaker", metadata={})
    assert ex.poll(orphan) == "cancelled"


# ---------------------------------------------------------------------
# cancel
# ---------------------------------------------------------------------


def test_cancel_calls_stop_training_job():
    client = _MockSMClient()
    ex = _make_executor(client)
    h = ex.launch_replicas(1, entrypoint="x", entrypoint_args=_VALID_ARGS)[0]
    ex.cancel(h)
    assert client.stopped == [client.created[0]["TrainingJobName"]]


def test_cancel_swallows_errors():
    class _RaisingStop(_MockSMClient):
        def stop_training_job(self, TrainingJobName):  # noqa: N803
            raise _ResourceNotFoundError("already terminal")

    client = _RaisingStop()
    ex = _make_executor(client)
    h = ex.launch_replicas(1, entrypoint="x", entrypoint_args=_VALID_ARGS)[0]
    ex.cancel(h)  # must not raise
    # unknown handle must also be a no-op
    ex.cancel(ReplicaHandle(rank=42, backend_name="sagemaker", metadata={}))


def test_cancel_reraises_unexpected_error():
    """R5: a genuinely unexpected error (not already-terminated) must propagate,
    not be silently swallowed as a successful cancel."""
    class _BoomClient(_MockSMClient):
        def stop_training_job(self, TrainingJobName):  # noqa: N803
            raise RuntimeError("AccessDeniedException: not authorized")
    client = _BoomClient()
    ex = _make_executor(client)
    h = ex.launch_replicas(1, entrypoint="x", entrypoint_args=_VALID_ARGS)[0]
    with pytest.raises(RuntimeError, match="AccessDenied"):
        ex.cancel(h)