Expert Parallelism in torchtitan
torchtitan (0.2.0)의 expert parallelism 구현을 정리한 문서. Muon optimizer의 MoE 지원에 필요한 배경 지식.
Reference: torchtitan/distributed/expert_parallel.py, torchtitan/distributed/parallel_dims.py
Overview
torchtitan은 MoE expert weights에 대해 4가지 parallelism 전략을 제공:
| Config | TP | EP | ETP | Expert Weight Placements | Token Dispatch |
|---|---|---|---|---|---|
| TP Only | >1 | 1 | - | [Shard(1/2)] on TP mesh |
None |
| EP Only | 1 | >1 | - | [Shard(0)] on EP mesh |
All-to-all |
| EP+ETP (etp=tp) | >1 | >1 | =tp | [Shard(0), Shard(1/2)] on [EP, TP] mesh |
All-to-all on EP |
| EP+ETP (etp=1) | >1 | >1 | 1 | [Shard(0)] on EP mesh |
Sequence parallel on TP |
Expert weights shape: (num_experts, out_dim, in_dim) (w1, w3) / (num_experts, in_dim, out_dim) (w2).
EP가 dp_shard를 빌리는 구조
EP는 새로운 물리적 차원이 아니라 dp_shard를 분해해서 사용:
dp_shard = dp_shard_mod_ep * dp_shard_in_ep
ETP=TP일 때: ep = dp_shard_in_ep * cp
ETP=1일 때: ep = dp_shard_in_ep * cp * tp
기존 mesh [pp, dp_replicate, dp_shard, cp, tp]가 EP 활성화 시:
[pp, dp_replicate, dp_shard_mod_ep, dp_shard_in_ep, cp, tp]
로 확장됨. dp_shard_mod_ep는 값이 1이어도 mesh에 유지 (FSDP wrapping 일관성).
예시: 8 GPUs, ep=4, dp_shard=8, tp=1, cp=1
dp_shard_in_ep = ep / cp = 4
dp_shard_mod_ep = dp_shard * cp / ep = 2
mesh: [dp_shard_mod_ep=2, dp_shard_in_ep=4]
EP mesh: [dp_shard_in_ep=4] → expert들을 4-way로 분배
FSDP mesh: [dp_shard_mod_ep=2] → expert FSDP는 2-way로 shard
Submesh 매핑
# Data loading (no communication)
dp = [dp_replicate, dp_shard_mod_ep, dp_shard_in_ep]
# Non-expert parameter sharding (FSDP)
dp_shard_cp = [dp_shard_mod_ep, dp_shard_in_ep, cp]
# Expert parameter sharding (EFSDP) — dp_shard_in_ep 제외
dp_mod_ep = [dp_replicate?, dp_shard_mod_ep]
# Expert parallelism mesh
ep = [dp_shard_in_ep, cp, (tp if etp==1)]
# Loss all-reduce
dp_cp = [dp_replicate, dp_shard_mod_ep, dp_shard_in_ep, cp]
4가지 전략 상세
1. TensorParallel (TP Only, EP=1)
EP 없이 TP만 사용. Expert weights를 TP mesh에서 column/row-wise sharding:
# expert_parallel.py: TensorParallel
w1: [Shard(1)] on TP mesh # column-wise (out_dim)
w2: [Shard(2)] on TP mesh # row-wise (out_dim, 3D에서 dim 2)
w3: [Shard(1)] on TP mesh # column-wise (out_dim)
Token dispatch 없음. 일반 TP와 동일하게 동작.
2. ExpertParallel (EP Only, TP=1)
Expert dim (dim 0)으로 sharding. Token all-to-all dispatch:
# expert_parallel.py: ExpertParallel
w1, w2, w3: [Shard(0)] on EP mesh # expert dim으로 분배
Forward pass:
- Router가 각 token을 expert에 할당
all_to_all_single으로 token을 해당 expert의 rank로 dispatch- 각 rank가 local expert에서 compute
all_to_all_single으로 결과를 원래 rank로 combine
3. ExpertTensorParallel (EP+TP, ETP=TP)
EP와 TP를 동시에 2D로 적용:
# expert_parallel.py: ExpertTensorParallel (extends ExpertParallel)
w1: [Shard(0), Shard(1)] on [EP, TP] mesh # expert + column
w2: [Shard(0), Shard(2)] on [EP, TP] mesh # expert + row
w3: [Shard(0), Shard(1)] on [EP, TP] mesh # expert + column
Token dispatch:
- TP mesh에서 input을 Replicate (gradient는 Partial)
- EP mesh에서 all-to-all dispatch (ExpertParallel과 동일)
- All-to-all은 EP mesh에서만 발생, TP 통신은 weight sharding으로 처리
4. ReordererSequenceParallel (EP+TP, ETP=1)
TP hardware를 EP에 빌려줌. TP mesh가 sequence parallel로 동작:
# expert_parallel.py: ReordererSequenceParallel
# Expert weights: [Shard(0)] on EP mesh (TP 안 씀)
# Token split: batch*seq_len을 TP rank 수로 나눠서 분배
# EP mesh = [dp_shard_in_ep, cp, tp] ← tp가 EP에 포함됨
TP rank들이 token을 나눠 처리 (sequence parallel). Expert weight에는 TP sharding 없음.
EFSDP (Expert FSDP)
Expert parameter에 대한 FSDP는 non-expert parameter와 다른 mesh를 사용:
# parallelize.py: apply_fsdp
# Non-expert: dp_shard_cp mesh 전체로 shard
fully_shard(transformer_block, mesh=dp_shard_cp_mesh)
# Expert (EP 활성화 시): dp_mod_ep mesh로만 shard
# dp_shard_in_ep는 이미 EP에서 사용 중이므로 제외
fully_shard(transformer_block.moe.experts, mesh=dp_mod_ep_mesh)
Dynamic shard placement
Expert 수보다 dp_mod_ep * ep가 클 때 (expert dim으로 더 쪼갤 수 없을 때),
dim 0 대신 dim 1로 shard.
torchtitan 코드 (torchtitan/models/llama4/infra/parallelize.py:339-359):
# NOTE: EP alreadys shards the routed experts on dim 0 (num_experts).
# When dp_mod_ep * ep > num_experts, FSDP default dim-0 sharding
# causes inefficiency, so we choose to do FSDP sharding on dim-1.
_experts_shard_placement_fn = None
if (
dp_mod_ep_mesh.size() * ep_degree
> transformer_block.moe.experts.num_experts
):
_experts_shard_placement_fn = lambda param: Shard(1)
fully_shard(
transformer_block.moe.experts,
**fsdp_mod_ep_config, # mesh=dp_mod_ep_mesh
reshard_after_forward=reshard_after_forward,
shard_placement_fn=_experts_shard_placement_fn,
)
dp_mod_ep_mesh 구성 (parallelize.py:140-159):
dp_mod_ep_mesh_dim_names = []
if parallel_dims.ep_enabled:
if parallel_dims.dp_replicate_enabled:
dp_mod_ep_mesh_dim_names.append("dp_replicate")
dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep")
# → dp_mod_ep_mesh = world_mesh[tuple(dp_mod_ep_mesh_dim_names)]
실제 placement 검증 결과
8 GPUs, num_experts=2, etp=1 기준:
num_experts=8 (기본)
모든 config에서 expert weights는 dim 0 (expert dim)으로만 shard:
| Config | Expert Placements | Mesh |
|---|---|---|
| ep=8 | [Shard(0)] |
[ep=8] |
| ep=4, fsdp=2 | [_StridedShard(0), Shard(0)] |
[dp_shard_mod_ep=2, ep=4] |
| ep=2, fsdp=4 | [_StridedShard(0), Shard(0)] |
[dp_shard_mod_ep=4, ep=2] |
| ep=2, hsdp=2+2 | [Replicate(), _StridedShard(0), Shard(0)] |
[dp_rep=2, dp_shard_mod_ep=2, ep=2] |
EFSDP는 _StridedShard(dim=0), EP는 Shard(dim=0). 비-dim-0 shard 없음.
num_experts=2 (expert 수 < EFSDP shard count)
dp_mod_ep * ep > num_experts 조건 충족 시 EFSDP가 Shard(1)로 전환:
| Config | 조건 | Expert Placements | Mesh |
|---|---|---|---|
| ep=2, fsdp=4 | 4*2=8 > 2 | [Shard(1), Shard(0)] |
[dp_shard_mod_ep=4, ep=2] |
| ep=2, hsdp=2+2 | 2*2=4 > 2 | [Replicate(), Shard(1), Shard(0)] |
[dp_rep=2, dp_shard_mod_ep=2, ep=2] |
- EFSDP:
Shard(1)ondp_shard_mod_ep→ out_dim을 shard (w1: 2816/4=704) - EP:
Shard(0)onep→ expert dim을 shard (2/2=1) _StridedShard가 아닌 일반Shard사용
Gradient Clipping with EP
EP parameter와 non-EP parameter의 gradient norm을 별도로 계산 후 합산:
# distributed/utils.py: _clip_grad_norm_with_ep
ep_norm = get_total_norm(ep_grads, ...)
non_ep_norm = get_total_norm(non_ep_grads, ...)
total_norm = (ep_norm**p + non_ep_norm**p) ** (1/p)
EP parameter 판별: device_mesh.mesh_dim_names에 "ep" 포함 여부.
Muon optimizer에서의 처리
현재 Muon optimizer의 MoE 지원:
_expand_expert_params: 3D expert weight를 expert dim (dim 0)으로 split하여 2D param으로 확장- TP가 있을 때: non-dim-0 shard (TP)를 TP submesh에 DTensor로 wrap
- 3D
(Shard(0), Shard(1))→ 2D(Shard(0),)on TP submesh
- 3D
construct_shard_meshfast path: 1D submesh에서dist.new_group()deadlock 방지
Muon이 지원하는 config
| Config | 지원 | 비고 |
|---|---|---|
| TP Only (EP=1) | O | expert를 TP submesh DTensor로 처리 |
| EP Only (TP=1) | O | expert를 plain tensor로 처리 (base mode) |
| FSDP + TP | O | FSDP는 expert dim, TP는 out/in dim |
| HSDP + TP | O | Replicate + FSDP + TP |
| EP Only (많은 experts) | O | EFSDP Shard(0) → plain tensor |
| EP + FSDP (적은 experts) | 미테스트 | EFSDP Shard(1) → 아래 참조 |
| EP + TP (ETP=TP) | 미테스트 | 2D expert DTensor [Shard(0), Shard(1/2)] |
| EP + TP (ETP=1) | 미테스트 | EP mesh에 TP가 포함된 경우 |
EFSDP Shard(1)과 Muon의 호환성
Muon은 placement-agnostic. _expand_expert_params의 non-dim-0 shard 처리 로직이
TP뿐 아니라 EFSDP Shard(1)에도 동일하게 적용됨 (변수명만 tp_*일 뿐 로직은 generic):
3D: (Shard(1), Shard(0)) on [dp_shard_mod_ep=4, ep=2]
local shape: (1, 704, 2048)
_expand_expert_params:
1. non-dim-0 shard 탐색 → Shard(1) on dp_shard_mod_ep
2. submesh 추출 → dp_shard_mod_ep (1D, size 4)
3. dim 0 split → (704, 2048)
4. DTensor wrap → Shard(0) on dp_shard_mod_ep
= 일반 FSDP sharded 2D 텐서와 동일
→ parallel()/distributed_muon()이 all-gather → Newton-Schulz → scatter 처리.
construct_shard_mesh fast path 적용 (1D submesh, deadlock 없음).