Skip to content

KDA_prefill_sm89#47

Open
Azir9 wants to merge 1 commit intoinclusionAI:mainfrom
Azir9:feature-kda-operator
Open

KDA_prefill_sm89#47
Azir9 wants to merge 1 commit intoinclusionAI:mainfrom
Azir9:feature-kda-operator

Conversation

@Azir9
Copy link
Copy Markdown

@Azir9 Azir9 commented Apr 11, 2026

KDA prefill (CUDA)

Fused-style KDA (Kimi Delta Attention) prefill for float32: intra-chunk (KKT / W–U) plus inter-chunk recurrence. Tensor layout matches cuLA KDA prefill conventions ([B,H,T,K] / [B,H,T,V]). This repo ships:

  • CUDA kernels under include/ + src/kda.cu
  • C ABI kda_prefill_f32 via CMake (libkda_prefill_runtime.so)
  • PyTorch custom op torch.ops.kda_prefill (src/kda_prefill_binding.cpp + setup.py)
  • ctypes benchmark harness: src/main.py
  • Unit tests: src/test.py (unittest, CUDA required)

GPU / architecture (sm_89)

  • Minimum: NVIDIA Ada Lovelace (sm_89) or newer, as enforced in include/kda_api.hpp (check_sm89_or_newer).
  • Build default: CMAKE_CUDA_ARCHITECTURES=89 and PyTorch extension NVCC flags use arch=compute_89,code=sm_89.
  • Examples: RTX 4090 / 4080 / 4070 (Ada). Older GPUs (e.g. sm_80) are not supported by this kernel path.

Interfaces

1) PyTorch (recommended for integration)

After building the extension (see How to run), import registers the op:

import kda_prefill_cuda  # noqa: F401 — registers torch.ops.kda_prefill
import torch

o, w, u = torch.ops.kda_prefill.forward(
    q, k, v, g, beta, chunk_size, w_opt, u_opt, o_opt
)

Or use the thin wrapper from src/test.py:

from pathlib import Path
import sys
sys.path.insert(0, str(Path("path/to/repo").resolve()))
from test import kda_prefill  # when running from src/ with PYTHONPATH=repo_root

o, w, u = kda_prefill(q, k, v, g, beta, chunk_size, w=None, u=None, o=None)

Signature (Python):

Argument Type Shape / notes
q, k, g torch.Tensor (B, H, T, K), float32, CUDA, contiguous
v torch.Tensor (B, H, T, V), float32, contiguous
beta torch.Tensor (B, H, T), float32, contiguous
chunk_size int Chunk length C
w, u, o optional Tensor If omitted, allocated internally. Shapes: w (B,H,Nc,C,K), u (B,H,Nc,C,V), o (B,H,T,V) with Nc = ceil(T / C)

Returns: (o, w, u) — output attention o and intra-chunk buffers w, u.

Supported (K, V, chunk_size) for the PyTorch entry point (enforced in src/kda_prefill_binding.cpp):

  • (64, 64, 32) and (64, 64, 64)
  • (128, 128, 32) only — (128, 128, 64) is not exposed via this binding (runtime issues on current kernel config).

Registered op name: kda_prefill::forward → Python: torch.ops.kda_prefill.forward(...).

2) C ABI (shared library from CMake)

extern "C" cudaError_t kda_prefill_f32(
    const kda::api::KdaPrefillIO<float>* io,
    cudaStream_t stream);

KdaPrefillIO<float> is defined in include/kda_prefill_io.hpp (pointer fields + B,H,T,K,V,C,num_chunks). See src/main.py for a ctypes layout mirror.

3) C++ launch (inside CUDA translation unit)

kda::api::launch_kda_prefill / launch_kda_prefill_f32 in include/kda_api.hpp + src/kda.cu (used by the PyTorch binding via kda_prefill_f32).


How to run

A. CMake + shared library (C / ctypes)

cmake -S . -B build -DCMAKE_CUDA_ARCHITECTURES=89
cmake --build build

