Skip to content

Commit e4bd151

Browse files
committed
Internal test about EP
PiperOrigin-RevId: 894290655
1 parent 44fc6d0 commit e4bd151

4 files changed

Lines changed: 641 additions & 119 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,8 @@ load_balance_loss_weight: 0.0 # weight for the load balance loss
195195
use_random_routing: false # whether to use random routing for debug/test purpose
196196
use_custom_sort_vjp: true # whether to use a custom VJP sort for efficient backward pass processing in sparse matmul
197197
use_ring_of_experts: false # whether to use ring of experts for sparse matmul expert parallelism
198+
use_iterative_moe: false # whether to use iterative routing for sparse matmul to save memory
199+
ra2a_num_chunks: 1 # number of chunks to split tokens into for iterative MoE
198200
# tunable tiling dimensions used for mlp gmm
199201
# megablox/jax ragged dot - supports forward pass only (6 configs: `wi_tile_fwd...` and `wo_tile_fwd_...`)
200202
# tokamax ragged dot - supports all 18 configs

src/maxtext/configs/types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,14 @@ class MoEGeneral(BaseModel):
650650
False,
651651
description="Whether to use Ring of Experts for sparse matmul expert parallelism.",
652652
)
653+
use_iterative_moe: bool = Field(
654+
False,
655+
description="Whether to use iterative MoE routing to save memory.",
656+
)
657+
ra2a_num_chunks: int = Field(
658+
1,
659+
description="Number of chunks for iterative MoE routing.",
660+
)
653661
use_random_routing: bool = Field(False, description="Whether to use random routing for debugging.")
654662
interleave_moe_layer_step: int = Field(1, description="Frequency of MoE layers, e.g., 2 means every 2nd layer is MoE.")
655663
expert_shard_attention_option: Literal["fsdp", "context"] = Field(

0 commit comments

Comments
 (0)