File size: 15,342 Bytes
33929c0 14040eb 33929c0 14040eb 33929c0 14040eb 33929c0 14040eb 33929c0 14040eb 33929c0 14040eb 33929c0 14040eb 33929c0 14040eb 33929c0 14040eb 33929c0 14040eb 33929c0 14040eb 33929c0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 | # Muon Optimizer: Implementation Guide
This document explains the internal architecture of the Muon optimizer for reviewers and new contributors. It covers the execution paths, the parallel pipeline design, and the distributed sharding utilities.
## Table of Contents
1. [Overview](#overview)
2. [Entry Point and Parameter Routing](#entry-point-and-parameter-routing)
3. [Execution Paths](#execution-paths)
4. [Parallel Pipeline (the core feature)](#parallel-pipeline)
5. [MoE Expert Weight Support](#moe-expert-weight-support-expert_keys)
6. [Distributed Utilities](#distributed-utilities)
7. [Newton-Schulz Orthogonalization](#newton-schulz-orthogonalization)
8. [QK Clipping](#qk-clipping)
9. [AdamW for Non-Muon Parameters](#adamw-for-non-muon-parameters)
10. [Source File Map](#source-file-map)
---
## Overview
Muon (MomentUm Orthogonalized by Newton-schulz) applies standard SGD-momentum and then replaces each 2D parameter's update with the nearest orthogonal matrix via a Newton-Schulz iteration. The iteration runs stably in bfloat16 on GPU.
The optimizer supports arbitrary N-D sharding configurations: FSDP2, TP, or hybrid setups like `2 TP x 2 DP-Replicate x 2 DP-Shard`. This generality is what drives most of the code complexity.
## Entry Point and Parameter Routing
**File:** `muon.py` — `Muon.step()` / `Muon._step_muon()`
Users must provide parameter groups with `use_muon=True/False` flags (via `get_default_muon_param_groups()`). At each step:
1. **Non-Muon groups** → `step_adamw()` (fused AdamW).
2. **Muon groups** → `_step_muon()`, which further classifies each parameter:
```
_step_muon(group)
|
+-- momentum update (batched _foreach_* ops)
+-- _expand_expert_params() -- 3D expert params → per-expert 2D views (cached)
|
+-- DTensor, all Replicate placements --> base() (no sharding)
+-- DTensor, sharded --> parallel() (pipelined all-to-all)
+-- plain Tensor --> base() (single device)
```
Parameters are classified by their DTensor placements:
- **Fully replicated** DTensors and plain tensors use `base()` — standard single-device Muon.
- **Sharded** DTensors use `parallel()` — the pipelined all-to-all approach described below.
- `distributed_muon()` exists as a **test-only reference implementation** for correctness verification.
## Execution Paths
### base() — Single Device
Straightforward per-parameter loop: momentum update → Newton-Schulz orthogonalization → parameter update → optional QK clipping.
### distributed_muon() — Full Gather (test-only)
Reference implementation for correctness verification. Uses batched all-gather to reconstruct full tensors, computes Newton-Schulz on the full grad, then slices back to local shards. Simple but communication-heavy — not used in production.
### parallel() — Pipelined All-to-All
This is the main advanced feature. Instead of all-gathering the full parameter, it uses **all-to-all** to distribute work: each rank "owns" a subset of parameters and is responsible for their Newton-Schulz computation.
## Parallel Pipeline
### Design Motivation
Newton-Schulz is compute-intensive. The key insight is that each rank only needs to orthogonalize the parameters it "owns" — not all parameters. So the flow is:
1. **Gather**: Each rank sends its local gradient shard to the owning rank via all-to-all.
2. **Compute**: The owning rank runs Newton-Schulz on the full (gathered) gradient.
3. **Scatter**: The owning rank sends the orthogonalized update back to all ranks via all-to-all.
4. **Update**: Each rank applies weight decay and the update to its local shard.
To overlap communication and computation, parameters are split into **chunks**, and multiple chunks are processed concurrently.
### Architecture
```
muon.py: parallel()
|
+-- init_state_and_assign_params() -- assigns ownership, precomputes indices
|
+-- pipelines() generator -- yields muon_chunk_pipeline() per chunk
|
+-- run_pipeline(pipelines, max_concurrent=warmup_step+1)
|
+-- interleaves chunks at yield boundaries
```
### The Chunk Pipeline Generator
**File:** `pipeline.py` — `muon_chunk_pipeline()`
Each chunk is a generator that yields **2 times**, creating stages separated by async communication:
```
YIELD 1 YIELD 2
| |
[Build bufs + async gather a2a] --> [wait + NS compute + async scatter a2a] --> [wait + Update params]
```
- **Async communication**: `dist.all_to_all_single(..., async_op=True)` launches non-blocking communication. The generator yields immediately after, allowing other chunks to run. `work.wait()` completes the operation after the yield.
- **Chunk-level overlap**: `run_pipeline()` interleaves multiple chunks at yield boundaries, so while chunk N waits for its communication, chunk N+1 can launch its own.
### The Pipeline Scheduler
**File:** `async_utils.py` — `run_pipeline()`
A simple round-robin scheduler:
```python
while have_new or previous_tasks:
# Admit one new pipeline if below concurrency limit
if have_new and len(previous_tasks) < max_concurrent:
task = next(pipelines) # runs to first yield
# Advance all existing tasks by one yield
for task in previous_tasks:
task.step() # runs to next yield
```
`max_concurrent = warmup_step + 1` controls how many chunks can be in-flight simultaneously. Higher values increase memory usage but improve communication/computation overlap.
### Ownership Assignment
**File:** `muon.py` — `init_state_and_assign_params()`
Parameters are sorted by FLOP cost (descending) and assigned to ranks in round-robin order across the shard mesh. This balances compute load across ranks.
### Precomputed Shard Indices
Instead of computing per-rank shard indices on every step, they are precomputed once during `init_state_and_assign_params()` and stored in `_muon_state`:
```python
@dataclass
class _muon_state:
worker_rank: int # which rank owns this param's computation
process_group: ProcessGroup # the all-to-all communication group
rank_indices: dict[int, tuple] # rank -> per-dim indices into full tensor
rank_numels: dict[int, int] # rank -> number of elements in shard
name: str
qk_clip_state: QKClipInfo | None
```
`rank_indices[r]` is a tuple of `slice` or `torch.Tensor` per dimension, describing which elements of the full tensor rank `r` owns. `rank_numels[r]` is the total number of elements in that shard. These are used directly in the pipeline's gather and scatter stages.
### Pipeline Stages in Detail
#### Stages 1-2: Gather
1. **Allocate** receive buffers for gathered gradients (only on owning ranks).
2. **Build send buffer**: Each rank flattens its local gradient shard for each destination rank.
3. **Async all-to-all**: `dist.all_to_all_single(..., async_op=True)` launches gather.
4. **Yield 1**: Other chunks can launch their gather while this one waits.
5. **`work.wait()`**: Complete the gather.
6. **Reconstruct**: The owning rank places received shards into the full gradient using `rank_indices`.
#### Stage 3: Compute
The owning rank runs `_zeropower_via_newtonschulz5()` on the full gathered gradient. This is the most compute-intensive stage. Runs inline (no yield) since it is synchronous GPU work.
#### Stages 4-5: Scatter
Inverse of gather:
1. **Allocate** receive buffers for the orthogonalized update `U`.
2. **Build send buffer**: The owning rank slices `U` using `rank_indices` for each destination rank.
3. **Async all-to-all**: `dist.all_to_all_single(..., async_op=True)` launches scatter.
4. **Yield 2**: Other chunks can launch their scatter while this one waits.
5. **`work.wait()`**: Complete the scatter.
6. **Copy** received shards into local update buffers.
#### Stage 6: Update
Each rank applies weight decay and the Muon update to its local parameter shard. Also applies QK clipping if configured.
## MoE Expert Weight Support (`expert_keys`)
**File:** `muon.py` — `_expand_expert_params()`
MoE models have 3D expert weights with shape `(num_experts, out_dim, in_dim)`. Since Muon operates on 2D matrices, expert params need special handling.
### Configuration
Pass `expert_keys` to both `get_default_muon_param_groups()` and `Muon()`:
```python
params = get_default_muon_param_groups(model, expert_keys=["experts"])
optim = Muon(params, expert_keys=["experts"], ...)
```
Any parameter whose name contains a string in `expert_keys` is treated as an expert-parallel parameter. Non-matching 3D+ parameters raise `AssertionError` to catch misconfiguration.
### How It Works
`_expand_expert_params()` runs after momentum and before routing to `base()`/`parallel()`/`distributed_muon()`:
1. **Split on dim 0**: A 3D `(E, out, in)` tensor becomes `E` separate 2D `(out, in)` `nn.Parameter` views. Views share storage with the original, so in-place updates propagate back.
2. **Placement remapping**: When the original is a DTensor, `Shard(k)` on dim `k > 0` becomes `Shard(k-1)` on the 2D slice (since dim 0 is consumed by the split).
3. **Submesh wrapping**: Non-dim-0 shard placements are preserved by wrapping each 2D slice as a DTensor on the corresponding submesh. This is **placement-agnostic** — the same logic handles TP `Shard(1/2)`, EFSDP `Shard(1)`, or any other non-dim-0 sharding.
### Placement-Agnostic Design
The expansion logic does not care *why* a dimension is sharded — only whether it's on dim 0 (consumed by split) or not (preserved on submesh):
| Original Placement | After Expansion |
|-------------------|-----------------|
| `Shard(0)` (EP) | Consumed by split → plain tensor |
| `Shard(1)` (TP or EFSDP) | `Shard(0)` on submesh → 2D DTensor |
| `Shard(2)` (TP row-wise) | `Shard(1)` on submesh → 2D DTensor |
| `Replicate` | Ignored (not a shard) |
| `_StridedShard(0)` (EFSDP) | Consumed by split → plain tensor |
After expansion, the 2D params flow through the standard routing: DTensors with shard placements go to `parallel()`, plain tensors go to `base()`.
For EP/EFSDP background and torchtitan integration details, see [`docs/expert_parallel.md`](expert_parallel.md).
## Distributed Utilities
**File:** `distributed/utils.py`
These utilities solve the problem of mapping from a DTensor's arbitrary sharding configuration to the concrete indices each rank owns.
### `construct_shard_mesh(placements, mesh)`
Given a DTensor's placements and device mesh, this function:
1. **Sorts** placements: Replicate dims first, then Shard dims by dimension (with `_StridedShard` before regular `Shard` on the same dim, so the outer sharding is applied first).
2. **Permutes** the mesh accordingly.
3. **Separates** replicate dims from shard dims — each replicate group gets its own shard sub-mesh.
4. **Creates** a ProcessGroup for the current rank's shard mesh.
Returns `(shard_mesh, process_group, shard_placements)` — used for all-to-all communication.
**Why this is needed:** A model might use `[Replicate, Shard(0), _StridedShard(0)]` across a 3D mesh. The optimizer needs to identify which ranks participate in the same shard group (share the same data) and create a ProcessGroup for them.
### `get_slices_of_dtensor(target, local_rank, shard_mesh, shard_placements)`
Computes the exact indices that a given rank owns in the full tensor. Handles both contiguous (`Shard`) and strided (`_StridedShard`) sharding, including composed multi-level sharding on the same dimension.
Returns a tuple of `slice` (contiguous) or `torch.LongTensor` (strided) per dimension.
**Example:** With `[Shard(0), _StridedShard(0)]` on a (16, 2048) tensor across 4 ranks:
- Rank 0 might own rows `[0, 4, 8, 12]` (strided)
- Rank 1 might own rows `[1, 5, 9, 13]`
- etc.
### PyTorch 2.10 Compatibility
In PyTorch 2.10, `_StridedShard` no longer inherits from `Shard`. The helper `_is_shard()` handles both old and new hierarchies:
```python
def _is_shard(placement):
return isinstance(placement, (Shard, _StridedShard))
```
## Newton-Schulz Orthogonalization
**File:** `newton_schulz.py`
`_zeropower_via_newtonschulz5()` computes the polar factor of a matrix using the Polar Express method — quintic Newton-Schulz iterations with analytically optimal (minimax/Remez) coefficients precomputed by `_optimal_composition()`. The default configuration uses 10 iterations with `l=1e-3`, converging all singular values to 1 to produce the exact polar factor `UV^T`. Wrapped by `zeropower_via_newtonschulz5()` which adds per-shape `torch.compile` caching with CUDA graph support.
Each iteration uses `matmul_transpose_assign()` (a Triton kernel for `X @ X^T`) for efficiency.
**File:** `matmul_transpose_triton.py`
The `matmul_transpose_assign(d_in, d_out)` kernel computes `d_out = d_in @ d_in^T` in-place. It exploits symmetry by computing only upper-triangle blocks and mirroring.
## QK Clipping
**File:** `qk_clip.py`
Optional dynamic clipping for attention head projections (Q and K weight matrices). When the maximum QK logit for a head exceeds a threshold, the corresponding rows of the weight matrix are scaled down by `sqrt(threshold / logit)`.
**In the parallel pipeline:** QK clipping is applied per-row using each row's global head index. This correctly handles strided sharding where local rows may be interleaved across multiple heads:
```python
# pipeline.py: _update_params()
ratio = p.shape[0] // scales_full.shape[0] # rows per head
idx0 = state.rank_indices[rank][0] # which global rows this rank owns
row_scales = scales_full[idx0 // ratio] # map each row to its head's scale
p._local_tensor.mul_(row_scales.view(-1, 1))
```
## AdamW for Non-Muon Parameters
**File:** `adamw.py`
Parameters not eligible for Muon (1D parameters, embeddings, LM head) are optimized with fused AdamW via `torch._fused_adamw_`. Parameters are grouped by device/dtype and DTensor placement before the fused call.
## Source File Map
| File | Lines | Purpose |
|------|-------|---------|
| `muon.py` | ~815 | Optimizer class, parameter routing, 3 execution paths, MoE expert expansion + caching |
| `pipeline.py` | ~400 | Generator-based parallel pipeline (gather/compute/scatter/update) |
| `async_utils.py` | ~75 | Pipeline scheduler with bounded concurrency |
| `core.py` | ~175 | `_muon_state` dataclass, batched momentum/update helpers, param grouping |
| `distributed/utils.py` | ~230 | Shard mesh construction, DTensor index computation |
| `newton_schulz.py` | ~190 | Polar Express coefficients, Newton-Schulz iteration + compile/CUDA graph |
| `matmul_transpose_triton.py` | ~130 | Triton kernel for symmetric matmul |
| `qk_clip.py` | ~135 | QK logit clipping |
| `adamw.py` | ~170 | Fused AdamW for non-Muon params |
### Dependency Graph
```
matmul_transpose_triton.py (leaf)
|
newton_schulz.py (leaf + triton)
|
core.py ---- qk_clip.py (leaf, distributed/utils)
| | |
| pipeline.py --- async_utils.py
| |
| adamw.py
| |
muon.py (all above)
|
__init__.py
```
|