| # Muon Optimizer: Implementation Guide | |
| This document explains the internal architecture of the Muon optimizer for reviewers and new contributors. It covers the execution paths, the parallel pipeline design, and the distributed sharding utilities. | |
| ## Table of Contents | |
| 1. [Overview](#overview) | |
| 2. [Entry Point and Parameter Routing](#entry-point-and-parameter-routing) | |
| 3. [Execution Paths](#execution-paths) | |
| 4. [Parallel Pipeline (the core feature)](#parallel-pipeline) | |
| 5. [Distributed Utilities](#distributed-utilities) | |
| 6. [Newton-Schulz Orthogonalization](#newton-schulz-orthogonalization) | |
| 7. [QK Clipping](#qk-clipping) | |
| 8. [AdamW for Non-Muon Parameters](#adamw-for-non-muon-parameters) | |
| 9. [Source File Map](#source-file-map) | |
| --- | |
| ## Overview | |
| Muon (MomentUm Orthogonalized by Newton-schulz) applies standard SGD-momentum and then replaces each 2D parameter's update with the nearest orthogonal matrix via a Newton-Schulz iteration. The iteration runs stably in bfloat16 on GPU. | |
| The optimizer supports arbitrary N-D sharding configurations: FSDP2, TP, or hybrid setups like `2 TP x 2 DP-Replicate x 2 DP-Shard`. This generality is what drives most of the code complexity. | |
| ## Entry Point and Parameter Routing | |
| **File:** `muon.py` — `Muon.step()` / `Muon._step_muon()` | |
| Users must provide parameter groups with `use_muon=True/False` flags (via `get_default_muon_param_groups()`). At each step: | |
| 1. **Non-Muon groups** → `step_adamw()` (fused AdamW). | |
| 2. **Muon groups** → `_step_muon()`, which further classifies each parameter: | |
| ``` | |
| _step_muon(group) | |
| | | |
| +-- DTensor, all Replicate placements --> base() (no sharding) | |
| +-- DTensor, numel <= threshold --> distributed_muon() (small param fallback) | |
| +-- DTensor, sharded --> parallel() (pipelined all-to-all) | |
| +-- plain Tensor --> base() (single device) | |
| ``` | |
| Parameters are classified by their DTensor placements: | |
| - **Fully replicated** DTensors and plain tensors use `base()` — standard single-device Muon. | |
| - **Small sharded** DTensors (below `small_param_numel_threshold`, default 65536) use `distributed_muon()` — gathers the full tensor via `full_tensor()`, computes the update, then redistributes. | |
| - **Large sharded** DTensors use `parallel()` — the pipelined all-to-all approach described below. | |
| ## Execution Paths | |
| ### base() — Single Device | |
| Straightforward per-parameter loop: momentum update → Newton-Schulz orthogonalization → parameter update → optional QK clipping. | |
| ### distributed_muon() — Full Gather | |
| Each parameter's gradient is gathered to full via `g.full_tensor()`, orthogonalized on every rank, then the updated full parameter is redistributed back to the original sharded placement. Simple but communication-heavy — used only as a fallback for small parameters. | |
| ### parallel() — Pipelined All-to-All | |
| This is the main advanced feature. Instead of all-gathering the full parameter, it uses **all-to-all** to distribute work: each rank "owns" a subset of parameters and is responsible for their Newton-Schulz computation. | |
| ## Parallel Pipeline | |
| ### Design Motivation | |
| Newton-Schulz is compute-intensive. The key insight is that each rank only needs to orthogonalize the parameters it "owns" — not all parameters. So the flow is: | |
| 1. **Gather**: Each rank sends its local gradient shard to the owning rank via all-to-all. | |
| 2. **Compute**: The owning rank runs Newton-Schulz on the full (gathered) gradient. | |
| 3. **Scatter**: The owning rank sends the orthogonalized update back to all ranks via all-to-all. | |
| 4. **Update**: Each rank applies weight decay and the update to its local shard. | |
| To overlap communication and computation, parameters are split into **chunks**, and multiple chunks are processed concurrently. | |
| ### Architecture | |
| ``` | |
| muon.py: parallel() | |
| | | |
| +-- init_state_and_assign_params() -- assigns ownership, precomputes indices | |
| | | |
| +-- pipelines() generator -- yields muon_chunk_pipeline() per chunk | |
| | | |
| +-- run_pipeline(pipelines, max_concurrent=warmup_step+1) | |
| | | |
| +-- interleaves chunks at yield boundaries | |
| ``` | |
| ### The Chunk Pipeline Generator | |
| **File:** `pipeline.py` — `muon_chunk_pipeline()` | |
| Each chunk is a generator that yields **2 times**, creating stages separated by async communication: | |
| ``` | |
| YIELD 1 YIELD 2 | |
| | | | |
| [Build bufs + async gather a2a] --> [wait + NS compute + async scatter a2a] --> [wait + Update params] | |
| ``` | |
| - **Async communication**: `dist.all_to_all_single(..., async_op=True)` launches non-blocking communication. The generator yields immediately after, allowing other chunks to run. `work.wait()` completes the operation after the yield. | |
| - **Chunk-level overlap**: `run_pipeline()` interleaves multiple chunks at yield boundaries, so while chunk N waits for its communication, chunk N+1 can launch its own. | |
| ### The Pipeline Scheduler | |
| **File:** `async_utils.py` — `run_pipeline()` | |
| A simple round-robin scheduler: | |
| ```python | |
| while have_new or previous_tasks: | |
| # Admit one new pipeline if below concurrency limit | |
| if have_new and len(previous_tasks) < max_concurrent: | |
| task = next(pipelines) # runs to first yield | |
| # Advance all existing tasks by one yield | |
| for task in previous_tasks: | |
| task.step() # runs to next yield | |
| ``` | |
| `max_concurrent = warmup_step + 1` controls how many chunks can be in-flight simultaneously. Higher values increase memory usage but improve communication/computation overlap. | |
| ### Ownership Assignment | |
| **File:** `muon.py` — `init_state_and_assign_params()` | |
| Parameters are sorted by FLOP cost (descending) and assigned to ranks in round-robin order across the shard mesh. This balances compute load across ranks. | |
| ### Precomputed Shard Indices | |
| Instead of computing per-rank shard indices on every step, they are precomputed once during `init_state_and_assign_params()` and stored in `_muon_state`: | |
| ```python | |
| @dataclass | |
| class _muon_state: | |
| worker_rank: int # which rank owns this param's computation | |
| process_group: ProcessGroup # the all-to-all communication group | |
| rank_indices: dict[int, tuple] # rank -> per-dim indices into full tensor | |
| rank_numels: dict[int, int] # rank -> number of elements in shard | |
| name: str | |
| qk_clip_state: QKClipInfo | None | |
| ``` | |
| `rank_indices[r]` is a tuple of `slice` or `torch.Tensor` per dimension, describing which elements of the full tensor rank `r` owns. `rank_numels[r]` is the total number of elements in that shard. These are used directly in the pipeline's gather and scatter stages. | |
| ### Pipeline Stages in Detail | |
| #### Stages 1-2: Gather | |
| 1. **Allocate** receive buffers for gathered gradients (only on owning ranks). | |
| 2. **Build send buffer**: Each rank flattens its local gradient shard for each destination rank. | |
| 3. **Async all-to-all**: `dist.all_to_all_single(..., async_op=True)` launches gather. | |
| 4. **Yield 1**: Other chunks can launch their gather while this one waits. | |
| 5. **`work.wait()`**: Complete the gather. | |
| 6. **Reconstruct**: The owning rank places received shards into the full gradient using `rank_indices`. | |
| #### Stage 3: Compute | |
| The owning rank runs `_zeropower_via_newtonschulz5()` on the full gathered gradient. This is the most compute-intensive stage. Runs inline (no yield) since it is synchronous GPU work. | |
| #### Stages 4-5: Scatter | |
| Inverse of gather: | |
| 1. **Allocate** receive buffers for the orthogonalized update `U`. | |
| 2. **Build send buffer**: The owning rank slices `U` using `rank_indices` for each destination rank. | |
| 3. **Async all-to-all**: `dist.all_to_all_single(..., async_op=True)` launches scatter. | |
| 4. **Yield 2**: Other chunks can launch their scatter while this one waits. | |
| 5. **`work.wait()`**: Complete the scatter. | |
| 6. **Copy** received shards into local update buffers. | |
| #### Stage 6: Update | |
| Each rank applies weight decay and the Muon update to its local parameter shard. Also applies QK clipping if configured. | |
| ## Distributed Utilities | |
| **File:** `distributed/utils.py` | |
| These utilities solve the problem of mapping from a DTensor's arbitrary sharding configuration to the concrete indices each rank owns. | |
| ### `construct_shard_mesh(placements, mesh)` | |
| Given a DTensor's placements and device mesh, this function: | |
| 1. **Sorts** placements: Replicate dims first, then Shard dims by dimension (with `_StridedShard` after regular `Shard` on the same dim). | |
| 2. **Permutes** the mesh accordingly. | |
| 3. **Separates** replicate dims from shard dims — each replicate group gets its own shard sub-mesh. | |
| 4. **Creates** a ProcessGroup for the current rank's shard mesh. | |
| Returns `(shard_mesh, process_group, shard_placements)` — used for all-to-all communication. | |
| **Why this is needed:** A model might use `[Replicate, Shard(0), _StridedShard(0)]` across a 3D mesh. The optimizer needs to identify which ranks participate in the same shard group (share the same data) and create a ProcessGroup for them. | |
| ### `get_slices_of_dtensor(target, local_rank, shard_mesh, shard_placements)` | |
| Computes the exact indices that a given rank owns in the full tensor. Handles both contiguous (`Shard`) and strided (`_StridedShard`) sharding, including composed multi-level sharding on the same dimension. | |
| Returns a tuple of `slice` (contiguous) or `torch.LongTensor` (strided) per dimension. | |
| **Example:** With `[Shard(0), _StridedShard(0)]` on a (16, 2048) tensor across 4 ranks: | |
| - Rank 0 might own rows `[0, 4, 8, 12]` (strided) | |
| - Rank 1 might own rows `[1, 5, 9, 13]` | |
| - etc. | |
| ### PyTorch 2.10 Compatibility | |
| In PyTorch 2.10, `_StridedShard` no longer inherits from `Shard`. The helper `_is_shard()` handles both old and new hierarchies: | |
| ```python | |
| def _is_shard(placement): | |
| return isinstance(placement, (Shard, _StridedShard)) | |
| ``` | |
| ## Newton-Schulz Orthogonalization | |
| **File:** `newton_schulz.py` | |
| `_zeropower_via_newtonschulz5()` computes the orthogonal approximation of a matrix using 5 quintic Newton-Schulz iterations with pre-optimized coefficients. The result approximates `US'V^T` where `S'` is near-uniform on `[0.5, 1.5]`, which empirically does not hurt model performance vs. exact `UV^T`. | |
| Each iteration uses `matmul_transpose_assign()` (a Triton kernel for `X @ X^T`) for efficiency. | |
| **File:** `matmul_transpose_triton.py` | |
| The `matmul_transpose_assign(d_in, d_out)` kernel computes `d_out = d_in @ d_in^T` in-place. It exploits symmetry by computing only upper-triangle blocks and mirroring. | |
| ## QK Clipping | |
| **File:** `qk_clip.py` | |
| Optional dynamic clipping for attention head projections (Q and K weight matrices). When the maximum QK logit for a head exceeds a threshold, the corresponding rows of the weight matrix are scaled down by `sqrt(threshold / logit)`. | |
| **In the parallel pipeline:** QK clipping is applied per-row using each row's global head index. This correctly handles strided sharding where local rows may be interleaved across multiple heads: | |
| ```python | |
| # pipeline.py: _update_params() | |
| ratio = p.shape[0] // scales_full.shape[0] # rows per head | |
| idx0 = state.rank_indices[rank][0] # which global rows this rank owns | |
| row_scales = scales_full[idx0 // ratio] # map each row to its head's scale | |
| p._local_tensor.mul_(row_scales.view(-1, 1)) | |
| ``` | |
| ## AdamW for Non-Muon Parameters | |
| **File:** `adamw.py` | |
| Parameters not eligible for Muon (1D parameters, embeddings, LM head) are optimized with fused AdamW via `torch._fused_adamw_`. Parameters are grouped by device/dtype and DTensor placement before the fused call. | |
| ## Source File Map | |
| | File | Lines | Purpose | | |
| |------|-------|---------| | |
| | `muon.py` | ~525 | Optimizer class, parameter routing, 3 execution paths | | |
| | `pipeline.py` | ~290 | Generator-based parallel pipeline (gather/compute/scatter/update) | | |
| | `async_utils.py` | ~75 | Pipeline scheduler with bounded concurrency | | |
| | `core.py` | ~110 | `_muon_state` dataclass, momentum/update helpers, param grouping | | |
| | `distributed/utils.py` | ~230 | Shard mesh construction, DTensor index computation | | |
| | `newton_schulz.py` | ~50 | Newton-Schulz iteration | | |
| | `matmul_transpose_triton.py` | ~120 | Triton kernel for symmetric matmul | | |
| | `qk_clip.py` | ~130 | QK logit clipping | | |
| | `adamw.py` | ~160 | Fused AdamW for non-Muon params | | |
| ### Dependency Graph | |
| ``` | |
| matmul_transpose_triton.py (leaf) | |
| | | |
| newton_schulz.py (leaf + triton) | |
| | | |
| core.py ---- qk_clip.py (leaf, distributed/utils) | |
| | | | | |
| | pipeline.py --- async_utils.py | |
| | | | |
| | adamw.py | |
| | | | |
| muon.py (all above) | |
| | | |
| __init__.py | |
| ``` | |