Skip to content

Commit 7a20eae

Browse files
authored
[Feature] Support cute cpp Encoder FA4 (#7016)
* add cute cpp fa4 * 删掉注释 * 修正合并错误 * sm_version放到函数内 * ci错误
1 parent 9765fa7 commit 7a20eae

4 files changed

Lines changed: 132 additions & 14 deletions

File tree

fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
)
3131
from fastdeploy.model_executor.layers.attention.ops import (
3232
append_attention,
33+
flash_attn_v4,
3334
flash_mask_attention,
3435
get_block_shape_and_split_kv_block,
3536
gqa_rope_write_cache,
@@ -51,6 +52,8 @@
5152
else:
5253
merge_prefill_decode_output = None
5354

55+
from fastdeploy.model_executor.utils import get_sm_version
56+
5457

5558
@dataclass
5659
class FlashMaskAttentionMetadata(AttentionMetadata):
@@ -124,6 +127,7 @@ def __init__(
124127
if fd_config.speculative_config.model_type != "main":
125128
self.rope_3d = False
126129
self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", "32768"))
130+
self.sm_version = get_sm_version()
127131

128132
def get_kv_cache_shape(
129133
self,
@@ -278,19 +282,30 @@ def forward_mixed(
278282
self.rope_3d,
279283
)
280284

281-
flash_mask_attention(
282-
q,
283-
k,
284-
v,
285-
forward_meta.cu_seqlens_q,
286-
forward_meta.attn_cu_seqlens_k,
287-
forward_meta.seq_lens_encoder,
288-
res_encoder,
289-
forward_meta.attn_mask_offsets,
290-
self.num_heads,
291-
self.kv_num_heads,
292-
self.head_dim,
293-
)
285+
if self.sm_version >= 100:
286+
flash_attn_v4(
287+
q,
288+
k,
289+
v,
290+
forward_meta.cu_seqlens_q,
291+
forward_meta.attn_cu_seqlens_k,
292+
res_encoder,
293+
forward_meta.attn_mask_offsets,
294+
)
295+
else:
296+
flash_mask_attention(
297+
q,
298+
k,
299+
v,
300+
forward_meta.cu_seqlens_q,
301+
forward_meta.attn_cu_seqlens_k,
302+
forward_meta.seq_lens_encoder,
303+
res_encoder,
304+
forward_meta.attn_mask_offsets,
305+
self.num_heads,
306+
self.kv_num_heads,
307+
self.head_dim,
308+
)
294309

295310
res_decoder = append_attention(
296311
qkv,

fastdeploy/model_executor/layers/attention/ops/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""
1616

1717
from .append_attention import append_attention, append_attention_with_output
18+
from .flash_attn_v4 import flash_attn_v4
1819
from .flash_mask_attention import flash_mask_attention
1920
from .get_attn_mask_q import get_attn_mask_q
2021
from .get_block_shape_and_split_kv_block import get_block_shape_and_split_kv_block
@@ -33,6 +34,7 @@
3334
"gqa_rope_write_cache",
3435
"pre_cache_len_concat",
3536
"init_kv_signal_per_query",
37+
"flash_attn_v4",
3638
"flash_mask_attention",
3739
"get_attn_mask_q",
3840
]
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
"""
2+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""
16+
17+
from typing import Optional
18+
19+
import paddle
20+
21+
from fastdeploy.model_executor.utils import get_sm_version
22+
from fastdeploy.platforms import current_platform
23+
24+
25+
def flash_attn_v4(
26+
q: paddle.Tensor,
27+
k: paddle.Tensor,
28+
v: paddle.Tensor,
29+
cu_seqlens_q: paddle.Tensor,
30+
cu_seqlens_k: paddle.Tensor,
31+
attn_out: paddle.Tensor,
32+
attn_mask_offsets: Optional[paddle.Tensor] = None,
33+
):
34+
if current_platform.is_cuda() and get_sm_version() >= 100:
35+
from blackwell_ops import flash_encoder_attn_fwd
36+
37+
flash_encoder_attn_fwd(q, k, v, cu_seqlens_q, cu_seqlens_k, attn_out, attn_mask_offsets)
38+
else:
39+
raise NotImplementedError

tests/operators/test_flash_mask_attn.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
from fastdeploy.model_executor.layers.attention.flash_attn_backend import (
2323
flash_attn_func,
2424
)
25-
from fastdeploy.model_executor.layers.attention.ops import get_attn_mask_q
25+
from fastdeploy.model_executor.layers.attention.ops import (
26+
flash_attn_v4,
27+
get_attn_mask_q,
28+
)
2629
from fastdeploy.model_executor.ops.gpu import flash_mask_attention
2730

2831

@@ -109,6 +112,65 @@ def test_flash_mask_attention(self):
109112
max_diff = (paddle_attn_out - naive_attn_out).abs().max().item()
110113
self.assertLessEqual(max_diff, 0.05)
111114

115+
def causal_attention_naive(self, q_input, k_input, v_input, cu_seq_q, cu_seq_k):
116+
"""Causal attention reference implementation for flash_attn_v4 testing."""
117+
bsz = cu_seq_q.shape[0] - 1
118+
q_token_sum, num_head, head_dim = q_input.shape
119+
k_token_sum, num_kv_head, _ = k_input.shape
120+
gqa_group_size = num_head // num_kv_head
121+
qk_scale = 1 / np.sqrt(head_dim)
122+
out = paddle.zeros([num_head, q_token_sum, head_dim], q_input.dtype)
123+
for bi in range(bsz):
124+
q = q_input[cu_seq_q[bi] : cu_seq_q[bi + 1], :, :].transpose([1, 0, 2]).astype("float32").numpy()
125+
k = k_input[cu_seq_k[bi] : cu_seq_k[bi + 1], :, :].transpose([1, 2, 0]).astype("float32").numpy()
126+
v = v_input[cu_seq_k[bi] : cu_seq_k[bi + 1], :, :].transpose([1, 0, 2]).astype("float32").numpy()
127+
qk = np.matmul(q, np.repeat(k, gqa_group_size, 0))
128+
qk *= qk_scale
129+
condition = np.tril(np.ones(qk.shape), q.shape[1] - k.shape[2])
130+
mask = np.ones(condition.shape).astype("float32") * -1000000
131+
qk = np.where(condition > 0, qk, mask)
132+
qk_max = qk.max(axis=-1, keepdims=True)
133+
qk -= qk_max
134+
qk = np.exp(qk)
135+
exp_sum = qk.sum(axis=-1, keepdims=True)
136+
exp_sum_inv = 1.0 / exp_sum
137+
temp_out = paddle.to_tensor(np.matmul(qk, np.repeat(v, gqa_group_size, 0)))
138+
out[:, cu_seq_q[bi] : cu_seq_q[bi + 1], :] = temp_out * exp_sum_inv
139+
return out.transpose([1, 0, 2])
140+
141+
def test_flash_encoder_attn_fwd(self):
142+
if self.sm_version < 100:
143+
self.skipTest("Flash Encoder Attention V4 requires SM100+.")
144+
145+
q_input = paddle.randn([self.q_len, self.num_head, self.head_dim], dtype="bfloat16")
146+
k_input = paddle.randn([self.q_len, self.num_kv_head, self.head_dim], dtype="bfloat16")
147+
v_input = paddle.randn(k_input.shape, dtype="bfloat16")
148+
149+
mask = paddle.arange(self.q_len).astype("int32") + 1
150+
151+
bsz = self.bsz
152+
cu_seq_q = paddle.arange(bsz + 1) * self.q_len
153+
cu_seq_k = paddle.arange(bsz + 1) * self.q_len
154+
cu_seq_q = cu_seq_q.astype("int32")
155+
cu_seq_k = cu_seq_k.astype("int32")
156+
157+
naive_attn_out = self.causal_attention_naive(q_input, k_input, v_input, cu_seq_q, cu_seq_k)
158+
159+
paddle_attn_out = paddle.empty(q_input.shape, dtype="bfloat16")
160+
161+
flash_attn_v4(
162+
q_input,
163+
k_input,
164+
v_input,
165+
cu_seq_q,
166+
cu_seq_k,
167+
paddle_attn_out,
168+
mask,
169+
)
170+
171+
max_diff = (paddle_attn_out - naive_attn_out).abs().max().item()
172+
self.assertLessEqual(max_diff, 0.05)
173+
112174
def test_fa4(
113175
self,
114176
):

0 commit comments

Comments
 (0)