Conversation
There was a problem hiding this comment.
Code Review
This pull request implements a fused-style Kimi Delta Attention (KDA) prefill for CUDA (sm_89+), providing intra-chunk and inter-chunk kernels, PyTorch bindings, and benchmarking scripts. The review feedback identifies critical technical issues, specifically numerical instability in decay calculations that may cause division by zero, and memory alignment violations in WMMA operations due to unaligned strides. Further improvements were suggested for the shared memory attribute error handling and the prevention of potential integer overflows when casting tensor dimensions. All review comments provide actionable feedback on actual issues or improvement opportunities.
| for (int d = 0; d < K_DIM; ++d) { | ||
| float k_i = static_cast<float>(s_K[row * k_stride + d]); | ||
| float k_j = static_cast<float>(s_K[col * k_stride + d]); | ||
| float decay = s_G_prefix[row * g_stride + d] / s_G_prefix[col * g_stride + d]; |
There was a problem hiding this comment.
The calculation of decay using a ratio of prefix products is numerically unstable. If s_G_prefix[col * g_stride + d] underflows to zero (which is possible for long sequences as it is a product of sigmoids), this will result in a division by zero and NaN values. A more robust approach would be to compute the decay as a product of sigmoids from col + 1 to row, or by using log-space additions and then exponentiating.
| wmma::load_matrix_sync(a_frag, a_ptr, k_stride); | ||
| wmma::load_matrix_sync(b_frag, b_ptr, v_stride); |
There was a problem hiding this comment.
The wmma::load_matrix_sync calls use strides (k_stride = 65, v_stride = 17) that are not multiples of 8 or 16 bytes. For tf32 WMMA operations on sm_80+, the memory access for fragments must be aligned. A stride of 65 floats (260 bytes) is not 16-byte aligned, which will likely lead to illegal memory access or incorrect results. Consider padding the shared memory buffers so that the leading dimension is a multiple of 8 or 16 elements.
| if (attr_err != cudaSuccess && fused_smem <= 48 * 1024) { | ||
| check_cuda(attr_err, "set fused shared memory attribute failed"); | ||
| } |
There was a problem hiding this comment.
The logic for checking cudaFuncSetAttribute failure seems inconsistent. It only throws an error if the requested shared memory is less than or equal to 48KB. If the kernel requires more than 48KB (which is common for fused kernels on sm89) and the attribute setting fails, the launch will likely fail with an 'out of shared memory' error anyway. It is better to check for success regardless of the size, or provide a more descriptive error message if the hardware limit is exceeded.
| io.batch_size = static_cast<int>(B); | ||
| io.num_heads = static_cast<int>(H); | ||
| io.seq_len = static_cast<int>(T); | ||
| io.head_dim = static_cast<int>(K); | ||
| io.value_dim = static_cast<int>(V); | ||
| io.chunk_size = static_cast<int>(chunk_size); | ||
| io.num_chunks = static_cast<int>(num_chunks); |
There was a problem hiding this comment.
The dimensions B, H, T, K, V are cast from int64_t to int when filling the KdaPrefillIO structure. While typical batch sizes and head dimensions fit in a 32-bit integer, the sequence length T or the total number of elements could potentially overflow if very large tensors are processed. Consider using int64_t in the IO structure or adding explicit checks to ensure dimensions fit within int limits.
KDA prefill (CUDA)
Fused-style KDA (Kimi Delta Attention) prefill for float32: intra-chunk (KKT / W–U) plus inter-chunk recurrence. Tensor layout matches cuLA KDA prefill conventions (
[B,H,T,K]/[B,H,T,V]). This repo ships:include/+src/kda.cukda_prefill_f32via CMake (libkda_prefill_runtime.so)torch.ops.kda_prefill(src/kda_prefill_binding.cpp+setup.py)src/main.pysrc/test.py(unittest, CUDA required)GPU / architecture (
sm_89)sm_89) or newer, as enforced ininclude/kda_api.hpp(check_sm89_or_newer).CMAKE_CUDA_ARCHITECTURES=89and PyTorch extension NVCC flags usearch=compute_89,code=sm_89.sm_80) are not supported by this kernel path.Interfaces
1) PyTorch (recommended for integration)
After building the extension (see How to run), import registers the op:
Or use the thin wrapper from
src/test.py:Signature (Python):
q,k,gtorch.Tensor(B, H, T, K), float32, CUDA, contiguousvtorch.Tensor(B, H, T, V), float32, contiguousbetatorch.Tensor(B, H, T), float32, contiguouschunk_sizeintCw,u,oTensorw(B,H,Nc,C,K),u(B,H,Nc,C,V),o(B,H,T,V)withNc = ceil(T / C)Returns:
(o, w, u)— output attentionoand intra-chunk buffersw,u.Supported
(K, V, chunk_size)for the PyTorch entry point (enforced insrc/kda_prefill_binding.cpp):(64, 64, 32)and(64, 64, 64)(128, 128, 32)only —(128, 128, 64)is not exposed via this binding (runtime issues on current kernel config).Registered op name:
kda_prefill::forward→ Python:torch.ops.kda_prefill.forward(...).2) C ABI (shared library from CMake)
KdaPrefillIO<float>is defined ininclude/kda_prefill_io.hpp(pointer fields +B,H,T,K,V,C,num_chunks). Seesrc/main.pyfor a ctypes layout mirror.3) C++ launch (inside CUDA translation unit)
kda::api::launch_kda_prefill/launch_kda_prefill_f32ininclude/kda_api.hpp+src/kda.cu(used by the PyTorch binding viakda_prefill_f32).How to run
A. CMake + shared library (C / ctypes)
cmake -S . -B build -DCMAKE_CUDA_ARCHITECTURES=89 cmake --build buildArtifacts:
build/libkda_prefill.a,build/libkda_prefill_runtime.so(kda_prefill_f32).B. PyTorch CUDA extension + tests
Use the same Python environment as
import torch(same CUDA / PyTorch ABI):src/test.pyadds the repo root tosys.pathand loadskda_prefill_cuda*.sofrom the root afterinplacebuild.C. Optional benchmark vs FLA (
src/main.py)Requires PyTorch; for FLA comparison install flash-linear-attention (
fla).local-benchmark— CUDA onlyfla-only-benchmark— FLA onlyfla-benchmark— both + quick accuracy check vschunk_kdaTuning env (optional)
KDA_INTRA_THREADSKDA_INTER_THREADSKDA_INTER_SHARDSV)Pull request text (for cuLA / upstream)
When you open a PR on cuLA (or another monorepo), paste the block below into the description. It mirrors
.github/pull_request_template.mdstyle.Title (suggestion)
feat: KDA prefill CUDA (sm_89) + PyTorch kda_prefill opDescription
Benchmark vs FLA (flash-linear-attention)
Hardware: NVIDIA RTX 4070,
sm_89(Ada). CMake built withCMAKE_CUDA_ARCHITECTURES=89.Setup:
K=V=64, chunkC=32. Local path is float32; FLAchunk_kdainsrc/main.pyruns bf16 — timing is not strictly dtype-matched but useful as a baseline.Latest sweep (machine‑logged) is under
analysis/ncu/20260409-193656-large-shape-compare/: summary table in that folder’s README, raw CSVbenchmark_results.csv, console captures underbenchmarks/in that tree.Quick accuracy line from the same harness (local fp32 vs FLA bf16):
max_abs ≈ 4.09,mean_abs ≈ 0.80on the small validation shape insidesrc/main.py(see folder README).Profiling artifacts
analysis/ncu/20260409-193656-large-shape-compare/holds the FLA comparison CSV, console logs, and a short README. Optional for builds.Kernel template instantiations (
src/kda.cu)(K,V,C)dispatched inlaunch_kda_prefill_f32:(64,64,64),(64,64,32),(128,128,64),(128,128,32). The PyTorch binding validates a subset (see Interfaces).