Kernels
File size: 9,183 Bytes
33929c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
# Expert Parallelism in torchtitan

torchtitan (0.2.0)의 expert parallelism 구현을 정리한 문서.
Muon optimizer의 MoE 지원에 필요한 배경 지식.

Reference: `torchtitan/distributed/expert_parallel.py`, `torchtitan/distributed/parallel_dims.py`

## Overview

torchtitan은 MoE expert weights에 대해 4가지 parallelism 전략을 제공:

| Config | TP | EP | ETP | Expert Weight Placements | Token Dispatch |
|--------|----|----|-----|--------------------------|----------------|
| TP Only | >1 | 1 | - | `[Shard(1/2)]` on TP mesh | None |
| EP Only | 1 | >1 | - | `[Shard(0)]` on EP mesh | All-to-all |
| EP+ETP (etp=tp) | >1 | >1 | =tp | `[Shard(0), Shard(1/2)]` on [EP, TP] mesh | All-to-all on EP |
| EP+ETP (etp=1) | >1 | >1 | 1 | `[Shard(0)]` on EP mesh | Sequence parallel on TP |

Expert weights shape: `(num_experts, out_dim, in_dim)` (w1, w3) / `(num_experts, in_dim, out_dim)` (w2).

## EP가 dp_shard를 빌리는 구조

EP는 새로운 물리적 차원이 아니라 `dp_shard`를 분해해서 사용:

```
dp_shard = dp_shard_mod_ep * dp_shard_in_ep

ETP=TP일 때:  ep = dp_shard_in_ep * cp
ETP=1일 때:   ep = dp_shard_in_ep * cp * tp
```

기존 mesh `[pp, dp_replicate, dp_shard, cp, tp]`가 EP 활성화 시:

```
[pp, dp_replicate, dp_shard_mod_ep, dp_shard_in_ep, cp, tp]
```

로 확장됨. `dp_shard_mod_ep`는 값이 1이어도 mesh에 유지 (FSDP wrapping 일관성).

### 예시: 8 GPUs, ep=4, dp_shard=8, tp=1, cp=1

```
dp_shard_in_ep = ep / cp = 4
dp_shard_mod_ep = dp_shard * cp / ep = 2

mesh: [dp_shard_mod_ep=2, dp_shard_in_ep=4]
EP mesh: [dp_shard_in_ep=4]  → expert들을 4-way로 분배
FSDP mesh: [dp_shard_mod_ep=2]  → expert FSDP는 2-way로 shard
```

## Submesh 매핑

```python
# Data loading (no communication)
dp = [dp_replicate, dp_shard_mod_ep, dp_shard_in_ep]

# Non-expert parameter sharding (FSDP)
dp_shard_cp = [dp_shard_mod_ep, dp_shard_in_ep, cp]

# Expert parameter sharding (EFSDP) — dp_shard_in_ep 제외
dp_mod_ep = [dp_replicate?, dp_shard_mod_ep]

# Expert parallelism mesh
ep = [dp_shard_in_ep, cp, (tp if etp==1)]

# Loss all-reduce
dp_cp = [dp_replicate, dp_shard_mod_ep, dp_shard_in_ep, cp]
```

## 4가지 전략 상세

### 1. TensorParallel (TP Only, EP=1)

EP 없이 TP만 사용. Expert weights를 TP mesh에서 column/row-wise sharding:

```python
# expert_parallel.py: TensorParallel
w1: [Shard(1)] on TP mesh   # column-wise (out_dim)
w2: [Shard(2)] on TP mesh   # row-wise (out_dim, 3D에서 dim 2)
w3: [Shard(1)] on TP mesh   # column-wise (out_dim)
```

Token dispatch 없음. 일반 TP와 동일하게 동작.

### 2. ExpertParallel (EP Only, TP=1)

Expert dim (dim 0)으로 sharding. Token all-to-all dispatch:

```python
# expert_parallel.py: ExpertParallel
w1, w2, w3: [Shard(0)] on EP mesh  # expert dim으로 분배
```

Forward pass:
1. Router가 각 token을 expert에 할당
2. `all_to_all_single`으로 token을 해당 expert의 rank로 dispatch
3. 각 rank가 local expert에서 compute
4. `all_to_all_single`으로 결과를 원래 rank로 combine

