Skip to content

Commit 0fa8678

Browse files
committed
fixing kernel precision
1 parent 768416a commit 0fa8678

2 files changed

Lines changed: 315 additions & 9 deletions

File tree

src/maxdiffusion/kernels/splash_attention/ring_attention_kernel.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
SplashCustomReturnType = base.SplashCustomReturnType
3939
MaskFunctionType = splash_kernel.MaskFunctionType
4040
_splash_attention_forward = splash_kernel._splash_attention_forward # pylint: disable=protected-access
41+
_splash_attention_forward_ring_raw = splash_kernel._splash_attention_forward_ring_raw # pylint: disable=protected-access
4142
_splash_attention_bwd = splash_kernel._splash_attention_bwd # pylint: disable=protected-access
4243

4344

@@ -104,15 +105,17 @@ def _ring_attention_forward(
104105
# permute_idx 1, offset (0-1) % 4 = 3, etc.
105106

106107
splash_fwd_partial = partial(
107-
_splash_attention_forward,
108-
save_residuals=True,
108+
_splash_attention_forward_ring_raw,
109109
mask_value=mask_value,
110110
is_mqa=is_mqa,
111111
config=config,
112112
mask_function=mask_function,
113113
fwd_mask_sparsity=fwd_mask_sparsity,
114114
max_logit_value=None,
115115
)
116+
117+
exp_fn = jnp.exp2 if config.use_base2_exp else jnp.exp
118+
log_fn = jnp.log2 if config.use_base2_exp else jnp.log
116119
# Initial accumulator values
117120
o_shape = q.shape
118121
o_init = jnp.zeros(o_shape, dtype=jnp.float32)
@@ -141,13 +144,12 @@ def body(carry, i: int) -> tuple[tuple[jax.Array, jax.Array, jax.Array, jax.Arra
141144
segment_ids=segment_ids_current,
142145
sinks=sinks,
143146
)
144-
lse_curr = stats["logsumexp"]
145-
m_curr = stats["max_logits"]
146-
l_curr = jnp.exp(lse_curr - m_curr)
147-
o_curr = out_curr.astype(jnp.float32) * l_curr[..., None]
147+
m_curr = stats["max_logits"].astype(jnp.float32)
148+
l_curr = stats["l_linear"].astype(jnp.float32)
149+
o_curr = out_curr.astype(jnp.float32)
148150
m_next = jnp.maximum(m_prev, m_curr)
149-
alpha = jnp.exp(m_prev - m_next)
150-
beta = jnp.exp(m_curr - m_next)
151+
alpha = exp_fn(m_prev - m_next)
152+
beta = exp_fn(m_curr - m_next)
151153
l_next = alpha * l_prev + beta * l_curr
152154
o_next = alpha[..., None] * o_prev + beta[..., None] * o_curr
153155
return (m_next, l_next, o_next, k_next, v_next, segment_ids_next), None
@@ -167,7 +169,7 @@ def body(carry, i: int) -> tuple[tuple[jax.Array, jax.Array, jax.Array, jax.Arra
167169
l_inv = jnp.where(l_final == 0.0, 0.0, 1.0 / l_final)
168170
out = (o_final * l_inv[..., None]).astype(q.dtype)
169171
# Final logsumexp for residuals
170-
lse = jnp.log(l_final) + m_final
172+
lse = log_fn(l_final) + m_final
171173
lse = jnp.where(l_final == 0.0, mask_value, lse)
172174

173175
return out, (lse, m_final)
@@ -596,13 +598,15 @@ def _resolve_spec(x):
596598
mask_info_specs,
597599
mask_info_specs if self.dkv_mask_info is not None else None,
598600
ring_axis=self.ring_axis,
601+
rotate_segment_ids=self.rotate_segment_ids,
599602
**self.kwargs,
600603
)
601604

602605
def tree_flatten(self):
603606
children = (self.fwd_mask_info, self.dkv_mask_info)
604607
aux_data = self.kwargs.copy()
605608
aux_data["ring_axis"] = self.ring_axis
609+
aux_data["rotate_segment_ids"] = self.rotate_segment_ids
606610
return children, aux_data
607611

608612
@classmethod

src/maxdiffusion/kernels/splash_attention/splash_attention_kernel.py

Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -859,6 +859,308 @@ def init_if_empty(x: jax.Array, value: float) -> jax.Array:
859859
return out
860860

861861

862+
def _splash_attention_forward_ring_raw(
863+
mask_info: MaskInfo,
864+
q: jax.Array,
865+
k: jax.Array,
866+
v: jax.Array,
867+
segment_ids: base.SegmentIds | None,
868+
sinks: jax.Array | None,
869+
mask_value: float,
870+
is_mqa: bool,
871+
config: SplashConfig,
872+
mask_function: MaskFunctionType | None,
873+
fwd_mask_sparsity: float,
874+
max_logit_value: jax.Array | None = None,
875+
) -> tuple[jax.Array, dict[str, jax.Array]]:
876+
"""Ring-specific forward path that returns pre-reciprocal fp32 accumulators.
877+
878+
Unlike `_splash_attention_forward`, this helper is intended for ring attention
879+
merging and returns the raw fp32 numerator (`out_linear`) together with the
880+
linear softmax denominator (`l_linear`) and per-row max logits (`max_logits`).
881+
This lets the outer ring kernel merge shard contributions and normalize only
882+
once at the very end.
883+
"""
884+
num_q_heads, q_seq_len, head_dim_qk = q.shape
885+
head_dim_v = v.shape[-1]
886+
bq, bkv = config.block_q, config.block_kv
887+
bkv_compute = config.block_kv_compute
888+
bounds_start, bounds_end = mask_info_lib.find_bounds(mask_info.active_rows)
889+
890+
if is_mqa:
891+
expected_kv_rank = 2
892+
num_kv_heads = 1
893+
else:
894+
expected_kv_rank = 3
895+
num_kv_heads = k.shape[0]
896+
897+
if len(k.shape) != expected_kv_rank:
898+
raise ValueError(
899+
f"Expected {expected_kv_rank}-dim 'key' tensor for MQA. Instead got a {len(k.shape)}-dim one."
900+
)
901+
902+
if k.shape[-1] != head_dim_qk:
903+
raise ValueError(f"Expected 'key' head dimension to be: {head_dim_qk}. Instead got: {k.shape[-1]}.")
904+
905+
if not is_mqa and num_q_heads % num_kv_heads != 0:
906+
raise ValueError(
907+
f"In MHA, expected number of 'key' heads ({num_kv_heads}) to be a multiple of the number of "
908+
f"'query' heads ({num_q_heads})"
909+
)
910+
911+
if k.shape[:-1] != v.shape[:-1]:
912+
raise ValueError(f"Expected 'key' {k.shape} and 'value' {v.shape} to have the same leading dimensions.")
913+
914+
if bkv % bkv_compute:
915+
raise ValueError(f"{bkv=} must be a multiple of {bkv_compute=}.")
916+
if bkv_compute % NUM_LANES:
917+
raise ValueError(f"{bkv_compute=} must be a multiple of {NUM_LANES}.")
918+
919+
kv_seq_len = k.shape[-2]
920+
kv_steps = kv_seq_len // bkv
921+
q_heads_per_kv_head = num_q_heads // num_kv_heads
922+
dynamic_grid = mask_info.active_rows is not None
923+
924+
if segment_ids is not None:
925+
assert isinstance(segment_ids.q, jax.Array)
926+
assert isinstance(segment_ids.kv, jax.Array)
927+
if segment_ids.q.shape != (q_seq_len,):
928+
raise ValueError(f"Invalid shape for q segment_ids: {segment_ids.q.shape}. Expected: {(q_seq_len,)}")
929+
if segment_ids.kv.shape != (kv_seq_len,):
930+
raise ValueError(f"Invalid shape for kv segment_ids: {segment_ids.kv.shape}. Expected: {(kv_seq_len,)}")
931+
932+
if config.max_logit_const is not None and max_logit_value is not None:
933+
raise ValueError(f"Only one of {config.max_logit_const=} and {max_logit_value=} can be set.")
934+
if max_logit_value is not None:
935+
if max_logit_value.shape not in ((), (1,), (num_q_heads,)):
936+
raise ValueError(
937+
"max_logit_value should be a 0,1-dim jax.Array of shape (), (1,) or "
938+
f"({num_q_heads=},) but got {jax.typeof(max_logit_value)}"
939+
)
940+
max_logit_value = jnp.broadcast_to(jnp.atleast_1d(max_logit_value), (num_q_heads,))
941+
942+
q_layout = config.q_layout
943+
k_layout = config.k_layout
944+
v_layout = config.v_layout
945+
946+
def unravel(f):
947+
def index_map(h, grid_idx, rows_ref, cols_ref, *_):
948+
if dynamic_grid:
949+
i = to_i32(rows_ref[grid_idx])
950+
j = to_i32(cols_ref[grid_idx])
951+
else:
952+
i = grid_idx // kv_steps
953+
j = grid_idx % kv_steps
954+
return f(h, i, j)
955+
956+
return index_map
957+
958+
def create_kv_index_map(layout):
959+
def index_map(h, i, j):
960+
del i
961+
prefix = () if is_mqa else (_div(h, q_heads_per_kv_head),)
962+
return from_head_minor((*prefix, j, 0), layout)
963+
964+
return index_map
965+
966+
q_index_map = unravel(lambda h, i, j: from_head_minor((h, i, 0), q_layout))
967+
out_index_map = unravel(lambda h, i, j: (h, i, 0))
968+
k_index_map = unravel(create_kv_index_map(k_layout))
969+
v_index_map = unravel(create_kv_index_map(v_layout))
970+
971+
def mask_index_map(h, grid_idx, rows_ref, cols_ref, mask_next_ref=None, *_):
972+
del h, rows_ref, cols_ref
973+
next_m = to_i32(mask_next_ref[grid_idx])
974+
return next_m, 0, 0
975+
976+
q_segment_ids_index_map = unravel(lambda h, i, j: (i, 0))
977+
kv_segment_ids_index_map = unravel(lambda h, i, j: (0, j))
978+
979+
in_specs = [
980+
pl.BlockSpec(from_head_minor((None, bq, head_dim_qk), q_layout), q_index_map),
981+
pl.BlockSpec(
982+
from_head_minor((bkv, head_dim_qk) if is_mqa else (None, bkv, head_dim_qk), k_layout),
983+
k_index_map,
984+
),
985+
pl.BlockSpec(
986+
from_head_minor((bkv, head_dim_v) if is_mqa else (None, bkv, head_dim_v), v_layout),
987+
v_index_map,
988+
),
989+
]
990+
if segment_ids is not None:
991+
in_specs += [
992+
pl.BlockSpec((bq, NUM_LANES), q_segment_ids_index_map),
993+
pl.BlockSpec((NUM_SUBLANES, bkv), kv_segment_ids_index_map),
994+
]
995+
q_segment_ids = jax.lax.broadcast_in_dim(segment_ids.q, (q_seq_len, NUM_LANES), (0,))
996+
kv_segment_ids = jax.lax.broadcast_in_dim(segment_ids.kv, (NUM_SUBLANES, kv_seq_len), (1,))
997+
else:
998+
in_specs += [None, None]
999+
q_segment_ids = kv_segment_ids = None
1000+
1001+
if sinks is not None:
1002+
assert sinks.shape == (num_q_heads,), f"{sinks.shape=} != {num_q_heads=}"
1003+
in_specs += [
1004+
pl.BlockSpec(
1005+
(NUM_SUBLANES, num_q_heads),
1006+
lambda h, i, j, *_: (0, 0),
1007+
memory_space=pltpu.SMEM,
1008+
)
1009+
]
1010+
sinks = jnp.broadcast_to(sinks.astype(jnp.float32)[None, :], (NUM_SUBLANES, num_q_heads))
1011+
else:
1012+
in_specs += [None]
1013+
1014+
if mask_info.partial_mask_blocks is not None:
1015+
in_specs.append(pl.BlockSpec((None, bq, bkv), mask_index_map))
1016+
else:
1017+
in_specs.append(None)
1018+
1019+
assert mask_info.partial_mask_blocks is None or mask_info.q_sequence is None
1020+
1021+
if mask_info.q_sequence is not None:
1022+
q_sequence = jax.lax.broadcast_in_dim(mask_info.q_sequence, (q_seq_len, NUM_LANES), (0,))
1023+
in_specs.append(pl.BlockSpec((bq, NUM_LANES), q_segment_ids_index_map))
1024+
else:
1025+
q_sequence = None
1026+
in_specs.append(None)
1027+
1028+
if max_logit_value is not None:
1029+
max_logit_value = jnp.broadcast_to(
1030+
max_logit_value.astype(jnp.float32)[None, :],
1031+
(NUM_SUBLANES, num_q_heads),
1032+
)
1033+
in_specs += [
1034+
pl.BlockSpec(
1035+
(NUM_SUBLANES, num_q_heads),
1036+
lambda *_: (0, 0),
1037+
memory_space=pltpu.SMEM,
1038+
)
1039+
]
1040+
else:
1041+
in_specs.append(None)
1042+
1043+
logsumexp_index_map = unravel(lambda h, i, j, *_: (h, i, 0))
1044+
out_shapes = [
1045+
jax.ShapeDtypeStruct((num_q_heads, q_seq_len, head_dim_v), jnp.float32),
1046+
None,
1047+
jax.ShapeDtypeStruct((num_q_heads, q_seq_len, NUM_LANES), jnp.float32),
1048+
jax.ShapeDtypeStruct((num_q_heads, q_seq_len, NUM_LANES), jnp.float32),
1049+
]
1050+
out_specs = [
1051+
pl.BlockSpec((None, bq, head_dim_v), out_index_map),
1052+
None,
1053+
pl.BlockSpec((None, bq, NUM_LANES), logsumexp_index_map),
1054+
pl.BlockSpec((None, bq, NUM_LANES), logsumexp_index_map),
1055+
]
1056+
1057+
kernel_name = f"{get_kernel_name(is_mqa=is_mqa, save_residuals=True, is_segmented=segment_ids is not None, phase='fwd')}_ring_raw"
1058+
metadata = {"xprof_metadata": json.dumps(dataclasses.asdict(config))}
1059+
1060+
vmem_inputs = [q, k, v, q_segment_ids, kv_segment_ids, mask_info.partial_mask_blocks]
1061+
def _fwd_cost_estimate(
1062+
q: jax.Array,
1063+
k: jax.Array,
1064+
v: jax.Array,
1065+
q_segment_ids: jax.Array | None,
1066+
kv_segment_ids: jax.Array | None,
1067+
partial_mask_blocks: jax.Array | None,
1068+
out_shapes: list[jax.ShapeDtypeStruct | None],
1069+
mask_sparsity: float,
1070+
) -> pl.CostEstimate:
1071+
num_q_heads, q_seq_len, head_dim_qk = q.shape
1072+
kv_seq_len, head_dim_v = v.shape[-2:]
1073+
matmul_flops = 2 * q_seq_len * kv_seq_len * head_dim_qk + 2 * q_seq_len * kv_seq_len * head_dim_v
1074+
total_flops = num_q_heads * matmul_flops * mask_sparsity
1075+
transcendentals = num_q_heads * q_seq_len * kv_seq_len * mask_sparsity
1076+
inputs_ = [q, k, v, q_segment_ids, kv_segment_ids, partial_mask_blocks]
1077+
input_bytes = sum(map(_bytes, inputs_))
1078+
output_bytes = sum(map(_bytes, out_shapes))
1079+
return pl.CostEstimate(
1080+
flops=int(total_flops),
1081+
transcendentals=int(transcendentals),
1082+
bytes_accessed=int(input_bytes + output_bytes),
1083+
)
1084+
1085+
cost_estimate = config.fwd_cost_estimate or _fwd_cost_estimate(*vmem_inputs, out_shapes, fwd_mask_sparsity)
1086+
1087+
if dynamic_grid:
1088+
num_active_blocks = mask_info.num_active_blocks[0]
1089+
grid = (num_q_heads, num_active_blocks)
1090+
is_empty_attention_block = num_active_blocks == 0
1091+
else:
1092+
grid = (num_q_heads, kv_steps * (q_seq_len // bq))
1093+
is_empty_attention_block = False
1094+
1095+
with jax.named_scope(kernel_name):
1096+
all_out = pl.pallas_call(
1097+
partial(
1098+
flash_attention_kernel,
1099+
mask_value=mask_value,
1100+
kv_steps=kv_steps,
1101+
bq=bq,
1102+
bkv=bkv,
1103+
bkv_compute=bkv_compute,
1104+
head_dim_v=head_dim_v,
1105+
fuse_reciprocal=False,
1106+
config=config,
1107+
mask_function=mask_function,
1108+
),
1109+
grid_spec=pltpu.PrefetchScalarGridSpec(
1110+
num_scalar_prefetch=6,
1111+
in_specs=in_specs,
1112+
out_specs=out_specs,
1113+
grid=grid,
1114+
scratch_shapes=[
1115+
pltpu.VMEM((bq, NUM_LANES), jnp.float32),
1116+
pltpu.VMEM((bq, NUM_LANES), jnp.float32),
1117+
pltpu.VMEM((bq, head_dim_v), jnp.float32),
1118+
],
1119+
),
1120+
compiler_params=pltpu.CompilerParams(
1121+
dimension_semantics=("parallel", "arbitrary"),
1122+
flags={"XLA_TPU_FORCE_LP_LLO_SCHEDULER": (config.use_experimental_scheduler)},
1123+
),
1124+
out_shape=out_shapes,
1125+
name=kernel_name,
1126+
cost_estimate=cost_estimate,
1127+
interpret=config.interpret,
1128+
metadata=metadata,
1129+
)(
1130+
mask_info.active_rows,
1131+
mask_info.active_cols,
1132+
mask_info.mask_next,
1133+
bounds_start,
1134+
bounds_end,
1135+
mask_info.block_mask,
1136+
q if q_layout == QKVLayout.HEAD_DIM_MINOR else q.mT,
1137+
k if k_layout == QKVLayout.HEAD_DIM_MINOR else k.mT,
1138+
v if v_layout == QKVLayout.HEAD_DIM_MINOR else v.mT,
1139+
q_segment_ids,
1140+
kv_segment_ids,
1141+
sinks,
1142+
mask_info.partial_mask_blocks,
1143+
q_sequence,
1144+
max_logit_value,
1145+
)
1146+
out_linear, _, l_linear, max_logits = all_out
1147+
1148+
def init_if_empty(x: jax.Array, value: float) -> jax.Array:
1149+
if not dynamic_grid:
1150+
return x
1151+
return jnp.where(is_empty_attention_block, value, x)
1152+
1153+
out_linear = init_if_empty(out_linear, 0.0)
1154+
assert l_linear is not None
1155+
assert max_logits is not None
1156+
l_linear = init_if_empty(l_linear[..., 0], 0.0)
1157+
max_logits = init_if_empty(max_logits[..., 0], mask_value)
1158+
1159+
stats = {"l_linear": l_linear, "max_logits": max_logits}
1160+
stats = jax.tree.map(jax.lax.stop_gradient, stats)
1161+
return out_linear, stats
1162+
1163+
8621164
@partial(
8631165
jax.custom_vjp,
8641166
nondiff_argnames=(

0 commit comments

Comments
 (0)