File size: 9,183 Bytes
33929c0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 | # Expert Parallelism in torchtitan
torchtitan (0.2.0)의 expert parallelism 구현을 정리한 문서.
Muon optimizer의 MoE 지원에 필요한 배경 지식.
Reference: `torchtitan/distributed/expert_parallel.py`, `torchtitan/distributed/parallel_dims.py`
## Overview
torchtitan은 MoE expert weights에 대해 4가지 parallelism 전략을 제공:
| Config | TP | EP | ETP | Expert Weight Placements | Token Dispatch |
|--------|----|----|-----|--------------------------|----------------|
| TP Only | >1 | 1 | - | `[Shard(1/2)]` on TP mesh | None |
| EP Only | 1 | >1 | - | `[Shard(0)]` on EP mesh | All-to-all |
| EP+ETP (etp=tp) | >1 | >1 | =tp | `[Shard(0), Shard(1/2)]` on [EP, TP] mesh | All-to-all on EP |
| EP+ETP (etp=1) | >1 | >1 | 1 | `[Shard(0)]` on EP mesh | Sequence parallel on TP |
Expert weights shape: `(num_experts, out_dim, in_dim)` (w1, w3) / `(num_experts, in_dim, out_dim)` (w2).
## EP가 dp_shard를 빌리는 구조
EP는 새로운 물리적 차원이 아니라 `dp_shard`를 분해해서 사용:
```
dp_shard = dp_shard_mod_ep * dp_shard_in_ep
ETP=TP일 때: ep = dp_shard_in_ep * cp
ETP=1일 때: ep = dp_shard_in_ep * cp * tp
```
기존 mesh `[pp, dp_replicate, dp_shard, cp, tp]`가 EP 활성화 시:
```
[pp, dp_replicate, dp_shard_mod_ep, dp_shard_in_ep, cp, tp]
```
로 확장됨. `dp_shard_mod_ep`는 값이 1이어도 mesh에 유지 (FSDP wrapping 일관성).
### 예시: 8 GPUs, ep=4, dp_shard=8, tp=1, cp=1
```
dp_shard_in_ep = ep / cp = 4
dp_shard_mod_ep = dp_shard * cp / ep = 2
mesh: [dp_shard_mod_ep=2, dp_shard_in_ep=4]
EP mesh: [dp_shard_in_ep=4] → expert들을 4-way로 분배
FSDP mesh: [dp_shard_mod_ep=2] → expert FSDP는 2-way로 shard
```
## Submesh 매핑
```python
# Data loading (no communication)
dp = [dp_replicate, dp_shard_mod_ep, dp_shard_in_ep]
# Non-expert parameter sharding (FSDP)
dp_shard_cp = [dp_shard_mod_ep, dp_shard_in_ep, cp]
# Expert parameter sharding (EFSDP) — dp_shard_in_ep 제외
dp_mod_ep = [dp_replicate?, dp_shard_mod_ep]
# Expert parallelism mesh
ep = [dp_shard_in_ep, cp, (tp if etp==1)]
# Loss all-reduce
dp_cp = [dp_replicate, dp_shard_mod_ep, dp_shard_in_ep, cp]
```
## 4가지 전략 상세
### 1. TensorParallel (TP Only, EP=1)
EP 없이 TP만 사용. Expert weights를 TP mesh에서 column/row-wise sharding:
```python
# expert_parallel.py: TensorParallel
w1: [Shard(1)] on TP mesh # column-wise (out_dim)
w2: [Shard(2)] on TP mesh # row-wise (out_dim, 3D에서 dim 2)
w3: [Shard(1)] on TP mesh # column-wise (out_dim)
```
Token dispatch 없음. 일반 TP와 동일하게 동작.
### 2. ExpertParallel (EP Only, TP=1)
Expert dim (dim 0)으로 sharding. Token all-to-all dispatch:
```python
# expert_parallel.py: ExpertParallel
w1, w2, w3: [Shard(0)] on EP mesh # expert dim으로 분배
```
Forward pass:
1. Router가 각 token을 expert에 할당
2. `all_to_all_single`으로 token을 해당 expert의 rank로 dispatch
3. 각 rank가 local expert에서 compute
4. `all_to_all_single`으로 결과를 원래 rank로 combine
### 3. ExpertTensorParallel (EP+TP, ETP=TP)
EP와 TP를 동시에 2D로 적용:
```python
# expert_parallel.py: ExpertTensorParallel (extends ExpertParallel)
w1: [Shard(0), Shard(1)] on [EP, TP] mesh # expert + column
w2: [Shard(0), Shard(2)] on [EP, TP] mesh # expert + row
w3: [Shard(0), Shard(1)] on [EP, TP] mesh # expert + column
```
Token dispatch:
1. TP mesh에서 input을 Replicate (gradient는 Partial)
2. EP mesh에서 all-to-all dispatch (ExpertParallel과 동일)
3. All-to-all은 EP mesh에서만 발생, TP 통신은 weight sharding으로 처리
### 4. ReordererSequenceParallel (EP+TP, ETP=1)
TP hardware를 EP에 빌려줌. TP mesh가 sequence parallel로 동작:
```python
# expert_parallel.py: ReordererSequenceParallel
# Expert weights: [Shard(0)] on EP mesh (TP 안 씀)
# Token split: batch*seq_len을 TP rank 수로 나눠서 분배
# EP mesh = [dp_shard_in_ep, cp, tp] ← tp가 EP에 포함됨
```
TP rank들이 token을 나눠 처리 (sequence parallel). Expert weight에는 TP sharding 없음.
## EFSDP (Expert FSDP)
Expert parameter에 대한 FSDP는 non-expert parameter와 **다른 mesh**를 사용:
```python
# parallelize.py: apply_fsdp
# Non-expert: dp_shard_cp mesh 전체로 shard
fully_shard(transformer_block, mesh=dp_shard_cp_mesh)
# Expert (EP 활성화 시): dp_mod_ep mesh로만 shard
# dp_shard_in_ep는 이미 EP에서 사용 중이므로 제외
fully_shard(transformer_block.moe.experts, mesh=dp_mod_ep_mesh)
```
### Dynamic shard placement
Expert 수보다 `dp_mod_ep * ep`가 클 때 (expert dim으로 더 쪼갤 수 없을 때),
dim 0 대신 dim 1로 shard.
**torchtitan 코드** (`torchtitan/models/llama4/infra/parallelize.py:339-359`):
```python
# NOTE: EP alreadys shards the routed experts on dim 0 (num_experts).
# When dp_mod_ep * ep > num_experts, FSDP default dim-0 sharding
# causes inefficiency, so we choose to do FSDP sharding on dim-1.
_experts_shard_placement_fn = None
if (
dp_mod_ep_mesh.size() * ep_degree
> transformer_block.moe.experts.num_experts
):
_experts_shard_placement_fn = lambda param: Shard(1)
fully_shard(
transformer_block.moe.experts,
**fsdp_mod_ep_config, # mesh=dp_mod_ep_mesh
reshard_after_forward=reshard_after_forward,
shard_placement_fn=_experts_shard_placement_fn,
)
```
`dp_mod_ep_mesh` 구성 (`parallelize.py:140-159`):
```python
dp_mod_ep_mesh_dim_names = []
if parallel_dims.ep_enabled:
if parallel_dims.dp_replicate_enabled:
dp_mod_ep_mesh_dim_names.append("dp_replicate")
dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep")
# → dp_mod_ep_mesh = world_mesh[tuple(dp_mod_ep_mesh_dim_names)]
```
### 실제 placement 검증 결과
8 GPUs, `num_experts=2`, `etp=1` 기준:
#### num_experts=8 (기본)
모든 config에서 expert weights는 **dim 0 (expert dim)으로만 shard**:
| Config | Expert Placements | Mesh |
|--------|-------------------|------|
| ep=8 | `[Shard(0)]` | `[ep=8]` |
| ep=4, fsdp=2 | `[_StridedShard(0), Shard(0)]` | `[dp_shard_mod_ep=2, ep=4]` |
| ep=2, fsdp=4 | `[_StridedShard(0), Shard(0)]` | `[dp_shard_mod_ep=4, ep=2]` |
| ep=2, hsdp=2+2 | `[Replicate(), _StridedShard(0), Shard(0)]` | `[dp_rep=2, dp_shard_mod_ep=2, ep=2]` |
EFSDP는 `_StridedShard(dim=0)`, EP는 `Shard(dim=0)`. 비-dim-0 shard 없음.
#### num_experts=2 (expert 수 < EFSDP shard count)
`dp_mod_ep * ep > num_experts` 조건 충족 시 **EFSDP가 Shard(1)로 전환**:
| Config | 조건 | Expert Placements | Mesh |
|--------|------|-------------------|------|
| ep=2, fsdp=4 | 4*2=8 > 2 | `[Shard(1), Shard(0)]` | `[dp_shard_mod_ep=4, ep=2]` |
| ep=2, hsdp=2+2 | 2*2=4 > 2 | `[Replicate(), Shard(1), Shard(0)]` | `[dp_rep=2, dp_shard_mod_ep=2, ep=2]` |
- EFSDP: `Shard(1)` on `dp_shard_mod_ep` → out_dim을 shard (w1: 2816/4=704)
- EP: `Shard(0)` on `ep` → expert dim을 shard (2/2=1)
- `_StridedShard`가 아닌 일반 `Shard` 사용
## Gradient Clipping with EP
EP parameter와 non-EP parameter의 gradient norm을 별도로 계산 후 합산:
```python
# distributed/utils.py: _clip_grad_norm_with_ep
ep_norm = get_total_norm(ep_grads, ...)
non_ep_norm = get_total_norm(non_ep_grads, ...)
total_norm = (ep_norm**p + non_ep_norm**p) ** (1/p)
```
EP parameter 판별: `device_mesh.mesh_dim_names`에 "ep" 포함 여부.
## Muon optimizer에서의 처리
현재 Muon optimizer의 MoE 지원:
1. **`_expand_expert_params`**: 3D expert weight를 expert dim (dim 0)으로 split하여 2D param으로 확장
2. **TP가 있을 때**: non-dim-0 shard (TP)를 TP submesh에 DTensor로 wrap
- 3D `(Shard(0), Shard(1))` → 2D `(Shard(0),)` on TP submesh
3. **`construct_shard_mesh` fast path**: 1D submesh에서 `dist.new_group()` deadlock 방지
### Muon이 지원하는 config
| Config | 지원 | 비고 |
|--------|------|------|
| TP Only (EP=1) | O | expert를 TP submesh DTensor로 처리 |
| EP Only (TP=1) | O | expert를 plain tensor로 처리 (base mode) |
| FSDP + TP | O | FSDP는 expert dim, TP는 out/in dim |
| HSDP + TP | O | Replicate + FSDP + TP |
| EP Only (많은 experts) | O | EFSDP `Shard(0)` → plain tensor |
| EP + FSDP (적은 experts) | 미테스트 | EFSDP `Shard(1)` → 아래 참조 |
| EP + TP (ETP=TP) | 미테스트 | 2D expert DTensor `[Shard(0), Shard(1/2)]` |
| EP + TP (ETP=1) | 미테스트 | EP mesh에 TP가 포함된 경우 |
### EFSDP Shard(1)과 Muon의 호환성
Muon은 placement-agnostic. `_expand_expert_params`의 non-dim-0 shard 처리 로직이
TP뿐 아니라 EFSDP `Shard(1)`에도 동일하게 적용됨 (변수명만 `tp_*`일 뿐 로직은 generic):
```
3D: (Shard(1), Shard(0)) on [dp_shard_mod_ep=4, ep=2]
local shape: (1, 704, 2048)
_expand_expert_params:
1. non-dim-0 shard 탐색 → Shard(1) on dp_shard_mod_ep
2. submesh 추출 → dp_shard_mod_ep (1D, size 4)
3. dim 0 split → (704, 2048)
4. DTensor wrap → Shard(0) on dp_shard_mod_ep
= 일반 FSDP sharded 2D 텐서와 동일
→ parallel()/distributed_muon()이 all-gather → Newton-Schulz → scatter 처리.
construct_shard_mesh fast path 적용 (1D submesh, deadlock 없음).
```
|