Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -223,16 +223,6 @@ wo_tile_drhs_batch_seq: 512
wo_tile_drhs_embed_dim: 1024
wo_tile_drhs_mlp_dim: 1024

wi_tile_fwd_buffer_count: 2
wi_tile_dlhs_buffer_count: 2
wi_tile_drhs_buffer_count: 2
wo_tile_fwd_buffer_count: 2
wo_tile_dlhs_buffer_count: 2
wo_tile_drhs_buffer_count: 2

wi_combine_scopes: False
wo_combine_scopes: False

merge_gating_gmm: False

norm_topk_prob: false # boolean to enable the top-k probability normalization. qwen3-specific normalization of router weights.
Expand Down
10 changes: 0 additions & 10 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,16 +734,6 @@ class MoEKernels(BaseModel):
wo_tile_drhs_embed_dim: int = Field(1024, description="bwd pass drhs tiling dimension for embedding in GMM for wo.")
wo_tile_drhs_mlp_dim: int = Field(1024, description="bwd pass drhs tiling dimension for MLP in GMM for wo.")

wi_tile_fwd_buffer_count: int = Field(2, description="forward pass tiling buffer count in GMM for wi.")
wi_tile_dlhs_buffer_count: int = Field(2, description="bwd pass dlhs tiling buffer count in GMM for wi.")
wi_tile_drhs_buffer_count: int = Field(2, description="bwd pass drhs tiling buffer count in GMM for wi.")
wo_tile_fwd_buffer_count: int = Field(2, description="forward pass tiling buffer count in GMM for wo.")
wo_tile_dlhs_buffer_count: int = Field(2, description="bwd pass dlhs tiling buffer count in GMM for wo.")
wo_tile_drhs_buffer_count: int = Field(2, description="bwd pass drhs tiling buffer count in GMM for wo.")

wi_combine_scopes: bool = Field(False, description="whether to use combine_scopes features for tgmm for wi.")
wo_combine_scopes: bool = Field(False, description="whether to use combine_scopes features for tgmm for wo.")

merge_gating_gmm: bool = Field(False, description="whether to merge the two gating gmm kernels into one.")


Expand Down
89 changes: 54 additions & 35 deletions src/maxtext/kernels/megablox/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,40 @@

# pylint: disable=too-many-positional-arguments

import functools
import dataclasses
from typing import Literal, List, Tuple
import functools
from typing import List, Literal, Tuple
import jax
import jax.numpy as jnp
from maxtext.kernels.megablox import backend
from tokamax._src.ops.ragged_dot import pallas_mosaic_tpu_kernel as tokamax_backend
import qwix
import qwix.pallas as qpl
import tokamax


DRHS_RAGGED_DOT_DIM_NUMS = jax.lax.RaggedDotDimensionNumbers(
dot_dimension_numbers=(([0], [0]), ([], [])),
lhs_ragged_dimensions=[0],
rhs_group_dimensions=[],
)