Artifacts: build/libkda_prefill.a, build/libkda_prefill_runtime.so (kda_prefill_f32).

B. PyTorch CUDA extension + tests

Use the same Python environment as import torch (same CUDA / PyTorch ABI):

python setup.py build_ext --inplace
python src/test.py

src/test.py adds the repo root to sys.path and loads kda_prefill_cuda*.so from the root after inplace build.

C. Optional benchmark vs FLA (src/main.py)

Requires PyTorch; for FLA comparison install flash-linear-attention (fla).

python src/main.py --suite fla-benchmark --B 2 --H 8 --T 4096 --K 64 --V 64 --C 32
  • local-benchmark — CUDA only
  • fla-only-benchmark — FLA only
  • fla-benchmark — both + quick accuracy check vs chunk_kda

Tuning env (optional)

Variable Role
KDA_INTRA_THREADS Block size for intra-chunk kernel
KDA_INTER_THREADS Block size for inter-chunk kernel
KDA_INTER_SHARDS V-way sharding for inter kernel (must divide V)

Pull request text (for cuLA / upstream)

When you open a PR on cuLA (or another monorepo), paste the block below into the description. It mirrors .github/pull_request_template.md style.

Title (suggestion)

feat: KDA prefill CUDA (sm_89) + PyTorch kda_prefill op

Description

## Description

Adds **KDA prefill** (Kimi Delta Attention, **prefill** path — not decode): float32 fused intra-chunk + inter-chunk CUDA kernels, optional **PyTorch** custom op `torch.ops.kda_prefill.forward`, and ctypes / unittest harness. Layout aligns with cuLA KDA tensor conventions `[B,H,T,K]` / `[B,H,T,V]`.

**GPU:** requires **sm_89+** (Ada); build targets `sm_89`.

**Interfaces:**
- PyTorch: `torch.ops.kda_prefill.forward(q,k,v,g,beta,chunk_size,w?,u?,o?) -> (o,w,u)` after `import kda_prefill_cuda`; wrapper `kda_prefill(...)` in `src/test.py`.
- C: `kda_prefill_f32(KdaPrefillIO<float>*, cudaStream_t)` in `include/kda_prefill_io.hpp`.

**How to verify:**
1. `python setup.py build_ext --inplace` (or repo-specific `setup_kda_prefill_cuda.py` if merged under cuLA).
2. `python src/test.py` — 5 CUDA `unittest` cases (K/V/chunk coverage + preallocated buffers + direct `torch.ops` call).

## Related Issues

<!-- e.g. Closes #123 -->

## Pull Request Checklist

### Pre-commit

- [ ] `pre-commit install` and `pre-commit run --all-files` (if contributing to cuLA main tree).

### Tests

- [ ] `python src/test.py` passes on **sm_89** GPU with built `kda_prefill_cuda`.

### Performance

<!-- Optional: NCU / ms vs FLA; local harness in src/main.py -->

## Reviewer Notes

- Host `.cpp` binding must **not** include `kda_kernel.hpp`; use `kda_prefill_io.hpp` + `kda_prefill_f32` only.
- PyTorch binding currently restricts `(128,128,64)` even though `kda.cu` instantiates that template — document if re-enabled after kernel fix.

Benchmark vs FLA (flash-linear-attention)

Hardware: NVIDIA RTX 4070, sm_89 (Ada). CMake built with CMAKE_CUDA_ARCHITECTURES=89.

Setup: K=V=64, chunk C=32. Local path is float32; FLA chunk_kda in src/main.py runs bf16 — timing is not strictly dtype-matched but useful as a baseline.

Latest sweep (machine‑logged) is under analysis/ncu/20260409-193656-large-shape-compare/: summary table in that folder’s README, raw CSV benchmark_results.csv, console captures under benchmarks/ in that tree.

