Muon Optimizer: Implementation Guide
This document explains the internal architecture of the Muon optimizer for reviewers and new contributors. It covers the execution paths, the parallel pipeline design, and the distributed sharding utilities.
Table of Contents
- Overview
- Entry Point and Parameter Routing
- Execution Paths
- Parallel Pipeline (the core feature)
- Distributed Utilities
- Newton-Schulz Orthogonalization
- QK Clipping
- AdamW for Non-Muon Parameters
- Source File Map
Overview
Muon (MomentUm Orthogonalized by Newton-schulz) applies standard SGD-momentum and then replaces each 2D parameter's update with the nearest orthogonal matrix via a Newton-Schulz iteration. The iteration runs stably in bfloat16 on GPU.
The optimizer supports arbitrary N-D sharding configurations: FSDP2, TP, or hybrid setups like 2 TP x 2 DP-Replicate x 2 DP-Shard. This generality is what drives most of the code complexity.
Entry Point and Parameter Routing
File: muon.py β Muon.step() / Muon._step_muon()
Users must provide parameter groups with use_muon=True/False flags (via get_default_muon_param_groups()). At each step:
- Non-Muon groups β
step_adamw()(fused AdamW). - Muon groups β
_step_muon(), which further classifies each parameter:
_step_muon(group)
|
+-- DTensor, all Replicate placements --> base() (no sharding)
+-- DTensor, numel <= threshold --> distributed_muon() (small param fallback)
+-- DTensor, sharded --> parallel() (pipelined all-to-all)
+-- plain Tensor --> base() (single device)
Parameters are classified by their DTensor placements:
- Fully replicated DTensors and plain tensors use
base()β standard single-device Muon. - Small sharded DTensors (below
small_param_numel_threshold, default 65536) usedistributed_muon()β gathers the full tensor viafull_tensor(), computes the update, then redistributes. - Large sharded DTensors use
parallel()β the pipelined all-to-all approach described below.
Execution Paths
base() β Single Device
Straightforward per-parameter loop: momentum update β Newton-Schulz orthogonalization β parameter update β optional QK clipping.
distributed_muon() β Full Gather
Each parameter's gradient is gathered to full via g.full_tensor(), orthogonalized on every rank, then the updated full parameter is redistributed back to the original sharded placement. Simple but communication-heavy β used only as a fallback for small parameters.
parallel() β Pipelined All-to-All
This is the main advanced feature. Instead of all-gathering the full parameter, it uses all-to-all to distribute work: each rank "owns" a subset of parameters and is responsible for their Newton-Schulz computation.
Parallel Pipeline
Design Motivation
Newton-Schulz is compute-intensive. The key insight is that each rank only needs to orthogonalize the parameters it "owns" β not all parameters. So the flow is:
- Gather: Each rank sends its local gradient shard to the owning rank via all-to-all.
- Compute: The owning rank runs Newton-Schulz on the full (gathered) gradient.
- Scatter: The owning rank sends the orthogonalized update back to all ranks via all-to-all.
- Update: Each rank applies weight decay and the update to its local shard.
To overlap communication and computation, parameters are split into chunks, and multiple chunks are processed concurrently.
Architecture
muon.py: parallel()
|
+-- init_state_and_assign_params() -- assigns ownership, precomputes indices
|
+-- pipelines() generator -- yields muon_chunk_pipeline() per chunk
|
+-- run_pipeline(pipelines, max_concurrent=warmup_step+1)
|
+-- interleaves chunks at yield boundaries
The Chunk Pipeline Generator
File: pipeline.py β muon_chunk_pipeline()
Each chunk is a generator that yields 2 times, creating stages separated by async communication:
YIELD 1 YIELD 2
| |
[Build bufs + async gather a2a] --> [wait + NS compute + async scatter a2a] --> [wait + Update params]
- Async communication:
dist.all_to_all_single(..., async_op=True)launches non-blocking communication. The generator yields immediately after, allowing other chunks to run.work.wait()completes the operation after the yield. - Chunk-level overlap:
run_pipeline()interleaves multiple chunks at yield boundaries, so while chunk N waits for its communication, chunk N+1 can launch its own.
The Pipeline Scheduler
File: async_utils.py β run_pipeline()
A simple round-robin scheduler:
while have_new or previous_tasks:
# Admit one new pipeline if below concurrency limit
if have_new and len(previous_tasks) < max_concurrent:
task = next(pipelines) # runs to first yield
# Advance all existing tasks by one yield
for task in previous_tasks:
task.step() # runs to next yield
max_concurrent = warmup_step + 1 controls how many chunks can be in-flight simultaneously. Higher values increase memory usage but improve communication/computation overlap.
Ownership Assignment
File: muon.py β init_state_and_assign_params()
Parameters are sorted by FLOP cost (descending) and assigned to ranks in round-robin order across the shard mesh. This balances compute load across ranks.
Precomputed Shard Indices
Instead of computing per-rank shard indices on every step, they are precomputed once during init_state_and_assign_params() and stored in _muon_state:
@dataclass
class _muon_state:
worker_rank: int # which rank owns this param's computation
process_group: ProcessGroup # the all-to-all communication group
rank_indices: dict[int, tuple] # rank -> per-dim indices into full tensor
rank_numels: dict[int, int] # rank -> number of elements in shard
name: str
qk_clip_state: QKClipInfo | None
rank_indices[r] is a tuple of slice or torch.Tensor per dimension, describing which elements of the full tensor rank r owns. rank_numels[r] is the total number of elements in that shard. These are used directly in the pipeline's gather and scatter stages.
Pipeline Stages in Detail
Stages 1-2: Gather
- Allocate receive buffers for gathered gradients (only on owning ranks).
- Build send buffer: Each rank flattens its local gradient shard for each destination rank.
- Async all-to-all:
dist.all_to_all_single(..., async_op=True)launches gather. - Yield 1: Other chunks can launch their gather while this one waits.
work.wait(): Complete the gather.- Reconstruct: The owning rank places received shards into the full gradient using
rank_indices.
Stage 3: Compute
The owning rank runs _zeropower_via_newtonschulz5() on the full gathered gradient. This is the most compute-intensive stage. Runs inline (no yield) since it is synchronous GPU work.
Stages 4-5: Scatter
Inverse of gather:
- Allocate receive buffers for the orthogonalized update
U. - Build send buffer: The owning rank slices
Uusingrank_indicesfor each destination rank. - Async all-to-all:
dist.all_to_all_single(..., async_op=True)launches scatter. - Yield 2: Other chunks can launch their scatter while this one waits.
work.wait(): Complete the scatter.- Copy received shards into local update buffers.
Stage 6: Update
Each rank applies weight decay and the Muon update to its local parameter shard. Also applies QK clipping if configured.
Distributed Utilities
File: distributed/utils.py
These utilities solve the problem of mapping from a DTensor's arbitrary sharding configuration to the concrete indices each rank owns.
construct_shard_mesh(placements, mesh)
Given a DTensor's placements and device mesh, this function:
- Sorts placements: Replicate dims first, then Shard dims by dimension (with
_StridedShardafter regularShardon the same dim). - Permutes the mesh accordingly.
- Separates replicate dims from shard dims β each replicate group gets its own shard sub-mesh.
- Creates a ProcessGroup for the current rank's shard mesh.
Returns (shard_mesh, process_group, shard_placements) β used for all-to-all communication.
Why this is needed: A model might use [Replicate, Shard(0), _StridedShard(0)] across a 3D mesh. The optimizer needs to identify which ranks participate in the same shard group (share the same data) and create a ProcessGroup for them.
get_slices_of_dtensor(target, local_rank, shard_mesh, shard_placements)
Computes the exact indices that a given rank owns in the full tensor. Handles both contiguous (Shard) and strided (_StridedShard) sharding, including composed multi-level sharding on the same dimension.
Returns a tuple of slice (contiguous) or torch.LongTensor (strided) per dimension.
Example: With [Shard(0), _StridedShard(0)] on a (16, 2048) tensor across 4 ranks:
- Rank 0 might own rows
[0, 4, 8, 12](strided) - Rank 1 might own rows
[1, 5, 9, 13] - etc.
PyTorch 2.10 Compatibility
In PyTorch 2.10, _StridedShard no longer inherits from Shard. The helper _is_shard() handles both old and new hierarchies:
def _is_shard(placement):
return isinstance(placement, (Shard, _StridedShard))
Newton-Schulz Orthogonalization
File: newton_schulz.py
_zeropower_via_newtonschulz5() computes the orthogonal approximation of a matrix using 5 quintic Newton-Schulz iterations with pre-optimized coefficients. The result approximates US'V^T where S' is near-uniform on [0.5, 1.5], which empirically does not hurt model performance vs. exact UV^T.
Each iteration uses matmul_transpose_assign() (a Triton kernel for X @ X^T) for efficiency.
File: matmul_transpose_triton.py
The matmul_transpose_assign(d_in, d_out) kernel computes d_out = d_in @ d_in^T in-place. It exploits symmetry by computing only upper-triangle blocks and mirroring.
QK Clipping
File: qk_clip.py
Optional dynamic clipping for attention head projections (Q and K weight matrices). When the maximum QK logit for a head exceeds a threshold, the corresponding rows of the weight matrix are scaled down by sqrt(threshold / logit).
In the parallel pipeline: QK clipping is applied per-row using each row's global head index. This correctly handles strided sharding where local rows may be interleaved across multiple heads:
# pipeline.py: _update_params()
ratio = p.shape[0] // scales_full.shape[0] # rows per head
idx0 = state.rank_indices[rank][0] # which global rows this rank owns
row_scales = scales_full[idx0 // ratio] # map each row to its head's scale
p._local_tensor.mul_(row_scales.view(-1, 1))
AdamW for Non-Muon Parameters
File: adamw.py
Parameters not eligible for Muon (1D parameters, embeddings, LM head) are optimized with fused AdamW via torch._fused_adamw_. Parameters are grouped by device/dtype and DTensor placement before the fused call.
Source File Map
| File | Lines | Purpose |
|---|---|---|
muon.py |
~525 | Optimizer class, parameter routing, 3 execution paths |
pipeline.py |
~290 | Generator-based parallel pipeline (gather/compute/scatter/update) |
async_utils.py |
~75 | Pipeline scheduler with bounded concurrency |
core.py |
~110 | _muon_state dataclass, momentum/update helpers, param grouping |
distributed/utils.py |
~230 | Shard mesh construction, DTensor index computation |
newton_schulz.py |
~50 | Newton-Schulz iteration |
matmul_transpose_triton.py |
~120 | Triton kernel for symmetric matmul |
qk_clip.py |
~130 | QK logit clipping |
adamw.py |
~160 | Fused AdamW for non-Muon params |
Dependency Graph
matmul_transpose_triton.py (leaf)
|
newton_schulz.py (leaf + triton)
|
core.py ---- qk_clip.py (leaf, distributed/utils)
| | |
| pipeline.py --- async_utils.py
| |
| adamw.py
| |
muon.py (all above)
|
__init__.py