def gmm(
lhs: jnp.ndarray,
rhs: jnp.ndarray,
group_sizes: jnp.ndarray,
preferred_element_type: jnp.dtype = jnp.float32,
tiling: tuple[int, int, int, int, int, int, int, int, int] = (128, 128, 128, 128, 128, 128, 128, 128, 128),
tiling: tuple[int, int, int, int, int, int, int, int, int] = (
128,
128,
128,
128,
128,
128,
128,
128,
128,
),
group_offset: jnp.ndarray | None = None,
existing_out: jnp.ndarray | None = None,
transpose_rhs: bool = False,
Expand All @@ -42,8 +59,6 @@ def gmm(
use_qwix_quantization: bool = False,
use_tokamax_backend: bool = False,
weight_gather_axes: List[Tuple[str, int]] | None = None,
input_buffer_count: tuple[int, int, int] = (2, 2, 2),
combine_scopes: bool = False,
# TODO(amandaliang): get rid of the qwix_rule in favor of Qwix's interception feature
qwix_rule: qwix.QtRule | None = None,
):
Expand All @@ -65,16 +80,14 @@ def gmm(
)

gmm_fwd_bwd = lambda *args: _gmm_fwd(*args)[0] # pylint: disable=C3001
gmm_fwd_bwd = jax.custom_vjp(gmm_fwd_bwd, nondiff_argnums=(3, 4, 5, 6, 9, 10, 11, 12, 13))
gmm_fwd_bwd = jax.custom_vjp(gmm_fwd_bwd, nondiff_argnums=(3, 4, 7, 8, 9, 10, 11))
gmm_fwd_bwd.defvjp(_gmm_fwd, functools.partial(_gmm_bwd, lhs.dtype, rhs.dtype))
return gmm_fwd_bwd(
lhs,
rhs,
group_sizes,
preferred_element_type,
tiling,
input_buffer_count,
combine_scopes,
group_offset,
existing_out,
transpose_rhs,
Expand All @@ -90,9 +103,17 @@ def _gmm_fwd(
rhs: jnp.ndarray,
group_sizes: jnp.ndarray,
preferred_element_type: jnp.dtype = jnp.float32,
tiling: tuple[int, int, int, int, int, int, int, int, int] = (128, 128, 128, 128, 128, 128, 128, 128, 128),
input_buffer_count: tuple[int, int, int] = (2, 2, 2),
combine_scopes: bool = False,
tiling: tuple[int, int, int, int, int, int, int, int, int] = (
128,
128,
128,
128,
128,
128,
128,
128,
128,
),
group_offset: jnp.ndarray | None = None,
existing_out: jnp.ndarray | None = None,
transpose_rhs: bool = False,
Expand Down Expand Up @@ -136,17 +157,18 @@ def _gmm_fwd(
for axis_name, axis_idx in weight_gather_axes:
rhs_qvalue = jax.lax.all_gather(rhs.qvalue, axis_name, axis=axis_idx, tiled=True)
rhs = dataclasses.replace(rhs, qvalue=rhs_qvalue)
out = tokamax_backend.gmm(
# Handle transpose_rhs manually as ragged_dot assumes (G, K, N)
if transpose_rhs:
rhs = rhs.swapaxes(1, 2)

out = tokamax.ragged_dot(
lhs=lhs,
rhs=rhs,
group_sizes=group_sizes,
precision=jax.lax.Precision.DEFAULT,
out_dtype=preferred_element_type,
tiling=tiling[:3],
preferred_element_type=preferred_element_type,
group_offset=group_offset,
transpose_rhs=transpose_rhs,
interpret=interpret,
input_buffer_count=input_buffer_count[0],
implementation="mosaic",
)
else:
out = backend.gmm(
Expand All @@ -168,8 +190,6 @@ def _gmm_bwd(
rhs_dtype: jax.typing.DTypeLike,
preferred_element_type: jnp.dtype,
tiling: tuple[int, int, int, int, int, int, int, int, int],
input_buffer_count: tuple[int, int, int],
combine_scopes: bool,
transpose_rhs: bool,
interpret: bool,
quantization_rule: qwix.QtRule | None,
Expand Down Expand Up @@ -224,30 +244,29 @@ def _gmm_bwd(
calibration_method=quantization_rule.bwd_calibration_method,
)
if use_tokamax_backend:
dlhs = tokamax_backend.gmm(
# Handle transpose_rhs manually
dlhs_rhs = rhs
if not transpose_rhs:
dlhs_rhs = dlhs_rhs.swapaxes(1, 2)

dlhs = tokamax.ragged_dot(
lhs=dlhs_dout,
rhs=rhs,
rhs=dlhs_rhs,
group_sizes=group_sizes,
precision=jax.lax.Precision.DEFAULT,
out_dtype=lhs_dtype,
tiling=tiling[3:6],
preferred_element_type=lhs_dtype,
group_offset=group_offset,
transpose_rhs=not transpose_rhs,
interpret=interpret,
input_buffer_count=input_buffer_count[1],
implementation="mosaic",
)
drhs = tokamax_backend.tgmm(
lhs=lhs.swapaxes(0, 1),
drhs = tokamax.ragged_dot_general(
lhs=lhs,
rhs=drhs_dout,
group_sizes=group_sizes,
ragged_dot_dimension_numbers=DRHS_RAGGED_DOT_DIM_NUMS,
precision=jax.lax.Precision.DEFAULT,
out_dtype=rhs_dtype,
tiling=tiling[-3:],
preferred_element_type=rhs_dtype,
group_offset=group_offset,
num_actual_groups=num_actual_groups,
interpret=interpret,
input_buffer_count=input_buffer_count[2],
combine_scopes=combine_scopes,
implementation="mosaic",
)
if quantization_rule and quantization_rule.bwd_qtype and weight_gather_axes:
# Scatter back in reverse order of gather
Expand Down
24 changes: 1 addition & 23 deletions src/maxtext/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,9 +965,7 @@ def get_quantization_dtypes():
rhs_quantize_dtype = quant_dg.fwd.dg_quantizer.rhs.numerics.get_dtype()
return lhs_quantize_dtype, rhs_quantize_dtype

def gmm(
inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes, input_buffer_count, combine_scopes
):
def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes):
if inputs.shape[0] != expert_assignments.shape[0]:
raise ValueError("The number of input tokens must match the number of expert assignments!")

Expand All @@ -993,8 +991,6 @@ def gmm(
use_qwix_quantization=self.config.use_qwix_quantization,
use_tokamax_backend=self.config.use_tokamax_gmm,
weight_gather_axes=weight_gather_axes,
input_buffer_count=input_buffer_count,
combine_scopes=combine_scopes,
)
else: # tokamax (unquantized)
output = tokamax.ragged_dot(
Expand Down Expand Up @@ -1250,26 +1246,12 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
self.config.wo_tile_drhs_embed_dim,
self.config.wo_tile_drhs_mlp_dim,
)
wi_input_buffer_count = (
self.config.wi_tile_fwd_buffer_count,
self.config.wi_tile_dlhs_buffer_count,
self.config.wi_tile_drhs_buffer_count,
)
wo_input_buffer_count = (
self.config.wo_tile_fwd_buffer_count,
self.config.wo_tile_dlhs_buffer_count,
self.config.wo_tile_drhs_buffer_count,
)

wi_combine_scopes = self.config.wi_combine_scopes
wo_combine_scopes = self.config.wo_combine_scopes
layer_w0 = gmm_fn(
x,
w0,
tiling=wi_tile_size,
weight_gather_axes=wi_gather_axes,
input_buffer_count=wi_input_buffer_count,
combine_scopes=wi_combine_scopes,
)
if self.get_tensor_transpose_parallelism_size() > 1:
layer_w0 = jax.lax.psum(layer_w0, "tensor_transpose")
Expand All @@ -1282,8 +1264,6 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
w1,
tiling=wi_tile_size,
weight_gather_axes=wi_gather_axes,
input_buffer_count=wi_input_buffer_count,
combine_scopes=wi_combine_scopes,
)
if self.get_tensor_transpose_parallelism_size() > 1:
layer_w1 = jax.lax.psum(layer_w1, "tensor_transpose")
Expand All @@ -1297,8 +1277,6 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
wo,
tiling=wo_tile_size,
weight_gather_axes=wo_gather_axes,
input_buffer_count=wo_input_buffer_count,
combine_scopes=wo_combine_scopes,
)
if self.get_tensor_parallelism_size() > 1:
intermediate_output = jax.lax.psum_scatter(
Expand Down
24 changes: 0 additions & 24 deletions src/maxtext/models/deepseek_batchsplit_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,8 +948,6 @@ def gmm(
group_sizes,
preferred_element_type,
weight_gather_axes,
input_buffer_count,
combine_scopes,
):
if config.use_qwix_quantization:
output = megablox.gmm(
Expand All @@ -961,8 +959,6 @@ def gmm(
use_qwix_quantization=config.use_qwix_quantization,
use_tokamax_backend=config.use_tokamax_gmm,
weight_gather_axes=weight_gather_axes,
input_buffer_count=input_buffer_count,
combine_scopes=combine_scopes,
qwix_rule=quantizations.get_fp8_full_qwix_rule(config),
)
else:
Expand Down Expand Up @@ -1002,19 +998,7 @@ def gmm(
config.wo_tile_drhs_embed_dim,
config.wo_tile_drhs_mlp_dim,
)
wi_input_buffer_count = (
config.wi_tile_fwd_buffer_count,
config.wi_tile_dlhs_buffer_count,
config.wi_tile_drhs_buffer_count,
)
wo_input_buffer_count = (
config.wo_tile_fwd_buffer_count,
config.wo_tile_dlhs_buffer_count,
config.wo_tile_drhs_buffer_count,
)

wi_combine_scopes = config.wi_combine_scopes
wo_combine_scopes = config.wo_combine_scopes
if config.use_qwix_quantization:
gating_pspec, linear_pspec = moe_lib.get_batchsplit_init_kernel_axes()
w0_pspec = nn.logical_to_mesh_axes(gating_pspec)
Expand Down Expand Up @@ -1043,8 +1027,6 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
w01,
tiling=wi_tile_size,
weight_gather_axes=wi_gather_axes,
input_buffer_count=wi_input_buffer_count,
combine_scopes=wi_combine_scopes,
)
layer_w0, layer_w1 = jnp.split(layer_w01, 2, axis=-1)
else:
Expand All @@ -1053,16 +1035,12 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
w0,
tiling=wi_tile_size,
weight_gather_axes=wi_gather_axes,
input_buffer_count=wi_input_buffer_count,
combine_scopes=wi_combine_scopes,
)
layer_w1 = gmm_fn(
x,
w1,
tiling=wi_tile_size,
weight_gather_axes=wi_gather_axes,
input_buffer_count=wi_input_buffer_count,
combine_scopes=wi_combine_scopes,
)
layer_w0 = jax.ad_checkpoint.checkpoint_name(layer_w0, "mlpwi_0")
layer_w1 = jax.ad_checkpoint.checkpoint_name(layer_w1, "mlpwi_1")
Expand All @@ -1073,8 +1051,6 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
wo,
tiling=wo_tile_size,
weight_gather_axes=wo_gather_axes,
input_buffer_count=wo_input_buffer_count,
combine_scopes=wo_combine_scopes,
)
return layer_wo

Expand Down
Loading