Shape (B,H,T) iters Local ms FLA ms Local / FLA
(2,8,4096) 10 1.8884 1.1918 1.58×
(2,8,8192) 10 3.6517 3.1311 1.17×
(4,8,4096) 10 3.8848 3.0597 1.27×
(2,16,4096) 10 3.8868 3.1079 1.25×
(2,8,16384) 5 7.2716 6.5675 1.11×

Quick accuracy line from the same harness (local fp32 vs FLA bf16): max_abs ≈ 4.09, mean_abs ≈ 0.80 on the small validation shape inside src/main.py (see folder README).

Profiling artifacts

analysis/ncu/20260409-193656-large-shape-compare/ holds the FLA comparison CSV, console logs, and a short README. Optional for builds.

Kernel template instantiations (src/kda.cu)

(K,V,C) dispatched in launch_kda_prefill_f32: (64,64,64), (64,64,32), (128,128,64), (128,128,32). The PyTorch binding validates a subset (see Interfaces).

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements a fused-style Kimi Delta Attention (KDA) prefill for CUDA (sm_89+), providing intra-chunk and inter-chunk kernels, PyTorch bindings, and benchmarking scripts. The review feedback identifies critical technical issues, specifically numerical instability in decay calculations that may cause division by zero, and memory alignment violations in WMMA operations due to unaligned strides. Further improvements were suggested for the shared memory attribute error handling and the prevention of potential integer overflows when casting tensor dimensions. All review comments provide actionable feedback on actual issues or improvement opportunities.

Comment thread include/kda_kernel.hpp
for (int d = 0; d < K_DIM; ++d) {
float k_i = static_cast<float>(s_K[row * k_stride + d]);
float k_j = static_cast<float>(s_K[col * k_stride + d]);
float decay = s_G_prefix[row * g_stride + d] / s_G_prefix[col * g_stride + d];
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The calculation of decay using a ratio of prefix products is numerically unstable. If s_G_prefix[col * g_stride + d] underflows to zero (which is possible for long sequences as it is a product of sigmoids), this will result in a division by zero and NaN values. A more robust approach would be to compute the decay as a product of sigmoids from col + 1 to row, or by using log-space additions and then exponentiating.

Comment thread include/kda_kernel.hpp
Comment on lines +633 to +634
wmma::load_matrix_sync(a_frag, a_ptr, k_stride);
wmma::load_matrix_sync(b_frag, b_ptr, v_stride);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The wmma::load_matrix_sync calls use strides (k_stride = 65, v_stride = 17) that are not multiples of 8 or 16 bytes. For tf32 WMMA operations on sm_80+, the memory access for fragments must be aligned. A stride of 65 floats (260 bytes) is not 16-byte aligned, which will likely lead to illegal memory access or incorrect results. Consider padding the shared memory buffers so that the leading dimension is a multiple of 8 or 16 elements.

Comment thread include/kda_api.hpp
Comment on lines +178 to +180
if (attr_err != cudaSuccess && fused_smem <= 48 * 1024) {
check_cuda(attr_err, "set fused shared memory attribute failed");
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for checking cudaFuncSetAttribute failure seems inconsistent. It only throws an error if the requested shared memory is less than or equal to 48KB. If the kernel requires more than 48KB (which is common for fused kernels on sm89) and the attribute setting fails, the launch will likely fail with an 'out of shared memory' error anyway. It is better to check for success regardless of the size, or provide a more descriptive error message if the hardware limit is exceeded.

Comment on lines +117 to +123
io.batch_size = static_cast<int>(B);
io.num_heads = static_cast<int>(H);
io.seq_len = static_cast<int>(T);
io.head_dim = static_cast<int>(K);
io.value_dim = static_cast<int>(V);
io.chunk_size = static_cast<int>(chunk_size);
io.num_chunks = static_cast<int>(num_chunks);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The dimensions B, H, T, K, V are cast from int64_t to int when filling the KdaPrefillIO structure. While typical batch sizes and head dimensions fit in a 32-bit integer, the sequence length T or the total number of elements could potentially overflow if very large tensors are processed. Consider using int64_t in the IO structure or adding explicit checks to ensure dimensions fit within int limits.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant