@@ -859,6 +859,308 @@ def init_if_empty(x: jax.Array, value: float) -> jax.Array:
859859 return out
860860
861861
862+ def _splash_attention_forward_ring_raw (
863+ mask_info : MaskInfo ,
864+ q : jax .Array ,
865+ k : jax .Array ,
866+ v : jax .Array ,
867+ segment_ids : base .SegmentIds | None ,
868+ sinks : jax .Array | None ,
869+ mask_value : float ,
870+ is_mqa : bool ,
871+ config : SplashConfig ,
872+ mask_function : MaskFunctionType | None ,
873+ fwd_mask_sparsity : float ,
874+ max_logit_value : jax .Array | None = None ,
875+ ) -> tuple [jax .Array , dict [str , jax .Array ]]:
876+ """Ring-specific forward path that returns pre-reciprocal fp32 accumulators.
877+
878+ Unlike `_splash_attention_forward`, this helper is intended for ring attention
879+ merging and returns the raw fp32 numerator (`out_linear`) together with the
880+ linear softmax denominator (`l_linear`) and per-row max logits (`max_logits`).
881+ This lets the outer ring kernel merge shard contributions and normalize only
882+ once at the very end.
883+ """
884+ num_q_heads , q_seq_len , head_dim_qk = q .shape
885+ head_dim_v = v .shape [- 1 ]
886+ bq , bkv = config .block_q , config .block_kv
887+ bkv_compute = config .block_kv_compute
888+ bounds_start , bounds_end = mask_info_lib .find_bounds (mask_info .active_rows )
889+
890+ if is_mqa :
891+ expected_kv_rank = 2
892+ num_kv_heads = 1
893+ else :
894+ expected_kv_rank = 3
895+ num_kv_heads = k .shape [0 ]
896+
897+ if len (k .shape ) != expected_kv_rank :
898+ raise ValueError (
899+ f"Expected { expected_kv_rank } -dim 'key' tensor for MQA. Instead got a { len (k .shape )} -dim one."
900+ )
901+
902+ if k .shape [- 1 ] != head_dim_qk :
903+ raise ValueError (f"Expected 'key' head dimension to be: { head_dim_qk } . Instead got: { k .shape [- 1 ]} ." )
904+
905+ if not is_mqa and num_q_heads % num_kv_heads != 0 :
906+ raise ValueError (
907+ f"In MHA, expected number of 'key' heads ({ num_kv_heads } ) to be a multiple of the number of "
908+ f"'query' heads ({ num_q_heads } )"
909+ )
910+
911+ if k .shape [:- 1 ] != v .shape [:- 1 ]:
912+ raise ValueError (f"Expected 'key' { k .shape } and 'value' { v .shape } to have the same leading dimensions." )
913+
914+ if bkv % bkv_compute :
915+ raise ValueError (f"{ bkv = } must be a multiple of { bkv_compute = } ." )
916+ if bkv_compute % NUM_LANES :
917+ raise ValueError (f"{ bkv_compute = } must be a multiple of { NUM_LANES } ." )
918+
919+ kv_seq_len = k .shape [- 2 ]
920+ kv_steps = kv_seq_len // bkv
921+ q_heads_per_kv_head = num_q_heads // num_kv_heads
922+ dynamic_grid = mask_info .active_rows is not None
923+
924+ if segment_ids is not None :
925+ assert isinstance (segment_ids .q , jax .Array )
926+ assert isinstance (segment_ids .kv , jax .Array )
927+ if segment_ids .q .shape != (q_seq_len ,):
928+ raise ValueError (f"Invalid shape for q segment_ids: { segment_ids .q .shape } . Expected: { (q_seq_len ,)} " )
929+ if segment_ids .kv .shape != (kv_seq_len ,):
930+ raise ValueError (f"Invalid shape for kv segment_ids: { segment_ids .kv .shape } . Expected: { (kv_seq_len ,)} " )
931+
932+ if config .max_logit_const is not None and max_logit_value is not None :
933+ raise ValueError (f"Only one of { config .max_logit_const = } and { max_logit_value = } can be set." )
934+ if max_logit_value is not None :
935+ if max_logit_value .shape not in ((), (1 ,), (num_q_heads ,)):
936+ raise ValueError (
937+ "max_logit_value should be a 0,1-dim jax.Array of shape (), (1,) or "
938+ f"({ num_q_heads = } ,) but got { jax .typeof (max_logit_value )} "
939+ )
940+ max_logit_value = jnp .broadcast_to (jnp .atleast_1d (max_logit_value ), (num_q_heads ,))
941+
942+ q_layout = config .q_layout
943+ k_layout = config .k_layout
944+ v_layout = config .v_layout
945+
946+ def unravel (f ):
947+ def index_map (h , grid_idx , rows_ref , cols_ref , * _ ):
948+ if dynamic_grid :
949+ i = to_i32 (rows_ref [grid_idx ])
950+ j = to_i32 (cols_ref [grid_idx ])
951+ else :
952+ i = grid_idx // kv_steps
953+ j = grid_idx % kv_steps
954+ return f (h , i , j )
955+
956+ return index_map
957+
958+ def create_kv_index_map (layout ):
959+ def index_map (h , i , j ):
960+ del i
961+ prefix = () if is_mqa else (_div (h , q_heads_per_kv_head ),)
962+ return from_head_minor ((* prefix , j , 0 ), layout )
963+
964+ return index_map
965+
966+ q_index_map = unravel (lambda h , i , j : from_head_minor ((h , i , 0 ), q_layout ))
967+ out_index_map = unravel (lambda h , i , j : (h , i , 0 ))
968+ k_index_map = unravel (create_kv_index_map (k_layout ))
969+ v_index_map = unravel (create_kv_index_map (v_layout ))
970+
971+ def mask_index_map (h , grid_idx , rows_ref , cols_ref , mask_next_ref = None , * _ ):
972+ del h , rows_ref , cols_ref
973+ next_m = to_i32 (mask_next_ref [grid_idx ])
974+ return next_m , 0 , 0
975+
976+ q_segment_ids_index_map = unravel (lambda h , i , j : (i , 0 ))
977+ kv_segment_ids_index_map = unravel (lambda h , i , j : (0 , j ))
978+
979+ in_specs = [
980+ pl .BlockSpec (from_head_minor ((None , bq , head_dim_qk ), q_layout ), q_index_map ),
981+ pl .BlockSpec (
982+ from_head_minor ((bkv , head_dim_qk ) if is_mqa else (None , bkv , head_dim_qk ), k_layout ),
983+ k_index_map ,
984+ ),
985+ pl .BlockSpec (
986+ from_head_minor ((bkv , head_dim_v ) if is_mqa else (None , bkv , head_dim_v ), v_layout ),
987+ v_index_map ,
988+ ),
989+ ]
990+ if segment_ids is not None :
991+ in_specs += [
992+ pl .BlockSpec ((bq , NUM_LANES ), q_segment_ids_index_map ),
993+ pl .BlockSpec ((NUM_SUBLANES , bkv ), kv_segment_ids_index_map ),
994+ ]
995+ q_segment_ids = jax .lax .broadcast_in_dim (segment_ids .q , (q_seq_len , NUM_LANES ), (0 ,))
996+ kv_segment_ids = jax .lax .broadcast_in_dim (segment_ids .kv , (NUM_SUBLANES , kv_seq_len ), (1 ,))
997+ else :
998+ in_specs += [None , None ]
999+ q_segment_ids = kv_segment_ids = None
1000+
1001+ if sinks is not None :
1002+ assert sinks .shape == (num_q_heads ,), f"{ sinks .shape = } != { num_q_heads = } "
1003+ in_specs += [
1004+ pl .BlockSpec (
1005+ (NUM_SUBLANES , num_q_heads ),
1006+ lambda h , i , j , * _ : (0 , 0 ),
1007+ memory_space = pltpu .SMEM ,
1008+ )
1009+ ]
1010+ sinks = jnp .broadcast_to (sinks .astype (jnp .float32 )[None , :], (NUM_SUBLANES , num_q_heads ))
1011+ else :
1012+ in_specs += [None ]
1013+
1014+ if mask_info .partial_mask_blocks is not None :
1015+ in_specs .append (pl .BlockSpec ((None , bq , bkv ), mask_index_map ))
1016+ else :
1017+ in_specs .append (None )
1018+
1019+ assert mask_info .partial_mask_blocks is None or mask_info .q_sequence is None
1020+
1021+ if mask_info .q_sequence is not None :
1022+ q_sequence = jax .lax .broadcast_in_dim (mask_info .q_sequence , (q_seq_len , NUM_LANES ), (0 ,))
1023+ in_specs .append (pl .BlockSpec ((bq , NUM_LANES ), q_segment_ids_index_map ))
1024+ else :
1025+ q_sequence = None
1026+ in_specs .append (None )
1027+
1028+ if max_logit_value is not None :
1029+ max_logit_value = jnp .broadcast_to (
1030+ max_logit_value .astype (jnp .float32 )[None , :],
1031+ (NUM_SUBLANES , num_q_heads ),
1032+ )
1033+ in_specs += [
1034+ pl .BlockSpec (
1035+ (NUM_SUBLANES , num_q_heads ),
1036+ lambda * _ : (0 , 0 ),
1037+ memory_space = pltpu .SMEM ,
1038+ )
1039+ ]
1040+ else :
1041+ in_specs .append (None )
1042+
1043+ logsumexp_index_map = unravel (lambda h , i , j , * _ : (h , i , 0 ))
1044+ out_shapes = [
1045+ jax .ShapeDtypeStruct ((num_q_heads , q_seq_len , head_dim_v ), jnp .float32 ),
1046+ None ,
1047+ jax .ShapeDtypeStruct ((num_q_heads , q_seq_len , NUM_LANES ), jnp .float32 ),
1048+ jax .ShapeDtypeStruct ((num_q_heads , q_seq_len , NUM_LANES ), jnp .float32 ),
1049+ ]
1050+ out_specs = [
1051+ pl .BlockSpec ((None , bq , head_dim_v ), out_index_map ),
1052+ None ,
1053+ pl .BlockSpec ((None , bq , NUM_LANES ), logsumexp_index_map ),
1054+ pl .BlockSpec ((None , bq , NUM_LANES ), logsumexp_index_map ),
1055+ ]
1056+
1057+ kernel_name = f"{ get_kernel_name (is_mqa = is_mqa , save_residuals = True , is_segmented = segment_ids is not None , phase = 'fwd' )} _ring_raw"
1058+ metadata = {"xprof_metadata" : json .dumps (dataclasses .asdict (config ))}
1059+
1060+ vmem_inputs = [q , k , v , q_segment_ids , kv_segment_ids , mask_info .partial_mask_blocks ]
1061+ def _fwd_cost_estimate (
1062+ q : jax .Array ,
1063+ k : jax .Array ,
1064+ v : jax .Array ,
1065+ q_segment_ids : jax .Array | None ,
1066+ kv_segment_ids : jax .Array | None ,
1067+ partial_mask_blocks : jax .Array | None ,
1068+ out_shapes : list [jax .ShapeDtypeStruct | None ],
1069+ mask_sparsity : float ,
1070+ ) -> pl .CostEstimate :
1071+ num_q_heads , q_seq_len , head_dim_qk = q .shape
1072+ kv_seq_len , head_dim_v = v .shape [- 2 :]
1073+ matmul_flops = 2 * q_seq_len * kv_seq_len * head_dim_qk + 2 * q_seq_len * kv_seq_len * head_dim_v
1074+ total_flops = num_q_heads * matmul_flops * mask_sparsity
1075+ transcendentals = num_q_heads * q_seq_len * kv_seq_len * mask_sparsity
1076+ inputs_ = [q , k , v , q_segment_ids , kv_segment_ids , partial_mask_blocks ]
1077+ input_bytes = sum (map (_bytes , inputs_ ))
1078+ output_bytes = sum (map (_bytes , out_shapes ))
1079+ return pl .CostEstimate (
1080+ flops = int (total_flops ),
1081+ transcendentals = int (transcendentals ),
1082+ bytes_accessed = int (input_bytes + output_bytes ),
1083+ )
1084+
1085+ cost_estimate = config .fwd_cost_estimate or _fwd_cost_estimate (* vmem_inputs , out_shapes , fwd_mask_sparsity )
1086+
1087+ if dynamic_grid :
1088+ num_active_blocks = mask_info .num_active_blocks [0 ]
1089+ grid = (num_q_heads , num_active_blocks )
1090+ is_empty_attention_block = num_active_blocks == 0
1091+ else :
1092+ grid = (num_q_heads , kv_steps * (q_seq_len // bq ))
1093+ is_empty_attention_block = False
1094+
1095+ with jax .named_scope (kernel_name ):
1096+ all_out = pl .pallas_call (
1097+ partial (
1098+ flash_attention_kernel ,
1099+ mask_value = mask_value ,
1100+ kv_steps = kv_steps ,
1101+ bq = bq ,
1102+ bkv = bkv ,
1103+ bkv_compute = bkv_compute ,
1104+ head_dim_v = head_dim_v ,
1105+ fuse_reciprocal = False ,
1106+ config = config ,
1107+ mask_function = mask_function ,
1108+ ),
1109+ grid_spec = pltpu .PrefetchScalarGridSpec (
1110+ num_scalar_prefetch = 6 ,
1111+ in_specs = in_specs ,
1112+ out_specs = out_specs ,
1113+ grid = grid ,
1114+ scratch_shapes = [
1115+ pltpu .VMEM ((bq , NUM_LANES ), jnp .float32 ),
1116+ pltpu .VMEM ((bq , NUM_LANES ), jnp .float32 ),
1117+ pltpu .VMEM ((bq , head_dim_v ), jnp .float32 ),
1118+ ],
1119+ ),
1120+ compiler_params = pltpu .CompilerParams (
1121+ dimension_semantics = ("parallel" , "arbitrary" ),
1122+ flags = {"XLA_TPU_FORCE_LP_LLO_SCHEDULER" : (config .use_experimental_scheduler )},
1123+ ),
1124+ out_shape = out_shapes ,
1125+ name = kernel_name ,
1126+ cost_estimate = cost_estimate ,
1127+ interpret = config .interpret ,
1128+ metadata = metadata ,
1129+ )(
1130+ mask_info .active_rows ,
1131+ mask_info .active_cols ,
1132+ mask_info .mask_next ,
1133+ bounds_start ,
1134+ bounds_end ,
1135+ mask_info .block_mask ,
1136+ q if q_layout == QKVLayout .HEAD_DIM_MINOR else q .mT ,
1137+ k if k_layout == QKVLayout .HEAD_DIM_MINOR else k .mT ,
1138+ v if v_layout == QKVLayout .HEAD_DIM_MINOR else v .mT ,
1139+ q_segment_ids ,
1140+ kv_segment_ids ,
1141+ sinks ,
1142+ mask_info .partial_mask_blocks ,
1143+ q_sequence ,
1144+ max_logit_value ,
1145+ )
1146+ out_linear , _ , l_linear , max_logits = all_out
1147+
1148+ def init_if_empty (x : jax .Array , value : float ) -> jax .Array :
1149+ if not dynamic_grid :
1150+ return x
1151+ return jnp .where (is_empty_attention_block , value , x )
1152+
1153+ out_linear = init_if_empty (out_linear , 0.0 )
1154+ assert l_linear is not None
1155+ assert max_logits is not None
1156+ l_linear = init_if_empty (l_linear [..., 0 ], 0.0 )
1157+ max_logits = init_if_empty (max_logits [..., 0 ], mask_value )
1158+
1159+ stats = {"l_linear" : l_linear , "max_logits" : max_logits }
1160+ stats = jax .tree .map (jax .lax .stop_gradient , stats )
1161+ return out_linear , stats
1162+
1163+
8621164@partial (
8631165 jax .custom_vjp ,
8641166 nondiff_argnames = (
0 commit comments