Skip to content

Commit a7cad55

Browse files
Kevin WangGoogle-ML-Automation
authored andcommitted
Implement batch-split schedule in deepseek_batchsplit.
PiperOrigin-RevId: 892685342
1 parent 1e97f2e commit a7cad55

12 files changed

Lines changed: 3002 additions & 638 deletions

src/maxtext/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ load_balance_loss_weight: 0.0 # weight for the load balance loss
191191
use_random_routing: false # whether to use random routing for debug/test purpose
192192
use_custom_sort_vjp: true # whether to use a custom VJP sort for efficient backward pass processing in sparse matmul
193193
use_ring_of_experts: false # whether to use ring of experts for sparse matmul expert parallelism
194+
use_gather_mosaic_kernel: false # whether to use a custom mosaic kernel for token gather ops
194195
# tunable tiling dimensions used for mlp gmm
195196
# megablox/jax ragged dot - supports forward pass only (6 configs: `wi_tile_fwd...` and `wo_tile_fwd_...`)
196197
# tokamax ragged dot - supports all 18 configs

src/maxtext/configs/models/deepseek3-671b-batchsplit.yml

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
# model config for DeepSeek V3 - 671B that uses fsdp on two logical axes
15+
# model config for DeepSeek V3 - 671B that uses batch split schedule
1616

1717
# For DeepSeek default device-limited routing,
1818
# please set n_routing_groups=8 and topk_routing_group=4 in your command-line arguments.
@@ -55,17 +55,18 @@ rope_interleave: True
5555
rope_truncate: True
5656
rope_attention_scaling: False
5757

