Skip to content

ONNX Attention CUDA: MEA decode, unfused softcap, spec-correct ordering, and NaN fix#27992

Open
titaiwangms wants to merge 24 commits intomicrosoft:mainfrom
titaiwangms:feature/mea-decode-support-v2
Open

ONNX Attention CUDA: MEA decode, unfused softcap, spec-correct ordering, and NaN fix#27992
titaiwangms wants to merge 24 commits intomicrosoft:mainfrom
titaiwangms:feature/mea-decode-support-v2

Conversation

@titaiwangms
Copy link
Copy Markdown
Contributor

@titaiwangms titaiwangms commented Apr 6, 2026

Summary

Adds four improvements to the ONNX Attention CUDA operator (core/providers/cuda/llm/attention.cc):

  1. MEA decode support: Memory Efficient Attention (CUTLASS) now handles the decode path (past_key/present_key) when head_size == v_head_size, using LaunchConcatNewToPastKV with uniform past sequence lengths and additive attention bias for masks.

  2. Unfused softcap: The unfused attention path now supports the softcap attribute via an ApplySoftcap CUDA kernel (softcap * tanh(score / softcap) element-wise).

  3. Spec-correct softcap+mask ordering: When both softcap and attn_mask are present, MEA is rejected (CUTLASS applies softcap before bias, diverging from ONNX spec). The unfused path handles this with spec-correct ordering: mask → softcap → softmax.

  4. MEA NaN fix for fully-masked batches: Capped mask_filter_value to -1e+30f in ConvertAttnMaskToBias to prevent CUTLASS kLog2e overflow (lowest() * 1.4427 overflows fp32 → -infs_prime=0 → NaN). All-false bool mask tests now pass on CUDA.

Softcap Ordering Design Decision

The ONNX spec defines: QK → scale → add_mask → softcap → softmax.

  • Flash/CUTLASS: Apply softcap before bias/mask (Flash Attention 2 / Gemma convention)
  • Unfused (this PR): Applies mask → softcap → softmax (matches ONNX spec)

When softcap + attn_mask are both present, MEA is rejected, falling to unfused for spec-correct results. For softcap without mask (the common case), all runners produce identical results.

Changes

Core dispatch (attention.cc):

  • Relaxed MEA eligibility: accept past_key when head_size == v_head_size
  • Added MEA decode path: LaunchFillInt32 + LaunchConcatNewToPastKV + present buffer tracking
  • MEA rejects softcap > 0 && attn_mask != nullptr for spec compliance
  • Removed unfused softcap rejection; softcap flows to QkvToContext
  • Capped mask_filter_value to -1e+30f (kCutlassSafeMaskFilterValue) preventing CUTLASS overflow
  • Moved present_k/v_scratch to function scope (prevent use-after-free for optional present outputs)
  • Added LOGS_DEFAULT(VERBOSE) at all 3 dispatch branches for runner selection diagnostics
  • Updated stale routing comments throughout

Unfused softcap (attention_impl.cu):

  • Added ApplySoftcapKernel + AddAttentionBiasKernel for spec-correct ordering
  • Softcap applied between Q*K' GEMM and softmax, after mask/bias composition

Test infrastructure (common.py):

  • Fixed Python reference softcap ordering (was softcap before bias → now bias before softcap per spec)
  • Added softcap passthrough to attention_ref() in all 4 parity check functions
  • Fixed missing attn_mask input in parity_check_mha_past for decode+mask configs
  • Added v_head_size support and mea_aligned_past_seq() utility

Tests: 16+ new methods covering MEA decode (GQA/MHA, fp16/bf16/fp32, bool/float masks, 4D BNSH), unfused softcap (prompt/decode, fp16/fp32, with/without mask), C++ MEA decode, and asymmetric head_size regression.

