Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
from fastdeploy.model_executor.layers.attention.ops import (
append_attention,
flash_attn_v4,
flash_mask_attention,
get_block_shape_and_split_kv_block,
gqa_rope_write_cache,
Expand All @@ -51,6 +52,8 @@
else:
merge_prefill_decode_output = None

from fastdeploy.model_executor.utils import get_sm_version


@dataclass
class FlashMaskAttentionMetadata(AttentionMetadata):
Expand Down Expand Up @@ -124,6 +127,7 @@ def __init__(
if fd_config.speculative_config.model_type != "main":
self.rope_3d = False
self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", "32768"))
self.sm_version = get_sm_version()

def get_kv_cache_shape(
self,
Expand Down Expand Up @@ -278,19 +282,30 @@ def forward_mixed(
self.rope_3d,
)

flash_mask_attention(
q,
k,
v,
forward_meta.cu_seqlens_q,
forward_meta.attn_cu_seqlens_k,
forward_meta.seq_lens_encoder,
res_encoder,
forward_meta.attn_mask_offsets,
self.num_heads,
self.kv_num_heads,
self.head_dim,
)
if self.sm_version >= 100:
flash_attn_v4(
q,
k,
v,
forward_meta.cu_seqlens_q,
forward_meta.attn_cu_seqlens_k,
res_encoder,
forward_meta.attn_mask_offsets,
)
else:
flash_mask_attention(
q,
k,
v,
forward_meta.cu_seqlens_q,
forward_meta.attn_cu_seqlens_k,
forward_meta.seq_lens_encoder,
res_encoder,
forward_meta.attn_mask_offsets,
self.num_heads,
self.kv_num_heads,
self.head_dim,
)

res_decoder = append_attention(
qkv,
Expand Down
2 changes: 2 additions & 0 deletions fastdeploy/model_executor/layers/attention/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""

from .append_attention import append_attention, append_attention_with_output
from .flash_attn_v4 import flash_attn_v4
from .flash_mask_attention import flash_mask_attention
from .get_attn_mask_q import get_attn_mask_q
from .get_block_shape_and_split_kv_block import get_block_shape_and_split_kv_block
Expand All @@ -33,6 +34,7 @@
"gqa_rope_write_cache",
"pre_cache_len_concat",
"init_kv_signal_per_query",
"flash_attn_v4",
"flash_mask_attention",
"get_attn_mask_q",
]
39 changes: 39 additions & 0 deletions fastdeploy/model_executor/layers/attention/ops/flash_attn_v4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

from typing import Optional

import paddle

from fastdeploy.model_executor.utils import get_sm_version
from fastdeploy.platforms import current_platform


def flash_attn_v4(
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
cu_seqlens_q: paddle.Tensor,
cu_seqlens_k: paddle.Tensor,
attn_out: paddle.Tensor,
attn_mask_offsets: Optional[paddle.Tensor] = None,
):
if current_platform.is_cuda() and get_sm_version() >= 100:
from blackwell_ops import flash_encoder_attn_fwd

flash_encoder_attn_fwd(q, k, v, cu_seqlens_q, cu_seqlens_k, attn_out, attn_mask_offsets)
else:
raise NotImplementedError
64 changes: 63 additions & 1 deletion tests/operators/test_flash_mask_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
from fastdeploy.model_executor.layers.attention.flash_attn_backend import (
flash_attn_func,
)
from fastdeploy.model_executor.layers.attention.ops import get_attn_mask_q
from fastdeploy.model_executor.layers.attention.ops import (
flash_attn_v4,
get_attn_mask_q,
)
from fastdeploy.model_executor.ops.gpu import flash_mask_attention


Expand Down Expand Up @@ -109,6 +112,65 @@ def test_flash_mask_attention(self):
max_diff = (paddle_attn_out - naive_attn_out).abs().max().item()
self.assertLessEqual(max_diff, 0.05)

def causal_attention_naive(self, q_input, k_input, v_input, cu_seq_q, cu_seq_k):
"""Causal attention reference implementation for flash_attn_v4 testing."""
bsz = cu_seq_q.shape[0] - 1
q_token_sum, num_head, head_dim = q_input.shape
k_token_sum, num_kv_head, _ = k_input.shape
gqa_group_size = num_head // num_kv_head
qk_scale = 1 / np.sqrt(head_dim)
out = paddle.zeros([num_head, q_token_sum, head_dim], q_input.dtype)
for bi in range(bsz):
q = q_input[cu_seq_q[bi] : cu_seq_q[bi + 1], :, :].transpose([1, 0, 2]).astype("float32").numpy()
k = k_input[cu_seq_k[bi] : cu_seq_k[bi + 1], :, :].transpose([1, 2, 0]).astype("float32").numpy()
v = v_input[cu_seq_k[bi] : cu_seq_k[bi + 1], :, :].transpose([1, 0, 2]).astype("float32").numpy()
qk = np.matmul(q, np.repeat(k, gqa_group_size, 0))
qk *= qk_scale
condition = np.tril(np.ones(qk.shape), q.shape[1] - k.shape[2])
mask = np.ones(condition.shape).astype("float32") * -1000000
qk = np.where(condition > 0, qk, mask)
qk_max = qk.max(axis=-1, keepdims=True)
qk -= qk_max
qk = np.exp(qk)
exp_sum = qk.sum(axis=-1, keepdims=True)
exp_sum_inv = 1.0 / exp_sum
temp_out = paddle.to_tensor(np.matmul(qk, np.repeat(v, gqa_group_size, 0)))
out[:, cu_seq_q[bi] : cu_seq_q[bi + 1], :] = temp_out * exp_sum_inv
return out.transpose([1, 0, 2])

def test_flash_encoder_attn_fwd(self):
if self.sm_version < 100:
self.skipTest("Flash Encoder Attention V4 requires SM100+.")

q_input = paddle.randn([self.q_len, self.num_head, self.head_dim], dtype="bfloat16")
k_input = paddle.randn([self.q_len, self.num_kv_head, self.head_dim], dtype="bfloat16")
v_input = paddle.randn(k_input.shape, dtype="bfloat16")

mask = paddle.arange(self.q_len).astype("int32") + 1

bsz = self.bsz
cu_seq_q = paddle.arange(bsz + 1) * self.q_len
cu_seq_k = paddle.arange(bsz + 1) * self.q_len
cu_seq_q = cu_seq_q.astype("int32")
cu_seq_k = cu_seq_k.astype("int32")

naive_attn_out = self.causal_attention_naive(q_input, k_input, v_input, cu_seq_q, cu_seq_k)

paddle_attn_out = paddle.empty(q_input.shape, dtype="bfloat16")

flash_attn_v4(
q_input,
k_input,
v_input,
cu_seq_q,
cu_seq_k,
paddle_attn_out,
mask,
)

max_diff = (paddle_attn_out - naive_attn_out).abs().max().item()
self.assertLessEqual(max_diff, 0.05)

def test_fa4(
self,
):
Expand Down
Loading