@@ -427,9 +427,7 @@ def _wrapped(
427427 def reshape_activations (activations ):
428428 if activations .ndim == 4 : # pytype: disable=attribute-error
429429 kv_heads , q_heads_per_kv_head , q_seq_len , head_dim = activations .shape # pytype: disable=attribute-error
430- return activations .reshape (
431- kv_heads * q_heads_per_kv_head , q_seq_len , head_dim
432- ) # pytype: disable=attribute-error
430+ return activations .reshape (kv_heads * q_heads_per_kv_head , q_seq_len , head_dim ) # pytype: disable=attribute-error
433431 return activations
434432
435433 def reshape_residuals (residuals ):
@@ -1166,10 +1164,7 @@ def _splash_attention_fwd(
11661164 mask_function : MaskFunctionType | None ,
11671165 attn_logits_soft_cap : float | None = None ,
11681166 interpret : bool = False ,
1169- ) -> tuple [
1170- tuple [jax .Array ],
1171- SplashResidualsType ,
1172- ]:
1167+ ) -> tuple [tuple [jax .Array ], SplashResidualsType ,]:
11731168 """Forward pass for splash attention."""
11741169 if save_residuals :
11751170 raise NotImplementedError ("Higher-order AD not supported" )
@@ -1606,7 +1601,6 @@ def init():
16061601 )
16071602
16081603 def body (i , _ ):
1609-
16101604 slice_k = pl .ds (i * bkv_compute , bkv_compute )
16111605 q = q_ref [...] # We keep q potentially transposed, since it's always RHS
16121606
@@ -2238,6 +2232,126 @@ def _collapse_partial_mask_blocks(mask_info: mask_info_lib.MaskInfo | None):
22382232 )
22392233
22402234
2235+ @partial (
2236+ jax .jit ,
2237+ static_argnames = [
2238+ "is_mqa" ,
2239+ "block_sizes" ,
2240+ "save_residuals" ,
2241+ "mask_value" ,
2242+ "attn_logits_soft_cap" ,
2243+ "residual_checkpoint_name" ,
2244+ "mask_function" ,
2245+ "interpret" ,
2246+ ],
2247+ )
2248+ def _splash_attention_manual_fwd (
2249+ fwd_mask_info : mask_info_lib .MaskInfo ,
2250+ dq_mask_info : mask_info_lib .MaskInfo | None ,
2251+ dkv_mask_info : mask_info_lib .MaskInfo | None ,
2252+ q : jax .Array ,
2253+ k : jax .Array ,
2254+ v : jax .Array ,
2255+ segment_ids : SegmentIds | None = None ,
2256+ sinks : jax .Array | None = None ,
2257+ * ,
2258+ is_mqa : bool ,
2259+ block_sizes : BlockSizes | None ,
2260+ save_residuals : bool ,
2261+ mask_value : float ,
2262+ attn_logits_soft_cap : float | None ,
2263+ residual_checkpoint_name : str | None ,
2264+ mask_function : MaskFunctionType | None ,
2265+ interpret : bool ,
2266+ ) -> SplashCustomReturnType :
2267+ """Returns both the attention output and logsumexp.
2268+
2269+ This is useful when manually controlling remat in the backward pass, as both
2270+ can be returned as residuals from the forward pass."""
2271+
2272+ def _collapse_partial_mask_blocks (mask_info : mask_info_lib .MaskInfo | None ):
2273+ if mask_info is None or mask_info .partial_mask_blocks is None :
2274+ return mask_info
2275+
2276+ return mask_info ._replace (
2277+ partial_mask_blocks = mask_info .partial_mask_blocks .reshape (- 1 , * mask_info .partial_mask_blocks .shape [- 2 :])
2278+ )
2279+
2280+ if not save_residuals :
2281+ raise ValueError ("Expected save_residuals to be `True`." )
2282+
2283+ fwd_mask_info = _collapse_partial_mask_blocks (fwd_mask_info )
2284+ dq_mask_info = _collapse_partial_mask_blocks (dq_mask_info )
2285+ dkv_mask_info = _collapse_partial_mask_blocks (dkv_mask_info )
2286+ del dq_mask_info , dkv_mask_info
2287+
2288+ out , (logsumexp ,) = _splash_attention_forward ( # pytype: disable=wrong-arg-types
2289+ fwd_mask_info ,
2290+ q ,
2291+ k ,
2292+ v ,
2293+ segment_ids ,
2294+ mask_value = mask_value ,
2295+ is_mqa = is_mqa ,
2296+ block_sizes = block_sizes ,
2297+ residual_checkpoint_name = residual_checkpoint_name ,
2298+ save_residuals = True ,
2299+ mask_function = mask_function ,
2300+ attn_logits_soft_cap = attn_logits_soft_cap ,
2301+ interpret = interpret ,
2302+ )
2303+ return out , logsumexp
2304+
2305+
2306+ def _splash_attention_manual_bwd (
2307+ fwd_mask_info : mask_info_lib .MaskInfo ,
2308+ dq_mask_info : mask_info_lib .MaskInfo | None ,
2309+ dkv_mask_info : mask_info_lib .MaskInfo | None ,
2310+ q : jax .Array ,
2311+ k : jax .Array ,
2312+ v : jax .Array ,
2313+ out : jax .Array ,
2314+ logsumexp : jax .Array ,
2315+ do : jax .Array ,
2316+ segment_ids : SegmentIds | None = None ,
2317+ sinks : jax .Array | None = None ,
2318+ * ,
2319+ is_mqa : bool ,
2320+ block_sizes : BlockSizes | None ,
2321+ save_residuals : bool ,
2322+ mask_value : float ,
2323+ attn_logits_soft_cap : float | None ,
2324+ residual_checkpoint_name : str | None ,
2325+ mask_function : MaskFunctionType | None ,
2326+ interpret : bool ,
2327+ ):
2328+ """Transpose of _splash_attention_manual_fwd that uses attention output and logsumexp."""
2329+ del fwd_mask_info
2330+ res = (
2331+ q ,
2332+ k ,
2333+ v ,
2334+ segment_ids ,
2335+ out ,
2336+ logsumexp ,
2337+ dq_mask_info ,
2338+ dkv_mask_info ,
2339+ )
2340+ _ , _ , _ , dq , dk , dv , _ = _splash_attention_bwd (
2341+ save_residuals = save_residuals ,
2342+ mask_value = mask_value ,
2343+ is_mqa = is_mqa ,
2344+ block_sizes = block_sizes ,
2345+ residual_checkpoint_name = residual_checkpoint_name ,
2346+ mask_function = mask_function ,
2347+ attn_logits_soft_cap = attn_logits_soft_cap ,
2348+ interpret = interpret ,
2349+ res = res ,
2350+ do = do ,
2351+ )
2352+ return dq , dk , dv
2353+
2354+
22412355@jax .tree_util .register_pytree_node_class
22422356class SplashAttentionKernel :
22432357 """Defines a SplashAttention kernel object."""
@@ -2264,6 +2378,26 @@ def __call__(self, *args, **kwargs) -> SplashCustomReturnType:
22642378 ** self .kwargs ,
22652379 )
22662380
2381+ def manual_fwd (self , * args , ** kwargs ) -> SplashCustomReturnType :
2382+ return _splash_attention_manual_fwd (
2383+ self .fwd_mask_info ,
2384+ self .dq_mask_info ,
2385+ self .dkv_mask_info ,
2386+ * args ,
2387+ ** kwargs ,
2388+ ** self .kwargs ,
2389+ )
2390+
2391+ def manual_bwd (self , * args , ** kwargs ):
2392+ return _splash_attention_manual_bwd (
2393+ self .fwd_mask_info ,
2394+ self .dq_mask_info ,
2395+ self .dkv_mask_info ,
2396+ * args ,
2397+ ** kwargs ,
2398+ ** self .kwargs ,
2399+ )
2400+
22672401 def manual_sharding_spec (self , sharding : jax .sharding .NamedSharding ):
22682402 """Returns a value that can be used as a shard_map partition spec for the kernel."""
22692403 if self .fwd_mask_info .data_next is not None :
0 commit comments