diff --git a/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py index 78c576dd13c..7fd749a7506 100644 --- a/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py @@ -30,6 +30,7 @@ ) from fastdeploy.model_executor.layers.attention.ops import ( append_attention, + flash_attn_v4, flash_mask_attention, get_block_shape_and_split_kv_block, gqa_rope_write_cache, @@ -51,6 +52,8 @@ else: merge_prefill_decode_output = None +from fastdeploy.model_executor.utils import get_sm_version + @dataclass class FlashMaskAttentionMetadata(AttentionMetadata): @@ -124,6 +127,7 @@ def __init__( if fd_config.speculative_config.model_type != "main": self.rope_3d = False self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", "32768")) + self.sm_version = get_sm_version() def get_kv_cache_shape( self, @@ -278,19 +282,30 @@ def forward_mixed( self.rope_3d, ) - flash_mask_attention( - q, - k, - v, - forward_meta.cu_seqlens_q, - forward_meta.attn_cu_seqlens_k, - forward_meta.seq_lens_encoder, - res_encoder, - forward_meta.attn_mask_offsets, - self.num_heads, - self.kv_num_heads, - self.head_dim, - ) + if self.sm_version >= 100: + flash_attn_v4( + q, + k, + v, + forward_meta.cu_seqlens_q, + forward_meta.attn_cu_seqlens_k, + res_encoder, + forward_meta.attn_mask_offsets, + ) + else: + flash_mask_attention( + q, + k, + v, + forward_meta.cu_seqlens_q, + forward_meta.attn_cu_seqlens_k, + forward_meta.seq_lens_encoder, + res_encoder, + forward_meta.attn_mask_offsets, + self.num_heads, + self.kv_num_heads, + self.head_dim, + ) res_decoder = append_attention( qkv, diff --git a/fastdeploy/model_executor/layers/attention/ops/__init__.py b/fastdeploy/model_executor/layers/attention/ops/__init__.py index c2db3396ce4..e0175573fa3 100644 --- a/fastdeploy/model_executor/layers/attention/ops/__init__.py +++ b/fastdeploy/model_executor/layers/attention/ops/__init__.py @@ -15,6 +15,7 @@ """ from .append_attention import append_attention, append_attention_with_output +from .flash_attn_v4 import flash_attn_v4 from .flash_mask_attention import flash_mask_attention from .get_attn_mask_q import get_attn_mask_q from .get_block_shape_and_split_kv_block import get_block_shape_and_split_kv_block @@ -33,6 +34,7 @@ "gqa_rope_write_cache", "pre_cache_len_concat", "init_kv_signal_per_query", + "flash_attn_v4", "flash_mask_attention", "get_attn_mask_q", ] diff --git a/fastdeploy/model_executor/layers/attention/ops/flash_attn_v4.py b/fastdeploy/model_executor/layers/attention/ops/flash_attn_v4.py new file mode 100644 index 00000000000..c95c79c8ac1 --- /dev/null +++ b/fastdeploy/model_executor/layers/attention/ops/flash_attn_v4.py @@ -0,0 +1,39 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from typing import Optional + +import paddle + +from fastdeploy.model_executor.utils import get_sm_version +from fastdeploy.platforms import current_platform + + +def flash_attn_v4( + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + cu_seqlens_q: paddle.Tensor, + cu_seqlens_k: paddle.Tensor, + attn_out: paddle.Tensor, + attn_mask_offsets: Optional[paddle.Tensor] = None, +): + if current_platform.is_cuda() and get_sm_version() >= 100: + from blackwell_ops import flash_encoder_attn_fwd + + flash_encoder_attn_fwd(q, k, v, cu_seqlens_q, cu_seqlens_k, attn_out, attn_mask_offsets) + else: + raise NotImplementedError diff --git a/tests/operators/test_flash_mask_attn.py b/tests/operators/test_flash_mask_attn.py index 462d745c782..1cf7caae960 100644 --- a/tests/operators/test_flash_mask_attn.py +++ b/tests/operators/test_flash_mask_attn.py @@ -22,7 +22,10 @@ from fastdeploy.model_executor.layers.attention.flash_attn_backend import ( flash_attn_func, ) -from fastdeploy.model_executor.layers.attention.ops import get_attn_mask_q +from fastdeploy.model_executor.layers.attention.ops import ( + flash_attn_v4, + get_attn_mask_q, +) from fastdeploy.model_executor.ops.gpu import flash_mask_attention @@ -109,6 +112,65 @@ def test_flash_mask_attention(self): max_diff = (paddle_attn_out - naive_attn_out).abs().max().item() self.assertLessEqual(max_diff, 0.05) + def causal_attention_naive(self, q_input, k_input, v_input, cu_seq_q, cu_seq_k): + """Causal attention reference implementation for flash_attn_v4 testing.""" + bsz = cu_seq_q.shape[0] - 1 + q_token_sum, num_head, head_dim = q_input.shape + k_token_sum, num_kv_head, _ = k_input.shape + gqa_group_size = num_head // num_kv_head + qk_scale = 1 / np.sqrt(head_dim) + out = paddle.zeros([num_head, q_token_sum, head_dim], q_input.dtype) + for bi in range(bsz): + q = q_input[cu_seq_q[bi] : cu_seq_q[bi + 1], :, :].transpose([1, 0, 2]).astype("float32").numpy() + k = k_input[cu_seq_k[bi] : cu_seq_k[bi + 1], :, :].transpose([1, 2, 0]).astype("float32").numpy() + v = v_input[cu_seq_k[bi] : cu_seq_k[bi + 1], :, :].transpose([1, 0, 2]).astype("float32").numpy() + qk = np.matmul(q, np.repeat(k, gqa_group_size, 0)) + qk *= qk_scale + condition = np.tril(np.ones(qk.shape), q.shape[1] - k.shape[2]) + mask = np.ones(condition.shape).astype("float32") * -1000000 + qk = np.where(condition > 0, qk, mask) + qk_max = qk.max(axis=-1, keepdims=True) + qk -= qk_max + qk = np.exp(qk) + exp_sum = qk.sum(axis=-1, keepdims=True) + exp_sum_inv = 1.0 / exp_sum + temp_out = paddle.to_tensor(np.matmul(qk, np.repeat(v, gqa_group_size, 0))) + out[:, cu_seq_q[bi] : cu_seq_q[bi + 1], :] = temp_out * exp_sum_inv + return out.transpose([1, 0, 2]) + + def test_flash_encoder_attn_fwd(self): + if self.sm_version < 100: + self.skipTest("Flash Encoder Attention V4 requires SM100+.") + + q_input = paddle.randn([self.q_len, self.num_head, self.head_dim], dtype="bfloat16") + k_input = paddle.randn([self.q_len, self.num_kv_head, self.head_dim], dtype="bfloat16") + v_input = paddle.randn(k_input.shape, dtype="bfloat16") + + mask = paddle.arange(self.q_len).astype("int32") + 1 + + bsz = self.bsz + cu_seq_q = paddle.arange(bsz + 1) * self.q_len + cu_seq_k = paddle.arange(bsz + 1) * self.q_len + cu_seq_q = cu_seq_q.astype("int32") + cu_seq_k = cu_seq_k.astype("int32") + + naive_attn_out = self.causal_attention_naive(q_input, k_input, v_input, cu_seq_q, cu_seq_k) + + paddle_attn_out = paddle.empty(q_input.shape, dtype="bfloat16") + + flash_attn_v4( + q_input, + k_input, + v_input, + cu_seq_q, + cu_seq_k, + paddle_attn_out, + mask, + ) + + max_diff = (paddle_attn_out - naive_attn_out).abs().max().item() + self.assertLessEqual(max_diff, 0.05) + def test_fa4( self, ):