Kernels
optimizer / docs /expert_parallel.md
wyldecat's picture
Refactor pipeline to async generator pattern (#16)
33929c0 unverified

Expert Parallelism in torchtitan

torchtitan (0.2.0)의 expert parallelism 구현을 정리한 문서. Muon optimizer의 MoE 지원에 필요한 배경 지식.

Reference: torchtitan/distributed/expert_parallel.py, torchtitan/distributed/parallel_dims.py

Overview

torchtitan은 MoE expert weights에 대해 4가지 parallelism 전략을 제공:

Config TP EP ETP Expert Weight Placements Token Dispatch
TP Only >1 1 - [Shard(1/2)] on TP mesh None
EP Only 1 >1 - [Shard(0)] on EP mesh All-to-all
EP+ETP (etp=tp) >1 >1 =tp [Shard(0), Shard(1/2)] on [EP, TP] mesh All-to-all on EP
EP+ETP (etp=1) >1 >1 1 [Shard(0)] on EP mesh Sequence parallel on TP

Expert weights shape: (num_experts, out_dim, in_dim) (w1, w3) / (num_experts, in_dim, out_dim) (w2).

EP가 dp_shard를 빌리는 구조

EP는 새로운 물리적 차원이 아니라 dp_shard를 분해해서 사용:

dp_shard = dp_shard_mod_ep * dp_shard_in_ep

ETP=TP일 때:  ep = dp_shard_in_ep * cp
ETP=1일 때:   ep = dp_shard_in_ep * cp * tp

기존 mesh [pp, dp_replicate, dp_shard, cp, tp]가 EP 활성화 시:

[pp, dp_replicate, dp_shard_mod_ep, dp_shard_in_ep, cp, tp]

로 확장됨. dp_shard_mod_ep는 값이 1이어도 mesh에 유지 (FSDP wrapping 일관성).

예시: 8 GPUs, ep=4, dp_shard=8, tp=1, cp=1

dp_shard_in_ep = ep / cp = 4
dp_shard_mod_ep = dp_shard * cp / ep = 2

mesh: [dp_shard_mod_ep=2, dp_shard_in_ep=4]
EP mesh: [dp_shard_in_ep=4]  → expert들을 4-way로 분배
FSDP mesh: [dp_shard_mod_ep=2]  → expert FSDP는 2-way로 shard

Submesh 매핑

# Data loading (no communication)
dp = [dp_replicate, dp_shard_mod_ep, dp_shard_in_ep]

# Non-expert parameter sharding (FSDP)
dp_shard_cp = [dp_shard_mod_ep, dp_shard_in_ep, cp]

# Expert parameter sharding (EFSDP) — dp_shard_in_ep 제외
dp_mod_ep = [dp_replicate?, dp_shard_mod_ep]

# Expert parallelism mesh
ep = [dp_shard_in_ep, cp, (tp if etp==1)]

# Loss all-reduce
dp_cp = [dp_replicate, dp_shard_mod_ep, dp_shard_in_ep, cp]

4가지 전략 상세

1. TensorParallel (TP Only, EP=1)

EP 없이 TP만 사용. Expert weights를 TP mesh에서 column/row-wise sharding:

# expert_parallel.py: TensorParallel
w1: [Shard(1)] on TP mesh   # column-wise (out_dim)
w2: [Shard(2)] on TP mesh   # row-wise (out_dim, 3D에서 dim 2)
w3: [Shard(1)] on TP mesh   # column-wise (out_dim)

Token dispatch 없음. 일반 TP와 동일하게 동작.

2. ExpertParallel (EP Only, TP=1)

Expert dim (dim 0)으로 sharding. Token all-to-all dispatch:

# expert_parallel.py: ExpertParallel
w1, w2, w3: [Shard(0)] on EP mesh  # expert dim으로 분배

Forward pass:

  1. Router가 각 token을 expert에 할당
  2. all_to_all_single으로 token을 해당 expert의 rank로 dispatch
  3. 각 rank가 local expert에서 compute
  4. all_to_all_single으로 결과를 원래 rank로 combine

3. ExpertTensorParallel (EP+TP, ETP=TP)

EP와 TP를 동시에 2D로 적용:

# expert_parallel.py: ExpertTensorParallel (extends ExpertParallel)
w1: [Shard(0), Shard(1)] on [EP, TP] mesh   # expert + column
w2: [Shard(0), Shard(2)] on [EP, TP] mesh   # expert + row
w3: [Shard(0), Shard(1)] on [EP, TP] mesh   # expert + column

Token dispatch:

  1. TP mesh에서 input을 Replicate (gradient는 Partial)
  2. EP mesh에서 all-to-all dispatch (ExpertParallel과 동일)
  3. All-to-all은 EP mesh에서만 발생, TP 통신은 weight sharding으로 처리

4. ReordererSequenceParallel (EP+TP, ETP=1)

TP hardware를 EP에 빌려줌. TP mesh가 sequence parallel로 동작:

# expert_parallel.py: ReordererSequenceParallel
# Expert weights: [Shard(0)] on EP mesh (TP 안 씀)
# Token split: batch*seq_len을 TP rank 수로 나눠서 분배

# EP mesh = [dp_shard_in_ep, cp, tp]  ← tp가 EP에 포함됨

TP rank들이 token을 나눠 처리 (sequence parallel). Expert weight에는 TP sharding 없음.

EFSDP (Expert FSDP)

Expert parameter에 대한 FSDP는 non-expert parameter와 다른 mesh를 사용:

# parallelize.py: apply_fsdp
# Non-expert: dp_shard_cp mesh 전체로 shard
fully_shard(transformer_block, mesh=dp_shard_cp_mesh)

# Expert (EP 활성화 시): dp_mod_ep mesh로만 shard
# dp_shard_in_ep는 이미 EP에서 사용 중이므로 제외
fully_shard(transformer_block.moe.experts, mesh=dp_mod_ep_mesh)

Dynamic shard placement

Expert 수보다 dp_mod_ep * ep가 클 때 (expert dim으로 더 쪼갤 수 없을 때), dim 0 대신 dim 1로 shard.

torchtitan 코드 (torchtitan/models/llama4/infra/parallelize.py:339-359):

# NOTE: EP alreadys shards the routed experts on dim 0 (num_experts).
#       When dp_mod_ep * ep > num_experts, FSDP default dim-0 sharding
#       causes inefficiency, so we choose to do FSDP sharding on dim-1.
_experts_shard_placement_fn = None
if (
    dp_mod_ep_mesh.size() * ep_degree
    > transformer_block.moe.experts.num_experts
):
    _experts_shard_placement_fn = lambda param: Shard(1)

fully_shard(
    transformer_block.moe.experts,
    **fsdp_mod_ep_config,   # mesh=dp_mod_ep_mesh
    reshard_after_forward=reshard_after_forward,
    shard_placement_fn=_experts_shard_placement_fn,
)

dp_mod_ep_mesh 구성 (parallelize.py:140-159):

dp_mod_ep_mesh_dim_names = []
if parallel_dims.ep_enabled:
    if parallel_dims.dp_replicate_enabled:
        dp_mod_ep_mesh_dim_names.append("dp_replicate")
    dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep")
# → dp_mod_ep_mesh = world_mesh[tuple(dp_mod_ep_mesh_dim_names)]

실제 placement 검증 결과

8 GPUs, num_experts=2, etp=1 기준:

num_experts=8 (기본)

모든 config에서 expert weights는 dim 0 (expert dim)으로만 shard:

Config Expert Placements Mesh
ep=8 [Shard(0)] [ep=8]
ep=4, fsdp=2 [_StridedShard(0), Shard(0)] [dp_shard_mod_ep=2, ep=4]
ep=2, fsdp=4 [_StridedShard(0), Shard(0)] [dp_shard_mod_ep=4, ep=2]
ep=2, hsdp=2+2 [Replicate(), _StridedShard(0), Shard(0)] [dp_rep=2, dp_shard_mod_ep=2, ep=2]

EFSDP는 _StridedShard(dim=0), EP는 Shard(dim=0). 비-dim-0 shard 없음.

num_experts=2 (expert 수 < EFSDP shard count)

dp_mod_ep * ep > num_experts 조건 충족 시 EFSDP가 Shard(1)로 전환:

Config 조건 Expert Placements Mesh
ep=2, fsdp=4 4*2=8 > 2 [Shard(1), Shard(0)] [dp_shard_mod_ep=4, ep=2]
ep=2, hsdp=2+2 2*2=4 > 2 [Replicate(), Shard(1), Shard(0)] [dp_rep=2, dp_shard_mod_ep=2, ep=2]
  • EFSDP: Shard(1) on dp_shard_mod_ep → out_dim을 shard (w1: 2816/4=704)
  • EP: Shard(0) on ep → expert dim을 shard (2/2=1)
  • _StridedShard가 아닌 일반 Shard 사용

Gradient Clipping with EP

EP parameter와 non-EP parameter의 gradient norm을 별도로 계산 후 합산:

# distributed/utils.py: _clip_grad_norm_with_ep
ep_norm = get_total_norm(ep_grads, ...)
non_ep_norm = get_total_norm(non_ep_grads, ...)
total_norm = (ep_norm**p + non_ep_norm**p) ** (1/p)

EP parameter 판별: device_mesh.mesh_dim_names에 "ep" 포함 여부.

Muon optimizer에서의 처리

현재 Muon optimizer의 MoE 지원:

  1. _expand_expert_params: 3D expert weight를 expert dim (dim 0)으로 split하여 2D param으로 확장
  2. TP가 있을 때: non-dim-0 shard (TP)를 TP submesh에 DTensor로 wrap
    • 3D (Shard(0), Shard(1)) → 2D (Shard(0),) on TP submesh
  3. construct_shard_mesh fast path: 1D submesh에서 dist.new_group() deadlock 방지

Muon이 지원하는 config

Config 지원 비고
TP Only (EP=1) O expert를 TP submesh DTensor로 처리
EP Only (TP=1) O expert를 plain tensor로 처리 (base mode)
FSDP + TP O FSDP는 expert dim, TP는 out/in dim
HSDP + TP O Replicate + FSDP + TP
EP Only (많은 experts) O EFSDP Shard(0) → plain tensor
EP + FSDP (적은 experts) 미테스트 EFSDP Shard(1) → 아래 참조
EP + TP (ETP=TP) 미테스트 2D expert DTensor [Shard(0), Shard(1/2)]
EP + TP (ETP=1) 미테스트 EP mesh에 TP가 포함된 경우

EFSDP Shard(1)과 Muon의 호환성

Muon은 placement-agnostic. _expand_expert_params의 non-dim-0 shard 처리 로직이 TP뿐 아니라 EFSDP Shard(1)에도 동일하게 적용됨 (변수명만 tp_*일 뿐 로직은 generic):

3D: (Shard(1), Shard(0)) on [dp_shard_mod_ep=4, ep=2]
    local shape: (1, 704, 2048)

_expand_expert_params:
  1. non-dim-0 shard 탐색 → Shard(1) on dp_shard_mod_ep
  2. submesh 추출 → dp_shard_mod_ep (1D, size 4)
  3. dim 0 split → (704, 2048)
  4. DTensor wrap → Shard(0) on dp_shard_mod_ep
     = 일반 FSDP sharded 2D 텐서와 동일

→ parallel()/distributed_muon()이 all-gather → Newton-Schulz → scatter 처리.
  construct_shard_mesh fast path 적용 (1D submesh, deadlock 없음).