Skip to content

Commit a4e2510

Browse files
Kevin WangGoogle-ML-Automation
authored andcommitted
Add support for data parallelism along data mesh axis for multi-slice scaling.
PiperOrigin-RevId: 899801953
1 parent fd1ad55 commit a4e2510

8 files changed

Lines changed: 294 additions & 111 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -987,6 +987,8 @@ optimize_mesh_for_tpu_v6e: False
987987

988988
shardy: True # Whether to use shardy XLA backend (default in Jax starting 0.7.0), or GSPMD (to be fully deprecated ~2026)
989989

990+
remove_size_one_mesh_axis_from_type: True # Whether to remove size one mesh axis from type through jax.config.
991+
990992
use_ragged_attention: False
991993
ragged_block_size: 256
992994

src/maxtext/configs/models/deepseek3-671b-batchsplit.yml

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,15 @@ rope_attention_scaling: False
5757

5858
use_batch_split_schedule: True
5959
shard_mode: "explicit"
60+
remove_size_one_mesh_axis_from_type: False
6061
override_logical_axis_rules: True
61-
mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']
62-
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']]
62+
mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'expert', 'context']
63+
data_sharding: [['data', 'stage', 'fsdp', 'expert', 'context']]
6364
logical_axis_rules: [
64-
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
65-
['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
66-
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
67-
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
65+
['activation_batch', ['data', 'fsdp', 'expert', 'context']],
66+
['activation_batch_moe', ['data', 'fsdp', 'expert', 'context']],
67+
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'expert', 'context']],
68+
['activation_kv_batch', ['data', 'fsdp', 'expert', 'context']],
6869
['activation_norm_length', []],
6970
['activation_norm_length_moe', []],
7071
['activation_heads', []],
@@ -76,14 +77,12 @@ logical_axis_rules: [
7677
['q_lora', ['fsdp']],
7778
['kv_lora', ['fsdp']],
7879
['layers', 'stage'],
79-
['q_lora_up_proj', ['fsdp_transpose']],
80-
['kv_lora_up_proj', ['fsdp_transpose']],
81-
['q_heads', ['fsdp_transpose']],
82-
['kv_heads', ['fsdp_transpose']],
83-
['heads', ['fsdp_transpose']],
84-
['mlp', ['fsdp_transpose']],
85-
['fsdp_transpose_and_expert', ['fsdp_transpose', 'expert']],
86-
['fsdp_transpose_only', ['fsdp_transpose']],
80+
['q_lora_up_proj', []],
81+
['kv_lora_up_proj', []],
82+
['q_heads', []],
83+
['kv_heads', []],
84+
['heads', []],
85+
['mlp', []],
8786
['expert_only', ['expert']],
8887
['diloco', 'diloco'],
8988
]

src/maxtext/configs/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -832,6 +832,9 @@ class HardwareAndMesh(BaseModel):
832832
shardy: bool = Field(True, description="Whether to use shardy XLA backend.")
833833
pure_nnx_decoder: bool = Field(False, description="Whether to enable pure NNX decoder.")
834834
pure_nnx: bool = Field(False, description="Whether to enable pure NNX mode.")
835+
remove_size_one_mesh_axis_from_type: bool = Field(
836+
True, description="Whether to remove size one mesh axis from type through jax.config."
837+
)
835838

836839

837840
class LayoutAndSharding(BaseModel):

src/maxtext/layers/moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,8 @@ def _sort_activations_custom_bwd(residuals: jax.Array, grads: jax.Array) -> tupl
102102

103103
def get_batchsplit_init_kernel_axes():
104104
return (
105-
("embed_moe", "fsdp_transpose_only", "expert_only"),
106-
("embed_moe", "fsdp_transpose_and_expert", None),
105+
("embed_moe", None, "expert_only"),
106+
("embed_moe", "expert_only", None),
107107
)
108108

109109

src/maxtext/models/deepseek.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
from maxtext.utils import max_utils
4444
from maxtext.utils.sharding import create_sharding
4545
from maxtext.utils.sharding import maybe_shard_with_logical
46-
from maxtext.utils.sharding import remove_size_one_mesh_axis
4746

4847
import transformers
4948

@@ -492,14 +491,13 @@ def __call__(
492491
return outputs, None
493492

494493
# bf16 code path
495-
activation_pspec = remove_size_one_mesh_axis(
496-
jax.sharding.PartitionSpec(
497-
("data", "fsdp", "fsdp_transpose", "expert", "context"),
498-
None,
499-
None,
500-
),
501-
self.mesh,
494+
input_sharding = jax.typeof(inputs).sharding
495+
activation_pspec = jax.sharding.PartitionSpec(
496+
("data", "fsdp", "expert"),
497+
None,
498+
None,
502499
)
500+
inputs = jax.reshard(inputs, jax.sharding.NamedSharding(self.mesh, activation_pspec))
503501
yarn_freqs = deepseek_batchsplit.initialize_yarn_freqs(
504502
decoder_positions,
505503
embedding_dims=self.config.qk_rope_head_dim,
@@ -571,6 +569,7 @@ def extract_fn(x):
571569
in_specs=([activation_pspec] * self.config.batch_split_factor,),
572570
out_specs=activation_pspec,
573571
)(outputs)
572+
outputs = jax.reshard(outputs, input_sharding)
574573
return outputs, None
575574

576575
x = self.with_logical_constraint(inputs)

0 commit comments

Comments
 (0)