Skip to content

Commit 18d7751

Browse files
CaptainO5Google-ML-Automation
authored andcommitted
Refactor Megablox Ops to use public Tokamax API
PiperOrigin-RevId: 885711552
1 parent 51c7f2b commit 18d7751

5 files changed

Lines changed: 55 additions & 102 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -223,16 +223,6 @@ wo_tile_drhs_batch_seq: 512
223223
wo_tile_drhs_embed_dim: 1024
224224
wo_tile_drhs_mlp_dim: 1024
225225

226-
wi_tile_fwd_buffer_count: 2
227-
wi_tile_dlhs_buffer_count: 2
228-
wi_tile_drhs_buffer_count: 2
229-
wo_tile_fwd_buffer_count: 2
230-
wo_tile_dlhs_buffer_count: 2
231-
wo_tile_drhs_buffer_count: 2
232-
233-
wi_combine_scopes: False
234-
wo_combine_scopes: False
235-
236226
merge_gating_gmm: False
237227

238228
norm_topk_prob: false # boolean to enable the top-k probability normalization. qwen3-specific normalization of router weights.

src/maxtext/configs/types.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -734,16 +734,6 @@ class MoEKernels(BaseModel):
734734
wo_tile_drhs_embed_dim: int = Field(1024, description="bwd pass drhs tiling dimension for embedding in GMM for wo.")
735735
wo_tile_drhs_mlp_dim: int = Field(1024, description="bwd pass drhs tiling dimension for MLP in GMM for wo.")
736736

737-
wi_tile_fwd_buffer_count: int = Field(2, description="forward pass tiling buffer count in GMM for wi.")
738-
wi_tile_dlhs_buffer_count: int = Field(2, description="bwd pass dlhs tiling buffer count in GMM for wi.")
739-
wi_tile_drhs_buffer_count: int = Field(2, description="bwd pass drhs tiling buffer count in GMM for wi.")
740-
wo_tile_fwd_buffer_count: int = Field(2, description="forward pass tiling buffer count in GMM for wo.")
741-
wo_tile_dlhs_buffer_count: int = Field(2, description="bwd pass dlhs tiling buffer count in GMM for wo.")
742-
wo_tile_drhs_buffer_count: int = Field(2, description="bwd pass drhs tiling buffer count in GMM for wo.")
743-
744-
wi_combine_scopes: bool = Field(False, description="whether to use combine_scopes features for tgmm for wi.")
745-
wo_combine_scopes: bool = Field(False, description="whether to use combine_scopes features for tgmm for wo.")
746-
747737
merge_gating_gmm: bool = Field(False, description="whether to merge the two gating gmm kernels into one.")
748738

749739

src/maxtext/kernels/megablox/ops.py

Lines changed: 54 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,40 @@
1616

1717
# pylint: disable=too-many-positional-arguments
1818

19-
import functools
2019
import dataclasses
21-
from typing import Literal, List, Tuple
20+
import functools
21+
from typing import List, Literal, Tuple
2222
import jax
2323
import jax.numpy as jnp
2424
from maxtext.kernels.megablox import backend
25-
from tokamax._src.ops.ragged_dot import pallas_mosaic_tpu_kernel as tokamax_backend
2625
import qwix
2726
import qwix.pallas as qpl
27+
import tokamax
28+
29+
30+
DRHS_RAGGED_DOT_DIM_NUMS = jax.lax.RaggedDotDimensionNumbers(
31+
dot_dimension_numbers=(([0], [0]), ([], [])),
32+
lhs_ragged_dimensions=[0],
33+
rhs_group_dimensions=[],
34+
)
2835

2936