Gaps Closed (from #27880)

Gap Status How
GQA + decode + attn_mask ✅ CLOSED MEA decode with additive bias
GQA + decode + head_size > 256 ✅ CLOSED MEA supports head_size ≤ 1024
Softcap + fp32 + decode ✅ CLOSED Unfused now accepts softcap
Softcap + decode + mask ✅ CLOSED Unfused spec-correct ordering
Softcap + h≠v_h + decode ✅ CLOSED Unfused handles softcap + h≠v_h

Related

titaiwangms and others added 15 commits April 6, 2026 20:42
Enable Memory Efficient Attention (CUTLASS FMHA) to handle decode
with past_key/past_value via internal KV cache concatenation.

Changes:
- Relax MEA eligibility: allow past_key when head_size == v_head_size
  (was: past_key == nullptr)
- Add decode path in RunMemoryEfficientAttention:
  1. LaunchFillInt32 for uniform past_seqlens (ONNX past_key has fixed shape)
  2. Transpose K/V to BSNH if 4D BNSH input
  3. LaunchConcatNewToPastKV to fuse past + new into present (BNSH)
  4. Point MEA at present buffers, track kv_is_bsnh layout
- Update LaunchUngroup and MEA params to use kv_is_bsnh
- Skip present_key/value population when already done by concat
- Update dispatch comments and error messages

Design: Uses uniform past_seqlens + additive bias for masks.
No new kernels needed. No memset needed (all present positions written).
Bool masks handled correctly via ConvertAttnMaskToBias (no NaN risk).

Closes gaps: GQA+decode+mask, GQA+decode+h>256, softcap+fp32+decode,
softcap+decode+mask (4 of 9 gaps from issue microsoft#27880).

Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6]
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Add comprehensive test coverage for the MEA decode path with past_kv:

common.py:
- Add v_head_size field to AttentionConfig (defaults to head_size)
- Update all tensor shape logic to use effective_v_head_size for V
  and output tensors (graph inputs/outputs, prompt/past bindings)

test_gqa.py:
- Add GQA+MEA decode tests (fp16): test_gqa_past_memory_efficient
- Add GQA+MEA decode tests (bf16): TestONNXAttentionMemoryEfficientGQABF16
- Re-enable GQA+padding mask decode tests via MEA (was skipped, now works)
- Add GQA+4D BNSH decode via MEA: TestONNXAttentionGQA4DBNSHMEA
- Add GQA+float mask decode via MEA: TestONNXAttentionMemoryEfficientGQAFloatMaskDecode

test_mha.py:
- Add MHA+MEA decode tests: TestONNXAttentionMHAPastMEA
- Add asymmetric head_size regression test: TestONNXAttentionMHAAsymmetricHeadSize
  (verifies MEA fallback to unfused when head_size != v_head_size)

Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6]
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…tics

- Rename TestONNXAttentionPaddingMaskGQA to TestONNXAttentionPaddingMaskMEAGQA
  to reflect that it now tests MEA (not Flash) for decode with past_key
- Add print_diff_statistics for present_k/v assertions in asymmetric
  head_size regression test for consistency with other tests

Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6]
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Replace `!(past_key != nullptr && head_size != v_head_size)` with
equivalent `(past_key == nullptr || head_size == v_head_size)`.
Same logic, reads more naturally and is safer when adding future conditions.

Per readability and code review feedback.

Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6]
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Change past_kv_sequence_length from 32 to 31 in tests that use
attn_mask with MEA decode. MEA requires total_sequence_length % 4 == 0
when an attention mask is present (CUTLASS kMinimumAlignment=4).

With past=32 + new=1 = 33, 33%4=1 → MEA disabled → NOT_IMPLEMENTED.
With past=31 + new=1 = 32, 32%4=0 → MEA eligible → test exercises MEA.

Affected tests:
- gqa_past_padding_test_cases(): seqs (1,32) → (1,31)
- test_gqa_past_float_mask_4d: past=32→31, total=33→32

Per code review feedback from @9c209479.

Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6]
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Log which attention runner is selected at VERBOSE level in ComputeInternal:
- Flash Attention: batch, q_seq, total_seq, past
- Memory Efficient Attention: batch, q_seq, total_seq, past, mask
- Unfused Attention: batch, q_seq, total_seq, past, mask

Helps verify runner targeting at runtime for debugging and testing.

Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6]
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Add utility function to compute past_kv_sequence_length values that
satisfy CUTLASS MEA's bias alignment requirement (total_seq % 4 == 0).

Prevents future test authors from accidentally picking sequence lengths
that disable MEA when attn_mask is present, causing tests to silently
fall through to unfused instead of exercising MEA.

Per architect suggestion from @a2065222.

Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6]
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Add ORT_ENFORCE(nonpad_kv_seqlen == nullptr) at the start of the
MEA decode path. nonpad_kv_seqlen (external cache, opset 24) and
past_key (internal cache) are mutually exclusive — enforced at
validation but adding a defensive check here too.

