From d6f832c8324bbcd830e1e3cddcf18b1b88800bfd Mon Sep 17 00:00:00 2001 From: mpgemm <1412599472@qq.com> Date: Wed, 25 Mar 2026 21:14:45 +0800 Subject: [PATCH 1/5] add cute cpp fa4 --- .../attention/flash_mask_attn_backend.py | 43 ++++++++---- .../layers/attention/ops/__init__.py | 2 + .../layers/attention/ops/flash_attn_v4.py | 38 +++++++++++ tests/operators/test_flash_mask_attn.py | 65 ++++++++++++++++++- 4 files changed, 134 insertions(+), 14 deletions(-) create mode 100644 fastdeploy/model_executor/layers/attention/ops/flash_attn_v4.py diff --git a/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py index 78c576dd13c..729d9b7625f 100644 --- a/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py @@ -30,6 +30,7 @@ ) from fastdeploy.model_executor.layers.attention.ops import ( append_attention, + flash_attn_v4, flash_mask_attention, get_block_shape_and_split_kv_block, gqa_rope_write_cache, @@ -52,6 +53,11 @@ merge_prefill_decode_output = None +from fastdeploy.model_executor.utils import get_sm_version + +sm_version = get_sm_version() + + @dataclass class FlashMaskAttentionMetadata(AttentionMetadata): """ @@ -278,19 +284,30 @@ def forward_mixed( self.rope_3d, ) - flash_mask_attention( - q, - k, - v, - forward_meta.cu_seqlens_q, - forward_meta.attn_cu_seqlens_k, - forward_meta.seq_lens_encoder, - res_encoder, - forward_meta.attn_mask_offsets, - self.num_heads, - self.kv_num_heads, - self.head_dim, - ) + if sm_version >= 100: + flash_attn_v4( + q, + k, + v, + forward_meta.cu_seqlens_q, + forward_meta.attn_cu_seqlens_k, + res_encoder, + forward_meta.attn_mask_offsets, + ) + else: + flash_mask_attention( + q, + k, + v, + forward_meta.cu_seqlens_q, + forward_meta.attn_cu_seqlens_k, + forward_meta.seq_lens_encoder, + res_encoder, + forward_meta.attn_mask_offsets, + self.num_heads, + self.kv_num_heads, + self.head_dim, + ) res_decoder = append_attention( qkv, diff --git a/fastdeploy/model_executor/layers/attention/ops/__init__.py b/fastdeploy/model_executor/layers/attention/ops/__init__.py index c2db3396ce4..e0175573fa3 100644 --- a/fastdeploy/model_executor/layers/attention/ops/__init__.py +++ b/fastdeploy/model_executor/layers/attention/ops/__init__.py @@ -15,6 +15,7 @@ """ from .append_attention import append_attention, append_attention_with_output +from .flash_attn_v4 import flash_attn_v4 from .flash_mask_attention import flash_mask_attention from .get_attn_mask_q import get_attn_mask_q from .get_block_shape_and_split_kv_block import get_block_shape_and_split_kv_block @@ -33,6 +34,7 @@ "gqa_rope_write_cache", "pre_cache_len_concat", "init_kv_signal_per_query", + "flash_attn_v4", "flash_mask_attention", "get_attn_mask_q", ] diff --git a/fastdeploy/model_executor/layers/attention/ops/flash_attn_v4.py b/fastdeploy/model_executor/layers/attention/ops/flash_attn_v4.py new file mode 100644 index 00000000000..d1b4ee623ee --- /dev/null +++ b/fastdeploy/model_executor/layers/attention/ops/flash_attn_v4.py @@ -0,0 +1,38 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from typing import Optional + +import paddle + +from fastdeploy.platforms import current_platform + + +def flash_attn_v4( + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + cu_seqlens_q: paddle.Tensor, + cu_seqlens_k: paddle.Tensor, + attn_out: paddle.Tensor, + attn_mask_offsets: Optional[paddle.Tensor] = None, +): + if current_platform.is_cuda(): + from blackwell_ops import flash_encoder_attn_fwd + + flash_encoder_attn_fwd(q, k, v, cu_seqlens_q, cu_seqlens_k, attn_out, attn_mask_offsets) + else: + raise NotImplementedError diff --git a/tests/operators/test_flash_mask_attn.py b/tests/operators/test_flash_mask_attn.py index 462d745c782..7ee61a36fbe 100644 --- a/tests/operators/test_flash_mask_attn.py +++ b/tests/operators/test_flash_mask_attn.py @@ -22,7 +22,10 @@ from fastdeploy.model_executor.layers.attention.flash_attn_backend import ( flash_attn_func, ) -from fastdeploy.model_executor.layers.attention.ops import get_attn_mask_q +from fastdeploy.model_executor.layers.attention.ops import ( + flash_attn_v4, + get_attn_mask_q, +) from fastdeploy.model_executor.ops.gpu import flash_mask_attention @@ -109,6 +112,66 @@ def test_flash_mask_attention(self): max_diff = (paddle_attn_out - naive_attn_out).abs().max().item() self.assertLessEqual(max_diff, 0.05) + def causal_attention_naive(self, q_input, k_input, v_input, cu_seq_q, cu_seq_k): + """Causal attention reference implementation for flash_attn_v4 testing.""" + bsz = cu_seq_q.shape[0] - 1 + q_token_sum, num_head, head_dim = q_input.shape + k_token_sum, num_kv_head, _ = k_input.shape + gqa_group_size = num_head // num_kv_head + qk_scale = 1 / np.sqrt(head_dim) + out = paddle.zeros([num_head, q_token_sum, head_dim], q_input.dtype) + for bi in range(bsz): + q = q_input[cu_seq_q[bi] : cu_seq_q[bi + 1], :, :].transpose([1, 0, 2]).astype("float32").numpy() + k = k_input[cu_seq_k[bi] : cu_seq_k[bi + 1], :, :].transpose([1, 2, 0]).astype("float32").numpy() + v = v_input[cu_seq_k[bi] : cu_seq_k[bi + 1], :, :].transpose([1, 0, 2]).astype("float32").numpy() + qk = np.matmul(q, np.repeat(k, gqa_group_size, 0)) + qk *= qk_scale + # Causal mask: lower triangular + condition = np.tril(np.ones(qk.shape), q.shape[1] - k.shape[2]) + mask = np.ones(condition.shape).astype("float32") * -1000000 + qk = np.where(condition > 0, qk, mask) + qk_max = qk.max(axis=-1, keepdims=True) + qk -= qk_max + qk = np.exp(qk) + exp_sum = qk.sum(axis=-1, keepdims=True) + exp_sum_inv = 1.0 / exp_sum + temp_out = paddle.to_tensor(np.matmul(qk, np.repeat(v, gqa_group_size, 0))) + out[:, cu_seq_q[bi] : cu_seq_q[bi + 1], :] = temp_out * exp_sum_inv + return out.transpose([1, 0, 2]) + + def test_flash_encoder_attn_fwd(self): + if self.sm_version < 100: + self.skipTest("Flash Attention V4 requires SM100+.") + + q_input = paddle.randn([self.q_len, self.num_head, self.head_dim], dtype="bfloat16") + k_input = paddle.randn([self.q_len, self.num_kv_head, self.head_dim], dtype="bfloat16") + v_input = paddle.randn(k_input.shape, dtype="bfloat16") + + mask = paddle.arange(self.q_len).astype("int32") + 1 + + bsz = self.bsz + cu_seq_q = paddle.arange(bsz + 1) * self.q_len + cu_seq_k = paddle.arange(bsz + 1) * self.q_len + cu_seq_q = cu_seq_q.astype("int32") + cu_seq_k = cu_seq_k.astype("int32") + + naive_attn_out = self.causal_attention_naive(q_input, k_input, v_input, cu_seq_q, cu_seq_k) + + paddle_attn_out = paddle.empty(q_input.shape, dtype="bfloat16") + + flash_attn_v4( + q_input, + k_input, + v_input, + cu_seq_q, + cu_seq_k, + paddle_attn_out, + mask, + ) + + max_diff = (paddle_attn_out - naive_attn_out).abs().max().item() + self.assertLessEqual(max_diff, 0.05) + def test_fa4( self, ): From 34e44ead8f1dfcaaace13b42dec73f27a336039d Mon Sep 17 00:00:00 2001 From: mpgemm <1412599472@qq.com> Date: Wed, 25 Mar 2026 21:14:45 +0800 Subject: [PATCH 2/5] =?UTF-8?q?=E5=88=A0=E6=8E=89=E6=B3=A8=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/operators/test_flash_mask_attn.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/operators/test_flash_mask_attn.py b/tests/operators/test_flash_mask_attn.py index 7ee61a36fbe..1cf7caae960 100644 --- a/tests/operators/test_flash_mask_attn.py +++ b/tests/operators/test_flash_mask_attn.py @@ -126,7 +126,6 @@ def causal_attention_naive(self, q_input, k_input, v_input, cu_seq_q, cu_seq_k): v = v_input[cu_seq_k[bi] : cu_seq_k[bi + 1], :, :].transpose([1, 0, 2]).astype("float32").numpy() qk = np.matmul(q, np.repeat(k, gqa_group_size, 0)) qk *= qk_scale - # Causal mask: lower triangular condition = np.tril(np.ones(qk.shape), q.shape[1] - k.shape[2]) mask = np.ones(condition.shape).astype("float32") * -1000000 qk = np.where(condition > 0, qk, mask) @@ -141,7 +140,7 @@ def causal_attention_naive(self, q_input, k_input, v_input, cu_seq_q, cu_seq_k): def test_flash_encoder_attn_fwd(self): if self.sm_version < 100: - self.skipTest("Flash Attention V4 requires SM100+.") + self.skipTest("Flash Encoder Attention V4 requires SM100+.") q_input = paddle.randn([self.q_len, self.num_head, self.head_dim], dtype="bfloat16") k_input = paddle.randn([self.q_len, self.num_kv_head, self.head_dim], dtype="bfloat16") From f2fe814f01c6fda503b1bf4f02995918fa107b57 Mon Sep 17 00:00:00 2001 From: mpgemm <1412599472@qq.com> Date: Wed, 25 Mar 2026 23:48:30 +0800 Subject: [PATCH 3/5] =?UTF-8?q?=E4=BF=AE=E6=AD=A3=E5=90=88=E5=B9=B6?= =?UTF-8?q?=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../layers/attention/flash_mask_attn_backend.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py index 729d9b7625f..2c75d3414f4 100644 --- a/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py @@ -55,9 +55,6 @@ from fastdeploy.model_executor.utils import get_sm_version -sm_version = get_sm_version() - - @dataclass class FlashMaskAttentionMetadata(AttentionMetadata): """ @@ -130,6 +127,7 @@ def __init__( if fd_config.speculative_config.model_type != "main": self.rope_3d = False self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", "32768")) + self.sm_version = get_sm_version() def get_kv_cache_shape( self, @@ -284,7 +282,7 @@ def forward_mixed( self.rope_3d, ) - if sm_version >= 100: + if self.sm_version >= 100: flash_attn_v4( q, k, From d622f3ea63d903f24eb87217bba6daa128da8119 Mon Sep 17 00:00:00 2001 From: mpgemm <1412599472@qq.com> Date: Thu, 26 Mar 2026 00:00:16 +0800 Subject: [PATCH 4/5] =?UTF-8?q?sm=5Fversion=E6=94=BE=E5=88=B0=E5=87=BD?= =?UTF-8?q?=E6=95=B0=E5=86=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../model_executor/layers/attention/flash_mask_attn_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py index 2c75d3414f4..7fd749a7506 100644 --- a/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py @@ -52,9 +52,9 @@ else: merge_prefill_decode_output = None - from fastdeploy.model_executor.utils import get_sm_version + @dataclass class FlashMaskAttentionMetadata(AttentionMetadata): """ From ffc1541c28715e24351bc6fc4167c0689e3c16f3 Mon Sep 17 00:00:00 2001 From: mpgemm <1412599472@qq.com> Date: Fri, 27 Mar 2026 10:46:39 +0800 Subject: [PATCH 5/5] =?UTF-8?q?ci=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../model_executor/layers/attention/ops/flash_attn_v4.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fastdeploy/model_executor/layers/attention/ops/flash_attn_v4.py b/fastdeploy/model_executor/layers/attention/ops/flash_attn_v4.py index d1b4ee623ee..c95c79c8ac1 100644 --- a/fastdeploy/model_executor/layers/attention/ops/flash_attn_v4.py +++ b/fastdeploy/model_executor/layers/attention/ops/flash_attn_v4.py @@ -18,6 +18,7 @@ import paddle +from fastdeploy.model_executor.utils import get_sm_version from fastdeploy.platforms import current_platform @@ -30,7 +31,7 @@ def flash_attn_v4( attn_out: paddle.Tensor, attn_mask_offsets: Optional[paddle.Tensor] = None, ): - if current_platform.is_cuda(): + if current_platform.is_cuda() and get_sm_version() >= 100: from blackwell_ops import flash_encoder_attn_fwd flash_encoder_attn_fwd(q, k, v, cu_seqlens_q, cu_seqlens_k, attn_out, attn_mask_offsets)