ONNX Attention CUDA: MEA decode, unfused softcap, spec-correct ordering, and NaN fix#27992
ONNX Attention CUDA: MEA decode, unfused softcap, spec-correct ordering, and NaN fix#27992titaiwangms wants to merge 24 commits intomicrosoft:mainfrom
Conversation
Enable Memory Efficient Attention (CUTLASS FMHA) to handle decode with past_key/past_value via internal KV cache concatenation. Changes: - Relax MEA eligibility: allow past_key when head_size == v_head_size (was: past_key == nullptr) - Add decode path in RunMemoryEfficientAttention: 1. LaunchFillInt32 for uniform past_seqlens (ONNX past_key has fixed shape) 2. Transpose K/V to BSNH if 4D BNSH input 3. LaunchConcatNewToPastKV to fuse past + new into present (BNSH) 4. Point MEA at present buffers, track kv_is_bsnh layout - Update LaunchUngroup and MEA params to use kv_is_bsnh - Skip present_key/value population when already done by concat - Update dispatch comments and error messages Design: Uses uniform past_seqlens + additive bias for masks. No new kernels needed. No memset needed (all present positions written). Bool masks handled correctly via ConvertAttnMaskToBias (no NaN risk). Closes gaps: GQA+decode+mask, GQA+decode+h>256, softcap+fp32+decode, softcap+decode+mask (4 of 9 gaps from issue microsoft#27880). Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6] Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Add comprehensive test coverage for the MEA decode path with past_kv: common.py: - Add v_head_size field to AttentionConfig (defaults to head_size) - Update all tensor shape logic to use effective_v_head_size for V and output tensors (graph inputs/outputs, prompt/past bindings) test_gqa.py: - Add GQA+MEA decode tests (fp16): test_gqa_past_memory_efficient - Add GQA+MEA decode tests (bf16): TestONNXAttentionMemoryEfficientGQABF16 - Re-enable GQA+padding mask decode tests via MEA (was skipped, now works) - Add GQA+4D BNSH decode via MEA: TestONNXAttentionGQA4DBNSHMEA - Add GQA+float mask decode via MEA: TestONNXAttentionMemoryEfficientGQAFloatMaskDecode test_mha.py: - Add MHA+MEA decode tests: TestONNXAttentionMHAPastMEA - Add asymmetric head_size regression test: TestONNXAttentionMHAAsymmetricHeadSize (verifies MEA fallback to unfused when head_size != v_head_size) Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6] Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…tics - Rename TestONNXAttentionPaddingMaskGQA to TestONNXAttentionPaddingMaskMEAGQA to reflect that it now tests MEA (not Flash) for decode with past_key - Add print_diff_statistics for present_k/v assertions in asymmetric head_size regression test for consistency with other tests Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6] Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Replace `!(past_key != nullptr && head_size != v_head_size)` with equivalent `(past_key == nullptr || head_size == v_head_size)`. Same logic, reads more naturally and is safer when adding future conditions. Per readability and code review feedback. Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6] Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Change past_kv_sequence_length from 32 to 31 in tests that use attn_mask with MEA decode. MEA requires total_sequence_length % 4 == 0 when an attention mask is present (CUTLASS kMinimumAlignment=4). With past=32 + new=1 = 33, 33%4=1 → MEA disabled → NOT_IMPLEMENTED. With past=31 + new=1 = 32, 32%4=0 → MEA eligible → test exercises MEA. Affected tests: - gqa_past_padding_test_cases(): seqs (1,32) → (1,31) - test_gqa_past_float_mask_4d: past=32→31, total=33→32 Per code review feedback from @9c209479. Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6] Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Log which attention runner is selected at VERBOSE level in ComputeInternal: - Flash Attention: batch, q_seq, total_seq, past - Memory Efficient Attention: batch, q_seq, total_seq, past, mask - Unfused Attention: batch, q_seq, total_seq, past, mask Helps verify runner targeting at runtime for debugging and testing. Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6] Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Add utility function to compute past_kv_sequence_length values that satisfy CUTLASS MEA's bias alignment requirement (total_seq % 4 == 0). Prevents future test authors from accidentally picking sequence lengths that disable MEA when attn_mask is present, causing tests to silently fall through to unfused instead of exercising MEA. Per architect suggestion from @a2065222. Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6] Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Add ORT_ENFORCE(nonpad_kv_seqlen == nullptr) at the start of the MEA decode path. nonpad_kv_seqlen (external cache, opset 24) and past_key (internal cache) are mutually exclusive — enforced at validation but adding a defensive check here too. Per code review feedback. Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6] Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Prevent negative results when past_seq is too small for the alignment adjustment. Per code review edge case feedback. Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6] Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Add element-wise softcap kernel (score = softcap * tanh(score / softcap)) applied after Q*K' GEMM and before softmax in the unfused attention path. Changes: - attention_impl.cu: Add ApplySoftcapKernel and ApplySoftcap template function. Insert call after GEMM, before ComputeSoftmax, when parameters.softcap > 0. Works for all dtypes (fp16/bf16/fp32). - attention.cc: Remove unfused softcap rejection, pass softcap to contribop_parameters, update comments. - onnx_backend_test_series_filters.jsonc: Remove 2 softcap+diff_head_size filters (unfused now handles softcap). Keep qk_matmul+softcap filters (qk_matmul modes still not implemented). - attention_op_test.cc: Enable CUDA for softcap tests (was disabled because unfused didn't support softcap). - test_mha.py: Add TestONNXAttentionMHAUnfusedSoftcap with prompt/decode tests for fp16 and fp32. Closes gap R4 from issue microsoft#27880 (softcap + h!=v_h decode). Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6] Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Move the softcap kernel and wrapper function definition above UnfusedAttention so readers encounter the definition before its call site. Per readability review feedback. Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6] Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
- Add CUDA_RETURN_IF_ERROR(cudaGetLastError()) after kernel launch to catch silent launch failures. - Replace total_elements <= 0 early return with ORT_ENFORCE since zero elements indicates a caller bug, not a valid no-op. Per critical and readability review feedback. Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6] Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Add `float softcap = 0.0f` to the base contrib::AttentionParameters struct so the unfused path can access it via contribop_parameters. The field was only in derived structs (GroupQueryAttentionParameters, PagedAttentionParameters) — the base struct lacked it, which would cause a compile error. Also update RunUnfusedAttention docstring to reflect that softcap is now supported via the ApplySoftcap kernel. Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6] Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Python tests (test_mha.py): - TestONNXAttentionMHAPastMEAFP32: MHA decode via MEA with fp32 - TestONNXAttentionMHAPastMEABoolMask: MHA decode via MEA with bool mask - TestONNXAttentionMHAPastMEAFloatMask: MHA decode via MEA with float 4D mask All use past_kv_sequence_length=31 for CUTLASS alignment. C++ test (attention_op_test.cc): - Attention4DMEADecodeFloat16: MEA decode with fp16, forces MEA via ScopedEnvironmentVariables disabling Flash. Uses same one-hot data as Attention4DDefaultBasic for predictable expected output. Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6] Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1. Update bool mask filter comments: note that bool masks now work via ConvertAttnMaskToBias in both unfused and MEA, but these ONNX backend tests need re-verification before removing the filter. 2. Fix GQA filter comment: was 'does not support fp16 and 4d QKV' which is misleading (fp16 GQA works via Flash/MEA). Actual blockers: fp32 GQA (no LaunchUngroup fp32 template) and 4D BNSH GQA format. 3. Standardize fp32 GQA filter comments for clarity. Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6] Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
This PR extends ONNX Runtime’s CUDA Attention operator to support softcap in the unfused CUDA path and adds Memory Efficient Attention (CUTLASS FMHA) decode support (past/present KV cache), along with expanded ONNX/Python/C++ tests and updated backend test filters.
Changes:
- Add
softcapplumbing (AttentionParameters.softcap) and apply softcap to unfused attention logits via a CUDA kernel before softmax. - Implement MEA decode by concatenating past+new KV into the present buffer (via
LaunchConcatNewToPastKV) and updating kernel selection/verbosity logs. - Add/adjust Python and C++ tests for MEA decode, unfused softcap, bool/float masks, and asymmetric head-size fallback; update ONNX backend filters accordingly.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/core/providers/cuda/llm/attention.cc | Adds MEA decode path (concat KV into present), propagates softcap to unfused, and adds verbose kernel-selection logging. |
| onnxruntime/contrib_ops/cuda/bert/attention_impl.cu | Implements ApplySoftcap CUDA kernel and applies it in unfused attention before softmax. |
| onnxruntime/contrib_ops/cpu/bert/attention_parameters.h | Adds softcap to AttentionParameters. |
| onnxruntime/test/providers/cpu/llm/attention_op_test.cc | Enables CUDA softcap tests and adds a CUDA MEA decode regression test (forced via env var). |
| onnxruntime/test/python/transformers/test_onnx_attention/common.py | Adds v_head_size support to graph IO shapes/bindings; adds MEA alignment helper for decode+mask tests. |
| onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py | Adds MEA decode tests (fp16/fp32, bool/float masks), unfused softcap tests, and asymmetric head-size fallback regression test. |
| onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py | Re-enables/extends GQA MEA decode tests (including bf16) and adjusts padding-mask cases for MEA alignment. |
| onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc | Updates exclusions/comments for attention backend tests based on new softcap/decode behavior. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
1. attention.cc: Replace ORT_ENFORCE for present_key/present_value with scratch buffer allocation when outputs are nullptr. MEA decode now works even when present outputs are not requested. Use ORT_RETURN_IF_NOT for user-facing validation (past_value, nonpad_kv_seqlen, head_size). 2. attention_impl.cu: Replace ORT_ENFORCE(total_elements > 0) with early return for zero elements, since q_sequence_length=0 is valid. Per Copilot review on PR microsoft#27992. Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6] Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
titaiwangms
left a comment
There was a problem hiding this comment.
Probably no but check whether we are not breaking graph at any point. For example, #27484
Also, does nonpad_kv_seqlens paths totally unrelated to gap table? I don't see it's mentioned at all.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 8 out of 8 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
ONNX spec defines: QK → add mask/bias → softcap → softmax. Flash/MEA kernels fuse softcap before bias (can't reorder). Changes: - attention.cc: Reject MEA when softcap>0 AND attn_mask!=nullptr (CUTLASS applies softcap before bias, diverging from spec). Updated unfused docstring with spec ordering note. - attention_impl.cu: Add AddAttentionBias kernel for in-place bias addition with 2D broadcasting. Restructure unfused softcap+softmax: when both softcap and bias present, apply bias→softcap→softmax (spec-correct); otherwise keep standard ordering. - test_mha.py: Add test_unfused_softcap_with_mask_prompt_fp16 and test_unfused_softcap_with_mask_decode_fp16 to verify spec ordering. Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6] Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Causal mask is applied in ComputeSoftmax (after softcap), but this is safe: causal positions get -inf, and softcap*tanh(-inf/softcap) = -softcap → 0 after softmax, identical to applying causal before softcap. Per code review feedback. Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6] Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1. attention.cc: Update MEA eligibility docstring to include softcap+mask rejection (CUTLASS ordering diverges from spec). 2. test_mha.py: Add test_mha_unfused_decode_fp32 for unfused decode with fp32 (both Flash and MEA disabled). Linting: lintrunner reports no issues on all changed files. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
MEA NaN fix: - Cap mask_filter_value to -1e+30f in ConvertAttnMaskToBias to prevent CUTLASS kLog2e overflow (lowest() * 1.4427 overflows fp32 to -inf, causing s_prime=0 NaN) - Move present_k/v_scratch to function scope (prevent use-after-free when present outputs are not requested) - Re-enable CUDA for Attention4DAttnMaskBoolAllFalseDecodeWithPast test - Add TODO(titaiwang) for ZeroOutputForFullyMaskedBatches cross-EP consistency Softcap test improvements: - Fix Python reference softcap ordering (was softcap before bias, now bias before softcap per ONNX spec) - Add softcap passthrough to attention_ref() in all 4 parity check functions - Fix missing attn_mask input in parity_check_mha_past for decode+mask configs - All 5 softcap tests now properly validate against reference with softcap applied Comment and documentation fixes: - Fix inaccurate Flash comments (no bias support, Q-only transpose) - Update stale routing comments (decode+mask tests now route to MEA) - Add head_size divisible-by-8 to MEA eligibility docstring - Fix Truen typo in backend test filters - Tighten various comment accuracy across attention.cc and tests Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 8 out of 8 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
ApplySoftcap: add grid-stride loop and cap grid_size to 65535 to prevent overflow for large tensors. AddAttentionBias: same grid-stride loop and grid_size cap. Fix causal masking comment: ComputeSoftmax applies causal mask AFTER softcap (sets masked positions to -inf inside softmax kernel), not before as the old comment incorrectly implied. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Agent-signed-off: Developer (93aa25b8) [claude-opus-4.6] Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Captures reusable knowledge from ONNX Attention work: - Runner dispatch cascade (Flash → MEA → Unfused) - CUTLASS kLog2e overflow and mask_filter_value cap - Bias alignment requirements - Softcap ordering (ONNX spec vs kernel behavior) - Grid-stride loop patterns - Fully-masked batch handling - Test runner targeting with env vars - Cross-EP consistency principles - File location map - Parameter bridge pattern Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 9 out of 9 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Clarify that this skill covers the ONNX domain Attention op, NOT contrib MultiHeadAttention/GroupQueryAttention. Add prominent scope note distinguishing shared infrastructure vs ONNX-specific vs contrib-specific code. Restructure file locations into three categories. Clarify parameter bridge is ONNX-only. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Bug 1: parity_check_mha_past created float additive masks but bound them as TensorProto.BOOL when config.attn_mask_type=='bool'. Fix: create actual boolean tensors for ORT and convert to additive for the PyTorch reference path. Bug 2: Both parity check functions used seqlens=total_seq_len, creating all-zero masks that can't detect ordering bugs between bias→softcap vs softcap→bias. Fix: use 75% seqlens so some positions are actually masked. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Agent-signed-off: Developer (93aa25b8) [claude-opus-4.6] Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
tianleiwu
left a comment
There was a problem hiding this comment.
Thanks for the updates here. I re-checked the current head and the earlier MEA optional-output, softcap zero-size, grid-stride, and mask-test concerns look addressed. I found one remaining correctness issue in the new unfused softcap path: kQK output is copied after the logits have already been softcapped, and in the softcap+bias branch it is not copied at all.
| // When only one of softcap/bias is present, ordering is irrelevant. | ||
| if (has_softcap) { | ||
| int64_t total_elements = static_cast<int64_t>(batches) * sequence_length * total_sequence_length; | ||
| ORT_RETURN_IF_ERROR(ApplySoftcap<T>(data.scratch, parameters.softcap, total_elements, stream)); |
There was a problem hiding this comment.
This changes the semantics of output_qk when qk_matmul_output_mode == kQK and softcap is active. kQK is supposed to expose the scaled Q*K logits, while kQKSoftCap is the separate mode for softcapped logits. In this no-mask branch we call ApplySoftcap before CopyQK, so output_qk contains SoftCap(Q*K) instead of Q*K; in the has_softcap && has_bias branch above, output_qk is not copied at all. Please copy data.scratch to output_qk immediately after the GEMM, before AddAttentionBias or ApplySoftcap mutate it, then continue using data.scratch for the attention computation.
Summary
Adds four improvements to the ONNX Attention CUDA operator (
core/providers/cuda/llm/attention.cc):MEA decode support: Memory Efficient Attention (CUTLASS) now handles the decode path (past_key/present_key) when
head_size == v_head_size, usingLaunchConcatNewToPastKVwith uniform past sequence lengths and additive attention bias for masks.Unfused softcap: The unfused attention path now supports the
softcapattribute via anApplySoftcapCUDA kernel (softcap * tanh(score / softcap)element-wise).Spec-correct softcap+mask ordering: When both
softcapandattn_maskare present, MEA is rejected (CUTLASS applies softcap before bias, diverging from ONNX spec). The unfused path handles this with spec-correct ordering: mask → softcap → softmax.MEA NaN fix for fully-masked batches: Capped
mask_filter_valueto-1e+30finConvertAttnMaskToBiasto prevent CUTLASSkLog2eoverflow (lowest() * 1.4427overflows fp32 →-inf→s_prime=0→ NaN). All-false bool mask tests now pass on CUDA.Softcap Ordering Design Decision
The ONNX spec defines:
QK → scale → add_mask → softcap → softmax.When softcap + attn_mask are both present, MEA is rejected, falling to unfused for spec-correct results. For softcap without mask (the common case), all runners produce identical results.
Changes
Core dispatch (
attention.cc):past_keywhenhead_size == v_head_sizeLaunchFillInt32+LaunchConcatNewToPastKV+ present buffer trackingsoftcap > 0 && attn_mask != nullptrfor spec complianceQkvToContextmask_filter_valueto-1e+30f(kCutlassSafeMaskFilterValue) preventing CUTLASS overflowpresent_k/v_scratchto function scope (prevent use-after-free for optional present outputs)LOGS_DEFAULT(VERBOSE)at all 3 dispatch branches for runner selection diagnosticsUnfused softcap (
attention_impl.cu):ApplySoftcapKernel+AddAttentionBiasKernelfor spec-correct orderingTest infrastructure (
common.py):softcappassthrough toattention_ref()in all 4 parity check functionsattn_maskinput inparity_check_mha_pastfor decode+mask configsv_head_sizesupport andmea_aligned_past_seq()utilityTests: 16+ new methods covering MEA decode (GQA/MHA, fp16/bf16/fp32, bool/float masks, 4D BNSH), unfused softcap (prompt/decode, fp16/fp32, with/without mask), C++ MEA decode, and asymmetric head_size regression.
Gaps Closed (from #27880)
Related