Skip to content

Commit 46ae9fa

Browse files
CaptainO5Google-ML-Automation
authored andcommitted
Refactor Megablox Ops to use public Tokamax API
PiperOrigin-RevId: 885711552
1 parent 51c7f2b commit 46ae9fa

5 files changed

Lines changed: 57 additions & 99 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -223,16 +223,6 @@ wo_tile_drhs_batch_seq: 512
223223
wo_tile_drhs_embed_dim: 1024
224224
wo_tile_drhs_mlp_dim: 1024
225225

226-
wi_tile_fwd_buffer_count: 2
227-
wi_tile_dlhs_buffer_count: 2
228-
wi_tile_drhs_buffer_count: 2
229-
wo_tile_fwd_buffer_count: 2
230-
wo_tile_dlhs_buffer_count: 2
231-
wo_tile_drhs_buffer_count: 2
232-
233-
wi_combine_scopes: False
234-
wo_combine_scopes: False
235-
236226
merge_gating_gmm: False
237227

238228
norm_topk_prob: false # boolean to enable the top-k probability normalization. qwen3-specific normalization of router weights.

src/maxtext/configs/types.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -734,15 +734,6 @@ class MoEKernels(BaseModel):
734734
wo_tile_drhs_embed_dim: int = Field(1024, description="bwd pass drhs tiling dimension for embedding in GMM for wo.")
735735
wo_tile_drhs_mlp_dim: int = Field(1024, description="bwd pass drhs tiling dimension for MLP in GMM for wo.")
736736

737-
wi_tile_fwd_buffer_count: int = Field(2, description="forward pass tiling buffer count in GMM for wi.")
738-
wi_tile_dlhs_buffer_count: int = Field(2, description="bwd pass dlhs tiling buffer count in GMM for wi.")
739-
wi_tile_drhs_buffer_count: int = Field(2, description="bwd pass drhs tiling buffer count in GMM for wi.")
740-
wo_tile_fwd_buffer_count: int = Field(2, description="forward pass tiling buffer count in GMM for wo.")
741-
wo_tile_dlhs_buffer_count: int = Field(2, description="bwd pass dlhs tiling buffer count in GMM for wo.")
742-
wo_tile_drhs_buffer_count: int = Field(2, description="bwd pass drhs tiling buffer count in GMM for wo.")
743-
744-
wi_combine_scopes: bool = Field(False, description="whether to use combine_scopes features for tgmm for wi.")
745-
wo_combine_scopes: bool = Field(False, description="whether to use combine_scopes features for tgmm for wo.")
746737

747738
merge_gating_gmm: bool = Field(False, description="whether to merge the two gating gmm kernels into one.")
748739

src/maxtext/kernels/megablox/ops.py

Lines changed: 56 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,40 @@
1616

1717
# pylint: disable=too-many-positional-arguments
1818

19-
import functools
2019
import dataclasses
21-
from typing import Literal, List, Tuple
20+
import functools
21+
from typing import List, Literal, Tuple
2222
import jax
2323
import jax.numpy as jnp
2424
from maxtext.kernels.megablox import backend
25-
from tokamax._src.ops.ragged_dot import pallas_mosaic_tpu_kernel as tokamax_backend
2625
import qwix
2726
import qwix.pallas as qpl
27+
import tokamax
28+
29+
30+
DRHS_RAGGED_DOT_DIM_NUMS = jax.lax.RaggedDotDimensionNumbers(
31+
dot_dimension_numbers=(([0], [0]), ([], [])),
32+
lhs_ragged_dimensions=[0],
33+
rhs_group_dimensions=[],
34+
)
2835

2936