Per code review feedback.

Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6]
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Prevent negative results when past_seq is too small for the
alignment adjustment. Per code review edge case feedback.

Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6]
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Add element-wise softcap kernel (score = softcap * tanh(score / softcap))
applied after Q*K' GEMM and before softmax in the unfused attention path.

Changes:
- attention_impl.cu: Add ApplySoftcapKernel and ApplySoftcap template
  function. Insert call after GEMM, before ComputeSoftmax, when
  parameters.softcap > 0. Works for all dtypes (fp16/bf16/fp32).
- attention.cc: Remove unfused softcap rejection, pass softcap to
  contribop_parameters, update comments.
- onnx_backend_test_series_filters.jsonc: Remove 2 softcap+diff_head_size
  filters (unfused now handles softcap). Keep qk_matmul+softcap filters
  (qk_matmul modes still not implemented).
- attention_op_test.cc: Enable CUDA for softcap tests (was disabled
  because unfused didn't support softcap).
- test_mha.py: Add TestONNXAttentionMHAUnfusedSoftcap with prompt/decode
  tests for fp16 and fp32.

Closes gap R4 from issue microsoft#27880 (softcap + h!=v_h decode).

Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6]
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Move the softcap kernel and wrapper function definition above
UnfusedAttention so readers encounter the definition before its
call site. Per readability review feedback.

Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6]
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
- Add CUDA_RETURN_IF_ERROR(cudaGetLastError()) after kernel launch
  to catch silent launch failures.
- Replace total_elements <= 0 early return with ORT_ENFORCE since
  zero elements indicates a caller bug, not a valid no-op.

Per critical and readability review feedback.

Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6]
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Add `float softcap = 0.0f` to the base contrib::AttentionParameters
struct so the unfused path can access it via contribop_parameters.
The field was only in derived structs (GroupQueryAttentionParameters,
PagedAttentionParameters) — the base struct lacked it, which would
cause a compile error.

Also update RunUnfusedAttention docstring to reflect that softcap
is now supported via the ApplySoftcap kernel.

Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6]
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Python tests (test_mha.py):
- TestONNXAttentionMHAPastMEAFP32: MHA decode via MEA with fp32
- TestONNXAttentionMHAPastMEABoolMask: MHA decode via MEA with bool mask
- TestONNXAttentionMHAPastMEAFloatMask: MHA decode via MEA with float 4D mask
  All use past_kv_sequence_length=31 for CUTLASS alignment.

C++ test (attention_op_test.cc):
- Attention4DMEADecodeFloat16: MEA decode with fp16, forces MEA via
  ScopedEnvironmentVariables disabling Flash. Uses same one-hot data
  as Attention4DDefaultBasic for predictable expected output.

Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6]
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1. Update bool mask filter comments: note that bool masks now work via
   ConvertAttnMaskToBias in both unfused and MEA, but these ONNX backend
   tests need re-verification before removing the filter.
2. Fix GQA filter comment: was 'does not support fp16 and 4d QKV' which
   is misleading (fp16 GQA works via Flash/MEA). Actual blockers: fp32
   GQA (no LaunchUngroup fp32 template) and 4D BNSH GQA format.
3. Standardize fp32 GQA filter comments for clarity.

Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6]
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends ONNX Runtime’s CUDA Attention operator to support softcap in the unfused CUDA path and adds Memory Efficient Attention (CUTLASS FMHA) decode support (past/present KV cache), along with expanded ONNX/Python/C++ tests and updated backend test filters.

