@@ -427,9 +427,7 @@ def _wrapped(
427427 def reshape_activations (activations ):
428428 if activations .ndim == 4 : # pytype: disable=attribute-error
429429 kv_heads , q_heads_per_kv_head , q_seq_len , head_dim = activations .shape # pytype: disable=attribute-error
430- return activations .reshape (
431- kv_heads * q_heads_per_kv_head , q_seq_len , head_dim
432- ) # pytype: disable=attribute-error
430+ return activations .reshape (kv_heads * q_heads_per_kv_head , q_seq_len , head_dim ) # pytype: disable=attribute-error
433431 return activations
434432
435433 def reshape_residuals (residuals ):
@@ -1166,10 +1164,7 @@ def _splash_attention_fwd(
11661164 mask_function : MaskFunctionType | None ,
11671165 attn_logits_soft_cap : float | None = None ,
11681166 interpret : bool = False ,
1169- ) -> tuple [
1170- tuple [jax .Array ],
1171- SplashResidualsType ,
1172- ]:
1167+ ) -> tuple [tuple [jax .Array ], SplashResidualsType ,]:
11731168 """Forward pass for splash attention."""
11741169 if save_residuals :
11751170 raise NotImplementedError ("Higher-order AD not supported" )
@@ -1606,7 +1601,6 @@ def init():
16061601 )
16071602
16081603 def body (i , _ ):
1609-
16101604 slice_k = pl .ds (i * bkv_compute , bkv_compute )
16111605 q = q_ref [...] # We keep q potentially transposed, since it's always RHS
16121606
@@ -2238,6 +2232,120 @@ def _collapse_partial_mask_blocks(mask_info: mask_info_lib.MaskInfo | None):
22382232 )
22392233
22402234
2235+ @partial (
2236+ jax .jit ,
2237+ static_argnames = [
2238+ "is_mqa" ,
2239+ "block_sizes" ,
2240+ "save_residuals" ,
2241+ "mask_value" ,
2242+ "attn_logits_soft_cap" ,
2243+ "residual_checkpoint_name" ,
2244+ "mask_function" ,
2245+ "interpret" ,
2246+ ],
2247+ )
2248+ def _splash_attention_manual_fwd (
2249+ fwd_mask_info : mask_info_lib .MaskInfo ,
2250+ dq_mask_info : mask_info_lib .MaskInfo | None ,
2251+ dkv_mask_info : mask_info_lib .MaskInfo | None ,
2252+ q : jax .Array ,
2253+ k : jax .Array ,
2254+ v : jax .Array ,
2255+ segment_ids : SegmentIds | None = None ,
2256+ sinks : jax .Array | None = None ,
2257+ * ,
2258+ is_mqa : bool ,
2259+ block_sizes : BlockSizes | None ,
2260+ save_residuals : bool ,
2261+ mask_value : float ,
2262+ attn_logits_soft_cap : float | None ,
2263+ residual_checkpoint_name : str | None ,
2264+ mask_function : MaskFunctionType | None ,
2265+ interpret : bool ,
2266+ ) -> SplashCustomReturnType :
2267+ def _collapse_partial_mask_blocks (mask_info : mask_info_lib .MaskInfo | None ):
2268+ if mask_info is None or mask_info .partial_mask_blocks is None :
2269+ return mask_info
2270+
2271+ return mask_info ._replace (
2272+ partial_mask_blocks = mask_info .partial_mask_blocks .reshape (- 1 , * mask_info .partial_mask_blocks .shape [- 2 :])
2273+ )
2274+
2275+ if not save_residuals :
2276+ raise ValueError ("Expected save_residuals to be `True`." )
2277+
2278+ fwd_mask_info = _collapse_partial_mask_blocks (fwd_mask_info )
2279+ dq_mask_info = _collapse_partial_mask_blocks (dq_mask_info )
2280+ dkv_mask_info = _collapse_partial_mask_blocks (dkv_mask_info )
2281+ del dq_mask_info , dkv_mask_info
2282+
2283+ out , (logsumexp ,) = _splash_attention_forward ( # pytype: disable=wrong-arg-types
2284+ fwd_mask_info ,
2285+ q ,
2286+ k ,
2287+ v ,
2288+ segment_ids ,
2289+ mask_value = mask_value ,
2290+ is_mqa = is_mqa ,
2291+ block_sizes = block_sizes ,
2292+ residual_checkpoint_name = residual_checkpoint_name ,
2293+ save_residuals = True ,
2294+ mask_function = mask_function ,
2295+ attn_logits_soft_cap = attn_logits_soft_cap ,
2296+ interpret = interpret ,
2297+ )
2298+ return out , logsumexp
2299+
2300+
2301+ def _splash_attention_manual_bwd (
2302+ fwd_mask_info : mask_info_lib .MaskInfo ,
2303+ dq_mask_info : mask_info_lib .MaskInfo | None ,
2304+ dkv_mask_info : mask_info_lib .MaskInfo | None ,
2305+ q : jax .Array ,
2306+ k : jax .Array ,
2307+ v : jax .Array ,
2308+ out : jax .Array ,
2309+ logsumexp : jax .Array ,
2310+ do : jax .Array ,
2311+ segment_ids : SegmentIds | None = None ,
2312+ sinks : jax .Array | None = None ,
2313+ * ,
2314+ is_mqa : bool ,
2315+ block_sizes : BlockSizes | None ,
2316+ save_residuals : bool ,
2317+ mask_value : float ,
2318+ attn_logits_soft_cap : float | None ,
2319+ residual_checkpoint_name : str | None ,
2320+ mask_function : MaskFunctionType | None ,
2321+ interpret : bool ,
2322+ ):
2323+ del fwd_mask_info
2324+ res = (
2325+ q ,
2326+ k ,
2327+ v ,
2328+ segment_ids ,
2329+ out ,
2330+ logsumexp ,
2331+ dq_mask_info ,
2332+ dkv_mask_info ,
2333+ )
2334+ _ , _ , _ , dq , dk , dv , _ = _splash_attention_bwd (
2335+ save_residuals = save_residuals ,
2336+ mask_value = mask_value ,
2337+ is_mqa = is_mqa ,
2338+ block_sizes = block_sizes ,
2339+ residual_checkpoint_name = residual_checkpoint_name ,
2340+ mask_function = mask_function ,
2341+ attn_logits_soft_cap = attn_logits_soft_cap ,
2342+ interpret = interpret ,
2343+ res = res ,
2344+ do = do ,
2345+ )
2346+ return dq , dk , dv
2347+
2348+
22412349@jax .tree_util .register_pytree_node_class
22422350class SplashAttentionKernel :
22432351 """Defines a SplashAttention kernel object."""
@@ -2264,6 +2372,26 @@ def __call__(self, *args, **kwargs) -> SplashCustomReturnType:
22642372 ** self .kwargs ,
22652373 )
22662374
2375+ def manual_fwd (self , * args , ** kwargs ) -> SplashCustomReturnType :
2376+ return _splash_attention_manual_fwd (
2377+ self .fwd_mask_info ,
2378+ self .dq_mask_info ,
2379+ self .dkv_mask_info ,
2380+ * args ,
2381+ ** kwargs ,
2382+ ** self .kwargs ,
2383+ )
2384+
2385+ def manual_bwd (self , * args , ** kwargs ):
2386+ return _splash_attention_manual_bwd (
2387+ self .fwd_mask_info ,
2388+ self .dq_mask_info ,
2389+ self .dkv_mask_info ,
2390+ * args ,
2391+ ** kwargs ,
2392+ ** self .kwargs ,
2393+ )
2394+
22672395 def manual_sharding_spec (self , sharding : jax .sharding .NamedSharding ):
22682396 """Returns a value that can be used as a shard_map partition spec for the kernel."""
22692397 if self .fwd_mask_info .data_next is not None :
0 commit comments