| # Expert Parallelism in torchtitan | |
| torchtitan (0.2.0)의 expert parallelism 구현을 정리한 문서. | |
| Muon optimizer의 MoE 지원에 필요한 배경 지식. | |
| Reference: `torchtitan/distributed/expert_parallel.py`, `torchtitan/distributed/parallel_dims.py` | |
| ## Overview | |
| torchtitan은 MoE expert weights에 대해 4가지 parallelism 전략을 제공: | |
| | Config | TP | EP | ETP | Expert Weight Placements | Token Dispatch | | |
| |--------|----|----|-----|--------------------------|----------------| | |
| | TP Only | >1 | 1 | - | `[Shard(1/2)]` on TP mesh | None | | |
| | EP Only | 1 | >1 | - | `[Shard(0)]` on EP mesh | All-to-all | | |
| | EP+ETP (etp=tp) | >1 | >1 | =tp | `[Shard(0), Shard(1/2)]` on [EP, TP] mesh | All-to-all on EP | | |
| | EP+ETP (etp=1) | >1 | >1 | 1 | `[Shard(0)]` on EP mesh | Sequence parallel on TP | | |
| Expert weights shape: `(num_experts, out_dim, in_dim)` (w1, w3) / `(num_experts, in_dim, out_dim)` (w2). | |
| ## EP가 dp_shard를 빌리는 구조 | |
| EP는 새로운 물리적 차원이 아니라 `dp_shard`를 분해해서 사용: | |
| ``` | |
| dp_shard = dp_shard_mod_ep * dp_shard_in_ep | |
| ETP=TP일 때: ep = dp_shard_in_ep * cp | |
| ETP=1일 때: ep = dp_shard_in_ep * cp * tp | |
| ``` | |
| 기존 mesh `[pp, dp_replicate, dp_shard, cp, tp]`가 EP 활성화 시: | |
| ``` | |
| [pp, dp_replicate, dp_shard_mod_ep, dp_shard_in_ep, cp, tp] | |
| ``` | |
| 로 확장됨. `dp_shard_mod_ep`는 값이 1이어도 mesh에 유지 (FSDP wrapping 일관성). | |
| ### 예시: 8 GPUs, ep=4, dp_shard=8, tp=1, cp=1 | |
| ``` | |
| dp_shard_in_ep = ep / cp = 4 | |
| dp_shard_mod_ep = dp_shard * cp / ep = 2 | |
| mesh: [dp_shard_mod_ep=2, dp_shard_in_ep=4] | |
| EP mesh: [dp_shard_in_ep=4] → expert들을 4-way로 분배 | |
| FSDP mesh: [dp_shard_mod_ep=2] → expert FSDP는 2-way로 shard | |
| ``` | |
| ## Submesh 매핑 | |
| ```python | |
| # Data loading (no communication) | |
| dp = [dp_replicate, dp_shard_mod_ep, dp_shard_in_ep] | |
| # Non-expert parameter sharding (FSDP) | |
| dp_shard_cp = [dp_shard_mod_ep, dp_shard_in_ep, cp] | |
| # Expert parameter sharding (EFSDP) — dp_shard_in_ep 제외 | |
| dp_mod_ep = [dp_replicate?, dp_shard_mod_ep] | |
| # Expert parallelism mesh | |
| ep = [dp_shard_in_ep, cp, (tp if etp==1)] | |
| # Loss all-reduce | |
| dp_cp = [dp_replicate, dp_shard_mod_ep, dp_shard_in_ep, cp] | |
| ``` | |
| ## 4가지 전략 상세 | |
| ### 1. TensorParallel (TP Only, EP=1) | |
| EP 없이 TP만 사용. Expert weights를 TP mesh에서 column/row-wise sharding: | |
| ```python | |
| # expert_parallel.py: TensorParallel | |
| w1: [Shard(1)] on TP mesh # column-wise (out_dim) | |
| w2: [Shard(2)] on TP mesh # row-wise (out_dim, 3D에서 dim 2) | |
| w3: [Shard(1)] on TP mesh # column-wise (out_dim) | |
| ``` | |
| Token dispatch 없음. 일반 TP와 동일하게 동작. | |
| ### 2. ExpertParallel (EP Only, TP=1) | |
| Expert dim (dim 0)으로 sharding. Token all-to-all dispatch: | |
| ```python | |
| # expert_parallel.py: ExpertParallel | |
| w1, w2, w3: [Shard(0)] on EP mesh # expert dim으로 분배 | |
| ``` | |
| Forward pass: | |
| 1. Router가 각 token을 expert에 할당 | |
| 2. `all_to_all_single`으로 token을 해당 expert의 rank로 dispatch | |
| 3. 각 rank가 local expert에서 compute | |
| 4. `all_to_all_single`으로 결과를 원래 rank로 combine | |
| ### 3. ExpertTensorParallel (EP+TP, ETP=TP) | |
| EP와 TP를 동시에 2D로 적용: | |
| ```python | |
| # expert_parallel.py: ExpertTensorParallel (extends ExpertParallel) | |
| w1: [Shard(0), Shard(1)] on [EP, TP] mesh # expert + column | |
| w2: [Shard(0), Shard(2)] on [EP, TP] mesh # expert + row | |
| w3: [Shard(0), Shard(1)] on [EP, TP] mesh # expert + column | |
| ``` | |
| Token dispatch: | |
| 1. TP mesh에서 input을 Replicate (gradient는 Partial) | |
| 2. EP mesh에서 all-to-all dispatch (ExpertParallel과 동일) | |
| 3. All-to-all은 EP mesh에서만 발생, TP 통신은 weight sharding으로 처리 | |
| ### 4. ReordererSequenceParallel (EP+TP, ETP=1) | |
| TP hardware를 EP에 빌려줌. TP mesh가 sequence parallel로 동작: | |
| ```python | |
| # expert_parallel.py: ReordererSequenceParallel | |
| # Expert weights: [Shard(0)] on EP mesh (TP 안 씀) | |
| # Token split: batch*seq_len을 TP rank 수로 나눠서 분배 | |
| # EP mesh = [dp_shard_in_ep, cp, tp] ← tp가 EP에 포함됨 | |
| ``` | |
| TP rank들이 token을 나눠 처리 (sequence parallel). Expert weight에는 TP sharding 없음. | |
| ## EFSDP (Expert FSDP) | |
| Expert parameter에 대한 FSDP는 non-expert parameter와 **다른 mesh**를 사용: | |
| ```python | |
| # parallelize.py: apply_fsdp | |
| # Non-expert: dp_shard_cp mesh 전체로 shard | |
| fully_shard(transformer_block, mesh=dp_shard_cp_mesh) | |
| # Expert (EP 활성화 시): dp_mod_ep mesh로만 shard | |
| # dp_shard_in_ep는 이미 EP에서 사용 중이므로 제외 | |
| fully_shard(transformer_block.moe.experts, mesh=dp_mod_ep_mesh) | |
| ``` | |
| ### Dynamic shard placement | |
| Expert 수보다 `dp_mod_ep * ep`가 클 때 (expert dim으로 더 쪼갤 수 없을 때), | |
| dim 0 대신 dim 1로 shard. | |
| **torchtitan 코드** (`torchtitan/models/llama4/infra/parallelize.py:339-359`): | |
| ```python | |
| # NOTE: EP alreadys shards the routed experts on dim 0 (num_experts). | |
| # When dp_mod_ep * ep > num_experts, FSDP default dim-0 sharding | |
| # causes inefficiency, so we choose to do FSDP sharding on dim-1. | |
| _experts_shard_placement_fn = None | |
| if ( | |
| dp_mod_ep_mesh.size() * ep_degree | |
| > transformer_block.moe.experts.num_experts | |
| ): | |
| _experts_shard_placement_fn = lambda param: Shard(1) | |
| fully_shard( | |
| transformer_block.moe.experts, | |
| **fsdp_mod_ep_config, # mesh=dp_mod_ep_mesh | |
| reshard_after_forward=reshard_after_forward, | |
| shard_placement_fn=_experts_shard_placement_fn, | |
| ) | |
| ``` | |
| `dp_mod_ep_mesh` 구성 (`parallelize.py:140-159`): | |
| ```python | |
| dp_mod_ep_mesh_dim_names = [] | |
| if parallel_dims.ep_enabled: | |
| if parallel_dims.dp_replicate_enabled: | |
| dp_mod_ep_mesh_dim_names.append("dp_replicate") | |
| dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep") | |
| # → dp_mod_ep_mesh = world_mesh[tuple(dp_mod_ep_mesh_dim_names)] | |
| ``` | |
| ### 실제 placement 검증 결과 | |
| 8 GPUs, `num_experts=2`, `etp=1` 기준: | |
| #### num_experts=8 (기본) | |
| 모든 config에서 expert weights는 **dim 0 (expert dim)으로만 shard**: | |
| | Config | Expert Placements | Mesh | | |
| |--------|-------------------|------| | |
| | ep=8 | `[Shard(0)]` | `[ep=8]` | | |
| | ep=4, fsdp=2 | `[_StridedShard(0), Shard(0)]` | `[dp_shard_mod_ep=2, ep=4]` | | |
| | ep=2, fsdp=4 | `[_StridedShard(0), Shard(0)]` | `[dp_shard_mod_ep=4, ep=2]` | | |
| | ep=2, hsdp=2+2 | `[Replicate(), _StridedShard(0), Shard(0)]` | `[dp_rep=2, dp_shard_mod_ep=2, ep=2]` | | |
| EFSDP는 `_StridedShard(dim=0)`, EP는 `Shard(dim=0)`. 비-dim-0 shard 없음. | |
| #### num_experts=2 (expert 수 < EFSDP shard count) | |
| `dp_mod_ep * ep > num_experts` 조건 충족 시 **EFSDP가 Shard(1)로 전환**: | |
| | Config | 조건 | Expert Placements | Mesh | | |
| |--------|------|-------------------|------| | |
| | ep=2, fsdp=4 | 4*2=8 > 2 | `[Shard(1), Shard(0)]` | `[dp_shard_mod_ep=4, ep=2]` | | |
| | ep=2, hsdp=2+2 | 2*2=4 > 2 | `[Replicate(), Shard(1), Shard(0)]` | `[dp_rep=2, dp_shard_mod_ep=2, ep=2]` | | |
| - EFSDP: `Shard(1)` on `dp_shard_mod_ep` → out_dim을 shard (w1: 2816/4=704) | |
| - EP: `Shard(0)` on `ep` → expert dim을 shard (2/2=1) | |
| - `_StridedShard`가 아닌 일반 `Shard` 사용 | |
| ## Gradient Clipping with EP | |
| EP parameter와 non-EP parameter의 gradient norm을 별도로 계산 후 합산: | |
| ```python | |
| # distributed/utils.py: _clip_grad_norm_with_ep | |
| ep_norm = get_total_norm(ep_grads, ...) | |
| non_ep_norm = get_total_norm(non_ep_grads, ...) | |
| total_norm = (ep_norm**p + non_ep_norm**p) ** (1/p) | |
| ``` | |
| EP parameter 판별: `device_mesh.mesh_dim_names`에 "ep" 포함 여부. | |
| ## Muon optimizer에서의 처리 | |
| 현재 Muon optimizer의 MoE 지원: | |
| 1. **`_expand_expert_params`**: 3D expert weight를 expert dim (dim 0)으로 split하여 2D param으로 확장 | |
| 2. **TP가 있을 때**: non-dim-0 shard (TP)를 TP submesh에 DTensor로 wrap | |
| - 3D `(Shard(0), Shard(1))` → 2D `(Shard(0),)` on TP submesh | |
| 3. **`construct_shard_mesh` fast path**: 1D submesh에서 `dist.new_group()` deadlock 방지 | |
| ### Muon이 지원하는 config | |
| | Config | 지원 | 비고 | | |
| |--------|------|------| | |
| | TP Only (EP=1) | O | expert를 TP submesh DTensor로 처리 | | |
| | EP Only (TP=1) | O | expert를 plain tensor로 처리 (base mode) | | |
| | FSDP + TP | O | FSDP는 expert dim, TP는 out/in dim | | |
| | HSDP + TP | O | Replicate + FSDP + TP | | |
| | EP Only (많은 experts) | O | EFSDP `Shard(0)` → plain tensor | | |
| | EP + FSDP (적은 experts) | 미테스트 | EFSDP `Shard(1)` → 아래 참조 | | |
| | EP + TP (ETP=TP) | 미테스트 | 2D expert DTensor `[Shard(0), Shard(1/2)]` | | |
| | EP + TP (ETP=1) | 미테스트 | EP mesh에 TP가 포함된 경우 | | |
| ### EFSDP Shard(1)과 Muon의 호환성 | |
| Muon은 placement-agnostic. `_expand_expert_params`의 non-dim-0 shard 처리 로직이 | |
| TP뿐 아니라 EFSDP `Shard(1)`에도 동일하게 적용됨 (변수명만 `tp_*`일 뿐 로직은 generic): | |
| ``` | |
| 3D: (Shard(1), Shard(0)) on [dp_shard_mod_ep=4, ep=2] | |
| local shape: (1, 704, 2048) | |
| _expand_expert_params: | |
| 1. non-dim-0 shard 탐색 → Shard(1) on dp_shard_mod_ep | |
| 2. submesh 추출 → dp_shard_mod_ep (1D, size 4) | |
| 3. dim 0 split → (704, 2048) | |
| 4. DTensor wrap → Shard(0) on dp_shard_mod_ep | |
| = 일반 FSDP sharded 2D 텐서와 동일 | |
| → parallel()/distributed_muon()이 all-gather → Newton-Schulz → scatter 처리. | |
| construct_shard_mesh fast path 적용 (1D submesh, deadlock 없음). | |
| ``` | |