1616
1717# pylint: disable=too-many-positional-arguments
1818
19- import functools
2019import dataclasses
21- from typing import Literal , List , Tuple
20+ import functools
21+ from typing import List , Literal , Tuple
2222import jax
2323import jax .numpy as jnp
2424from maxtext .kernels .megablox import backend
25- from tokamax ._src .ops .ragged_dot import pallas_mosaic_tpu_kernel as tokamax_backend
2625import qwix
2726import qwix .pallas as qpl
27+ import tokamax
28+
29+
30+ DRHS_RAGGED_DOT_DIM_NUMS = jax .lax .RaggedDotDimensionNumbers (
31+ dot_dimension_numbers = (([0 ], [0 ]), ([], [])),
32+ lhs_ragged_dimensions = [0 ],
33+ rhs_group_dimensions = [],
34+ )
2835
2936
3037def gmm (
3138 lhs : jnp .ndarray ,
3239 rhs : jnp .ndarray ,
3340 group_sizes : jnp .ndarray ,
3441 preferred_element_type : jnp .dtype = jnp .float32 ,
35- tiling : tuple [int , int , int , int , int , int , int , int , int ] = (128 , 128 , 128 , 128 , 128 , 128 , 128 , 128 , 128 ),
42+ tiling : tuple [int , int , int , int , int , int , int , int , int ] = (
43+ 128 ,
44+ 128 ,
45+ 128 ,
46+ 128 ,
47+ 128 ,
48+ 128 ,
49+ 128 ,
50+ 128 ,
51+ 128 ,
52+ ),
3653 group_offset : jnp .ndarray | None = None ,
3754 existing_out : jnp .ndarray | None = None ,
3855 transpose_rhs : bool = False ,
@@ -42,8 +59,6 @@ def gmm(
4259 use_qwix_quantization : bool = False ,
4360 use_tokamax_backend : bool = False ,
4461 weight_gather_axes : List [Tuple [str , int ]] | None = None ,
45- input_buffer_count : tuple [int , int , int ] = (2 , 2 , 2 ),
46- combine_scopes : bool = False ,
4762 # TODO(amandaliang): get rid of the qwix_rule in favor of Qwix's interception feature
4863 qwix_rule : qwix .QtRule | None = None ,
4964):
@@ -65,16 +80,16 @@ def gmm(
6580 )
6681
6782 gmm_fwd_bwd = lambda * args : _gmm_fwd (* args )[0 ] # pylint: disable=C3001
68- gmm_fwd_bwd = jax .custom_vjp (gmm_fwd_bwd , nondiff_argnums = (3 , 4 , 5 , 6 , 9 , 10 , 11 , 12 , 13 ))
83+ gmm_fwd_bwd = jax .custom_vjp (
84+ gmm_fwd_bwd , nondiff_argnums = (3 , 4 , 7 , 8 , 9 , 10 , 11 )
85+ )
6986 gmm_fwd_bwd .defvjp (_gmm_fwd , functools .partial (_gmm_bwd , lhs .dtype , rhs .dtype ))
7087 return gmm_fwd_bwd (
7188 lhs ,
7289 rhs ,
7390 group_sizes ,
7491 preferred_element_type ,
7592 tiling ,
76- input_buffer_count ,
77- combine_scopes ,
7893 group_offset ,
7994 existing_out ,
8095 transpose_rhs ,
@@ -90,9 +105,17 @@ def _gmm_fwd(
90105 rhs : jnp .ndarray ,
91106 group_sizes : jnp .ndarray ,
92107 preferred_element_type : jnp .dtype = jnp .float32 ,
93- tiling : tuple [int , int , int , int , int , int , int , int , int ] = (128 , 128 , 128 , 128 , 128 , 128 , 128 , 128 , 128 ),
94- input_buffer_count : tuple [int , int , int ] = (2 , 2 , 2 ),
95- combine_scopes : bool = False ,
108+ tiling : tuple [int , int , int , int , int , int , int , int , int ] = (
109+ 128 ,
110+ 128 ,
111+ 128 ,
112+ 128 ,
113+ 128 ,
114+ 128 ,
115+ 128 ,
116+ 128 ,
117+ 128 ,
118+ ),
96119 group_offset : jnp .ndarray | None = None ,
97120 existing_out : jnp .ndarray | None = None ,
98121 transpose_rhs : bool = False ,
@@ -136,17 +159,18 @@ def _gmm_fwd(
136159 for axis_name , axis_idx in weight_gather_axes :
137160 rhs_qvalue = jax .lax .all_gather (rhs .qvalue , axis_name , axis = axis_idx , tiled = True )
138161 rhs = dataclasses .replace (rhs , qvalue = rhs_qvalue )
139- out = tokamax_backend .gmm (
162+ # Handle transpose_rhs manually as ragged_dot assumes (G, K, N)
163+ if transpose_rhs :
164+ rhs = rhs .swapaxes (1 , 2 )
165+
166+ out = tokamax .ragged_dot (
140167 lhs = lhs ,
141168 rhs = rhs ,
142169 group_sizes = group_sizes ,
143170 precision = jax .lax .Precision .DEFAULT ,
144- out_dtype = preferred_element_type ,
145- tiling = tiling [:3 ],
171+ preferred_element_type = preferred_element_type ,
146172 group_offset = group_offset ,
147- transpose_rhs = transpose_rhs ,
148- interpret = interpret ,
149- input_buffer_count = input_buffer_count [0 ],
173+ implementation = "mosaic" ,
150174 )
151175 else :
152176 out = backend .gmm (
@@ -168,8 +192,6 @@ def _gmm_bwd(
168192 rhs_dtype : jax .typing .DTypeLike ,
169193 preferred_element_type : jnp .dtype ,
170194 tiling : tuple [int , int , int , int , int , int , int , int , int ],
171- input_buffer_count : tuple [int , int , int ],
172- combine_scopes : bool ,
173195 transpose_rhs : bool ,
174196 interpret : bool ,
175197 quantization_rule : qwix .QtRule | None ,
@@ -224,30 +246,29 @@ def _gmm_bwd(
224246 calibration_method = quantization_rule .bwd_calibration_method ,
225247 )
226248 if use_tokamax_backend :
227- dlhs = tokamax_backend .gmm (
249+ # Handle transpose_rhs manually
250+ dlhs_rhs = rhs
251+ if not transpose_rhs :
252+ dlhs_rhs = dlhs_rhs .swapaxes (1 , 2 )
253+
254+ dlhs = tokamax .ragged_dot (
228255 lhs = dlhs_dout ,
229- rhs = rhs ,
256+ rhs = dlhs_rhs ,
230257 group_sizes = group_sizes ,
231258 precision = jax .lax .Precision .DEFAULT ,
232- out_dtype = lhs_dtype ,
233- tiling = tiling [3 :6 ],
259+ preferred_element_type = lhs_dtype ,
234260 group_offset = group_offset ,
235- transpose_rhs = not transpose_rhs ,
236- interpret = interpret ,
237- input_buffer_count = input_buffer_count [1 ],
261+ implementation = "mosaic" ,
238262 )
239- drhs = tokamax_backend . tgmm (
240- lhs = lhs . swapaxes ( 0 , 1 ) ,
263+ drhs = tokamax . ragged_dot_general (
264+ lhs = lhs ,
241265 rhs = drhs_dout ,
242266 group_sizes = group_sizes ,
267+ ragged_dot_dimension_numbers = DRHS_RAGGED_DOT_DIM_NUMS ,
243268 precision = jax .lax .Precision .DEFAULT ,
244- out_dtype = rhs_dtype ,
245- tiling = tiling [- 3 :],
269+ preferred_element_type = rhs_dtype ,
246270 group_offset = group_offset ,
247- num_actual_groups = num_actual_groups ,
248- interpret = interpret ,
249- input_buffer_count = input_buffer_count [2 ],
250- combine_scopes = combine_scopes ,
271+ implementation = "mosaic" ,
251272 )
252273 if quantization_rule and quantization_rule .bwd_qtype and weight_gather_axes :
253274 # Scatter back in reverse order of gather
0 commit comments