58+
use_batch_split_schedule: True
59+
shard_mode: "explicit"
5860
override_logical_axis_rules: True
59-
mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']
61+
mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']
6062
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']]
6163
logical_axis_rules: [
6264
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
6365
['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
6466
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
6567
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
66-
['activation_embed_and_logits_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
67-
['activation_norm_length', ['context']],
68-
['activation_norm_length_moe', ['context']],
68+
['activation_norm_length', []],
69+
['activation_norm_length_moe', []],
6970
['activation_heads', []],
7071
['activation_stage', 'stage'],
7172
['embed', ['fsdp']],
@@ -81,8 +82,8 @@ logical_axis_rules: [
8182
['kv_heads', ['fsdp_transpose']],
8283
['heads', ['fsdp_transpose']],
8384
['mlp', ['fsdp_transpose']],
84-
['mlp_only_fsdp_transpose', ['fsdp_transpose']],
85-
['expert_only', ['expert']],
86-
['fsdp_transpose_only', ['fsdp_transpose']],
8785
['fsdp_transpose_and_expert', ['fsdp_transpose', 'expert']],
86+
['fsdp_transpose_only', ['fsdp_transpose']],
87+
['expert_only', ['expert']],
88+
['diloco', 'diloco'],
8889
]

src/maxtext/configs/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,10 @@ class MoEGeneral(BaseModel):
634634
False,
635635
description="Whether to use Ring of Experts for sparse matmul expert parallelism.",
636636
)
637+
use_gather_mosaic_kernel: bool = Field(
638+
False,
639+
description="Whether to use a custom mosaic kernel for token gather ops.",
640+
)
637641
use_random_routing: bool = Field(False, description="Whether to use random routing for debugging.")
638642
interleave_moe_layer_step: int = Field(1, description="Frequency of MoE layers, e.g., 2 means every 2nd layer is MoE.")
639643
expert_shard_attention_option: Literal["fsdp", "context"] = Field(
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Attention kernels."""
16+
17+
from maxtext.kernels.attention import splash_attention_kernel

src/maxtext/kernels/attention/splash_attention_kernel.py

Lines changed: 136 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -427,9 +427,7 @@ def _wrapped(
427427
def reshape_activations(activations):
428428
if activations.ndim == 4: # pytype: disable=attribute-error
429429
kv_heads, q_heads_per_kv_head, q_seq_len, head_dim = activations.shape # pytype: disable=attribute-error
430-
return activations.reshape(
431-
kv_heads * q_heads_per_kv_head, q_seq_len, head_dim
432-
) # pytype: disable=attribute-error
430+
return activations.reshape(kv_heads * q_heads_per_kv_head, q_seq_len, head_dim) # pytype: disable=attribute-error
433431
return activations
434432

435433
def reshape_residuals(residuals):
@@ -1166,10 +1164,7 @@ def _splash_attention_fwd(
11661164
mask_function: MaskFunctionType | None,
11671165
attn_logits_soft_cap: float | None = None,
11681166
interpret: bool = False,
1169-
) -> tuple[
1170-
tuple[jax.Array],
1171-
SplashResidualsType,
1172-
]:
1167+
) -> tuple[tuple[jax.Array], SplashResidualsType,]:
11731168
"""Forward pass for splash attention."""
11741169
if save_residuals:
11751170
raise NotImplementedError("Higher-order AD not supported")
@@ -1606,7 +1601,6 @@ def init():
16061601
)
16071602

16081603
def body(i, _):
1609-
16101604
slice_k = pl.ds(i * bkv_compute, bkv_compute)
16111605
q = q_ref[...] # We keep q potentially transposed, since it's always RHS
16121606

@@ -2238,6 +2232,120 @@ def _collapse_partial_mask_blocks(mask_info: mask_info_lib.MaskInfo | None):
22382232
)
22392233

22402234

2235+
@partial(
2236+
jax.jit,
2237+
static_argnames=[
2238+
"is_mqa",
2239+
"block_sizes",
2240+
"save_residuals",
2241+
"mask_value",
2242+
"attn_logits_soft_cap",
2243+
"residual_checkpoint_name",
2244+
"mask_function",
2245+
"interpret",
2246+
],
2247+
)
2248+
def _splash_attention_manual_fwd(
2249+
fwd_mask_info: mask_info_lib.MaskInfo,
2250+
dq_mask_info: mask_info_lib.MaskInfo | None,
2251+
dkv_mask_info: mask_info_lib.MaskInfo | None,
2252+
q: jax.Array,
2253+
k: jax.Array,
2254+
v: jax.Array,
2255+
segment_ids: SegmentIds | None = None,
2256+
sinks: jax.Array | None = None,
2257+
*,
2258+
is_mqa: bool,
2259+
block_sizes: BlockSizes | None,
2260+
save_residuals: bool,
2261+
mask_value: float,
2262+
attn_logits_soft_cap: float | None,
2263+
residual_checkpoint_name: str | None,
2264+
mask_function: MaskFunctionType | None,
2265+
interpret: bool,
2266+
) -> SplashCustomReturnType:
2267+
def _collapse_partial_mask_blocks(mask_info: mask_info_lib.MaskInfo | None):
2268+
if mask_info is None or mask_info.partial_mask_blocks is None:
2269+
return mask_info
2270+
2271+
return mask_info._replace(
2272+
partial_mask_blocks=mask_info.partial_mask_blocks.reshape(-1, *mask_info.partial_mask_blocks.shape[-2:])
2273+
)
2274+
2275+
if not save_residuals:
2276+
raise ValueError("Expected save_residuals to be `True`.")
2277+
2278+
fwd_mask_info = _collapse_partial_mask_blocks(fwd_mask_info)
2279+
dq_mask_info = _collapse_partial_mask_blocks(dq_mask_info)
2280+
dkv_mask_info = _collapse_partial_mask_blocks(dkv_mask_info)
2281+
del dq_mask_info, dkv_mask_info
2282+
2283+
out, (logsumexp,) = _splash_attention_forward( # pytype: disable=wrong-arg-types
2284+
fwd_mask_info,
2285+
q,
2286+
k,
2287+
v,
2288+
segment_ids,
2289+
mask_value=mask_value,
2290+
is_mqa=is_mqa,
2291+
block_sizes=block_sizes,
2292+
residual_checkpoint_name=residual_checkpoint_name,
2293+
save_residuals=True,
2294+
mask_function=mask_function,
2295+
attn_logits_soft_cap=attn_logits_soft_cap,
2296+
interpret=interpret,
2297+
)
2298+
return out, logsumexp
2299+
2300+
2301+
def _splash_attention_manual_bwd(
2302+
fwd_mask_info: mask_info_lib.MaskInfo,
2303+
dq_mask_info: mask_info_lib.MaskInfo | None,
2304+
dkv_mask_info: mask_info_lib.MaskInfo | None,
2305+
q: jax.Array,
2306+
k: jax.Array,
2307+
v: jax.Array,
2308+
out: jax.Array,
2309+
logsumexp: jax.Array,
2310+
do: jax.Array,
2311+
segment_ids: SegmentIds | None = None,
2312+
sinks: jax.Array | None = None,
2313+
*,
2314+
is_mqa: bool,
2315+
block_sizes: BlockSizes | None,
2316+
save_residuals: bool,
2317+
mask_value: float,
2318+
attn_logits_soft_cap: float | None,
2319+
residual_checkpoint_name: str | None,
2320+
mask_function: MaskFunctionType | None,
2321+
interpret: bool,
2322+
):
2323+
del fwd_mask_info
2324+
res = (
2325+
q,
2326+
k,
2327+
v,
2328+
segment_ids,
2329+
out,
2330+
logsumexp,
2331+
dq_mask_info,
2332+
dkv_mask_info,
2333+
)
2334+
_, _, _, dq, dk, dv, _ = _splash_attention_bwd(
2335+
save_residuals=save_residuals,
2336+
mask_value=mask_value,
2337+
is_mqa=is_mqa,
2338+
block_sizes=block_sizes,
2339+
residual_checkpoint_name=residual_checkpoint_name,
2340+
mask_function=mask_function,
2341+
attn_logits_soft_cap=attn_logits_soft_cap,
2342+
interpret=interpret,
2343+
res=res,
2344+
do=do,
2345+
)
2346+
return dq, dk, dv
2347+
2348+
22412349
@jax.tree_util.register_pytree_node_class
22422350
class SplashAttentionKernel:
22432351
"""Defines a SplashAttention kernel object."""
@@ -2264,6 +2372,26 @@ def __call__(self, *args, **kwargs) -> SplashCustomReturnType:
22642372
**self.kwargs,
22652373
)
22662374

2375+
def manual_fwd(self, *args, **kwargs) -> SplashCustomReturnType:
2376+
return _splash_attention_manual_fwd(
2377+
self.fwd_mask_info,
2378+
self.dq_mask_info,
2379+
self.dkv_mask_info,
2380+
*args,
2381+
**kwargs,
2382+
**self.kwargs,
2383+
)
2384+
2385+
def manual_bwd(self, *args, **kwargs):
2386+
return _splash_attention_manual_bwd(
2387+
self.fwd_mask_info,
2388+
self.dq_mask_info,
2389+
self.dkv_mask_info,
2390+
*args,
2391+
**kwargs,
2392+
**self.kwargs,
2393+
)
2394+
22672395
def manual_sharding_spec(self, sharding: jax.sharding.NamedSharding):
22682396
"""Returns a value that can be used as a shard_map partition spec for the kernel."""
22692397
if self.fwd_mask_info.data_next is not None:

src/maxtext/kernels/sort_activations.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,6 @@ def _route_impl(
9090
assert (
9191
tokens.shape[0] == selected_experts.shape[0] and selected_experts.ndim == 2
9292
), f"{tokens.shape=}, {selected_experts.shape=}"
93-
if use_custom_mosaic_kernel:
94-
raise NotImplementedError("Custom Mosaic kernel not implemented.")
9593
inds = jnp.argsort(jnp.ravel(selected_experts)) // selected_experts.shape[1]
9694
return _sort_impl(tokens, inds, use_custom_mosaic_kernel)
9795

@@ -114,7 +112,4 @@ def _unroute_impl(
114112

115113

116114
def _sort_impl(tokens: jax.Array, inds: jax.Array, use_custom_mosaic_kernel: bool) -> jax.Array:
117-
if use_custom_mosaic_kernel:
118-
raise NotImplementedError("Custom Mosaic kernel not implemented.")
119-
else:
120-
return tokens[inds, ...]
115+
return tokens[inds, ...]

src/maxtext/layers/decoders.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -919,12 +919,9 @@ def __call__(
919919
y,
920920
self.variables["params"]["moe_layers"],
921921
decoder_positions,
922-
decoder_segment_ids,
923-
model_mode=model_mode,
924922
mesh=mesh,
925-
quant=self.quant,
926923
cfg=cfg,
927-
policy=policy,
924+
num_layers=num_moe_layers,
928925
)
929926
else:
930927
y, _ = self.scan_decoder_layers(

src/maxtext/layers/embeddings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def __call__(self, inputs: Array, model_mode: str = MODEL_MODE_TRAIN) -> Array:
169169
"activation_embed",
170170
)
171171
)
172-
out_pspec = logical_to_mesh_axes(output_axis_names, self.mesh)
172+
out_pspec = logical_to_mesh_axes(output_axis_names, self.mesh, rules=self.config.logical_axis_rules)
173173

174174
out_sharding = NamedSharding(self.mesh, out_pspec) if self.config.shard_mode == ShardMode.EXPLICIT else None
175175

src/maxtext/layers/nnx_decoders.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -976,20 +976,15 @@ def __call__(
976976
num_moe = cfg.num_decoder_layers - cfg.first_num_dense_layers
977977

978978
if cfg.use_batch_split_schedule:
979-
policy = self.get_remat_policy()
980-
981979
mock_params = self._build_linen_params(self.moe_layer)
982980

983981
y = deepseek_batchsplit.scan_batch_split_layers(
984982
y,
985983
mock_params,
986984
decoder_positions,
987-
decoder_segment_ids,
988-
model_mode=model_mode,
989985
mesh=self.mesh,
990-
quant=self.quant,
991986
cfg=cfg,
992-
policy=policy,
987+
num_layers=num_moe,
993988
)
994989
else:
995990
y, self.moe_layer = self._apply_layers_sequentially(

0 commit comments

Comments
 (0)