### 3. ExpertTensorParallel (EP+TP, ETP=TP)

EP와 TP를 동시에 2D로 적용:

```python
# expert_parallel.py: ExpertTensorParallel (extends ExpertParallel)
w1: [Shard(0), Shard(1)] on [EP, TP] mesh   # expert + column
w2: [Shard(0), Shard(2)] on [EP, TP] mesh   # expert + row
w3: [Shard(0), Shard(1)] on [EP, TP] mesh   # expert + column
```

Token dispatch:
1. TP mesh에서 input을 Replicate (gradient는 Partial)
2. EP mesh에서 all-to-all dispatch (ExpertParallel과 동일)
3. All-to-all은 EP mesh에서만 발생, TP 통신은 weight sharding으로 처리

### 4. ReordererSequenceParallel (EP+TP, ETP=1)

TP hardware를 EP에 빌려줌. TP mesh가 sequence parallel로 동작:

```python
# expert_parallel.py: ReordererSequenceParallel
# Expert weights: [Shard(0)] on EP mesh (TP 안 씀)
# Token split: batch*seq_len을 TP rank 수로 나눠서 분배

# EP mesh = [dp_shard_in_ep, cp, tp]  ← tp가 EP에 포함됨
```

TP rank들이 token을 나눠 처리 (sequence parallel). Expert weight에는 TP sharding 없음.

## EFSDP (Expert FSDP)

Expert parameter에 대한 FSDP는 non-expert parameter와 **다른 mesh**를 사용:

```python
# parallelize.py: apply_fsdp
# Non-expert: dp_shard_cp mesh 전체로 shard
fully_shard(transformer_block, mesh=dp_shard_cp_mesh)

# Expert (EP 활성화 시): dp_mod_ep mesh로만 shard
# dp_shard_in_ep는 이미 EP에서 사용 중이므로 제외
fully_shard(transformer_block.moe.experts, mesh=dp_mod_ep_mesh)
```

### Dynamic shard placement

Expert 수보다 `dp_mod_ep * ep`가 클 때 (expert dim으로 더 쪼갤 수 없을 때),
dim 0 대신 dim 1로 shard.

**torchtitan 코드** (`torchtitan/models/llama4/infra/parallelize.py:339-359`):

```python
# NOTE: EP alreadys shards the routed experts on dim 0 (num_experts).
#       When dp_mod_ep * ep > num_experts, FSDP default dim-0 sharding
#       causes inefficiency, so we choose to do FSDP sharding on dim-1.
_experts_shard_placement_fn = None
if (
    dp_mod_ep_mesh.size() * ep_degree
    > transformer_block.moe.experts.num_experts
):
    _experts_shard_placement_fn = lambda param: Shard(1)

fully_shard(
    transformer_block.moe.experts,
    **fsdp_mod_ep_config,   # mesh=dp_mod_ep_mesh
    reshard_after_forward=reshard_after_forward,
    shard_placement_fn=_experts_shard_placement_fn,
)
```

`dp_mod_ep_mesh` 구성 (`parallelize.py:140-159`):

```python
dp_mod_ep_mesh_dim_names = []
if parallel_dims.ep_enabled:
    if parallel_dims.dp_replicate_enabled:
        dp_mod_ep_mesh_dim_names.append("dp_replicate")
    dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep")
# → dp_mod_ep_mesh = world_mesh[tuple(dp_mod_ep_mesh_dim_names)]
```

### 실제 placement 검증 결과

8 GPUs, `num_experts=2`, `etp=1` 기준:

#### num_experts=8 (기본)

모든 config에서 expert weights는 **dim 0 (expert dim)으로만 shard**:

| Config | Expert Placements | Mesh |
|--------|-------------------|------|
| ep=8 | `[Shard(0)]` | `[ep=8]` |
| ep=4, fsdp=2 | `[_StridedShard(0), Shard(0)]` | `[dp_shard_mod_ep=2, ep=4]` |
| ep=2, fsdp=4 | `[_StridedShard(0), Shard(0)]` | `[dp_shard_mod_ep=4, ep=2]` |
| ep=2, hsdp=2+2 | `[Replicate(), _StridedShard(0), Shard(0)]` | `[dp_rep=2, dp_shard_mod_ep=2, ep=2]` |

