Skip to content

Commit ba08097

Browse files
author
guanshihui
committed
删掉注释
1 parent d83f9ef commit ba08097

1 file changed

Lines changed: 1 addition & 2 deletions

File tree

tests/operators/test_flash_mask_attn.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,6 @@ def causal_attention_naive(self, q_input, k_input, v_input, cu_seq_q, cu_seq_k):
126126
v = v_input[cu_seq_k[bi] : cu_seq_k[bi + 1], :, :].transpose([1, 0, 2]).astype("float32").numpy()
127127
qk = np.matmul(q, np.repeat(k, gqa_group_size, 0))
128128
qk *= qk_scale
129-
# Causal mask: lower triangular
130129
condition = np.tril(np.ones(qk.shape), q.shape[1] - k.shape[2])
131130
mask = np.ones(condition.shape).astype("float32") * -1000000
132131
qk = np.where(condition > 0, qk, mask)
@@ -141,7 +140,7 @@ def causal_attention_naive(self, q_input, k_input, v_input, cu_seq_q, cu_seq_k):
141140

142141
def test_flash_encoder_attn_fwd(self):
143142
if self.sm_version < 100:
144-
self.skipTest("Flash Attention V4 requires SM100+.")
143+
self.skipTest("Flash Encoder Attention V4 requires SM100+.")
145144

146145
q_input = paddle.randn([self.q_len, self.num_head, self.head_dim], dtype="bfloat16")
147146
k_input = paddle.randn([self.q_len, self.num_kv_head, self.head_dim], dtype="bfloat16")

0 commit comments

Comments
 (0)