Changes:

  • Add softcap plumbing (AttentionParameters.softcap) and apply softcap to unfused attention logits via a CUDA kernel before softmax.
  • Implement MEA decode by concatenating past+new KV into the present buffer (via LaunchConcatNewToPastKV) and updating kernel selection/verbosity logs.
  • Add/adjust Python and C++ tests for MEA decode, unfused softcap, bool/float masks, and asymmetric head-size fallback; update ONNX backend filters accordingly.

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
onnxruntime/core/providers/cuda/llm/attention.cc Adds MEA decode path (concat KV into present), propagates softcap to unfused, and adds verbose kernel-selection logging.
onnxruntime/contrib_ops/cuda/bert/attention_impl.cu Implements ApplySoftcap CUDA kernel and applies it in unfused attention before softmax.
onnxruntime/contrib_ops/cpu/bert/attention_parameters.h Adds softcap to AttentionParameters.
onnxruntime/test/providers/cpu/llm/attention_op_test.cc Enables CUDA softcap tests and adds a CUDA MEA decode regression test (forced via env var).
onnxruntime/test/python/transformers/test_onnx_attention/common.py Adds v_head_size support to graph IO shapes/bindings; adds MEA alignment helper for decode+mask tests.
onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py Adds MEA decode tests (fp16/fp32, bool/float masks), unfused softcap tests, and asymmetric head-size fallback regression test.
onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py Re-enables/extends GQA MEA decode tests (including bf16) and adjusts padding-mask cases for MEA alignment.
onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc Updates exclusions/comments for attention backend tests based on new softcap/decode behavior.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/core/providers/cuda/llm/attention.cc Outdated
Comment thread onnxruntime/contrib_ops/cuda/bert/attention_impl.cu Outdated
1. attention.cc: Replace ORT_ENFORCE for present_key/present_value with
   scratch buffer allocation when outputs are nullptr. MEA decode now
   works even when present outputs are not requested. Use ORT_RETURN_IF_NOT
   for user-facing validation (past_value, nonpad_kv_seqlen, head_size).
2. attention_impl.cu: Replace ORT_ENFORCE(total_elements > 0) with
   early return for zero elements, since q_sequence_length=0 is valid.

Per Copilot review on PR microsoft#27992.

Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6]
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@titaiwangms titaiwangms changed the title Add Memory Efficient Attention decode support and tests for ONNX ONNX Attention CUDA: Add MEA decode support and unfused softcap Apr 6, 2026
Copy link
Copy Markdown
Contributor Author

@titaiwangms titaiwangms left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably no but check whether we are not breaking graph at any point. For example, #27484

Also, does nonpad_kv_seqlens paths totally unrelated to gap table? I don't see it's mentioned at all.

Comment thread onnxruntime/contrib_ops/cpu/bert/attention_parameters.h
Comment thread onnxruntime/core/providers/cuda/llm/attention.cc
Comment thread onnxruntime/core/providers/cuda/llm/attention.cc
Comment thread onnxruntime/test/providers/cpu/llm/attention_op_test.cc
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 8 out of 8 changed files in this pull request and generated 1 comment.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/contrib_ops/cuda/bert/attention_impl.cu Outdated
titaiwangms and others added 3 commits April 6, 2026 23:46
ONNX spec defines: QK → add mask/bias → softcap → softmax.
Flash/MEA kernels fuse softcap before bias (can't reorder).

Changes:
- attention.cc: Reject MEA when softcap>0 AND attn_mask!=nullptr
  (CUTLASS applies softcap before bias, diverging from spec).
  Updated unfused docstring with spec ordering note.
- attention_impl.cu: Add AddAttentionBias kernel for in-place bias
  addition with 2D broadcasting. Restructure unfused softcap+softmax:
  when both softcap and bias present, apply bias→softcap→softmax
  (spec-correct); otherwise keep standard ordering.
- test_mha.py: Add test_unfused_softcap_with_mask_prompt_fp16 and
  test_unfused_softcap_with_mask_decode_fp16 to verify spec ordering.

Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6]
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Causal mask is applied in ComputeSoftmax (after softcap), but this
is safe: causal positions get -inf, and softcap*tanh(-inf/softcap)
= -softcap → 0 after softmax, identical to applying causal before
softcap. Per code review feedback.

Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6]
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1. attention.cc: Update MEA eligibility docstring to include
   softcap+mask rejection (CUTLASS ordering diverges from spec).
2. test_mha.py: Add test_mha_unfused_decode_fp32 for unfused decode
   with fp32 (both Flash and MEA disabled).

Linting: lintrunner reports no issues on all changed files.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@titaiwangms titaiwangms changed the title ONNX Attention CUDA: Add MEA decode support and unfused softcap ONNX Attention CUDA: MEA decode, unfused softcap, and spec-correct softcap ordering Apr 7, 2026
MEA NaN fix:
- Cap mask_filter_value to -1e+30f in ConvertAttnMaskToBias to prevent CUTLASS
  kLog2e overflow (lowest() * 1.4427 overflows fp32 to -inf, causing s_prime=0 NaN)
- Move present_k/v_scratch to function scope (prevent use-after-free when
  present outputs are not requested)