EFSDP는 `_StridedShard(dim=0)`, EP는 `Shard(dim=0)`. 비-dim-0 shard 없음.

#### num_experts=2 (expert 수 < EFSDP shard count)

`dp_mod_ep * ep > num_experts` 조건 충족 시 **EFSDP가 Shard(1)로 전환**:

| Config | 조건 | Expert Placements | Mesh |
|--------|------|-------------------|------|
| ep=2, fsdp=4 | 4*2=8 > 2 | `[Shard(1), Shard(0)]` | `[dp_shard_mod_ep=4, ep=2]` |
| ep=2, hsdp=2+2 | 2*2=4 > 2 | `[Replicate(), Shard(1), Shard(0)]` | `[dp_rep=2, dp_shard_mod_ep=2, ep=2]` |

- EFSDP: `Shard(1)` on `dp_shard_mod_ep` → out_dim을 shard (w1: 2816/4=704)
- EP: `Shard(0)` on `ep` → expert dim을 shard (2/2=1)
- `_StridedShard`가 아닌 일반 `Shard` 사용

## Gradient Clipping with EP

EP parameter와 non-EP parameter의 gradient norm을 별도로 계산 후 합산:

```python
# distributed/utils.py: _clip_grad_norm_with_ep
ep_norm = get_total_norm(ep_grads, ...)
non_ep_norm = get_total_norm(non_ep_grads, ...)
total_norm = (ep_norm**p + non_ep_norm**p) ** (1/p)
```

EP parameter 판별: `device_mesh.mesh_dim_names`에 "ep" 포함 여부.

## Muon optimizer에서의 처리

현재 Muon optimizer의 MoE 지원:

1. **`_expand_expert_params`**: 3D expert weight를 expert dim (dim 0)으로 split하여 2D param으로 확장
2. **TP가 있을 때**: non-dim-0 shard (TP)를 TP submesh에 DTensor로 wrap
   - 3D `(Shard(0), Shard(1))` → 2D `(Shard(0),)` on TP submesh
3. **`construct_shard_mesh` fast path**: 1D submesh에서 `dist.new_group()` deadlock 방지

### Muon이 지원하는 config

| Config | 지원 | 비고 |
|--------|------|------|
| TP Only (EP=1) | O | expert를 TP submesh DTensor로 처리 |
| EP Only (TP=1) | O | expert를 plain tensor로 처리 (base mode) |
| FSDP + TP | O | FSDP는 expert dim, TP는 out/in dim |
| HSDP + TP | O | Replicate + FSDP + TP |
| EP Only (많은 experts) | O | EFSDP `Shard(0)` → plain tensor |
| EP + FSDP (적은 experts) | 미테스트 | EFSDP `Shard(1)` → 아래 참조 |
| EP + TP (ETP=TP) | 미테스트 | 2D expert DTensor `[Shard(0), Shard(1/2)]` |
| EP + TP (ETP=1) | 미테스트 | EP mesh에 TP가 포함된 경우 |

### EFSDP Shard(1)과 Muon의 호환성

Muon은 placement-agnostic. `_expand_expert_params`의 non-dim-0 shard 처리 로직이
TP뿐 아니라 EFSDP `Shard(1)`에도 동일하게 적용됨 (변수명만 `tp_*`일 뿐 로직은 generic):

```
3D: (Shard(1), Shard(0)) on [dp_shard_mod_ep=4, ep=2]
    local shape: (1, 704, 2048)

_expand_expert_params:
  1. non-dim-0 shard 탐색 → Shard(1) on dp_shard_mod_ep
  2. submesh 추출 → dp_shard_mod_ep (1D, size 4)
  3. dim 0 split → (704, 2048)
  4. DTensor wrap → Shard(0) on dp_shard_mod_ep
     = 일반 FSDP sharded 2D 텐서와 동일

→ parallel()/distributed_muon()이 all-gather → Newton-Schulz → scatter 처리.
  construct_shard_mesh fast path 적용 (1D submesh, deadlock 없음).
```