3037
def gmm(
3138
lhs: jnp.ndarray,
3239
rhs: jnp.ndarray,
3340
group_sizes: jnp.ndarray,
3441
preferred_element_type: jnp.dtype = jnp.float32,
35-
tiling: tuple[int, int, int, int, int, int, int, int, int] = (128, 128, 128, 128, 128, 128, 128, 128, 128),
42+
tiling: tuple[int, int, int, int, int, int, int, int, int] = (
43+
128,
44+
128,
45+
128,
46+
128,
47+
128,
48+
128,
49+
128,
50+
128,
51+
128,
52+
),
3653
group_offset: jnp.ndarray | None = None,
3754
existing_out: jnp.ndarray | None = None,
3855
transpose_rhs: bool = False,
@@ -42,8 +59,6 @@ def gmm(
4259
use_qwix_quantization: bool = False,
4360
use_tokamax_backend: bool = False,
4461
weight_gather_axes: List[Tuple[str, int]] | None = None,
45-
input_buffer_count: tuple[int, int, int] = (2, 2, 2),
46-
combine_scopes: bool = False,
4762
# TODO(amandaliang): get rid of the qwix_rule in favor of Qwix's interception feature
4863
qwix_rule: qwix.QtRule | None = None,
4964
):
@@ -65,16 +80,16 @@ def gmm(
6580
)
6681

6782
gmm_fwd_bwd = lambda *args: _gmm_fwd(*args)[0] # pylint: disable=C3001
68-
gmm_fwd_bwd = jax.custom_vjp(gmm_fwd_bwd, nondiff_argnums=(3, 4, 5, 6, 9, 10, 11, 12, 13))
83+
gmm_fwd_bwd = jax.custom_vjp(
84+
gmm_fwd_bwd, nondiff_argnums=(3, 4, 7, 8, 9, 10, 11)
85+
)
6986
gmm_fwd_bwd.defvjp(_gmm_fwd, functools.partial(_gmm_bwd, lhs.dtype, rhs.dtype))
7087
return gmm_fwd_bwd(
7188
lhs,
7289
rhs,
7390
group_sizes,
7491
preferred_element_type,
7592
tiling,
76-
input_buffer_count,
77-
combine_scopes,
7893
group_offset,
7994
existing_out,
8095
transpose_rhs,
@@ -90,9 +105,17 @@ def _gmm_fwd(
90105
rhs: jnp.ndarray,
91106
group_sizes: jnp.ndarray,
92107
preferred_element_type: jnp.dtype = jnp.float32,
93-
tiling: tuple[int, int, int, int, int, int, int, int, int] = (128, 128, 128, 128, 128, 128, 128, 128, 128),
94-
input_buffer_count: tuple[int, int, int] = (2, 2, 2),
95-
combine_scopes: bool = False,
108+
tiling: tuple[int, int, int, int, int, int, int, int, int] = (
109+
128,
110+
128,
111+
128,
112+
128,
113+
128,
114+
128,
115+
128,
116+
128,
117+
128,
118+
),
96119
group_offset: jnp.ndarray | None = None,
97120
existing_out: jnp.ndarray | None = None,
98121
transpose_rhs: bool = False,
@@ -136,17 +159,18 @@ def _gmm_fwd(
136159
for axis_name, axis_idx in weight_gather_axes:
137160
rhs_qvalue = jax.lax.all_gather(rhs.qvalue, axis_name, axis=axis_idx, tiled=True)
138161
rhs = dataclasses.replace(rhs, qvalue=rhs_qvalue)
139-
out = tokamax_backend.gmm(
162+
# Handle transpose_rhs manually as ragged_dot assumes (G, K, N)
163+
if transpose_rhs:
164+
rhs = rhs.swapaxes(1, 2)
165+
166+
out = tokamax.ragged_dot(
140167
lhs=lhs,
141168
rhs=rhs,
142169
group_sizes=group_sizes,
143170
precision=jax.lax.Precision.DEFAULT,
144-
out_dtype=preferred_element_type,
145-
tiling=tiling[:3],
171+
preferred_element_type=preferred_element_type,
146172
group_offset=group_offset,
147-
transpose_rhs=transpose_rhs,
148-
interpret=interpret,
149-
input_buffer_count=input_buffer_count[0],
173+
implementation="mosaic",
150174
)
151175
else:
152176
out = backend.gmm(
@@ -168,8 +192,6 @@ def _gmm_bwd(
168192
rhs_dtype: jax.typing.DTypeLike,
169193
preferred_element_type: jnp.dtype,
170194
tiling: tuple[int, int, int, int, int, int, int, int, int],
171-
input_buffer_count: tuple[int, int, int],
172-
combine_scopes: bool,
173195
transpose_rhs: bool,
174196
interpret: bool,
175197
quantization_rule: qwix.QtRule | None,
@@ -224,30 +246,29 @@ def _gmm_bwd(
224246
calibration_method=quantization_rule.bwd_calibration_method,
225247
)
226248
if use_tokamax_backend:
227-
dlhs = tokamax_backend.gmm(
249+
# Handle transpose_rhs manually
250+
dlhs_rhs = rhs
251+
if not transpose_rhs:
252+
dlhs_rhs = dlhs_rhs.swapaxes(1, 2)
253+
254+
dlhs = tokamax.ragged_dot(
228255
lhs=dlhs_dout,
229-
rhs=rhs,
256+
rhs=dlhs_rhs,
230257
group_sizes=group_sizes,
231258
precision=jax.lax.Precision.DEFAULT,
232-
out_dtype=lhs_dtype,
233-
tiling=tiling[3:6],
259+
preferred_element_type=lhs_dtype,
234260
group_offset=group_offset,
235-
transpose_rhs=not transpose_rhs,
236-
interpret=interpret,
237-
input_buffer_count=input_buffer_count[1],
261+
implementation="mosaic",
238262
)
239-
drhs = tokamax_backend.tgmm(
240-
lhs=lhs.swapaxes(0, 1),
263+
drhs = tokamax.ragged_dot_general(
264+
lhs=lhs,
241265
rhs=drhs_dout,
242266
group_sizes=group_sizes,
267+
ragged_dot_dimension_numbers=DRHS_RAGGED_DOT_DIM_NUMS,
243268
precision=jax.lax.Precision.DEFAULT,
244-
out_dtype=rhs_dtype,
245-
tiling=tiling[-3:],
269+
preferred_element_type=rhs_dtype,
246270
group_offset=group_offset,
247-
num_actual_groups=num_actual_groups,
248-
interpret=interpret,
249-
input_buffer_count=input_buffer_count[2],
250-
combine_scopes=combine_scopes,
271+
implementation="mosaic",
251272
)
252273
if quantization_rule and quantization_rule.bwd_qtype and weight_gather_axes:
253274
# Scatter back in reverse order of gather

src/maxtext/layers/moe.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -969,7 +969,7 @@ def get_quantization_dtypes():
969969
return lhs_quantize_dtype, rhs_quantize_dtype
970970

971971
def gmm(
972-
inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes, input_buffer_count, combine_scopes
972+
inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes
973973
):
974974
if inputs.shape[0] != expert_assignments.shape[0]:
975975
raise ValueError("The number of input tokens must match the number of expert assignments!")
@@ -996,8 +996,6 @@ def gmm(
996996
use_qwix_quantization=self.config.use_qwix_quantization,
997997
use_tokamax_backend=self.config.use_tokamax_gmm,
998998
weight_gather_axes=weight_gather_axes,
999-
input_buffer_count=input_buffer_count,
1000-
combine_scopes=combine_scopes,
1001999
)
10021000
else: # tokamax (unquantized)
10031001
output = tokamax.ragged_dot(
@@ -1253,26 +1251,12 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
12531251
self.config.wo_tile_drhs_embed_dim,
12541252
self.config.wo_tile_drhs_mlp_dim,
12551253
)
1256-
wi_input_buffer_count = (
1257-
self.config.wi_tile_fwd_buffer_count,
1258-
self.config.wi_tile_dlhs_buffer_count,
1259-
self.config.wi_tile_drhs_buffer_count,
1260-
)
1261-
wo_input_buffer_count = (
1262-
self.config.wo_tile_fwd_buffer_count,
1263-
self.config.wo_tile_dlhs_buffer_count,
1264-
self.config.wo_tile_drhs_buffer_count,
1265-
)
12661254

1267-
wi_combine_scopes = self.config.wi_combine_scopes
1268-
wo_combine_scopes = self.config.wo_combine_scopes
12691255
layer_w0 = gmm_fn(
12701256
x,
12711257
w0,
12721258
tiling=wi_tile_size,
12731259
weight_gather_axes=wi_gather_axes,
1274-
input_buffer_count=wi_input_buffer_count,
1275-
combine_scopes=wi_combine_scopes,
12761260
)
12771261
if self.get_tensor_transpose_parallelism_size() > 1:
12781262
layer_w0 = jax.lax.psum(layer_w0, "tensor_transpose")
@@ -1285,8 +1269,6 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
12851269
w1,
12861270
tiling=wi_tile_size,
12871271
weight_gather_axes=wi_gather_axes,
1288-
input_buffer_count=wi_input_buffer_count,
1289-
combine_scopes=wi_combine_scopes,
12901272
)
12911273
if self.get_tensor_transpose_parallelism_size() > 1:
12921274
layer_w1 = jax.lax.psum(layer_w1, "tensor_transpose")
@@ -1300,8 +1282,6 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
13001282
wo,
13011283
tiling=wo_tile_size,
13021284
weight_gather_axes=wo_gather_axes,
1303-
input_buffer_count=wo_input_buffer_count,
1304-
combine_scopes=wo_combine_scopes,
13051285
)
13061286
if self.get_tensor_parallelism_size() > 1:
13071287
intermediate_output = jax.lax.psum_scatter(

src/maxtext/models/deepseek_batchsplit_fp8.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -949,8 +949,6 @@ def gmm(
949949
group_sizes,
950950
preferred_element_type,
951951
weight_gather_axes,
952-
input_buffer_count,
953-
combine_scopes,
954952
):
955953
if config.use_qwix_quantization:
956954
output = megablox.gmm(
@@ -962,8 +960,6 @@ def gmm(
962960
use_qwix_quantization=config.use_qwix_quantization,
963961
use_tokamax_backend=config.use_tokamax_gmm,
964962
weight_gather_axes=weight_gather_axes,
965-
input_buffer_count=input_buffer_count,
966-
combine_scopes=combine_scopes,
967963
qwix_rule=quantizations.get_fp8_full_qwix_rule(config),
968964
)
969965
else:
@@ -1006,19 +1002,7 @@ def gmm(
10061002
config.wo_tile_drhs_embed_dim,
10071003
config.wo_tile_drhs_mlp_dim,
10081004
)
1009-
wi_input_buffer_count = (
1010-
config.wi_tile_fwd_buffer_count,
1011-
config.wi_tile_dlhs_buffer_count,
1012-
config.wi_tile_drhs_buffer_count,
1013-
)
1014-
wo_input_buffer_count = (
1015-
config.wo_tile_fwd_buffer_count,
1016-
config.wo_tile_dlhs_buffer_count,
1017-
config.wo_tile_drhs_buffer_count,
1018-
)
10191005

1020-
wi_combine_scopes = config.wi_combine_scopes
1021-
wo_combine_scopes = config.wo_combine_scopes
10221006
if config.use_qwix_quantization:
10231007
gating_pspec, linear_pspec = moe_lib.get_batchsplit_init_kernel_axes()
10241008
w0_pspec = nn.logical_to_mesh_axes(gating_pspec)
@@ -1047,8 +1031,6 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
10471031
w01,
10481032
tiling=wi_tile_size,
10491033
weight_gather_axes=wi_gather_axes,
1050-
input_buffer_count=wi_input_buffer_count,
1051-
combine_scopes=wi_combine_scopes,
10521034
)
10531035
layer_w0, layer_w1 = jnp.split(layer_w01, 2, axis=-1)
10541036
else:
@@ -1057,16 +1039,12 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
10571039
w0,
10581040
tiling=wi_tile_size,
10591041
weight_gather_axes=wi_gather_axes,
1060-
input_buffer_count=wi_input_buffer_count,
1061-
combine_scopes=wi_combine_scopes,
10621042
)
10631043
layer_w1 = gmm_fn(
10641044
x,
10651045
w1,
10661046
tiling=wi_tile_size,
10671047
weight_gather_axes=wi_gather_axes,
1068-
input_buffer_count=wi_input_buffer_count,
1069-
combine_scopes=wi_combine_scopes,
10701048
)
10711049
layer_w0 = jax.ad_checkpoint.checkpoint_name(layer_w0, "mlpwi_0")
10721050
layer_w1 = jax.ad_checkpoint.checkpoint_name(layer_w1, "mlpwi_1")
@@ -1077,8 +1055,6 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
10771055
wo,
10781056
tiling=wo_tile_size,
10791057
weight_gather_axes=wo_gather_axes,
1080-
input_buffer_count=wo_input_buffer_count,
1081-
combine_scopes=wo_combine_scopes,
10821058
)
10831059
return layer_wo
10841060

0 commit comments

Comments
 (0)