3037
def gmm(
3138
lhs: jnp.ndarray,
3239
rhs: jnp.ndarray,
3340
group_sizes: jnp.ndarray,
3441
preferred_element_type: jnp.dtype = jnp.float32,
35-
tiling: tuple[int, int, int, int, int, int, int, int, int] = (128, 128, 128, 128, 128, 128, 128, 128, 128),
42+
tiling: tuple[int, int, int, int, int, int, int, int, int] = (
43+
128,
44+
128,
45+
128,
46+
128,
47+
128,
48+
128,
49+
128,
50+
128,
51+
128,
52+
),
3653
group_offset: jnp.ndarray | None = None,
3754
existing_out: jnp.ndarray | None = None,
3855
transpose_rhs: bool = False,
@@ -42,8 +59,6 @@ def gmm(
4259
use_qwix_quantization: bool = False,
4360
use_tokamax_backend: bool = False,
4461
weight_gather_axes: List[Tuple[str, int]] | None = None,
45-
input_buffer_count: tuple[int, int, int] = (2, 2, 2),
46-
combine_scopes: bool = False,
4762
# TODO(amandaliang): get rid of the qwix_rule in favor of Qwix's interception feature
4863
qwix_rule: qwix.QtRule | None = None,
4964
):
@@ -65,16 +80,14 @@ def gmm(
6580
)
6681

6782
gmm_fwd_bwd = lambda *args: _gmm_fwd(*args)[0] # pylint: disable=C3001
68-
gmm_fwd_bwd = jax.custom_vjp(gmm_fwd_bwd, nondiff_argnums=(3, 4, 5, 6, 9, 10, 11, 12, 13))
83+
gmm_fwd_bwd = jax.custom_vjp(gmm_fwd_bwd, nondiff_argnums=(3, 4, 7, 8, 9, 10, 11))
6984
gmm_fwd_bwd.defvjp(_gmm_fwd, functools.partial(_gmm_bwd, lhs.dtype, rhs.dtype))
7085
return gmm_fwd_bwd(
7186
lhs,
7287
rhs,
7388
group_sizes,
7489
preferred_element_type,
7590
tiling,
76-
input_buffer_count,
77-
combine_scopes,
7891
group_offset,
7992
existing_out,
8093
transpose_rhs,
@@ -90,9 +103,17 @@ def _gmm_fwd(
90103
rhs: jnp.ndarray,
91104
group_sizes: jnp.ndarray,
92105
preferred_element_type: jnp.dtype = jnp.float32,
93-
tiling: tuple[int, int, int, int, int, int, int, int, int] = (128, 128, 128, 128, 128, 128, 128, 128, 128),
94-
input_buffer_count: tuple[int, int, int] = (2, 2, 2),
95-
combine_scopes: bool = False,
106+
tiling: tuple[int, int, int, int, int, int, int, int, int] = (
107+
128,
108+
128,
109+
128,
110+
128,
111+
128,
112+
128,
113+
128,
114+
128,
115+
128,
116+
),
96117
group_offset: jnp.ndarray | None = None,
97118
existing_out: jnp.ndarray | None = None,
98119
transpose_rhs: bool = False,
@@ -136,17 +157,18 @@ def _gmm_fwd(
136157
for axis_name, axis_idx in weight_gather_axes:
137158
rhs_qvalue = jax.lax.all_gather(rhs.qvalue, axis_name, axis=axis_idx, tiled=True)
138159
rhs = dataclasses.replace(rhs, qvalue=rhs_qvalue)
139-
out = tokamax_backend.gmm(
160+
# Handle transpose_rhs manually as ragged_dot assumes (G, K, N)
161+
if transpose_rhs:
162+
rhs = rhs.swapaxes(1, 2)
163+
164+
out = tokamax.ragged_dot(
140165
lhs=lhs,
141166
rhs=rhs,
142167
group_sizes=group_sizes,
143168
precision=jax.lax.Precision.DEFAULT,
144-
out_dtype=preferred_element_type,
145-
tiling=tiling[:3],
169+
preferred_element_type=preferred_element_type,
146170
group_offset=group_offset,
147-
transpose_rhs=transpose_rhs,
148-
interpret=interpret,
149-
input_buffer_count=input_buffer_count[0],
171+
implementation="mosaic",
150172
)
151173
else:
152174
out = backend.gmm(
@@ -168,8 +190,6 @@ def _gmm_bwd(
168190
rhs_dtype: jax.typing.DTypeLike,
169191
preferred_element_type: jnp.dtype,
170192
tiling: tuple[int, int, int, int, int, int, int, int, int],
171-
input_buffer_count: tuple[int, int, int],
172-
combine_scopes: bool,
173193
transpose_rhs: bool,
174194
interpret: bool,
175195
quantization_rule: qwix.QtRule | None,
@@ -224,30 +244,29 @@ def _gmm_bwd(
224244
calibration_method=quantization_rule.bwd_calibration_method,
225245
)
226246
if use_tokamax_backend:
227-
dlhs = tokamax_backend.gmm(
247+
# Handle transpose_rhs manually
248+
dlhs_rhs = rhs
249+
if not transpose_rhs:
250+
dlhs_rhs = dlhs_rhs.swapaxes(1, 2)
251+
252+
dlhs = tokamax.ragged_dot(
228253
lhs=dlhs_dout,
229-
rhs=rhs,
254+
rhs=dlhs_rhs,
230255
group_sizes=group_sizes,
231256
precision=jax.lax.Precision.DEFAULT,
232-
out_dtype=lhs_dtype,
233-
tiling=tiling[3:6],
257+
preferred_element_type=lhs_dtype,
234258
group_offset=group_offset,
235-
transpose_rhs=not transpose_rhs,
236-
interpret=interpret,
237-
input_buffer_count=input_buffer_count[1],
259+
implementation="mosaic",
238260
)
239-
drhs = tokamax_backend.tgmm(
240-
lhs=lhs.swapaxes(0, 1),
261+
drhs = tokamax.ragged_dot_general(
262+
lhs=lhs,
241263
rhs=drhs_dout,
242264
group_sizes=group_sizes,
265+
ragged_dot_dimension_numbers=DRHS_RAGGED_DOT_DIM_NUMS,
243266
precision=jax.lax.Precision.DEFAULT,
244-
out_dtype=rhs_dtype,
245-
tiling=tiling[-3:],
267+
preferred_element_type=rhs_dtype,
246268
group_offset=group_offset,
247-
num_actual_groups=num_actual_groups,
248-
interpret=interpret,
249-
input_buffer_count=input_buffer_count[2],
250-
combine_scopes=combine_scopes,
269+
implementation="mosaic",
251270
)
252271
if quantization_rule and quantization_rule.bwd_qtype and weight_gather_axes:
253272
# Scatter back in reverse order of gather

src/maxtext/layers/moe.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -968,9 +968,7 @@ def get_quantization_dtypes():
968968
rhs_quantize_dtype = quant_dg.fwd.dg_quantizer.rhs.numerics.get_dtype()
969969
return lhs_quantize_dtype, rhs_quantize_dtype
970970

971-
def gmm(
972-
inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes, input_buffer_count, combine_scopes
973-
):
971+
def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes):
974972
if inputs.shape[0] != expert_assignments.shape[0]:
975973
raise ValueError("The number of input tokens must match the number of expert assignments!")
976974

@@ -996,8 +994,6 @@ def gmm(
996994
use_qwix_quantization=self.config.use_qwix_quantization,
997995
use_tokamax_backend=self.config.use_tokamax_gmm,
998996
weight_gather_axes=weight_gather_axes,
999-
input_buffer_count=input_buffer_count,
1000-
combine_scopes=combine_scopes,
1001997
)
1002998
else: # tokamax (unquantized)
1003999
output = tokamax.ragged_dot(
@@ -1253,26 +1249,12 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
12531249
self.config.wo_tile_drhs_embed_dim,
12541250
self.config.wo_tile_drhs_mlp_dim,
12551251
)
1256-
wi_input_buffer_count = (
1257-
self.config.wi_tile_fwd_buffer_count,
1258-
self.config.wi_tile_dlhs_buffer_count,
1259-
self.config.wi_tile_drhs_buffer_count,
1260-
)
1261-
wo_input_buffer_count = (
1262-
self.config.wo_tile_fwd_buffer_count,
1263-
self.config.wo_tile_dlhs_buffer_count,
1264-
self.config.wo_tile_drhs_buffer_count,
1265-
)
12661252

1267-
wi_combine_scopes = self.config.wi_combine_scopes
1268-
wo_combine_scopes = self.config.wo_combine_scopes
12691253
layer_w0 = gmm_fn(
12701254
x,
12711255
w0,
12721256
tiling=wi_tile_size,
12731257
weight_gather_axes=wi_gather_axes,
1274-
input_buffer_count=wi_input_buffer_count,
1275-
combine_scopes=wi_combine_scopes,
12761258
)
12771259
if self.get_tensor_transpose_parallelism_size() > 1:
12781260
layer_w0 = jax.lax.psum(layer_w0, "tensor_transpose")
@@ -1285,8 +1267,6 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
12851267
w1,
12861268
tiling=wi_tile_size,
12871269
weight_gather_axes=wi_gather_axes,
1288-
input_buffer_count=wi_input_buffer_count,
1289-
combine_scopes=wi_combine_scopes,
12901270
)
12911271
if self.get_tensor_transpose_parallelism_size() > 1:
12921272
layer_w1 = jax.lax.psum(layer_w1, "tensor_transpose")
@@ -1300,8 +1280,6 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
13001280
wo,
13011281
tiling=wo_tile_size,
13021282
weight_gather_axes=wo_gather_axes,
1303-
input_buffer_count=wo_input_buffer_count,
1304-
combine_scopes=wo_combine_scopes,
13051283
)
13061284
if self.get_tensor_parallelism_size() > 1:
13071285
intermediate_output = jax.lax.psum_scatter(

src/maxtext/models/deepseek_batchsplit_fp8.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -949,8 +949,6 @@ def gmm(
949949
group_sizes,
950950
preferred_element_type,
951951
weight_gather_axes,
952-
input_buffer_count,
953-
combine_scopes,
954952
):
955953
if config.use_qwix_quantization:
956954
output = megablox.gmm(
@@ -962,8 +960,6 @@ def gmm(
962960
use_qwix_quantization=config.use_qwix_quantization,
963961
use_tokamax_backend=config.use_tokamax_gmm,
964962
weight_gather_axes=weight_gather_axes,
965-
input_buffer_count=input_buffer_count,
966-
combine_scopes=combine_scopes,
967963
qwix_rule=quantizations.get_fp8_full_qwix_rule(config),
968964
)
969965
else:
@@ -1006,19 +1002,7 @@ def gmm(
10061002
config.wo_tile_drhs_embed_dim,
10071003
config.wo_tile_drhs_mlp_dim,
10081004
)
1009-
wi_input_buffer_count = (
1010-
config.wi_tile_fwd_buffer_count,
1011-
config.wi_tile_dlhs_buffer_count,
1012-
config.wi_tile_drhs_buffer_count,
1013-
)
1014-
wo_input_buffer_count = (
1015-
config.wo_tile_fwd_buffer_count,
1016-
config.wo_tile_dlhs_buffer_count,
1017-
config.wo_tile_drhs_buffer_count,
1018-
)
10191005

1020-
wi_combine_scopes = config.wi_combine_scopes
1021-
wo_combine_scopes = config.wo_combine_scopes
10221006
if config.use_qwix_quantization:
10231007
gating_pspec, linear_pspec = moe_lib.get_batchsplit_init_kernel_axes()
10241008
w0_pspec = nn.logical_to_mesh_axes(gating_pspec)
@@ -1047,8 +1031,6 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
10471031
w01,
10481032
tiling=wi_tile_size,
10491033
weight_gather_axes=wi_gather_axes,
1050-
input_buffer_count=wi_input_buffer_count,
1051-
combine_scopes=wi_combine_scopes,
10521034
)
10531035
layer_w0, layer_w1 = jnp.split(layer_w01, 2, axis=-1)
10541036
else:
@@ -1057,16 +1039,12 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
10571039
w0,
10581040
tiling=wi_tile_size,
10591041
weight_gather_axes=wi_gather_axes,
1060-
input_buffer_count=wi_input_buffer_count,
1061-
combine_scopes=wi_combine_scopes,
10621042
)
10631043
layer_w1 = gmm_fn(
10641044
x,
10651045
w1,
10661046
tiling=wi_tile_size,
10671047
weight_gather_axes=wi_gather_axes,
1068-
input_buffer_count=wi_input_buffer_count,
1069-
combine_scopes=wi_combine_scopes,
10701048
)
10711049
layer_w0 = jax.ad_checkpoint.checkpoint_name(layer_w0, "mlpwi_0")
10721050
layer_w1 = jax.ad_checkpoint.checkpoint_name(layer_w1, "mlpwi_1")
@@ -1077,8 +1055,6 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
10771055
wo,
10781056
tiling=wo_tile_size,
10791057
weight_gather_axes=wo_gather_axes,
1080-
input_buffer_count=wo_input_buffer_count,
1081-
combine_scopes=wo_combine_scopes,
10821058
)
10831059
return layer_wo
10841060

0 commit comments

Comments
 (0)