Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
composer-replication-framework / composer_replication /diloco /serverless /tests /test_sagemaker_executor.py
| """Tests for SageMakerExecutor (composer_replication.diloco.serverless.sagemaker). | |
| The executor is exercised with an INJECTED mock boto3 sagemaker client (the | |
| `sagemaker_client=` ctor arg), so these run on any host without boto3 or AWS | |
| credentials — mirroring the _MockFunctionCall pattern in | |
| test_modal_spawn_executor.py and the _MockBatchV1Api pattern in | |
| test_eks_executor.py. | |
| Closes the test-coverage gap left when the SageMakerExecutor was first written | |
| without a test module (caught during Wave-2 integration, 2026-06-09). | |
| """ | |
| from __future__ import annotations | |
| import importlib.util | |
| import pytest | |
| from composer_replication.diloco.serverless import SageMakerExecutor | |
| from composer_replication.diloco.serverless.executor import ReplicaHandle | |
| # --------------------------------------------------------------------- | |
| # Mock boto3 sagemaker client | |
| # --------------------------------------------------------------------- | |
| class _MockSMClient: | |
| """Records create/stop calls and serves a scripted status per job name.""" | |
| def __init__(self): | |
| self.created: list[dict] = [] | |
| self.stopped: list[str] = [] | |
| # job_name -> (TrainingJobStatus, SecondaryStatus) | |
| self._status: dict[str, tuple[str, str]] = {} | |
| self.raise_not_found_on: set[str] = set() | |
| def create_training_job(self, **request): | |
| self.created.append(request) | |
| # default a newly-created job to InProgress/Starting (== pending) | |
| self._status[request["TrainingJobName"]] = ("InProgress", "Starting") | |
| return {"TrainingJobArn": f"arn:aws:sagemaker:::training-job/{request['TrainingJobName']}"} | |
| def describe_training_job(self, TrainingJobName): # noqa: N803 (boto3 casing) | |
| if TrainingJobName in self.raise_not_found_on: | |
| raise _ResourceNotFoundError(f"job {TrainingJobName} not found") | |
| status, secondary = self._status.get(TrainingJobName, ("InProgress", "Training")) | |
| return { | |
| "TrainingJobName": TrainingJobName, | |
| "TrainingJobStatus": status, | |
| "SecondaryStatus": secondary, | |
| "TrainingJobArn": f"arn:aws:sagemaker:::training-job/{TrainingJobName}", | |
| } | |
| def stop_training_job(self, TrainingJobName): # noqa: N803 | |
| self.stopped.append(TrainingJobName) | |
| # test helper | |
| def set_status(self, job_name, status, secondary="Completed"): | |
| self._status[job_name] = (status, secondary) | |
| class _ResourceNotFoundError(Exception): | |
| """Stand-in for botocore ResourceNotFound (the executor matches on name/text).""" | |
| def __init__(self, msg): | |
| super().__init__(msg) | |
| # botocore-style response shape some impls check | |
| self.response = {"Error": {"Code": "ResourceNotFound", "Message": msg}} | |
| def _make_executor(client=None): | |
| return SageMakerExecutor( | |
| image_uri="123.dkr.ecr.us-east-1.amazonaws.com/trainer:latest", | |
| role_arn="arn:aws:iam::123:role/SMRole", | |
| output_s3_path="s3://bucket/out/", | |
| region="us-east-1", | |
| sagemaker_client=client or _MockSMClient(), | |
| ) | |
| _VALID_ARGS = { | |
| "rendezvous_uri": "s3://bucket/rendezvous/run1/", | |
| "trainer_module": "my_pkg.trainer", | |
| } | |
| # --------------------------------------------------------------------- | |
| # Construction | |
| # --------------------------------------------------------------------- | |
| def test_backend_identity(): | |
| ex = _make_executor() | |
| assert ex.backend_name == "sagemaker" | |
| assert ex.supports_inter_replica_network is False | |
| def test_missing_boto3_raises_when_no_client_injected(): | |
| """The import-guard path only fires when boto3 is genuinely absent.""" | |
| if importlib.util.find_spec("boto3") is not None: | |
| pytest.skip("boto3 importable; absent-path cannot be exercised") | |
| with pytest.raises(RuntimeError, match="boto3"): | |
| SageMakerExecutor( | |
| image_uri="x", role_arn="r", output_s3_path="s3://b/o/", | |
| ) | |
| def test_construction_with_injected_client_needs_no_boto3(): | |
| ex = _make_executor() | |
| assert ex is not None | |
| # --------------------------------------------------------------------- | |
| # launch_replicas | |
| # --------------------------------------------------------------------- | |
| def test_launch_returns_rank_ordered_handles(): | |
| client = _MockSMClient() | |
| ex = _make_executor(client) | |
| handles = ex.launch_replicas(3, entrypoint="ignored", entrypoint_args=_VALID_ARGS) | |
| assert len(handles) == 3 | |
| assert [h.rank for h in handles] == [0, 1, 2] | |
| assert all(isinstance(h, ReplicaHandle) and h.backend_name == "sagemaker" for h in handles) | |
| assert len(client.created) == 3 | |
| def test_launch_injects_rank_world_size_and_rendezvous_env(): | |
| client = _MockSMClient() | |
| ex = _make_executor(client) | |
| ex.launch_replicas(2, entrypoint="ignored", entrypoint_args=_VALID_ARGS) | |
| for rank, req in enumerate(client.created): | |
| env = req["Environment"] | |
| assert env["REPLICA_RANK"] == str(rank) | |
| assert env["WORLD_SIZE"] == "2" | |
| assert env["RENDEZVOUS_URI"] == _VALID_ARGS["rendezvous_uri"] | |
| # network isolation MUST stay False (else S3 rendezvous deadlocks) | |
| assert req["EnableNetworkIsolation"] is False | |
| assert req["OutputDataConfig"]["S3OutputPath"] == "s3://bucket/out/" | |
| assert req["ResourceConfig"]["InstanceCount"] == 1 | |
| def test_launch_validates_n_replicas(): | |
| ex = _make_executor() | |
| with pytest.raises(ValueError, match="n_replicas"): | |
| ex.launch_replicas(0, entrypoint="x", entrypoint_args=_VALID_ARGS) | |
| def test_launch_requires_rendezvous_and_trainer_module(): | |
| ex = _make_executor() | |
| with pytest.raises(ValueError, match="rendezvous_uri"): | |
| ex.launch_replicas(1, entrypoint="x", entrypoint_args={"trainer_module": "m"}) | |
| with pytest.raises(ValueError, match="trainer_module"): | |
| ex.launch_replicas(1, entrypoint="x", entrypoint_args={"rendezvous_uri": "s3://b/r/"}) | |
| def test_launch_partial_failure_stops_siblings_and_raises(): | |
| class _FailingClient(_MockSMClient): | |
| def create_training_job(self, **request): | |
| if len(self.created) >= 2: # 3rd create fails | |
| raise RuntimeError("ThrottlingException") | |
| return super().create_training_job(**request) | |
| client = _FailingClient() | |
| ex = _make_executor(client) | |
| with pytest.raises(RuntimeError, match="rank=2"): | |
| ex.launch_replicas(3, entrypoint="x", entrypoint_args=_VALID_ARGS) | |
| # the two already-launched siblings were best-effort stopped | |
| assert len(client.stopped) == 2 | |
| # --------------------------------------------------------------------- | |
| # poll status mapping | |
| # --------------------------------------------------------------------- | |
| def test_poll_status_mapping(): | |
| client = _MockSMClient() | |
| ex = _make_executor(client) | |
| handles = ex.launch_replicas(1, entrypoint="x", entrypoint_args=_VALID_ARGS) | |
| h = handles[0] | |
| job = client.created[0]["TrainingJobName"] | |
| client.set_status(job, "InProgress", "Starting") | |
| assert ex.poll(h) == "pending" | |
| client.set_status(job, "InProgress", "Training") | |
| assert ex.poll(h) == "running" | |
| client.set_status(job, "Completed") | |
| assert ex.poll(h) == "succeeded" | |
| def test_poll_failed_and_stopped(): | |
| client = _MockSMClient() | |
| ex = _make_executor(client) | |
| h = ex.launch_replicas(1, entrypoint="x", entrypoint_args=_VALID_ARGS)[0] | |
| job = client.created[0]["TrainingJobName"] | |
| client.set_status(job, "Failed") | |
| assert ex.poll(h) == "failed" | |
| client2 = _MockSMClient() | |
| ex2 = _make_executor(client2) | |
| h2 = ex2.launch_replicas(1, entrypoint="x", entrypoint_args=_VALID_ARGS)[0] | |
| job2 = client2.created[0]["TrainingJobName"] | |
| client2.set_status(job2, "Stopped") | |
| assert ex2.poll(h2) == "cancelled" | |
| def test_poll_vanished_job_is_cancelled(): | |
| client = _MockSMClient() | |
| ex = _make_executor(client) | |
| h = ex.launch_replicas(1, entrypoint="x", entrypoint_args=_VALID_ARGS)[0] | |
| client.raise_not_found_on.add(client.created[0]["TrainingJobName"]) | |
| assert ex.poll(h) == "cancelled" | |
| def test_poll_unknown_handle_is_cancelled(): | |
| ex = _make_executor() | |
| orphan = ReplicaHandle(rank=99, backend_name="sagemaker", metadata={}) | |
| assert ex.poll(orphan) == "cancelled" | |
| # --------------------------------------------------------------------- | |
| # cancel | |
| # --------------------------------------------------------------------- | |
| def test_cancel_calls_stop_training_job(): | |
| client = _MockSMClient() | |
| ex = _make_executor(client) | |
| h = ex.launch_replicas(1, entrypoint="x", entrypoint_args=_VALID_ARGS)[0] | |
| ex.cancel(h) | |
| assert client.stopped == [client.created[0]["TrainingJobName"]] | |
| def test_cancel_swallows_errors(): | |
| class _RaisingStop(_MockSMClient): | |
| def stop_training_job(self, TrainingJobName): # noqa: N803 | |
| raise _ResourceNotFoundError("already terminal") | |
| client = _RaisingStop() | |
| ex = _make_executor(client) | |
| h = ex.launch_replicas(1, entrypoint="x", entrypoint_args=_VALID_ARGS)[0] | |
| ex.cancel(h) # must not raise | |
| # unknown handle must also be a no-op | |
| ex.cancel(ReplicaHandle(rank=42, backend_name="sagemaker", metadata={})) | |
| def test_cancel_reraises_unexpected_error(): | |
| """R5: a genuinely unexpected error (not already-terminated) must propagate, | |
| not be silently swallowed as a successful cancel.""" | |
| class _BoomClient(_MockSMClient): | |
| def stop_training_job(self, TrainingJobName): # noqa: N803 | |
| raise RuntimeError("AccessDeniedException: not authorized") | |
| client = _BoomClient() | |
| ex = _make_executor(client) | |
| h = ex.launch_replicas(1, entrypoint="x", entrypoint_args=_VALID_ARGS)[0] | |
| with pytest.raises(RuntimeError, match="AccessDenied"): | |
| ex.cancel(h) | |