Kernels
File size: 15,342 Bytes
33929c0
 
 
 
 
 
 
 
 
 
14040eb
 
 
 
 
 
33929c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14040eb
 
 
33929c0
 
 
 
 
 
 
 
14040eb
 
33929c0
 
 
 
 
 
 
14040eb
33929c0
14040eb
33929c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14040eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33929c0
 
 
 
 
 
 
 
 
 
14040eb
33929c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14040eb
33929c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14040eb
 
33929c0
14040eb
33929c0
14040eb
 
 
 
33929c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
# Muon Optimizer: Implementation Guide

This document explains the internal architecture of the Muon optimizer for reviewers and new contributors. It covers the execution paths, the parallel pipeline design, and the distributed sharding utilities.

## Table of Contents

1. [Overview](#overview)
2. [Entry Point and Parameter Routing](#entry-point-and-parameter-routing)
3. [Execution Paths](#execution-paths)
4. [Parallel Pipeline (the core feature)](#parallel-pipeline)
5. [MoE Expert Weight Support](#moe-expert-weight-support-expert_keys)
6. [Distributed Utilities](#distributed-utilities)
7. [Newton-Schulz Orthogonalization](#newton-schulz-orthogonalization)
8. [QK Clipping](#qk-clipping)
9. [AdamW for Non-Muon Parameters](#adamw-for-non-muon-parameters)
10. [Source File Map](#source-file-map)

---

## Overview

Muon (MomentUm Orthogonalized by Newton-schulz) applies standard SGD-momentum and then replaces each 2D parameter's update with the nearest orthogonal matrix via a Newton-Schulz iteration. The iteration runs stably in bfloat16 on GPU.

The optimizer supports arbitrary N-D sharding configurations: FSDP2, TP, or hybrid setups like `2 TP x 2 DP-Replicate x 2 DP-Shard`. This generality is what drives most of the code complexity.

## Entry Point and Parameter Routing

**File:** `muon.py` — `Muon.step()` / `Muon._step_muon()`

Users must provide parameter groups with `use_muon=True/False` flags (via `get_default_muon_param_groups()`). At each step:

1. **Non-Muon groups** → `step_adamw()` (fused AdamW).
2. **Muon groups** → `_step_muon()`, which further classifies each parameter:

```
_step_muon(group)
  |
  +-- momentum update (batched _foreach_* ops)
  +-- _expand_expert_params()   -- 3D expert params → per-expert 2D views (cached)
  |
  +-- DTensor, all Replicate placements  --> base()       (no sharding)
  +-- DTensor, sharded                   --> parallel()   (pipelined all-to-all)
  +-- plain Tensor                       --> base()       (single device)
```

Parameters are classified by their DTensor placements:
- **Fully replicated** DTensors and plain tensors use `base()` — standard single-device Muon.
- **Sharded** DTensors use `parallel()` — the pipelined all-to-all approach described below.
- `distributed_muon()` exists as a **test-only reference implementation** for correctness verification.

## Execution Paths

### base() — Single Device

Straightforward per-parameter loop: momentum update → Newton-Schulz orthogonalization → parameter update → optional QK clipping.

### distributed_muon() — Full Gather (test-only)

Reference implementation for correctness verification. Uses batched all-gather to reconstruct full tensors, computes Newton-Schulz on the full grad, then slices back to local shards. Simple but communication-heavy — not used in production.

### parallel() — Pipelined All-to-All

This is the main advanced feature. Instead of all-gathering the full parameter, it uses **all-to-all** to distribute work: each rank "owns" a subset of parameters and is responsible for their Newton-Schulz computation.

## Parallel Pipeline

### Design Motivation

Newton-Schulz is compute-intensive. The key insight is that each rank only needs to orthogonalize the parameters it "owns" — not all parameters. So the flow is:

1. **Gather**: Each rank sends its local gradient shard to the owning rank via all-to-all.
2. **Compute**: The owning rank runs Newton-Schulz on the full (gathered) gradient.
3. **Scatter**: The owning rank sends the orthogonalized update back to all ranks via all-to-all.
4. **Update**: Each rank applies weight decay and the update to its local shard.

To overlap communication and computation, parameters are split into **chunks**, and multiple chunks are processed concurrently.

### Architecture

```
muon.py: parallel()
  |
  +-- init_state_and_assign_params()    -- assigns ownership, precomputes indices
  |
  +-- pipelines() generator             -- yields muon_chunk_pipeline() per chunk
  |
  +-- run_pipeline(pipelines, max_concurrent=warmup_step+1)
        |
        +-- interleaves chunks at yield boundaries
```

### The Chunk Pipeline Generator

**File:** `pipeline.py` — `muon_chunk_pipeline()`

Each chunk is a generator that yields **2 times**, creating stages separated by async communication:

```
 YIELD 1                                    YIELD 2
    |                                          |
[Build bufs + async gather a2a]  -->  [wait + NS compute + async scatter a2a]  -->  [wait + Update params]
```

- **Async communication**: `dist.all_to_all_single(..., async_op=True)` launches non-blocking communication. The generator yields immediately after, allowing other chunks to run. `work.wait()` completes the operation after the yield.
- **Chunk-level overlap**: `run_pipeline()` interleaves multiple chunks at yield boundaries, so while chunk N waits for its communication, chunk N+1 can launch its own.

### The Pipeline Scheduler

**File:** `async_utils.py` — `run_pipeline()`

A simple round-robin scheduler:

```python
while have_new or previous_tasks:
    # Admit one new pipeline if below concurrency limit
    if have_new and len(previous_tasks) < max_concurrent:
        task = next(pipelines)   # runs to first yield
    # Advance all existing tasks by one yield
    for task in previous_tasks:
        task.step()              # runs to next yield
```

`max_concurrent = warmup_step + 1` controls how many chunks can be in-flight simultaneously. Higher values increase memory usage but improve communication/computation overlap.

### Ownership Assignment

**File:** `muon.py` &mdash; `init_state_and_assign_params()`

Parameters are sorted by FLOP cost (descending) and assigned to ranks in round-robin order across the shard mesh. This balances compute load across ranks.

### Precomputed Shard Indices

Instead of computing per-rank shard indices on every step, they are precomputed once during `init_state_and_assign_params()` and stored in `_muon_state`:

```python
@dataclass
class _muon_state:
    worker_rank: int              # which rank owns this param's computation
    process_group: ProcessGroup   # the all-to-all communication group
    rank_indices: dict[int, tuple]  # rank -> per-dim indices into full tensor
    rank_numels: dict[int, int]     # rank -> number of elements in shard
    name: str
    qk_clip_state: QKClipInfo | None
```

`rank_indices[r]` is a tuple of `slice` or `torch.Tensor` per dimension, describing which elements of the full tensor rank `r` owns. `rank_numels[r]` is the total number of elements in that shard. These are used directly in the pipeline's gather and scatter stages.

### Pipeline Stages in Detail

#### Stages 1-2: Gather

1. **Allocate** receive buffers for gathered gradients (only on owning ranks).
2. **Build send buffer**: Each rank flattens its local gradient shard for each destination rank.
3. **Async all-to-all**: `dist.all_to_all_single(..., async_op=True)` launches gather.
4. **Yield 1**: Other chunks can launch their gather while this one waits.
5. **`work.wait()`**: Complete the gather.
6. **Reconstruct**: The owning rank places received shards into the full gradient using `rank_indices`.

#### Stage 3: Compute

The owning rank runs `_zeropower_via_newtonschulz5()` on the full gathered gradient. This is the most compute-intensive stage. Runs inline (no yield) since it is synchronous GPU work.

#### Stages 4-5: Scatter

Inverse of gather:
1. **Allocate** receive buffers for the orthogonalized update `U`.
2. **Build send buffer**: The owning rank slices `U` using `rank_indices` for each destination rank.
3. **Async all-to-all**: `dist.all_to_all_single(..., async_op=True)` launches scatter.
4. **Yield 2**: Other chunks can launch their scatter while this one waits.
5. **`work.wait()`**: Complete the scatter.
6. **Copy** received shards into local update buffers.

#### Stage 6: Update

Each rank applies weight decay and the Muon update to its local parameter shard. Also applies QK clipping if configured.

## MoE Expert Weight Support (`expert_keys`)

**File:** `muon.py` &mdash; `_expand_expert_params()`

MoE models have 3D expert weights with shape `(num_experts, out_dim, in_dim)`. Since Muon operates on 2D matrices, expert params need special handling.

### Configuration

Pass `expert_keys` to both `get_default_muon_param_groups()` and `Muon()`:

```python
params = get_default_muon_param_groups(model, expert_keys=["experts"])
optim = Muon(params, expert_keys=["experts"], ...)
```

Any parameter whose name contains a string in `expert_keys` is treated as an expert-parallel parameter. Non-matching 3D+ parameters raise `AssertionError` to catch misconfiguration.

### How It Works

`_expand_expert_params()` runs after momentum and before routing to `base()`/`parallel()`/`distributed_muon()`:

1. **Split on dim 0**: A 3D `(E, out, in)` tensor becomes `E` separate 2D `(out, in)` `nn.Parameter` views. Views share storage with the original, so in-place updates propagate back.
2. **Placement remapping**: When the original is a DTensor, `Shard(k)` on dim `k > 0` becomes `Shard(k-1)` on the 2D slice (since dim 0 is consumed by the split).
3. **Submesh wrapping**: Non-dim-0 shard placements are preserved by wrapping each 2D slice as a DTensor on the corresponding submesh. This is **placement-agnostic** &mdash; the same logic handles TP `Shard(1/2)`, EFSDP `Shard(1)`, or any other non-dim-0 sharding.

### Placement-Agnostic Design

The expansion logic does not care *why* a dimension is sharded &mdash; only whether it's on dim 0 (consumed by split) or not (preserved on submesh):

| Original Placement | After Expansion |
|-------------------|-----------------|
| `Shard(0)` (EP) | Consumed by split &rarr; plain tensor |
| `Shard(1)` (TP or EFSDP) | `Shard(0)` on submesh &rarr; 2D DTensor |
| `Shard(2)` (TP row-wise) | `Shard(1)` on submesh &rarr; 2D DTensor |
| `Replicate` | Ignored (not a shard) |
| `_StridedShard(0)` (EFSDP) | Consumed by split &rarr; plain tensor |

After expansion, the 2D params flow through the standard routing: DTensors with shard placements go to `parallel()`, plain tensors go to `base()`.

For EP/EFSDP background and torchtitan integration details, see [`docs/expert_parallel.md`](expert_parallel.md).

## Distributed Utilities

**File:** `distributed/utils.py`

These utilities solve the problem of mapping from a DTensor's arbitrary sharding configuration to the concrete indices each rank owns.

### `construct_shard_mesh(placements, mesh)`

Given a DTensor's placements and device mesh, this function:

1. **Sorts** placements: Replicate dims first, then Shard dims by dimension (with `_StridedShard` before regular `Shard` on the same dim, so the outer sharding is applied first).
2. **Permutes** the mesh accordingly.
3. **Separates** replicate dims from shard dims &mdash; each replicate group gets its own shard sub-mesh.
4. **Creates** a ProcessGroup for the current rank's shard mesh.

Returns `(shard_mesh, process_group, shard_placements)` &mdash; used for all-to-all communication.

**Why this is needed:** A model might use `[Replicate, Shard(0), _StridedShard(0)]` across a 3D mesh. The optimizer needs to identify which ranks participate in the same shard group (share the same data) and create a ProcessGroup for them.

### `get_slices_of_dtensor(target, local_rank, shard_mesh, shard_placements)`

Computes the exact indices that a given rank owns in the full tensor. Handles both contiguous (`Shard`) and strided (`_StridedShard`) sharding, including composed multi-level sharding on the same dimension.

Returns a tuple of `slice` (contiguous) or `torch.LongTensor` (strided) per dimension.

**Example:** With `[Shard(0), _StridedShard(0)]` on a (16, 2048) tensor across 4 ranks:
- Rank 0 might own rows `[0, 4, 8, 12]` (strided)
- Rank 1 might own rows `[1, 5, 9, 13]`
- etc.

### PyTorch 2.10 Compatibility

In PyTorch 2.10, `_StridedShard` no longer inherits from `Shard`. The helper `_is_shard()` handles both old and new hierarchies:

```python
def _is_shard(placement):
    return isinstance(placement, (Shard, _StridedShard))
```

## Newton-Schulz Orthogonalization

**File:** `newton_schulz.py`

`_zeropower_via_newtonschulz5()` computes the polar factor of a matrix using the Polar Express method &mdash; quintic Newton-Schulz iterations with analytically optimal (minimax/Remez) coefficients precomputed by `_optimal_composition()`. The default configuration uses 10 iterations with `l=1e-3`, converging all singular values to 1 to produce the exact polar factor `UV^T`. Wrapped by `zeropower_via_newtonschulz5()` which adds per-shape `torch.compile` caching with CUDA graph support.

Each iteration uses `matmul_transpose_assign()` (a Triton kernel for `X @ X^T`) for efficiency.

**File:** `matmul_transpose_triton.py`

The `matmul_transpose_assign(d_in, d_out)` kernel computes `d_out = d_in @ d_in^T` in-place. It exploits symmetry by computing only upper-triangle blocks and mirroring.

## QK Clipping

**File:** `qk_clip.py`

Optional dynamic clipping for attention head projections (Q and K weight matrices). When the maximum QK logit for a head exceeds a threshold, the corresponding rows of the weight matrix are scaled down by `sqrt(threshold / logit)`.

**In the parallel pipeline:** QK clipping is applied per-row using each row's global head index. This correctly handles strided sharding where local rows may be interleaved across multiple heads:

```python
# pipeline.py: _update_params()
ratio = p.shape[0] // scales_full.shape[0]     # rows per head
idx0 = state.rank_indices[rank][0]              # which global rows this rank owns
row_scales = scales_full[idx0 // ratio]         # map each row to its head's scale
p._local_tensor.mul_(row_scales.view(-1, 1))
```

## AdamW for Non-Muon Parameters

**File:** `adamw.py`

Parameters not eligible for Muon (1D parameters, embeddings, LM head) are optimized with fused AdamW via `torch._fused_adamw_`. Parameters are grouped by device/dtype and DTensor placement before the fused call.

## Source File Map

| File | Lines | Purpose |
|------|-------|---------|
| `muon.py` | ~815 | Optimizer class, parameter routing, 3 execution paths, MoE expert expansion + caching |
| `pipeline.py` | ~400 | Generator-based parallel pipeline (gather/compute/scatter/update) |
| `async_utils.py` | ~75 | Pipeline scheduler with bounded concurrency |
| `core.py` | ~175 | `_muon_state` dataclass, batched momentum/update helpers, param grouping |
| `distributed/utils.py` | ~230 | Shard mesh construction, DTensor index computation |
| `newton_schulz.py` | ~190 | Polar Express coefficients, Newton-Schulz iteration + compile/CUDA graph |
| `matmul_transpose_triton.py` | ~130 | Triton kernel for symmetric matmul |
| `qk_clip.py` | ~135 | QK logit clipping |
| `adamw.py` | ~170 | Fused AdamW for non-Muon params |

### Dependency Graph

```
matmul_transpose_triton.py       (leaf)
         |
    newton_schulz.py              (leaf + triton)
         |
      core.py ---- qk_clip.py    (leaf, distributed/utils)
       |    |         |
       |  pipeline.py --- async_utils.py
       |       |
       |   adamw.py
       |       |
      muon.py                     (all above)
         |
    __init__.py
```