2424def route (
2525 tokens : jax .Array ,
2626 selected_experts : jax .Array ,
27- use_custom_mosaic_kernel : bool ,
27+ use_gather_mosaic_kernel : bool ,
2828) -> jax .Array :
2929 """Route tokens to selected experts."""
30- return _route_fwd (tokens , selected_experts , use_custom_mosaic_kernel )[0 ]
30+ return _route_fwd (tokens , selected_experts , use_gather_mosaic_kernel )[0 ]
3131
3232
3333def _route_fwd (
3434 tokens : jax .Array ,
3535 selected_experts : jax .Array ,
36- use_custom_mosaic_kernel : bool ,
36+ use_gather_mosaic_kernel : bool ,
3737) -> tuple [jax .Array , jax .Array ]:
3838 return (
39- _route_impl (tokens , selected_experts , use_custom_mosaic_kernel ),
39+ _route_impl (tokens , selected_experts , use_gather_mosaic_kernel ),
4040 selected_experts ,
4141 )
4242
4343
4444def _route_bwd (
45- use_custom_mosaic_kernel : bool ,
45+ use_gather_mosaic_kernel : bool ,
4646 residuals : jax .Array ,
4747 grads : jax .Array ,
4848) -> tuple [jax .Array , None ]:
4949 selected_experts = residuals
50- return _unroute_impl (grads , selected_experts , use_custom_mosaic_kernel ), None
50+ return _unroute_impl (grads , selected_experts , use_gather_mosaic_kernel ), None
5151
5252
5353route .defvjp (_route_fwd , _route_bwd )
@@ -57,25 +57,25 @@ def _route_bwd(
5757def unroute (
5858 tokens : jax .Array ,
5959 selected_experts : jax .Array ,
60- use_custom_mosaic_kernel : bool ,
60+ use_gather_mosaic_kernel : bool ,
6161) -> jax .Array :
62- return _unroute_fwd (tokens , selected_experts , use_custom_mosaic_kernel )[0 ]
62+ return _unroute_fwd (tokens , selected_experts , use_gather_mosaic_kernel )[0 ]
6363
6464
6565def _unroute_fwd (
6666 tokens : jax .Array ,
6767 selected_experts : jax .Array ,
68- use_custom_mosaic_kernel : bool ,
68+ use_gather_mosaic_kernel : bool ,
6969) -> tuple [jax .Array , jax .Array ]:
7070 return (
71- _unroute_impl (tokens , selected_experts , use_custom_mosaic_kernel ),
71+ _unroute_impl (tokens , selected_experts , use_gather_mosaic_kernel ),
7272 selected_experts ,
7373 )
7474
7575
76- def _unroute_bwd (use_custom_mosaic_kernel : bool , residuals : jax .Array , grads : jax .Array ) -> tuple [jax .Array , None ]:
76+ def _unroute_bwd (use_gather_mosaic_kernel : bool , residuals : jax .Array , grads : jax .Array ) -> tuple [jax .Array , None ]:
7777 selected_experts = residuals
78- return _route_impl (grads , selected_experts , use_custom_mosaic_kernel ), None
78+ return _route_impl (grads , selected_experts , use_gather_mosaic_kernel ), None
7979
8080
8181unroute .defvjp (_unroute_fwd , _unroute_bwd )
@@ -84,37 +84,32 @@ def _unroute_bwd(use_custom_mosaic_kernel: bool, residuals: jax.Array, grads: ja
8484def _route_impl (
8585 tokens : jax .Array ,
8686 selected_experts : jax .Array ,
87- use_custom_mosaic_kernel : bool ,
87+ use_gather_mosaic_kernel : bool ,
8888) -> jax .Array :
8989 """Gather `tokens` according to `selected_experts`."""
9090 assert (
9191 tokens .shape [0 ] == selected_experts .shape [0 ] and selected_experts .ndim == 2
9292 ), f"{ tokens .shape = } , { selected_experts .shape = } "
93- if use_custom_mosaic_kernel :
94- raise NotImplementedError ("Custom Mosaic kernel not implemented." )
9593 inds = jnp .argsort (jnp .ravel (selected_experts )) // selected_experts .shape [1 ]
96- return _sort_impl (tokens , inds , use_custom_mosaic_kernel )
94+ return _sort_impl (tokens , inds , use_gather_mosaic_kernel )
9795
9896
9997def _unroute_impl (
10098 tokens : jax .Array ,
10199 selected_experts : jax .Array ,
102- use_custom_mosaic_kernel : bool ,
100+ use_gather_mosaic_kernel : bool ,
103101) -> jax .Array :
104102 """Reverse the routing operation, restoring tokens to their original order."""
105103 assert tokens .shape [0 ] == selected_experts .shape [0 ] * selected_experts .shape [1 ] and selected_experts .ndim == 2
106104 inds = jnp .argsort (jnp .argsort (jnp .ravel (selected_experts )))
107105 return jnp .sum (
108106 jnp .reshape (
109- _sort_impl (tokens , inds , use_custom_mosaic_kernel ),
107+ _sort_impl (tokens , inds , use_gather_mosaic_kernel ),
110108 (- 1 , selected_experts .shape [1 ]) + tokens .shape [1 :],
111109 ),
112110 axis = 1 ,
113111 )
114112
115113
116- def _sort_impl (tokens : jax .Array , inds : jax .Array , use_custom_mosaic_kernel : bool ) -> jax .Array :
117- if use_custom_mosaic_kernel :
118- raise NotImplementedError ("Custom Mosaic kernel not implemented." )
119- else :
120- return tokens [inds , ...]
114+ def _sort_impl (tokens : jax .Array , inds : jax .Array , use_gather_mosaic_kernel : bool ) -> jax .Array :
115+ return tokens [inds , ...]
0 commit comments