Skip to content

Commit b303738

Browse files
authored
update decode_update_mla_metadata_v1 for atom dp attention (#2392)
* update decode_update_mla_metadata_v1 natively_supported logic * edit get_mla_metadata_v1_2_device params.num_heads = num_heads;
1 parent 5b3eb3d commit b303738

2 files changed

Lines changed: 19 additions & 6 deletions

File tree

aiter/ops/attention.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1211,8 +1211,24 @@ def decode_update_mla_metadata_v1(
12111211
assert kv_granularity >= 16
12121212
assert page_size == 1
12131213
# assert not (dtype_q == dtypes.bf16 and dtype_kv == dtypes.bf16 and num_heads_per_head_k == 128), "In this case, use get_mla_metadata_v1 instead"
1214-
natively_supported = (num_heads_per_head_k == 16) or (
1215-
num_heads_per_head_k == 128 and dtype_q == dtypes.fp8 and dtype_kv == dtypes.fp8
1214+
q_is_fp8 = dtype_q == dtypes.fp8
1215+
kv_is_fp8 = dtype_kv == dtypes.fp8
1216+
arch_id = get_gfx()
1217+
natively_supported = (
1218+
(num_heads_per_head_k == 16)
1219+
or (
1220+
arch_id == "gfx950"
1221+
and num_heads_per_head_k == 32
1222+
and q_is_fp8
1223+
and kv_is_fp8
1224+
and max_seqlen_qo == 4
1225+
)
1226+
or (
1227+
arch_id == "gfx942"
1228+
and num_heads_per_head_k == 128
1229+
and q_is_fp8
1230+
and kv_is_fp8
1231+
)
12161232
)
12171233
cu_num = work_indptr.shape[0] - 1
12181234
tile_reduce_cnt = reduce_indptr.shape[0] - 1

csrc/kernels/mla/metadata/v1_2_device.cuh

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -478,8 +478,6 @@ void get_mla_metadata_v1_2_device(const torch::Tensor& seqlens_qo_indptr, // [ba
478478
? num_clusters
479479
: min(num_clusters, max_split_per_batch * num_batches);
480480

481-
const bool fold_to_qh16 = !natively_supported && q_is_fp8 && kv_is_fp8;
482-
483481
MlaMetadataV1KernelParameter params = {};
484482
params.p_work_metadata_ptrs = work_metadata_ptrs.data_ptr<uint64_t>();
485483
params.p_work_indptr = work_indptr.data_ptr<int32_t>();
@@ -491,8 +489,7 @@ void get_mla_metadata_v1_2_device(const torch::Tensor& seqlens_qo_indptr, // [ba
491489
params.p_seqlens_kv_indptr = seqlens_kv_indptr.data_ptr<int32_t>();
492490
params.p_kv_last_page_lens = kv_last_page_lens.data_ptr<int32_t>();
493491
params.num_batches = num_batches;
494-
params.num_heads = fold_to_qh16 ? num_heads
495-
: num_heads_k * num_heads_per_head_k;
492+
params.num_heads = num_heads;
496493
params.num_cu = num_clusters;
497494
params.num_splits = num_splits;
498495
params.reduce_indptr_size = reduce_indptr.size(0);

0 commit comments

Comments
 (0)