|
22 | 22 | from fastdeploy.model_executor.layers.attention.flash_attn_backend import ( |
23 | 23 | flash_attn_func, |
24 | 24 | ) |
25 | | -from fastdeploy.model_executor.layers.attention.ops import get_attn_mask_q |
| 25 | +from fastdeploy.model_executor.layers.attention.ops import ( |
| 26 | + flash_attn_v4, |
| 27 | + get_attn_mask_q, |
| 28 | +) |
26 | 29 | from fastdeploy.model_executor.ops.gpu import flash_mask_attention |
27 | 30 |
|
28 | 31 |
|
@@ -109,6 +112,65 @@ def test_flash_mask_attention(self): |
109 | 112 | max_diff = (paddle_attn_out - naive_attn_out).abs().max().item() |
110 | 113 | self.assertLessEqual(max_diff, 0.05) |
111 | 114 |
|
| 115 | + def causal_attention_naive(self, q_input, k_input, v_input, cu_seq_q, cu_seq_k): |
| 116 | + """Causal attention reference implementation for flash_attn_v4 testing.""" |
| 117 | + bsz = cu_seq_q.shape[0] - 1 |
| 118 | + q_token_sum, num_head, head_dim = q_input.shape |
| 119 | + k_token_sum, num_kv_head, _ = k_input.shape |
| 120 | + gqa_group_size = num_head // num_kv_head |
| 121 | + qk_scale = 1 / np.sqrt(head_dim) |
| 122 | + out = paddle.zeros([num_head, q_token_sum, head_dim], q_input.dtype) |
| 123 | + for bi in range(bsz): |
| 124 | + q = q_input[cu_seq_q[bi] : cu_seq_q[bi + 1], :, :].transpose([1, 0, 2]).astype("float32").numpy() |
| 125 | + k = k_input[cu_seq_k[bi] : cu_seq_k[bi + 1], :, :].transpose([1, 2, 0]).astype("float32").numpy() |
| 126 | + v = v_input[cu_seq_k[bi] : cu_seq_k[bi + 1], :, :].transpose([1, 0, 2]).astype("float32").numpy() |
| 127 | + qk = np.matmul(q, np.repeat(k, gqa_group_size, 0)) |
| 128 | + qk *= qk_scale |
| 129 | + condition = np.tril(np.ones(qk.shape), q.shape[1] - k.shape[2]) |
| 130 | + mask = np.ones(condition.shape).astype("float32") * -1000000 |
| 131 | + qk = np.where(condition > 0, qk, mask) |
| 132 | + qk_max = qk.max(axis=-1, keepdims=True) |
| 133 | + qk -= qk_max |
| 134 | + qk = np.exp(qk) |
| 135 | + exp_sum = qk.sum(axis=-1, keepdims=True) |
| 136 | + exp_sum_inv = 1.0 / exp_sum |
| 137 | + temp_out = paddle.to_tensor(np.matmul(qk, np.repeat(v, gqa_group_size, 0))) |
| 138 | + out[:, cu_seq_q[bi] : cu_seq_q[bi + 1], :] = temp_out * exp_sum_inv |
| 139 | + return out.transpose([1, 0, 2]) |
| 140 | + |
| 141 | + def test_flash_encoder_attn_fwd(self): |
| 142 | + if self.sm_version < 100: |
| 143 | + self.skipTest("Flash Encoder Attention V4 requires SM100+.") |
| 144 | + |
| 145 | + q_input = paddle.randn([self.q_len, self.num_head, self.head_dim], dtype="bfloat16") |
| 146 | + k_input = paddle.randn([self.q_len, self.num_kv_head, self.head_dim], dtype="bfloat16") |
| 147 | + v_input = paddle.randn(k_input.shape, dtype="bfloat16") |
| 148 | + |
| 149 | + mask = paddle.arange(self.q_len).astype("int32") + 1 |
| 150 | + |
| 151 | + bsz = self.bsz |
| 152 | + cu_seq_q = paddle.arange(bsz + 1) * self.q_len |
| 153 | + cu_seq_k = paddle.arange(bsz + 1) * self.q_len |
| 154 | + cu_seq_q = cu_seq_q.astype("int32") |
| 155 | + cu_seq_k = cu_seq_k.astype("int32") |
| 156 | + |
| 157 | + naive_attn_out = self.causal_attention_naive(q_input, k_input, v_input, cu_seq_q, cu_seq_k) |
| 158 | + |
| 159 | + paddle_attn_out = paddle.empty(q_input.shape, dtype="bfloat16") |
| 160 | + |
| 161 | + flash_attn_v4( |
| 162 | + q_input, |
| 163 | + k_input, |
| 164 | + v_input, |
| 165 | + cu_seq_q, |
| 166 | + cu_seq_k, |
| 167 | + paddle_attn_out, |
| 168 | + mask, |
| 169 | + ) |
| 170 | + |
| 171 | + max_diff = (paddle_attn_out - naive_attn_out).abs().max().item() |
| 172 | + self.assertLessEqual(max_diff, 0.05) |
| 173 | + |
112 | 174 | def test_fa4( |
113 | 175 | self, |
114 | 176 | ): |
|
0 commit comments