- Re-enable CUDA for Attention4DAttnMaskBoolAllFalseDecodeWithPast test
- Add TODO(titaiwang) for ZeroOutputForFullyMaskedBatches cross-EP consistency

Softcap test improvements:
- Fix Python reference softcap ordering (was softcap before bias, now bias before
  softcap per ONNX spec)
- Add softcap passthrough to attention_ref() in all 4 parity check functions
- Fix missing attn_mask input in parity_check_mha_past for decode+mask configs
- All 5 softcap tests now properly validate against reference with softcap applied

Comment and documentation fixes:
- Fix inaccurate Flash comments (no bias support, Q-only transpose)
- Update stale routing comments (decode+mask tests now route to MEA)
- Add head_size divisible-by-8 to MEA eligibility docstring
- Fix Truen typo in backend test filters
- Tighten various comment accuracy across attention.cc and tests

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@titaiwangms titaiwangms changed the title ONNX Attention CUDA: MEA decode, unfused softcap, and spec-correct softcap ordering ONNX Attention CUDA: MEA decode, unfused softcap, spec-correct ordering, and NaN fix Apr 7, 2026
@titaiwangms titaiwangms requested a review from Copilot April 7, 2026 23:25
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 8 out of 8 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/contrib_ops/cuda/bert/attention_impl.cu Outdated
Comment thread onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Comment thread onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Comment thread onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
titaiwangms and others added 2 commits April 8, 2026 18:39
ApplySoftcap: add grid-stride loop and cap grid_size to 65535
to prevent overflow for large tensors.

AddAttentionBias: same grid-stride loop and grid_size cap.

Fix causal masking comment: ComputeSoftmax applies causal mask
AFTER softcap (sets masked positions to -inf inside softmax kernel),
not before as the old comment incorrectly implied.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

Agent-signed-off: Developer (93aa25b8) [claude-opus-4.6]
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Captures reusable knowledge from ONNX Attention work:
- Runner dispatch cascade (Flash → MEA → Unfused)
- CUTLASS kLog2e overflow and mask_filter_value cap
- Bias alignment requirements
- Softcap ordering (ONNX spec vs kernel behavior)
- Grid-stride loop patterns
- Fully-masked batch handling
- Test runner targeting with env vars
- Cross-EP consistency principles
- File location map
- Parameter bridge pattern

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 9 out of 9 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

titaiwangms and others added 2 commits April 8, 2026 21:53
Clarify that this skill covers the ONNX domain Attention op, NOT
contrib MultiHeadAttention/GroupQueryAttention. Add prominent scope
note distinguishing shared infrastructure vs ONNX-specific vs
contrib-specific code. Restructure file locations into three
categories. Clarify parameter bridge is ONNX-only.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Bug 1: parity_check_mha_past created float additive masks but bound
them as TensorProto.BOOL when config.attn_mask_type=='bool'. Fix:
create actual boolean tensors for ORT and convert to additive for
the PyTorch reference path.

Bug 2: Both parity check functions used seqlens=total_seq_len,
creating all-zero masks that can't detect ordering bugs between
bias→softcap vs softcap→bias. Fix: use 75% seqlens so some
positions are actually masked.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

Agent-signed-off: Developer (93aa25b8) [claude-opus-4.6]
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates here. I re-checked the current head and the earlier MEA optional-output, softcap zero-size, grid-stride, and mask-test concerns look addressed. I found one remaining correctness issue in the new unfused softcap path: kQK output is copied after the logits have already been softcapped, and in the softcap+bias branch it is not copied at all.

// When only one of softcap/bias is present, ordering is irrelevant.
if (has_softcap) {
int64_t total_elements = static_cast<int64_t>(batches) * sequence_length * total_sequence_length;
ORT_RETURN_IF_ERROR(ApplySoftcap<T>(data.scratch, parameters.softcap, total_elements, stream));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This changes the semantics of output_qk when qk_matmul_output_mode == kQK and softcap is active. kQK is supposed to expose the scaled Q*K logits, while kQKSoftCap is the separate mode for softcapped logits. In this no-mask branch we call ApplySoftcap before CopyQK, so output_qk contains SoftCap(Q*K) instead of Q*K; in the has_softcap && has_bias branch above, output_qk is not copied at all. Please copy data.scratch to output_qk immediately after the GEMM, before AddAttentionBias or ApplySoftcap mutate it, then continue using data.scratch for the attention computation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants