Kernels
optimizer / docs /expert_parallel.md
wyldecat's picture
Refactor pipeline to async generator pattern (#16)
33929c0 unverified
# Expert Parallelism in torchtitan
torchtitan (0.2.0)의 expert parallelism 구현을 정리한 문서.
Muon optimizer의 MoE 지원에 필요한 배경 지식.
Reference: `torchtitan/distributed/expert_parallel.py`, `torchtitan/distributed/parallel_dims.py`
## Overview
torchtitan은 MoE expert weights에 대해 4가지 parallelism 전략을 제공:
| Config | TP | EP | ETP | Expert Weight Placements | Token Dispatch |
|--------|----|----|-----|--------------------------|----------------|
| TP Only | >1 | 1 | - | `[Shard(1/2)]` on TP mesh | None |
| EP Only | 1 | >1 | - | `[Shard(0)]` on EP mesh | All-to-all |
| EP+ETP (etp=tp) | >1 | >1 | =tp | `[Shard(0), Shard(1/2)]` on [EP, TP] mesh | All-to-all on EP |
| EP+ETP (etp=1) | >1 | >1 | 1 | `[Shard(0)]` on EP mesh | Sequence parallel on TP |
Expert weights shape: `(num_experts, out_dim, in_dim)` (w1, w3) / `(num_experts, in_dim, out_dim)` (w2).
## EP가 dp_shard를 빌리는 구조
EP는 새로운 물리적 차원이 아니라 `dp_shard`를 분해해서 사용:
```
dp_shard = dp_shard_mod_ep * dp_shard_in_ep
ETP=TP일 때: ep = dp_shard_in_ep * cp
ETP=1일 때: ep = dp_shard_in_ep * cp * tp
```
기존 mesh `[pp, dp_replicate, dp_shard, cp, tp]`가 EP 활성화 시:
```
[pp, dp_replicate, dp_shard_mod_ep, dp_shard_in_ep, cp, tp]
```
로 확장됨. `dp_shard_mod_ep`는 값이 1이어도 mesh에 유지 (FSDP wrapping 일관성).
### 예시: 8 GPUs, ep=4, dp_shard=8, tp=1, cp=1
```
dp_shard_in_ep = ep / cp = 4
dp_shard_mod_ep = dp_shard * cp / ep = 2
mesh: [dp_shard_mod_ep=2, dp_shard_in_ep=4]
EP mesh: [dp_shard_in_ep=4] → expert들을 4-way로 분배
FSDP mesh: [dp_shard_mod_ep=2] → expert FSDP는 2-way로 shard
```
## Submesh 매핑
```python
# Data loading (no communication)
dp = [dp_replicate, dp_shard_mod_ep, dp_shard_in_ep]
# Non-expert parameter sharding (FSDP)
dp_shard_cp = [dp_shard_mod_ep, dp_shard_in_ep, cp]
# Expert parameter sharding (EFSDP) — dp_shard_in_ep 제외
dp_mod_ep = [dp_replicate?, dp_shard_mod_ep]
# Expert parallelism mesh
ep = [dp_shard_in_ep, cp, (tp if etp==1)]
# Loss all-reduce
dp_cp = [dp_replicate, dp_shard_mod_ep, dp_shard_in_ep, cp]
```
## 4가지 전략 상세
### 1. TensorParallel (TP Only, EP=1)
EP 없이 TP만 사용. Expert weights를 TP mesh에서 column/row-wise sharding:
```python
# expert_parallel.py: TensorParallel
w1: [Shard(1)] on TP mesh # column-wise (out_dim)
w2: [Shard(2)] on TP mesh # row-wise (out_dim, 3D에서 dim 2)
w3: [Shard(1)] on TP mesh # column-wise (out_dim)
```
Token dispatch 없음. 일반 TP와 동일하게 동작.
### 2. ExpertParallel (EP Only, TP=1)
Expert dim (dim 0)으로 sharding. Token all-to-all dispatch:
```python
# expert_parallel.py: ExpertParallel
w1, w2, w3: [Shard(0)] on EP mesh # expert dim으로 분배
```
Forward pass:
1. Router가 각 token을 expert에 할당
2. `all_to_all_single`으로 token을 해당 expert의 rank로 dispatch
3. 각 rank가 local expert에서 compute
4. `all_to_all_single`으로 결과를 원래 rank로 combine
### 3. ExpertTensorParallel (EP+TP, ETP=TP)
EP와 TP를 동시에 2D로 적용:
```python
# expert_parallel.py: ExpertTensorParallel (extends ExpertParallel)
w1: [Shard(0), Shard(1)] on [EP, TP] mesh # expert + column
w2: [Shard(0), Shard(2)] on [EP, TP] mesh # expert + row
w3: [Shard(0), Shard(1)] on [EP, TP] mesh # expert + column
```
Token dispatch:
1. TP mesh에서 input을 Replicate (gradient는 Partial)
2. EP mesh에서 all-to-all dispatch (ExpertParallel과 동일)
3. All-to-all은 EP mesh에서만 발생, TP 통신은 weight sharding으로 처리
### 4. ReordererSequenceParallel (EP+TP, ETP=1)
TP hardware를 EP에 빌려줌. TP mesh가 sequence parallel로 동작:
```python
# expert_parallel.py: ReordererSequenceParallel
# Expert weights: [Shard(0)] on EP mesh (TP 안 씀)
# Token split: batch*seq_len을 TP rank 수로 나눠서 분배
# EP mesh = [dp_shard_in_ep, cp, tp] ← tp가 EP에 포함됨
```
TP rank들이 token을 나눠 처리 (sequence parallel). Expert weight에는 TP sharding 없음.
## EFSDP (Expert FSDP)
Expert parameter에 대한 FSDP는 non-expert parameter와 **다른 mesh**를 사용:
```python
# parallelize.py: apply_fsdp
# Non-expert: dp_shard_cp mesh 전체로 shard
fully_shard(transformer_block, mesh=dp_shard_cp_mesh)
# Expert (EP 활성화 시): dp_mod_ep mesh로만 shard
# dp_shard_in_ep는 이미 EP에서 사용 중이므로 제외
fully_shard(transformer_block.moe.experts, mesh=dp_mod_ep_mesh)
```
### Dynamic shard placement
Expert 수보다 `dp_mod_ep * ep`가 클 때 (expert dim으로 더 쪼갤 수 없을 때),
dim 0 대신 dim 1로 shard.
**torchtitan 코드** (`torchtitan/models/llama4/infra/parallelize.py:339-359`):
```python
# NOTE: EP alreadys shards the routed experts on dim 0 (num_experts).
# When dp_mod_ep * ep > num_experts, FSDP default dim-0 sharding
# causes inefficiency, so we choose to do FSDP sharding on dim-1.
_experts_shard_placement_fn = None
if (
dp_mod_ep_mesh.size() * ep_degree
> transformer_block.moe.experts.num_experts
):
_experts_shard_placement_fn = lambda param: Shard(1)
fully_shard(
transformer_block.moe.experts,
**fsdp_mod_ep_config, # mesh=dp_mod_ep_mesh
reshard_after_forward=reshard_after_forward,
shard_placement_fn=_experts_shard_placement_fn,
)
```
`dp_mod_ep_mesh` 구성 (`parallelize.py:140-159`):
```python
dp_mod_ep_mesh_dim_names = []
if parallel_dims.ep_enabled:
if parallel_dims.dp_replicate_enabled:
dp_mod_ep_mesh_dim_names.append("dp_replicate")
dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep")
# → dp_mod_ep_mesh = world_mesh[tuple(dp_mod_ep_mesh_dim_names)]
```
### 실제 placement 검증 결과
8 GPUs, `num_experts=2`, `etp=1` 기준:
#### num_experts=8 (기본)
모든 config에서 expert weights는 **dim 0 (expert dim)으로만 shard**:
| Config | Expert Placements | Mesh |
|--------|-------------------|------|
| ep=8 | `[Shard(0)]` | `[ep=8]` |
| ep=4, fsdp=2 | `[_StridedShard(0), Shard(0)]` | `[dp_shard_mod_ep=2, ep=4]` |
| ep=2, fsdp=4 | `[_StridedShard(0), Shard(0)]` | `[dp_shard_mod_ep=4, ep=2]` |
| ep=2, hsdp=2+2 | `[Replicate(), _StridedShard(0), Shard(0)]` | `[dp_rep=2, dp_shard_mod_ep=2, ep=2]` |
EFSDP는 `_StridedShard(dim=0)`, EP는 `Shard(dim=0)`. 비-dim-0 shard 없음.
#### num_experts=2 (expert 수 < EFSDP shard count)
`dp_mod_ep * ep > num_experts` 조건 충족 시 **EFSDP가 Shard(1)로 전환**:
| Config | 조건 | Expert Placements | Mesh |
|--------|------|-------------------|------|
| ep=2, fsdp=4 | 4*2=8 > 2 | `[Shard(1), Shard(0)]` | `[dp_shard_mod_ep=4, ep=2]` |
| ep=2, hsdp=2+2 | 2*2=4 > 2 | `[Replicate(), Shard(1), Shard(0)]` | `[dp_rep=2, dp_shard_mod_ep=2, ep=2]` |
- EFSDP: `Shard(1)` on `dp_shard_mod_ep` → out_dim을 shard (w1: 2816/4=704)
- EP: `Shard(0)` on `ep` → expert dim을 shard (2/2=1)
- `_StridedShard`가 아닌 일반 `Shard` 사용
## Gradient Clipping with EP
EP parameter와 non-EP parameter의 gradient norm을 별도로 계산 후 합산:
```python
# distributed/utils.py: _clip_grad_norm_with_ep
ep_norm = get_total_norm(ep_grads, ...)
non_ep_norm = get_total_norm(non_ep_grads, ...)
total_norm = (ep_norm**p + non_ep_norm**p) ** (1/p)
```
EP parameter 판별: `device_mesh.mesh_dim_names`에 "ep" 포함 여부.
## Muon optimizer에서의 처리
현재 Muon optimizer의 MoE 지원:
1. **`_expand_expert_params`**: 3D expert weight를 expert dim (dim 0)으로 split하여 2D param으로 확장
2. **TP가 있을 때**: non-dim-0 shard (TP)를 TP submesh에 DTensor로 wrap
- 3D `(Shard(0), Shard(1))` → 2D `(Shard(0),)` on TP submesh
3. **`construct_shard_mesh` fast path**: 1D submesh에서 `dist.new_group()` deadlock 방지
### Muon이 지원하는 config
| Config | 지원 | 비고 |
|--------|------|------|
| TP Only (EP=1) | O | expert를 TP submesh DTensor로 처리 |
| EP Only (TP=1) | O | expert를 plain tensor로 처리 (base mode) |
| FSDP + TP | O | FSDP는 expert dim, TP는 out/in dim |
| HSDP + TP | O | Replicate + FSDP + TP |
| EP Only (많은 experts) | O | EFSDP `Shard(0)` → plain tensor |
| EP + FSDP (적은 experts) | 미테스트 | EFSDP `Shard(1)` → 아래 참조 |
| EP + TP (ETP=TP) | 미테스트 | 2D expert DTensor `[Shard(0), Shard(1/2)]` |
| EP + TP (ETP=1) | 미테스트 | EP mesh에 TP가 포함된 경우 |
### EFSDP Shard(1)과 Muon의 호환성
Muon은 placement-agnostic. `_expand_expert_params`의 non-dim-0 shard 처리 로직이
TP뿐 아니라 EFSDP `Shard(1)`에도 동일하게 적용됨 (변수명만 `tp_*`일 뿐 로직은 generic):
```
3D: (Shard(1), Shard(0)) on [dp_shard_mod_ep=4, ep=2]
local shape: (1, 704, 2048)
_expand_expert_params:
1. non-dim-0 shard 탐색 → Shard(1) on dp_shard_mod_ep
2. submesh 추출 → dp_shard_mod_ep (1D, size 4)
3. dim 0 split → (704, 2048)
4. DTensor wrap → Shard(0) on dp_shard_mod_ep
= 일반 FSDP sharded 2D 텐서와 동일
→ parallel()/distributed_muon()이 all-gather → Newton-Schulz → scatter 처리.
construct_shard_mesh fast path 적용 (1D submesh, deadlock 없음).
```