diff --git a/example/ck_tile/50_sparse_attn/CMakeLists.txt b/example/ck_tile/50_sparse_attn/CMakeLists.txt index 65bb2077642..b20a661805f 100644 --- a/example/ck_tile/50_sparse_attn/CMakeLists.txt +++ b/example/ck_tile/50_sparse_attn/CMakeLists.txt @@ -1,8 +1,8 @@ -# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -# SPDX-License-Identifier: MIT -# CMakeLists.txt for sparse attention (Jenga and VSA) +#Copyright(c) Advanced Micro Devices, Inc., or its affiliates. +#SPDX - License - Identifier : MIT +#CMakeLists.txt for sparse attention(Jenga and VSA) -# Use SUPPORTED_GPU_TARGETS directly +#Use SUPPORTED_GPU_TARGETS directly set(INST_TARGETS ${SUPPORTED_GPU_TARGETS}) set(GPU_TARGETS ${SUPPORTED_GPU_TARGETS}) @@ -16,7 +16,7 @@ endif() message(STATUS "Building Sparse Attention (Jenga & VSA) for targets: ${INST_TARGETS}") -# Code generation scripts +#Code generation scripts file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS ${CMAKE_CURRENT_LIST_DIR}/generate.py ${CMAKE_CURRENT_LIST_DIR}/codegen/*.py @@ -153,4 +153,47 @@ target_compile_options(${EXAMPLE_VSA_SPARSE_ATTN} PRIVATE -Wno-float-equal ) +# ============================================================================ +# Sparge BlockMap GPU Kernel (hand-written instantiation, no codegen) +# ============================================================================ +set(SPARGE_BLOCKMAP_INSTANCES "tile_sparge_blockmap_instances") + +add_library(${SPARGE_BLOCKMAP_INSTANCES} OBJECT EXCLUDE_FROM_ALL + ${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap_inst.cpp +) +target_include_directories(${SPARGE_BLOCKMAP_INSTANCES} PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn +) +set_source_files_properties( + ${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap_inst.cpp + PROPERTIES LANGUAGE HIP +) +set_property(TARGET ${SPARGE_BLOCKMAP_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) + +target_compile_options(${SPARGE_BLOCKMAP_INSTANCES} PRIVATE + -DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN + -DCK_TILE_FMHA_FWD_FAST_EXP2 + -Wno-undefined-func-template + -Wno-float-equal +) + +# ---------------------------------------------------------------------------- +# Build unified Sparge test: combines blockmap, Jenga, and VSA attention +# for end-to-end evaluation and timing in a single executable. +# ---------------------------------------------------------------------------- +set(EXAMPLE_SPARGE "tile_example_sparge") +message(DEBUG "adding example ${EXAMPLE_SPARGE}") +add_executable(${EXAMPLE_SPARGE} EXCLUDE_FROM_ALL test_sparge.cpp) +target_link_libraries(${EXAMPLE_SPARGE} + ${SPARSE_ATTN_JENGA_INSTANCES} + ${SPARSE_ATTN_VSA_INSTANCES} + ${SPARGE_BLOCKMAP_INSTANCES} +) +target_include_directories(${EXAMPLE_SPARGE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_compile_options(${EXAMPLE_SPARGE} PRIVATE + -Wno-undefined-func-template + -Wno-float-equal +) + set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py index a3d32652a98..1f0a78048d9 100644 --- a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py @@ -141,6 +141,17 @@ def update_file(file_path, content): constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} + +template<> +void fmha_jenga_fwd_oneshot_(const ck_tile::stream_config& s, fmha_jenga_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( + ck_tile::stream_config{{s.stream_id_}}); +}} """ FMHA_FWD_API_FILENAME = "fmha_jenga_fwd_api.cpp" @@ -219,6 +230,45 @@ def update_file(file_path, content): }} """ +FMHA_FWD_ONESHOT_API_FILENAME = "fmha_jenga_fwd_oneshot_api.cpp" +FMHA_FWD_ONESHOT_API = """ +#include "fmha_fwd_trek.hpp" +#include + +void fmha_jenga_fwd_oneshot(fmha_jenga_fwd_traits t, fmha_jenga_fwd_args a, const ck_tile::stream_config& s){{ + + const bool has_load_tr = ck_tile::is_load_tr_supported(); + +{F_dispatch} + std::cerr << "fmha_jenga_fwd_oneshot: no matching dispatch (dtype=" << t.data_type + << " hdim_q=" << t.hdim_q << " hdim_v=" << t.hdim_v + << " seqlen_q=" << a.seqlen_q << " seqlen_k=" << a.seqlen_k + << " mask=" << static_cast(t.mask_type) << ")" << std::endl; +}} +""" + +FMHA_FWD_ONESHOT_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{ +{F_dtype_case} + }} +""" + +FMHA_FWD_ONESHOT_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_FWD_ONESHOT_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_FWD_ONESHOT_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && + ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ + using trait_ = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; + fmha_jenga_fwd_oneshot_(s, a); + return; + }} +""" + @dataclass class CppConstraint: @@ -274,10 +324,7 @@ def scheck(self) -> str: @property def seqtune(self) -> str: - if self.bm0 == 128: - return "true/*fall back to largest tile*/" # group mode only generate spad/skpad == true - else: - return f"a.seqlen_q <= {self.bm0}" + return "true" @property def skcheck(self) -> str: @@ -447,6 +494,67 @@ def api(self) -> str: per_tr_load += " (void)t ; (void)s ; (void)a;" return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load) + @property + def oneshot_api(self) -> str: + tr_load_cond_map = {"t": "has_load_tr", "f": "true"} + + per_tr_load = str() + for tr_load in ["t", "f"]: + per_dtypes = str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case = str() + for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): + traits = [ + t + for t in self.pool[dtype][(hdim, hdim_v)] + if tr_load == t.tr_load + ] + inners = str() + for k, trait in enumerate(traits): + if_k = "if" if k == 0 else "else if" + inners = inners + FMHA_FWD_ONESHOT_API_INNER_DISPATCH.format( + F_if=if_k, + F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], + F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], + F_trload=BOOL_MAP[trait.tr_load], + F_scheck=trait.scheck, + F_seqtune=trait.seqtune, + F_skcheck=trait.skcheck, + F_dcheck=trait.dcheck, + F_dvcheck=trait.dvcheck, + F_constraint=trait.constraint, + F_spad=BOOL_MAP[trait.spad], + F_skpad=BOOL_MAP[trait.skpad], + F_dpad=BOOL_MAP[trait.dpad], + F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, + F_bn0=trait.bn0, + F_bk0=trait.bk0, + F_bn1=trait.bn1, + F_bk1=trait.bk1, + F_bk0max=trait.bk0max, + F_hdim=hdim, + F_dtype=FWD_DTYPE_MAP[dtype], + ) + if_j = "if" if j == 0 else "else if" + per_hdim_case = per_hdim_case + FMHA_FWD_ONESHOT_API_PER_HDIM_CASE.format( + F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners + ) + if_i = "if" if i == 0 else "else if" + per_dtypes = per_dtypes + FMHA_FWD_ONESHOT_API_PER_DTYPE.format( + F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case + ) + per_tr_load += FMHA_FWD_ONESHOT_API_PER_TRLOAD.format( + F_if="if", + F_trload_cond=tr_load_cond_map[tr_load], + F_dtype_case=per_dtypes, + ) + if not per_tr_load: + per_tr_load += " (void)t ; (void)s ; (void)a;" + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_ONESHOT_API.format(F_dispatch=per_tr_load) + @dataclass class FmhaFwdTileSize: @@ -582,6 +690,27 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: # FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], # (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (128, 128): [ + FmhaFwdTileSize( # fmt: skip -- 64x128 tile matching blockmap kM0=64, kN0=128 + 64, + 128, + 64, + 128, + 64, + 128, + 4, + 1, + 1, + 4, + 1, + 1, + 16, + 16, + 16, + 16, + 16, + 16, + -1, + ), FmhaFwdTileSize( # fmt: skip 16, 32, @@ -780,7 +909,7 @@ def get_fwd_blobs( for tile, pipeline in itertools.product( tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) ): - if tile.F_bm0 != 128 or tile.F_bn0 != 128: + if tile.F_bm0 != 64 or tile.F_bn0 != 128: continue if pipeline.tag != "qr_async": continue @@ -846,6 +975,7 @@ def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None: update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api) + update_file(autogen_dir / FMHA_FWD_ONESHOT_API_FILENAME, api_pool.oneshot_api) def write_blobs( @@ -865,3 +995,4 @@ def list_blobs( for kernel in kernels: f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n") f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n") + f.write((file_path.parent / GEN_DIR / FMHA_FWD_ONESHOT_API_FILENAME).as_posix() + "\n") diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py index 038738de246..217cfcfe2a4 100644 --- a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py @@ -141,6 +141,17 @@ def update_file(file_path, content): constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} + +template<> +void fmha_vsa_fwd_oneshot_(const ck_tile::stream_config& s, fmha_vsa_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( + ck_tile::stream_config{{s.stream_id_}}); +}} """ FMHA_FWD_API_FILENAME = "fmha_vsa_fwd_api.cpp" @@ -219,6 +230,45 @@ def update_file(file_path, content): }} """ +FMHA_FWD_ONESHOT_API_FILENAME = "fmha_vsa_fwd_oneshot_api.cpp" +FMHA_FWD_ONESHOT_API = """ +#include "fmha_fwd_trek.hpp" +#include + +void fmha_vsa_fwd_oneshot(fmha_vsa_fwd_traits t, fmha_vsa_fwd_args a, const ck_tile::stream_config& s){{ + + const bool has_load_tr = ck_tile::is_load_tr_supported(); + +{F_dispatch} + std::cerr << "fmha_vsa_fwd_oneshot: no matching dispatch (dtype=" << t.data_type + << " hdim_q=" << t.hdim_q << " hdim_v=" << t.hdim_v + << " seqlen_q=" << a.seqlen_q << " seqlen_k=" << a.seqlen_k + << " mask=" << static_cast(t.mask_type) << ")" << std::endl; +}} +""" + +FMHA_FWD_ONESHOT_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{ +{F_dtype_case} + }} +""" + +FMHA_FWD_ONESHOT_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_FWD_ONESHOT_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_FWD_ONESHOT_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && + ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ + using trait_ = fmha_vsa_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; + fmha_vsa_fwd_oneshot_(s, a); + return; + }} +""" + @dataclass class CppConstraint: @@ -274,10 +324,7 @@ def scheck(self) -> str: @property def seqtune(self) -> str: - if self.bm0 == 128: - return "true/*fall back to largest tile*/" # group mode only generate spad/skpad == true - else: - return f"a.seqlen_q <= {self.bm0}" + return "true" @property def skcheck(self) -> str: @@ -447,6 +494,67 @@ def api(self) -> str: per_tr_load += " (void)t ; (void)s ; (void)a;" return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load) + @property + def oneshot_api(self) -> str: + tr_load_cond_map = {"t": "has_load_tr", "f": "true"} + + per_tr_load = str() + for tr_load in ["t", "f"]: + per_dtypes = str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case = str() + for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): + traits = [ + t + for t in self.pool[dtype][(hdim, hdim_v)] + if tr_load == t.tr_load + ] + inners = str() + for k, trait in enumerate(traits): + if_k = "if" if k == 0 else "else if" + inners = inners + FMHA_FWD_ONESHOT_API_INNER_DISPATCH.format( + F_if=if_k, + F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], + F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], + F_trload=BOOL_MAP[trait.tr_load], + F_scheck=trait.scheck, + F_seqtune=trait.seqtune, + F_skcheck=trait.skcheck, + F_dcheck=trait.dcheck, + F_dvcheck=trait.dvcheck, + F_constraint=trait.constraint, + F_spad=BOOL_MAP[trait.spad], + F_skpad=BOOL_MAP[trait.skpad], + F_dpad=BOOL_MAP[trait.dpad], + F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, + F_bn0=trait.bn0, + F_bk0=trait.bk0, + F_bn1=trait.bn1, + F_bk1=trait.bk1, + F_bk0max=trait.bk0max, + F_hdim=hdim, + F_dtype=FWD_DTYPE_MAP[dtype], + ) + if_j = "if" if j == 0 else "else if" + per_hdim_case = per_hdim_case + FMHA_FWD_ONESHOT_API_PER_HDIM_CASE.format( + F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners + ) + if_i = "if" if i == 0 else "else if" + per_dtypes = per_dtypes + FMHA_FWD_ONESHOT_API_PER_DTYPE.format( + F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case + ) + per_tr_load += FMHA_FWD_ONESHOT_API_PER_TRLOAD.format( + F_if="if", + F_trload_cond=tr_load_cond_map[tr_load], + F_dtype_case=per_dtypes, + ) + if not per_tr_load: + per_tr_load += " (void)t ; (void)s ; (void)a;" + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_ONESHOT_API.format(F_dispatch=per_tr_load) + @dataclass class FmhaFwdTileSize: @@ -582,6 +690,27 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: # FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], # (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (128, 128): [ + FmhaFwdTileSize( # fmt: skip -- 64x128 tile matching blockmap kM0=64, kN0=128 + 64, + 128, + 64, + 128, + 64, + 128, + 4, + 1, + 1, + 4, + 1, + 1, + 16, + 16, + 16, + 16, + 16, + 16, + -1, + ), FmhaFwdTileSize( # fmt: skip 16, 32, @@ -780,7 +909,7 @@ def get_fwd_blobs( for tile, pipeline in itertools.product( tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) ): - if tile.F_bm0 != 128 or tile.F_bn0 != 128: + if tile.F_bm0 != 64 or tile.F_bn0 != 128: continue if pipeline.tag != "qr_async_vsa": continue @@ -846,6 +975,7 @@ def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None: update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api) + update_file(autogen_dir / FMHA_FWD_ONESHOT_API_FILENAME, api_pool.oneshot_api) def write_blobs( @@ -865,3 +995,4 @@ def list_blobs( for kernel in kernels: f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n") f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n") + f.write((file_path.parent / GEN_DIR / FMHA_FWD_ONESHOT_API_FILENAME).as_posix() + "\n") diff --git a/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp index 7349c3576e8..350d1803f66 100644 --- a/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp +++ b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp @@ -280,7 +280,10 @@ float fmha_jenga_fwd(fmha_jenga_fwd_traits, fmha_jenga_fwd_args, const ck_tile:: template float fmha_jenga_fwd_(const ck_tile::stream_config&, fmha_jenga_fwd_args); -float fmha_jenga_fwd(fmha_jenga_fwd_args, const ck_tile::stream_config&); +template +void fmha_jenga_fwd_oneshot_(const ck_tile::stream_config&, fmha_jenga_fwd_args); + +void fmha_jenga_fwd_oneshot(fmha_jenga_fwd_traits, fmha_jenga_fwd_args, const ck_tile::stream_config&); // VSA uses the same traits structure as Jenga; aliases for clarity template float fmha_vsa_fwd_(const ck_tile::stream_config&, fmha_vsa_fwd_args); -float fmha_vsa_fwd(fmha_vsa_fwd_args, const ck_tile::stream_config&); +template +void fmha_vsa_fwd_oneshot_(const ck_tile::stream_config&, fmha_vsa_fwd_args); + +void fmha_vsa_fwd_oneshot(fmha_vsa_fwd_traits, fmha_vsa_fwd_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp b/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp new file mode 100644 index 00000000000..a2df5bac569 --- /dev/null +++ b/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp @@ -0,0 +1,227 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// Hand-written template instantiation for SpargeBlockMapKernel (fp16, D=128). + +#include "sparge_blockmap_trek.hpp" +#include "ck_tile/ops/fmha/block/variants.hpp" + +#include + +// ============================================================================ +// Type configuration for block map kernel (reuses FmhaSparseFwdTypeConfig) +// ============================================================================ + +// fp16: D=128, kM0=64, kN0=128 +using bmap_fp16_block_tile = ck_tile::sequence<64, 128, 128, 128, 128, 128>; +// kM0 kN0 kK0 kN1 kK1 kQKHeaddim(D) + +using bmap_fp16_shape = + ck_tile::TileFmhaShape, // Gemm0BlockWarps + ck_tile::sequence<16, 16, 16>, // Gemm0WarpTile (unused by blockmap, but + // needed by shape) + ck_tile::sequence<4, 1, 1>, // Gemm1BlockWarps + ck_tile::sequence<16, 16, 16>, // Gemm1WarpTile + true>; // VLayout row-major + +using bmap_fp16_trait = ck_tile::TileFmhaTraits; // kIsVRowMajorSkip + +using bmap_fp16_variant = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>; +using bmap_fp16_mask = ck_tile::GenericAttentionMask; + +using bmap_fp16_problem = ck_tile::BlockFmhaPipelineProblem; + +using bmap_fp16_pipeline = ck_tile::SpargeBlockMapPipeline; +using bmap_fp16_kernel = ck_tile::SpargeBlockMapKernel; + +// ============================================================================ +// bf16: D=128, kM0=64, kN0=128 +// ============================================================================ + +using bmap_bf16_block_tile = ck_tile::sequence<64, 128, 128, 128, 128, 128>; + +using bmap_bf16_shape = + ck_tile::TileFmhaShape, + ck_tile::sequence<16, 16, 16>, + ck_tile::sequence<4, 1, 1>, + ck_tile::sequence<16, 16, 16>, + true>; + +using bmap_bf16_trait = ck_tile::TileFmhaTraits; + +using bmap_bf16_variant = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>; +using bmap_bf16_mask = ck_tile::GenericAttentionMask; + +using bmap_bf16_problem = ck_tile::BlockFmhaPipelineProblem; + +using bmap_bf16_pipeline = ck_tile::SpargeBlockMapPipeline; +using bmap_bf16_kernel = ck_tile::SpargeBlockMapKernel; + +// ============================================================================ +// Dispatch +// ============================================================================ + +float sparge_blockmap_fwd(sparge_blockmap_traits traits, + sparge_blockmap_args args, + const ck_tile::stream_config& s) +{ + if(traits.data_type == "fp16" && traits.hdim_q == 128) + { + using k_ = bmap_fp16_kernel; + if(s.log_level_ > 0) + std::cout << ", sparge_blockmap_fp16_d128" << std::flush; + auto [kargs, grids] = sparge_blockmap_create_kargs_and_grids(args); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); + } + + if(traits.data_type == "bf16" && traits.hdim_q == 128) + { + using k_ = bmap_bf16_kernel; + if(s.log_level_ > 0) + std::cout << ", sparge_blockmap_bf16_d128" << std::flush; + auto [kargs, grids] = sparge_blockmap_create_kargs_and_grids(args); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); + } + + if(s.log_level_ > 0) + std::cerr << "sparge_blockmap_fwd: unsupported config (data_type=" << traits.data_type + << ", hdim_q=" << traits.hdim_q << ")" << std::endl; + return -1.f; +} + +// ============================================================================ +// Oneshot version: launches kernel without timing wrapper +// ============================================================================ + +void sparge_blockmap_fwd_oneshot(sparge_blockmap_traits traits, + sparge_blockmap_args args, + const ck_tile::stream_config& s) +{ + if(traits.data_type == "fp16" && traits.hdim_q == 128) + { + using k_ = bmap_fp16_kernel; + auto [kargs, grids] = sparge_blockmap_create_kargs_and_grids(args); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); + return; + } + + if(traits.data_type == "bf16" && traits.hdim_q == 128) + { + using k_ = bmap_bf16_kernel; + auto [kargs, grids] = sparge_blockmap_create_kargs_and_grids(args); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); + return; + } + + std::cerr << "sparge_blockmap_fwd_oneshot: unsupported config (data_type=" << traits.data_type + << ", hdim_q=" << traits.hdim_q << ")" << std::endl; +} + +// ============================================================================ +// Combined functions: blockmap + attention timed together via launch_kernel +// ============================================================================ + +float sparge_jenga_fwd(sparge_blockmap_traits bmap_t, sparge_blockmap_args bmap_a, + fmha_jenga_fwd_traits attn_t, fmha_jenga_fwd_args attn_a, + const ck_tile::stream_config& s) +{ + if(s.log_level_ > 0) + std::cout << ", sparge_blockmap_" << bmap_t.data_type << "_d" << bmap_t.hdim_q + << ", fmha_jenga_fwd_" << attn_t.data_type << "_d" << attn_t.hdim_q + << std::flush; + + return ck_tile::launch_kernel( + s, + [=](const ck_tile::stream_config& s_) { + sparge_blockmap_fwd_oneshot(bmap_t, bmap_a, s_); + }, + [=](const ck_tile::stream_config& s_) { + fmha_jenga_fwd_oneshot(attn_t, attn_a, s_); + }); +} + +float sparge_vsa_fwd_combined(sparge_blockmap_traits bmap_t, sparge_blockmap_args bmap_a, + fmha_vsa_fwd_traits attn_t, fmha_vsa_fwd_args attn_a, + const ck_tile::stream_config& s) +{ + if(s.log_level_ > 0) + std::cout << ", sparge_blockmap_" << bmap_t.data_type << "_d" << bmap_t.hdim_q + << ", fmha_vsa_fwd_" << attn_t.data_type << "_d" << attn_t.hdim_q + << std::flush; + + return ck_tile::launch_kernel( + s, + [=](const ck_tile::stream_config& s_) { + sparge_blockmap_fwd_oneshot(bmap_t, bmap_a, s_); + }, + [=](const ck_tile::stream_config& s_) { + fmha_vsa_fwd_oneshot(attn_t, attn_a, s_); + }); +} diff --git a/example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp b/example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp new file mode 100644 index 00000000000..6eaeb9ea77b --- /dev/null +++ b/example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp @@ -0,0 +1,106 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp" +#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" +#include "ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp" +#include "ck_tile/ops/sparse_attn/kernel/sparge_blockmap_kernel.hpp" + +#include "fmha_fwd_trek.hpp" + +#include +#include + +// ============================================================================ +// Args and traits for sparge block map GPU kernel +// ============================================================================ +struct sparge_blockmap_args +{ + const void* q_ptr; + const void* k_ptr; + + ck_tile::index_t batch; + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t hdim_q; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + + float simthreshd1; + float cdfthreshd; + float topk; + float scale; + + void* block_map_ptr; + void* lut_ptr; + void* valid_block_num_ptr; +}; + +struct sparge_blockmap_traits +{ + std::string data_type; + int hdim_q; +}; + +// ============================================================================ +// Create kernel args and grid dimensions +// ============================================================================ +template +auto sparge_blockmap_create_kargs_and_grids(sparge_blockmap_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = BlockMapKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.stride_q, + args.stride_k, + args.nhead_stride_q, + args.nhead_stride_k, + args.batch_stride_q, + args.batch_stride_k, + args.simthreshd1, + args.cdfthreshd, + args.topk, + args.scale, + args.block_map_ptr, + args.lut_ptr, + args.valid_block_num_ptr); + + dim3 grids = BlockMapKernel::GridSize(args.batch, args.nhead_q, args.seqlen_q); + return ck_tile::make_tuple(kargs, grids); +} + +// ============================================================================ +// Hand-written template instantiation dispatch +// ============================================================================ +float sparge_blockmap_fwd(sparge_blockmap_traits traits, + sparge_blockmap_args args, + const ck_tile::stream_config& stream_config); + +void sparge_blockmap_fwd_oneshot(sparge_blockmap_traits traits, + sparge_blockmap_args args, + const ck_tile::stream_config& stream_config); + +// Combined functions: blockmap + attention with unified timing +float sparge_jenga_fwd(sparge_blockmap_traits, sparge_blockmap_args, + fmha_jenga_fwd_traits, fmha_jenga_fwd_args, + const ck_tile::stream_config&); + +float sparge_vsa_fwd_combined(sparge_blockmap_traits, sparge_blockmap_args, + fmha_vsa_fwd_traits, fmha_vsa_fwd_args, + const ck_tile::stream_config&); diff --git a/example/ck_tile/50_sparse_attn/sparge_tool.hpp b/example/ck_tile/50_sparse_attn/sparge_tool.hpp new file mode 100644 index 00000000000..49c69cc6f74 --- /dev/null +++ b/example/ck_tile/50_sparse_attn/sparge_tool.hpp @@ -0,0 +1,408 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +namespace sparge { + +struct SpargeParams +{ + int BLKQ = 128; + int BLKK = 128; + + // Similarity gate threshold (TODO: per-head support). + float simthreshd1 = 0.6f; + + // Exactly one of the following should be used: + // - Use CDF threshold if topk < 0 + // - Both should be in [0, 1] <-- NEED TO CHECK THIS + float cdfthreshd = 0.98f; + float topk = -1.0f; + + // If true, treat Q/K as BHSD; otherwise BSHD (same convention as CK examples). + bool i_perm = true; +}; + +// Output format CK VSA expects. +struct VSALut +{ + ck_tile::HostTensor lut; // [B, Hq, Q_blk, K_blk] delta-encoded + ck_tile::HostTensor valid_block_num; // [B, Hq, Q_blk] +}; + +namespace detail { + +template +inline float to_f32(const T& x) +{ + return ck_tile::type_convert(x); +} + +// Read element from HostTensor with either BHSD or BSHD layout. +// Q: [B, Hq, Sq, D] if i_perm else [B, Sq, Hq, D] +// K: [B, Hk, Sk, D] if i_perm else [B, Sk, Hk, D] +template +inline float load(const ck_tile::HostTensor& X, bool i_perm, int b, int h, int s, int d) +{ + return i_perm ? to_f32(X(b, h, s, d)) : to_f32(X(b, s, h, d)); +} + +// Compute pooled mean vector of one block: mean over tokens in [s0, s1). +template +std::vector +pooled_mean_block(const ck_tile::HostTensor& X, bool i_perm, int b, int h, int s0, int s1, int d) +{ + std::vector mean(d, 0.0f); + const int bs = std::max(0, s1 - s0); + if(bs == 0) + return mean; + + for(int s = s0; s < s1; ++s) + { + for(int d_ = 0; d_ < d; ++d_) + { + mean[d_] += load(X, i_perm, b, h, s, d_); + } + } + const float inv = 1.0f / static_cast(bs); + for(int d_ = 0; d_ < d; ++d_) + mean[d_] *= inv; + return mean; +} + +// Compute "sim" flag of one block following SpargeAttn's intent: +// mean_sim = sum(Gram(x_hat)) / (BS_*BS_), where x_hat are token vectors normalized along D. +// +// Important: sum(Gram) = ||sum_i x_hat_i||^2, so we can compute it in O(BS_*D) exactly +// instead of O(BS_^2 * D). +template +bool sim_block_flag(const ck_tile::HostTensor& X, + bool i_perm, + int b, + int h, + int s0, + int s1, + int d, + float simthreshd1) +{ + const int bs = std::max(0, s1 - s0); + if(bs == 0) + return false; + + std::vector sum_hat(d, 0.0f); + + for(int s = s0; s < s1; ++s) + { + // Compute L2 norm over D. + float norm2 = 0.0f; + for(int d_ = 0; d_ < d; ++d_) + { + const float v = load(X, i_perm, b, h, s, d_); + norm2 += v * v; + } + float inv_norm = 1.0f; + // spargeAttn use eps to prevent division by zero + if(norm2 > 0.0f) + inv_norm = 1.0f / std::sqrt(norm2); + + // Accumulate normalized vector. + for(int d_ = 0; d_ < d; ++d_) + { + sum_hat[d_] += load(X, i_perm, b, h, s, d_) * inv_norm; + } + } + + float sum_gram = 0.0f; + for(int d_ = 0; d_ < d; ++d_) + sum_gram += sum_hat[d_] * sum_hat[d_]; + + const float denom = static_cast(bs) * static_cast(bs); + const float mean_sim = sum_gram / denom; + + return mean_sim > simthreshd1; +} + +inline int select_count_from_cdf(const std::vector& sorted_probs, float cdfthreshd) +{ + // Choose the smallest n such that cdf[n-1] >= cdfthreshd. + // Ensure at least 1. + if(sorted_probs.empty()) + return 0; + if(cdfthreshd <= 0.0f) + return 1; + + float c = 0.0f; + for(int i = 0; i < static_cast(sorted_probs.size()); ++i) + { + c += sorted_probs[i]; + if(c >= cdfthreshd) + return i + 1; + } + return static_cast(sorted_probs.size()); +} + +inline int select_count_from_topk(int K_blk, float topk) +{ + if(K_blk <= 0) + return 0; + int n = static_cast(std::floor(topk * static_cast(K_blk))); + n = std::max(1, n); + return n; +} + +} // namespace detail + +// Build one-hot block_map[b,hq,qb,kb] in {0,1}. +// - No causal mask +// - No attention sink +// - Logic matches SpargeAttn's structure: +// - score softmax is only over sim_kblocks; ~sim_kblocks are forced ON later +// - if a Q-block is not "similar", force the whole row ON +template +ck_tile::HostTensor build_block_map_meansim(const ck_tile::HostTensor& Q, + const ck_tile::HostTensor& K, + const SpargeParams& p) +{ + const auto qlens = Q.get_lengths(); + const auto klens = K.get_lengths(); + + const int B = static_cast(qlens[0]); + const int Hq = p.i_perm ? static_cast(qlens[1]) : static_cast(qlens[2]); + const int Sq = p.i_perm ? static_cast(qlens[2]) : static_cast(qlens[1]); + const int D = static_cast(qlens[3]); + + [[maybe_unused]] const int Bk = static_cast(klens[0]); + const int Hk = p.i_perm ? static_cast(klens[1]) : static_cast(klens[2]); + const int Sk = p.i_perm ? static_cast(klens[2]) : static_cast(klens[1]); + [[maybe_unused]] const int Dk = static_cast(klens[3]); + + assert(B == Bk && D == Dk && Hq % Hk == 0); + assert(p.BLKQ > 0 && p.BLKK > 0); + + const int nhead_ratio_qk = Hq / Hk; + const int Q_blk = ck_tile::integer_divide_ceil(Sq, p.BLKQ); + const int K_blk = ck_tile::integer_divide_ceil(Sk, p.BLKK); + + ck_tile::HostTensor block_map({B, Hq, Q_blk, K_blk}); + + // pooled_q: [B,Hq,Q_blk,D], pooled_k: [B,Hk,K_blk,D] + // sim_q: [B,Hq,Q_blk], sim_k: [B,Hk,K_blk] + std::vector pooled_q(static_cast(B) * Hq * Q_blk * D, 0.0f); + std::vector pooled_k(static_cast(B) * Hk * K_blk * D, 0.0f); + std::vector sim_q(static_cast(B) * Hq * Q_blk, 0); + std::vector sim_k(static_cast(B) * Hk * K_blk, 0); + + auto idx_pq = [&](int b, int hq, int qb, int d) { + return (((b * Hq + hq) * Q_blk + qb) * D + d); + }; + auto idx_pk = [&](int b, int hk, int kb, int d) { + return (((b * Hk + hk) * K_blk + kb) * D + d); + }; + auto idx_sq = [&](int b, int hq, int qb) { return ((b * Hq + hq) * Q_blk + qb); }; + auto idx_sk = [&](int b, int hk, int kb) { return ((b * Hk + hk) * K_blk + kb); }; + + for(int b = 0; b < B; ++b) + { + for(int hq = 0; hq < Hq; ++hq) + { + // Q blocks + for(int qb = 0; qb < Q_blk; ++qb) + { + const int s0 = qb * p.BLKQ; + const int s1 = std::min(Sq, (qb + 1) * p.BLKQ); + + // pooled mean + auto mean = detail::pooled_mean_block(Q, p.i_perm, b, hq, s0, s1, D); + for(int d = 0; d < D; ++d) + pooled_q[idx_pq(b, hq, qb, d)] = mean[d]; + + // sim flag + sim_q[idx_sq(b, hq, qb)] = + detail::sim_block_flag(Q, p.i_perm, b, hq, s0, s1, D, p.simthreshd1) ? 1 : 0; + } + } + + for(int hk = 0; hk < Hk; ++hk) + { + // K blocks + for(int kb = 0; kb < K_blk; ++kb) + { + const int s0 = kb * p.BLKK; + const int s1 = std::min(Sk, (kb + 1) * p.BLKK); + + auto mean = detail::pooled_mean_block(K, p.i_perm, b, hk, s0, s1, D); + for(int d = 0; d < D; ++d) + pooled_k[idx_pk(b, hk, kb, d)] = mean[d]; + + sim_k[idx_sk(b, hk, kb)] = + detail::sim_block_flag(K, p.i_perm, b, hk, s0, s1, D, p.simthreshd1) ? 1 : 0; + } + } + } + + const float scale = 1.0f / std::sqrt(static_cast(D)); + + // Main loop + for(int b = 0; b < B; ++b) + { + for(int hq = 0; hq < Hq; ++hq) + { + const int hk = hq / nhead_ratio_qk; + + for(int qb = 0; qb < Q_blk; ++qb) + { + const bool q_is_sim = (sim_q[idx_sq(b, hq, qb)] != 0); + + // If Q-block is not "similar", force dense row. + if(!q_is_sim) + { + for(int kb = 0; kb < K_blk; ++kb) + block_map(b, hq, qb, kb) = 1; + continue; + } + + // Compute scores over K blocks (only sim_kblocks participate in softmax; others set + // to -inf). + std::vector score(K_blk, -std::numeric_limits::infinity()); + for(int kb = 0; kb < K_blk; ++kb) + { + const bool k_is_sim = (sim_k[idx_sk(b, hk, kb)] != 0); + if(!k_is_sim) + { + block_map(b, hq, qb, kb) = 1; + continue; + } + + float dot = 0.0f; + for(int d = 0; d < D; ++d) + { + dot += pooled_q[idx_pq(b, hq, qb, d)] * pooled_k[idx_pk(b, hk, kb, d)]; + } + score[kb] = dot * scale; + } + + // Softmax over K_blk (numerically stable). If all -inf, probs become all zeros. + float maxv = -std::numeric_limits::infinity(); + for(int kb = 0; kb < K_blk; ++kb) + maxv = std::max(maxv, score[kb]); + + std::vector prob(K_blk, 0.0f); + if(std::isfinite(maxv)) + { + float sumexp = 0.0f; + for(int kb = 0; kb < K_blk; ++kb) + { + if(!std::isfinite(score[kb])) + continue; + const float e = std::exp(score[kb] - maxv); + prob[kb] = e; + sumexp += e; + } + if(sumexp > 0.0f) + { + const float inv = 1.0f / sumexp; + for(int kb = 0; kb < K_blk; ++kb) + prob[kb] *= inv; + } + else + { + // All exponentials underflowed: keep zeros. + std::fill(prob.begin(), prob.end(), 0.0f); + } + } + + // Sort indices by prob descending. + std::vector order(K_blk); + std::iota(order.begin(), order.end(), 0); + std::sort(order.begin(), order.end(), [&](int a, int c) { + if(prob[a] != prob[c]) + return prob[a] > prob[c]; + return a < c; // tie-breaker for determinism + }); + + // Determine how many to select. + int num_to_select = 0; + if(p.topk > 0.0f) + { + num_to_select = detail::select_count_from_topk(K_blk, p.topk); + } + else + { + // Use CDF threshold selection (smallest n s.t. cumulative prob >= cdfthreshd). + std::vector sorted_probs(K_blk); + for(int i = 0; i < K_blk; ++i) + sorted_probs[i] = prob[order[i]]; + num_to_select = detail::select_count_from_cdf(sorted_probs, p.cdfthreshd); + num_to_select = std::max(1, num_to_select); + } + + // Select top-kb blocks by order[0..num_to_select-1]. + for(int i = 0; i < num_to_select; ++i) + { + const int kb = order[i]; + block_map(b, hq, qb, kb) = 1; + } + } + } + } + + return block_map; +} + +// Convert one-hot block_map -> delta-encoded LUT + valid_block_num (CK VSA format). +template +VSALut block_map_to_vsa_lut_delta(const ck_tile::HostTensor& block_map) +{ + const auto lens = block_map.get_lengths(); + const int B = static_cast(lens[0]); + const int H = static_cast(lens[1]); + const int Q = static_cast(lens[2]); + const int K = static_cast(lens[3]); + + VSALut out{ + ck_tile::HostTensor({B, H, Q, K}), + ck_tile::HostTensor({B, H, Q}), + }; + + for(int b = 0; b < B; ++b) + { + for(int h = 0; h < H; ++h) + { + for(int q = 0; q < Q; ++q) + { + int32_t valid = 0; + int32_t prev = 0; + + for(int k = 0; k < K; ++k) + { + const bool on = static_cast(block_map(b, h, q, k)) != 0; + if(on) + { + out.lut(b, h, q, valid) = static_cast(k - prev); + prev = static_cast(k); + ++valid; + } + } + + out.valid_block_num(b, h, q) = valid; + + // Optional: zero-fill the unused tail for determinism. + for(int i = valid; i < K; ++i) + out.lut(b, h, q, i) = 0; + } + } + } + + return out; +} + +} // namespace sparge diff --git a/example/ck_tile/50_sparse_attn/test_sparge.cpp b/example/ck_tile/50_sparse_attn/test_sparge.cpp new file mode 100644 index 00000000000..7c30a10b062 --- /dev/null +++ b/example/ck_tile/50_sparse_attn/test_sparge.cpp @@ -0,0 +1,432 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// Unified test for Sparge pipeline: blockmap generation + sparse attention (Jenga/VSA). + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host/reference/reference_blocked_attention.hpp" +#include "ck_tile/core/utility/bit_cast.hpp" + +#include "fmha_fwd_trek.hpp" +#include "sparge_blockmap_trek.hpp" +#include "sparge_tool.hpp" + +// ============================================================================ +// Helpers +// ============================================================================ + +template +ck_tile::HostTensor +make_qkv_tensor(ck_tile::index_t batch, ck_tile::index_t nhead, ck_tile::index_t seqlen, ck_tile::index_t hdim, bool i_perm) +{ + if(i_perm) + return ck_tile::HostTensor({batch, nhead, seqlen, hdim}); + return ck_tile::HostTensor({batch, seqlen, nhead, hdim}); +} + +template +ck_tile::HostTensor to_bhsd(const ck_tile::HostTensor& tensor, bool is_bhsd) +{ + auto lens = tensor.get_lengths(); + ck_tile::index_t batch = lens[0]; + ck_tile::index_t seqlen = is_bhsd ? lens[2] : lens[1]; + ck_tile::index_t nhead = is_bhsd ? lens[1] : lens[2]; + ck_tile::index_t hdim = lens[3]; + + ck_tile::HostTensor out({batch, nhead, seqlen, hdim}); + for(ck_tile::index_t b = 0; b < batch; ++b) + for(ck_tile::index_t h = 0; h < nhead; ++h) + for(ck_tile::index_t s = 0; s < seqlen; ++s) + for(ck_tile::index_t d = 0; d < hdim; ++d) + out(b, h, s, d) = is_bhsd ? tensor(b, h, s, d) : tensor(b, s, h, d); + return out; +} + +template +auto get_error_tolerance() +{ + double rtol = 1e-2; + double atol = 4e-2; + if constexpr(std::is_same_v) + { + atol = 2e-1; + rtol = 2e-1; + } + return ck_tile::make_tuple(rtol, atol); +} + +template +float to_float_for_compare(T value) +{ + return static_cast(value); +} + +template <> +float to_float_for_compare(ck_tile::bf16_t value) +{ +#if CK_TILE_USE_CUSTOM_DATA_TYPE + return static_cast(value); +#else + return ck_tile::bf16_to_float_raw(ck_tile::bit_cast(value)); +#endif +} + +// ============================================================================ +// Arg parser +// ============================================================================ +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser + .insert("v", "1", "0:no validation, 1:cpu validation") + .insert("pipeline", "jenga", "attention pipeline: jenga / vsa") + .insert("b", "1", "batch size") + .insert("h", "4", "num of head for q") + .insert("h_k", "-1", "num of head for k/v, -1 means equal to h") + .insert("s", "4096", "seqlen_q") + .insert("s_k", "-1", "seqlen_k, -1 means equal to s") + .insert("d", "128", "head dim for q, k") + .insert("d_v", "-1", "head dim for v, -1 means equal to d") + .insert("topk", "0.3", "topk ratio for blockmap (fraction of K-blocks to keep)") + .insert("cdfthreshd", "-1", "CDF threshold for blockmap (overrides topk if >= 0)") + .insert("simthreshd1", "0.6", "similarity threshold for blockmap") + .insert("prec", "fp16", "data type: fp16/bf16") + .insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d") + .insert("operm", "1", "permute output") + .insert("seed", "42", "random seed") + .insert("warmup", "5", "warmup iterations") + .insert("repeat", "20", "benchmark iterations") + .insert("kname", "0", "print kernel name"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// ============================================================================ +// Main test +// ============================================================================ +template +bool run_test(const ck_tile::ArgParser& arg_parser) +{ + int do_validation = arg_parser.get_int("v"); + std::string pipeline = arg_parser.get_str("pipeline"); + ck_tile::index_t batch = arg_parser.get_int("b"); + ck_tile::index_t nhead = arg_parser.get_int("h"); + ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); + ck_tile::index_t seqlen_q = arg_parser.get_int("s"); + ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); + ck_tile::index_t hdim_q = arg_parser.get_int("d"); + ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); + float topk = arg_parser.get_float("topk"); + float cdfthreshd = arg_parser.get_float("cdfthreshd"); + float simthreshd1 = arg_parser.get_float("simthreshd1"); + bool i_perm = arg_parser.get_bool("iperm"); + bool o_perm = arg_parser.get_bool("operm"); + uint32_t seed = arg_parser.get_uint32("seed"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + int kname = arg_parser.get_int("kname"); + + if(nhead_k < 0) nhead_k = nhead; + if(seqlen_k < 0) seqlen_k = seqlen_q; + if(hdim_v < 0) hdim_v = hdim_q; + + // If cdfthreshd >= 0, use CDF mode; otherwise use topk mode + if(cdfthreshd >= 0.0f) + topk = -1.0f; + + constexpr ck_tile::index_t BLKQ = 64; + constexpr ck_tile::index_t BLKK = 128; + + if(hdim_q != 128 || hdim_v != 128) + { + std::cout << "\n>>> TEST SKIPPED <<<\n" + << "Kernel instances are generated for hdim=128 only.\n"; + return true; + } + + ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ; + ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK; + + std::string prec_str = std::is_same_v ? "fp16" : "bf16"; + std::cout << "[" << pipeline << "|" << prec_str + << "] b=" << batch << " h=" << nhead << " s=" << seqlen_q + << " d=" << hdim_q << " topk=" << topk + << " sim1=" << simthreshd1 << std::flush; + + // ---- allocate host tensors ---- + auto q_host = make_qkv_tensor(batch, nhead, seqlen_q, hdim_q, i_perm); + auto k_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_q, i_perm); + auto v_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_v, i_perm); + auto output_host = o_perm ? ck_tile::HostTensor({batch, nhead, seqlen_q, hdim_v}) + : ck_tile::HostTensor({batch, seqlen_q, nhead, hdim_v}); + + ck_tile::HostTensor block_map_host({batch, nhead, num_q_blocks, num_k_blocks}); + ck_tile::HostTensor lut_host({batch, nhead, num_q_blocks, num_k_blocks}); + ck_tile::HostTensor valid_block_num_host({batch, nhead, num_q_blocks}); + + ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed}(q_host); + ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 1}(k_host); + ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 2}(v_host); + + // ---- device tensors ---- + ck_tile::DeviceMem q_dev(q_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem k_dev(k_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem v_dev(v_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem o_dev(output_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem block_map_dev(block_map_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem lut_dev(lut_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem valid_bn_dev(valid_block_num_host.get_element_space_size_in_bytes()); + + q_dev.ToDevice(q_host.data()); + k_dev.ToDevice(k_host.data()); + v_dev.ToDevice(v_host.data()); + o_dev.SetZero(); + block_map_dev.SetZero(); + lut_dev.SetZero(); + valid_bn_dev.SetZero(); + + // ---- strides (BHSD when i_perm=true) ---- + auto q_strides = q_host.get_strides(); + auto k_strides = k_host.get_strides(); + auto v_strides = v_host.get_strides(); + auto o_strides = output_host.get_strides(); + + float scale_s = 1.0f / std::sqrt(static_cast(hdim_q)); + + // ---- build blockmap args ---- + sparge_blockmap_traits bmap_traits; + bmap_traits.data_type = std::is_same_v ? "fp16" : "bf16"; + bmap_traits.hdim_q = hdim_q; + + sparge_blockmap_args bmap_args; + bmap_args.q_ptr = q_dev.GetDeviceBuffer(); + bmap_args.k_ptr = k_dev.GetDeviceBuffer(); + bmap_args.batch = batch; + bmap_args.seqlen_q = seqlen_q; + bmap_args.seqlen_k = seqlen_k; + bmap_args.hdim_q = hdim_q; + bmap_args.nhead_q = nhead; + bmap_args.nhead_k = nhead_k; + bmap_args.stride_q = q_strides[i_perm ? 2 : 1]; + bmap_args.stride_k = k_strides[i_perm ? 2 : 1]; + bmap_args.nhead_stride_q = q_strides[i_perm ? 1 : 2]; + bmap_args.nhead_stride_k = k_strides[i_perm ? 1 : 2]; + bmap_args.batch_stride_q = q_strides[0]; + bmap_args.batch_stride_k = k_strides[0]; + bmap_args.simthreshd1 = simthreshd1; + bmap_args.cdfthreshd = (topk < 0.0f) ? cdfthreshd : -1.0f; + bmap_args.topk = topk; + bmap_args.scale = scale_s; + bmap_args.block_map_ptr = block_map_dev.GetDeviceBuffer(); + bmap_args.lut_ptr = (pipeline == "vsa") ? lut_dev.GetDeviceBuffer() : nullptr; + bmap_args.valid_block_num_ptr = (pipeline == "vsa") ? valid_bn_dev.GetDeviceBuffer() : nullptr; + + // ---- build attention args ---- + ck_tile::stream_config stream_cfg; + stream_cfg.stream_id_ = nullptr; + stream_cfg.time_kernel_ = true; + stream_cfg.log_level_ = kname; + stream_cfg.cold_niters_ = warmup; + stream_cfg.nrepeat_ = repeat; + + float avg_ms = -1.0f; + + if(pipeline == "jenga") + { + fmha_jenga_fwd_traits attn_traits; + attn_traits.hdim_q = hdim_q; + attn_traits.hdim_v = hdim_v; + attn_traits.data_type = std::is_same_v ? "fp16" : "bf16"; + attn_traits.is_v_rowmajor = true; + attn_traits.mask_type = mask_enum::no_mask; + + fmha_jenga_fwd_args attn_args; + attn_args.q_ptr = q_dev.GetDeviceBuffer(); + attn_args.k_ptr = k_dev.GetDeviceBuffer(); + attn_args.v_ptr = v_dev.GetDeviceBuffer(); + attn_args.block_relation_onehot_ptr = block_map_dev.GetDeviceBuffer(); + attn_args.o_ptr = o_dev.GetDeviceBuffer(); + attn_args.seqlen_q = seqlen_q; + attn_args.seqlen_k = seqlen_k; + attn_args.batch = batch; + attn_args.max_seqlen_q = seqlen_q; + attn_args.hdim_q = hdim_q; + attn_args.hdim_v = hdim_v; + attn_args.nhead_q = nhead; + attn_args.nhead_k = nhead_k; + attn_args.scale_s = scale_s; + attn_args.stride_q = q_strides[i_perm ? 2 : 1]; + attn_args.stride_k = k_strides[i_perm ? 2 : 1]; + attn_args.stride_v = v_strides[i_perm ? 2 : 1]; + attn_args.stride_o = o_strides[o_perm ? 2 : 1]; + attn_args.nhead_stride_q = q_strides[i_perm ? 1 : 2]; + attn_args.nhead_stride_k = k_strides[i_perm ? 1 : 2]; + attn_args.nhead_stride_v = v_strides[i_perm ? 1 : 2]; + attn_args.nhead_stride_o = o_strides[o_perm ? 1 : 2]; + attn_args.batch_stride_q = q_strides[0]; + attn_args.batch_stride_k = k_strides[0]; + attn_args.batch_stride_v = v_strides[0]; + attn_args.batch_stride_o = o_strides[0]; + attn_args.window_size_left = -1; + attn_args.window_size_right = -1; + attn_args.mask_type = 0; + + avg_ms = sparge_jenga_fwd(bmap_traits, bmap_args, attn_traits, attn_args, stream_cfg); + } + else if(pipeline == "vsa") + { + fmha_vsa_fwd_traits attn_traits; + attn_traits.hdim_q = hdim_q; + attn_traits.hdim_v = hdim_v; + attn_traits.data_type = std::is_same_v ? "fp16" : "bf16"; + attn_traits.is_v_rowmajor = true; + attn_traits.mask_type = mask_enum::no_mask; + + fmha_vsa_fwd_args attn_args; + attn_args.q_ptr = q_dev.GetDeviceBuffer(); + attn_args.k_ptr = k_dev.GetDeviceBuffer(); + attn_args.v_ptr = v_dev.GetDeviceBuffer(); + attn_args.lut_ptr = lut_dev.GetDeviceBuffer(); + attn_args.valid_block_num_ptr = valid_bn_dev.GetDeviceBuffer(); + attn_args.o_ptr = o_dev.GetDeviceBuffer(); + attn_args.seqlen_q = seqlen_q; + attn_args.seqlen_k = seqlen_k; + attn_args.batch = batch; + attn_args.max_seqlen_q = seqlen_q; + attn_args.hdim_q = hdim_q; + attn_args.hdim_v = hdim_v; + attn_args.nhead_q = nhead; + attn_args.nhead_k = nhead_k; + attn_args.scale_s = scale_s; + attn_args.stride_q = q_strides[i_perm ? 2 : 1]; + attn_args.stride_k = k_strides[i_perm ? 2 : 1]; + attn_args.stride_v = v_strides[i_perm ? 2 : 1]; + attn_args.stride_o = o_strides[o_perm ? 2 : 1]; + attn_args.nhead_stride_q = q_strides[i_perm ? 1 : 2]; + attn_args.nhead_stride_k = k_strides[i_perm ? 1 : 2]; + attn_args.nhead_stride_v = v_strides[i_perm ? 1 : 2]; + attn_args.nhead_stride_o = o_strides[o_perm ? 1 : 2]; + attn_args.batch_stride_q = q_strides[0]; + attn_args.batch_stride_k = k_strides[0]; + attn_args.batch_stride_v = v_strides[0]; + attn_args.batch_stride_o = o_strides[0]; + attn_args.window_size_left = -1; + attn_args.window_size_right = -1; + attn_args.mask_type = 0; + + avg_ms = sparge_vsa_fwd_combined(bmap_traits, bmap_args, attn_traits, attn_args, stream_cfg); + } + else + { + std::cerr << "Unknown pipeline: " << pipeline << " (use jenga or vsa)\n"; + return false; + } + + // ---- TFLOPS calculation (dense FMHA formula, so sparsity gains show as higher TFLOPS) ---- + std::size_t flop = static_cast(batch) * nhead * + (static_cast(2) * seqlen_q * seqlen_k * hdim_q + + static_cast(2) * seqlen_q * seqlen_k * hdim_v); + float tflops = (avg_ms > 0.f) ? static_cast(flop) / 1.E9f / avg_ms : 0.f; + + if(avg_ms > 0.f) + { + std::cout << std::fixed << ", " << std::setprecision(3) << avg_ms << " ms, " + << std::setprecision(2) << tflops << " TFlops" << std::flush; + } + + // ---- copy results back ---- + o_dev.FromDevice(output_host.data()); + block_map_dev.FromDevice(block_map_host.data()); + + // ---- count active blocks ---- + ck_tile::index_t total_blocks = batch * nhead * num_q_blocks * num_k_blocks; + ck_tile::index_t active_blocks = 0; + for(size_t i = 0; i < block_map_host.mData.size(); ++i) + if(block_map_host.mData[i]) + active_blocks++; + float actual_sparsity = 1.0f - static_cast(active_blocks) / static_cast(total_blocks); + std::cout << ", sparsity=" << std::setprecision(2) << actual_sparsity + << "(" << active_blocks << "/" << total_blocks << ")" << std::flush; + + // ---- validation ---- + bool pass = true; + if(do_validation) + { + auto q_ref = to_bhsd(q_host, i_perm); + auto k_ref = to_bhsd(k_host, i_perm); + auto v_ref = to_bhsd(v_host, i_perm); + + ck_tile::HostTensor output_ref({batch, nhead, seqlen_q, hdim_v}); + ck_tile::reference_blocked_attention( + q_ref, k_ref, v_ref, block_map_host, output_ref, BLKQ, BLKK, scale_s); + + auto [rtol, atol] = get_error_tolerance(); + + float max_diff = 0.0f; + size_t num_errors = 0; + + auto output_host_bhsd = to_bhsd(output_host, o_perm); + for(size_t i = 0; i < output_host_bhsd.mData.size(); ++i) + { + float gpu_val = to_float_for_compare(output_host_bhsd.mData[i]); + float ref_val = to_float_for_compare(output_ref.mData[i]); + float diff = std::abs(gpu_val - ref_val); + float rel_diff = (std::abs(ref_val) > 1e-6f) ? diff / std::abs(ref_val) : diff; + + max_diff = std::max(max_diff, diff); + + if(diff > atol && rel_diff > rtol) + num_errors++; + } + + pass = (num_errors == 0); + std::cout << ", " << (pass ? "PASS" : "FAIL") + << "(err=" << num_errors << "/" << output_host_bhsd.mData.size() + << " maxdiff=" << max_diff << ")"; + } + + std::cout << std::endl; + return pass; +} + +// ============================================================================ +// Main +// ============================================================================ +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + { + std::cerr << "Failed to parse arguments\n"; + return -1; + } + + std::string prec = arg_parser.get_str("prec"); + + bool test_result = false; + if(prec == "fp16") + { + test_result = run_test(arg_parser); + } + else if(prec == "bf16") + { + test_result = run_test(arg_parser); + } + else + { + std::cerr << "Unsupported precision: " << prec << "\n"; + return -1; + } + + return test_result ? 0 : -1; +} diff --git a/include/ck_tile/ops/sparse_attn/kernel/sparge_blockmap_kernel.hpp b/include/ck_tile/ops/sparse_attn/kernel/sparge_blockmap_kernel.hpp new file mode 100644 index 00000000000..ca177abf23a --- /dev/null +++ b/include/ck_tile/ops/sparse_attn/kernel/sparge_blockmap_kernel.hpp @@ -0,0 +1,195 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include "ck_tile/core.hpp" +#include + +namespace ck_tile { + +template +struct SpargeBlockMapKernel +{ + using Pipeline = remove_cvref_t; + + static constexpr index_t kBlockSize = Pipeline::kBlockSize; + static constexpr index_t kBlockPerCu = Pipeline::kBlockPerCu; + + using QDataType = typename Pipeline::QDataType; + using KDataType = typename Pipeline::KDataType; + + static constexpr index_t kM0 = Pipeline::kM0; + static constexpr index_t kN0 = Pipeline::kN0; + static constexpr index_t D = Pipeline::D; + + static constexpr index_t kAlignment = 16 / sizeof(QDataType); + + struct Kargs + { + const void* q_ptr; + const void* k_ptr; + + index_t seqlen_q; + index_t seqlen_k; + index_t hdim_q; + + index_t nhead_q; + index_t nhead_ratio_qk; + + index_t stride_q; + index_t stride_k; + index_t nhead_stride_q; + index_t nhead_stride_k; + index_t batch_stride_q; + index_t batch_stride_k; + + float simthreshd1; + float cdfthreshd; + float topk; + float scale; + + void* block_map_ptr; + void* lut_ptr; + void* valid_block_num_ptr; + + index_t N_k; + }; + + CK_TILE_HOST static constexpr auto MakeKargs(const void* q_ptr, + const void* k_ptr, + index_t seqlen_q, + index_t seqlen_k, + index_t hdim_q, + index_t nhead_q, + index_t nhead_ratio_qk, + index_t stride_q, + index_t stride_k, + index_t nhead_stride_q, + index_t nhead_stride_k, + index_t batch_stride_q, + index_t batch_stride_k, + float simthreshd1, + float cdfthreshd, + float topk, + float scale, + void* block_map_ptr, + void* lut_ptr, + void* valid_block_num_ptr) + { + const index_t N_k = integer_divide_ceil(seqlen_k, kN0); + return Kargs{q_ptr, + k_ptr, + seqlen_q, + seqlen_k, + hdim_q, + nhead_q, + nhead_ratio_qk, + stride_q, + stride_k, + nhead_stride_q, + nhead_stride_k, + batch_stride_q, + batch_stride_k, + simthreshd1, + cdfthreshd, + topk, + scale, + block_map_ptr, + lut_ptr, + valid_block_num_ptr, + N_k}; + } + + CK_TILE_HOST static constexpr auto GridSize(index_t batch, index_t nhead_q, index_t seqlen_q) + { + const index_t Q_blk = integer_divide_ceil(seqlen_q, kM0); + return dim3(Q_blk, nhead_q, batch); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + const index_t qb = static_cast(blockIdx.x); + const index_t hq = static_cast(blockIdx.y); + const index_t b = static_cast(blockIdx.z); + + const index_t hk = hq / kargs.nhead_ratio_qk; + + // Q pointer for this (batch, head, q_block) + const auto* q_base = reinterpret_cast(kargs.q_ptr) + + b * kargs.batch_stride_q + hq * kargs.nhead_stride_q + + qb * kM0 * kargs.stride_q; + + // K pointer for this (batch, head_k) + const auto* k_base = reinterpret_cast(kargs.k_ptr) + + b * kargs.batch_stride_k + hk * kargs.nhead_stride_k; + + // Q DRAM view with OOB padding + const auto q_dram_naive = make_naive_tensor_view( + q_base, + make_tuple(kargs.seqlen_q - qb * kM0, D), + make_tuple(kargs.stride_q, 1), + number{}, + number<1>{}); + const auto q_dram = pad_tensor_view( + q_dram_naive, make_tuple(number{}, number{}), sequence{}); + + auto q_window = make_tile_window(q_dram, + make_tuple(number{}, number{}), + {0, 0}, + Pipeline::MakeQBlockDistribution()); + + // K DRAM view with OOB padding + const auto k_dram_naive = + make_naive_tensor_view(k_base, + make_tuple(kargs.seqlen_k, D), + make_tuple(kargs.stride_k, 1), + number{}, + number<1>{}); + const auto k_dram = pad_tensor_view( + k_dram_naive, make_tuple(number{}, number{}), sequence{}); + + auto k_window = make_tile_window(k_dram, + make_tuple(number{}, number{}), + {0, 0}, + Pipeline::MakeKBlockDistribution()); + + // Output pointers for this (batch, head, q_block) + const index_t N_k = kargs.N_k; + const index_t bmap_offset = + (b * kargs.nhead_q + hq) * integer_divide_ceil(kargs.seqlen_q, kM0) * N_k + qb * N_k; + auto* bmap_ptr = reinterpret_cast(kargs.block_map_ptr) + bmap_offset; + + int32_t* lut_out = nullptr; + int32_t* valid_out = nullptr; + if(kargs.lut_ptr != nullptr) + { + lut_out = reinterpret_cast(kargs.lut_ptr) + bmap_offset; + const index_t valid_offset = + (b * kargs.nhead_q + hq) * integer_divide_ceil(kargs.seqlen_q, kM0) + qb; + valid_out = reinterpret_cast(kargs.valid_block_num_ptr) + valid_offset; + } + + // Shared memory + __shared__ char smem[Pipeline::GetSmemSize()]; + + Pipeline{}(q_window, + k_window, + kargs.seqlen_q, + kargs.seqlen_k, + qb, + N_k, + kargs.nhead_ratio_qk, + kargs.simthreshd1, + kargs.cdfthreshd, + kargs.topk, + kargs.scale, + bmap_ptr, + lut_out, + valid_out, + static_cast(smem)); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp index 67936c4353f..9fe8b365b00 100644 --- a/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp +++ b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp @@ -318,26 +318,26 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga { if(!block_relation_onehot[i_total_loops]) { - i_total_loops++; - if(i_total_loops < num_total_loop) - { - // move K tile windows - move_tile_window(k_dram_block_window, {kN0, 0}); - k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); - - if(block_relation_onehot[i_total_loops]) - { - async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), - k_dram_window, - number<-1>{}, - k_oob_ck, - k_pre_np); - } - move_tile_window(k_dram_window, {0, kK0}); - move_tile_window(v_dram_window, {0, kN0}); - continue; - } - break; + // scan-ahead: find the next active block in one shot + index_t next = i_total_loops + 1; + while(next < num_total_loop && !block_relation_onehot[next]) + next++; + if(next >= num_total_loop) + break; + const index_t delta = next - i_total_loops; + i_total_loops = next; + // jump K/V windows to the next active block + move_tile_window(k_dram_block_window, {kN0 * delta, 0}); + k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); + move_tile_window(v_dram_window, {0, kN0 * delta}); + // immediately prefetch the active K tile + async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), + k_dram_window, + number<-1>{}, + k_oob_ck, + k_pre_np); + move_tile_window(k_dram_window, {0, kK0}); + continue; } // STAGE 1, QK gemm diff --git a/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp index 2b097ae5827..578ad7e6039 100644 --- a/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp +++ b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp @@ -200,7 +200,7 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); - int seqlen_k_start = kv_block_idx_ptr[0] * kM0; + int seqlen_k_start = kv_block_idx_ptr[0] * kN0; auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), q_dram_block_window_tmp.get_window_lengths(), q_dram_block_window_tmp.get_window_origin(), diff --git a/include/ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp b/include/ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp new file mode 100644 index 00000000000..222e73c60e2 --- /dev/null +++ b/include/ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp @@ -0,0 +1,521 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/reduce.hpp" + +namespace ck_tile { + +template +struct SpargeBlockMapPipeline +{ + using Problem = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using BlockFmhaShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t D = BlockFmhaShape::kQKHeaddim; + static constexpr index_t NumWarps = BlockFmhaShape::NumWarps; + static constexpr index_t WarpSize = get_warp_size(); + + static constexpr index_t KPerThread = 16 / sizeof(QDataType); + static constexpr index_t KThreads = D / KPerThread; + static constexpr index_t SeqThreadPerWarp = WarpSize / KThreads; + static constexpr index_t MPerThread = kM0 / (SeqThreadPerWarp * NumWarps); + static constexpr index_t NPerThread = kN0 / (SeqThreadPerWarp * NumWarps); + + static constexpr index_t kBlockPerCu = 1; + static constexpr index_t kMaxKBlocks = 1024; + + // LDS layout (non-overlapping, all used simultaneously in Phase 2): + // [0 .. kReduceBytes) cross-warp reduction scratch + // [kScoreOffset ..) scores[N_k] + // [kBmapOffset ..) block_map[N_k] + // [kSmallOffset ..) Phase 3 argmax scratch (2*NumWarps floats) + static constexpr index_t kReduceBytes = NumWarps * D * sizeof(float); + static constexpr index_t kScoreOffset = kReduceBytes; + static constexpr index_t kBmapOffset = kScoreOffset + kMaxKBlocks * sizeof(float); + static constexpr index_t kSmallOffset = kBmapOffset + kMaxKBlocks * sizeof(uint8_t); + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return kSmallOffset + 2 * NumWarps * sizeof(float); + } + + CK_TILE_HOST_DEVICE static constexpr auto MakeQBlockDistribution() + { + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + CK_TILE_HOST_DEVICE static constexpr auto MakeKBlockDistribution() + { + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + // Extract tile data into a local float array via static_for (compile-time indices). + template + CK_TILE_DEVICE static void tile_to_float(const Tile& tile, float (&out)[BufSize]) + { + static_assert(Tile::get_thread_buffer_size() == BufSize); + const auto& buf = tile.get_thread_buffer(); + static_for<0, BufSize, 1>{}([&](auto i) { out[i.value] = type_convert(buf[i]); }); + } + + // Column-wise (dim=0) sum: accumulate SeqPerThread rows into KPerThread partial sums, + // then xor-shuffle across m_idx within warp. + template + CK_TILE_DEVICE static void column_reduce_thread_and_warp(const float* __restrict__ data, + float (&col_acc)[KPerThread]) + { + for(index_t k = 0; k < KPerThread; ++k) + col_acc[k] = 0.f; + + for(index_t m = 0; m < SeqPerThread; ++m) + for(index_t k = 0; k < KPerThread; ++k) + col_acc[k] += data[m * KPerThread + k]; + + for(index_t stride = KThreads; stride < WarpSize; stride *= 2) + for(index_t k = 0; k < KPerThread; ++k) + col_acc[k] += warp_shuffle(col_acc[k], __lane_id() ^ stride); + } + + // Cross-warp LDS reduction for column sums. + CK_TILE_DEVICE static void column_reduce_cross_warp(float (&col_acc)[KPerThread], + float* __restrict__ smem_reduce) + { + const index_t tid = static_cast(threadIdx.x); + const index_t warp_id = tid / WarpSize; + const index_t lane_id = tid % WarpSize; + const index_t k_idx = lane_id % KThreads; + const index_t m_idx = lane_id / KThreads; + + if(m_idx == 0) + for(index_t k = 0; k < KPerThread; ++k) + smem_reduce[warp_id * D + k_idx * KPerThread + k] = col_acc[k]; + __syncthreads(); + + for(index_t k = 0; k < KPerThread; ++k) + col_acc[k] = 0.f; + for(index_t w = 0; w < NumWarps; ++w) + for(index_t k = 0; k < KPerThread; ++k) + col_acc[k] += smem_reduce[w * D + k_idx * KPerThread + k]; + __syncthreads(); + } + + // Compute ||v||^2 per row: sum along KPerThread then xor-shuffle across k_idx. + template + CK_TILE_DEVICE static void row_reduce_sq_norm(const float* __restrict__ data, + float (&row_norms)[SeqPerThread], + index_t actual_seq) + { + const index_t tid = static_cast(threadIdx.x); + const index_t warp_id = tid / WarpSize; + const index_t m_idx = (tid % WarpSize) / KThreads; + + for(index_t m = 0; m < SeqPerThread; ++m) + { + float sq = 0.f; + for(index_t k = 0; k < KPerThread; ++k) + { + float v = data[m * KPerThread + k]; + sq += v * v; + } + for(index_t stride = 1; stride < KThreads; stride *= 2) + sq += warp_shuffle(sq, __lane_id() ^ stride); + + index_t gsq = m * (SeqThreadPerWarp * NumWarps) + warp_id * SeqThreadPerWarp + m_idx; + row_norms[m] = (gsq < actual_seq) ? sq : 0.f; + } + } + + // Column reduce of normalised rows: sum_hat[d] = sum_i data[i,d] / ||data[i,:]||. + template + CK_TILE_DEVICE static void column_reduce_normalised(const float* __restrict__ data, + const float* __restrict__ row_norms, + float (&col_acc)[KPerThread], + index_t actual_seq) + { + const index_t tid = static_cast(threadIdx.x); + const index_t warp_id = tid / WarpSize; + const index_t m_idx = (tid % WarpSize) / KThreads; + + for(index_t k = 0; k < KPerThread; ++k) + col_acc[k] = 0.f; + + for(index_t m = 0; m < SeqPerThread; ++m) + { + float inv_norm = (row_norms[m] > 0.f) ? (1.0f / __builtin_sqrtf(row_norms[m])) : 0.f; + index_t gsq = m * (SeqThreadPerWarp * NumWarps) + warp_id * SeqThreadPerWarp + m_idx; + if(gsq < actual_seq) + for(index_t k = 0; k < KPerThread; ++k) + col_acc[k] += data[m * KPerThread + k] * inv_norm; + } + + for(index_t stride = KThreads; stride < WarpSize; stride *= 2) + for(index_t k = 0; k < KPerThread; ++k) + col_acc[k] += warp_shuffle(col_acc[k], __lane_id() ^ stride); + } + + // Scalar reduce across k_idx lanes (within warp). + CK_TILE_DEVICE static float reduce_across_k(float v) + { + for(index_t stride = 1; stride < KThreads; stride *= 2) + v += warp_shuffle(v, __lane_id() ^ stride); + return v; + } + + // Full-block scalar reduce (warp xor + cross-warp LDS). + CK_TILE_DEVICE static float block_reduce_sum(float v, float* smem_small) + { + const index_t tid = static_cast(threadIdx.x); + const index_t warp_id = tid / WarpSize; + const index_t lane_id = tid % WarpSize; + + for(index_t stride = 1; stride < WarpSize; stride *= 2) + v += warp_shuffle(v, __lane_id() ^ stride); + if(lane_id == 0) + smem_small[warp_id] = v; + __syncthreads(); + if(tid == 0) + { + float s = 0.f; + for(index_t w = 0; w < NumWarps; ++w) + s += smem_small[w]; + smem_small[0] = s; + } + __syncthreads(); + return smem_small[0]; + } + + CK_TILE_DEVICE static float block_reduce_max(float v, float* smem_small) + { + const index_t tid = static_cast(threadIdx.x); + const index_t warp_id = tid / WarpSize; + const index_t lane_id = tid % WarpSize; + + for(index_t stride = 1; stride < WarpSize; stride *= 2) + v = max(v, warp_shuffle(v, __lane_id() ^ stride)); + if(lane_id == 0) + smem_small[warp_id] = v; + __syncthreads(); + if(tid == 0) + { + float s = smem_small[0]; + for(index_t w = 1; w < NumWarps; ++w) + s = max(s, smem_small[w]); + smem_small[0] = s; + } + __syncthreads(); + return smem_small[0]; + } + + // ====================================================================== + template + CK_TILE_DEVICE void operator()(const QWindowType& q_window_in, + const KWindowType& k_window_in, + index_t seqlen_q, + index_t seqlen_k, + index_t qb, + index_t N_k, + index_t /*nhead_ratio_qk*/, + float simthreshd1, + float cdfthreshd, + float topk, + float scale, + uint8_t* block_map_ptr, + int32_t* lut_ptr, + int32_t* valid_block_num_ptr, + void* smem_ptr) const + { + const index_t tid = static_cast(threadIdx.x); + + auto* smem_float = reinterpret_cast(smem_ptr); + auto* smem_scores = + reinterpret_cast(reinterpret_cast(smem_ptr) + kScoreOffset); + auto* smem_bmap = + reinterpret_cast(reinterpret_cast(smem_ptr) + kBmapOffset); + auto* smem_small = + reinterpret_cast(reinterpret_cast(smem_ptr) + kSmallOffset); + + const index_t bs_q = min(static_cast(kM0), seqlen_q - qb * kM0); + const float inv_bs_q = (bs_q > 0) ? (1.0f / static_cast(bs_q)) : 0.f; + + // ================================================================== + // Phase 1: Q Block Statistics + // ================================================================== + auto q_tile = load_tile(q_window_in); + + float q_data[MPerThread * KPerThread]; + tile_to_float(q_tile, q_data); + + // 1a. L2 norm per token + float psq[MPerThread]; + row_reduce_sq_norm(q_data, psq, bs_q); + + // 1b. Column sum -> mean + float pooled_q_mean[KPerThread]; + column_reduce_thread_and_warp(q_data, pooled_q_mean); + column_reduce_cross_warp(pooled_q_mean, smem_float); + for(index_t k = 0; k < KPerThread; ++k) + pooled_q_mean[k] *= inv_bs_q; + + // 1c. Normalised sum_hat + float sum_hat[KPerThread]; + column_reduce_normalised(q_data, psq, sum_hat, bs_q); + column_reduce_cross_warp(sum_hat, smem_float); + + // 1d. sim_q = ||sum_hat||^2 / bs_q^2 + float sh_sq = 0.f; + for(index_t k = 0; k < KPerThread; ++k) + sh_sq += sum_hat[k] * sum_hat[k]; + sh_sq = reduce_across_k(sh_sq); + const float denom_q = static_cast(bs_q) * static_cast(bs_q); + const bool sim_q = (denom_q > 0.f) && ((sh_sq / denom_q) > simthreshd1); + + // Not similar → force all K blocks ON, early exit + if(!sim_q) + { + for(index_t i = tid; i < N_k; i += kBlockSize) + block_map_ptr[i] = 1; + + if(lut_ptr != nullptr && tid == 0) + { + int32_t valid = 0, prev = 0; + for(index_t kb = 0; kb < N_k; ++kb) + { + lut_ptr[valid] = static_cast(kb) - prev; + prev = static_cast(kb); + ++valid; + } + for(index_t i = valid; i < N_k; ++i) + lut_ptr[i] = 0; + *valid_block_num_ptr = valid; + } + return; + } + + // ================================================================== + // Phase 2: K Block Loop + // ================================================================== + for(index_t i = tid; i < N_k; i += kBlockSize) + smem_bmap[i] = 0; + __syncthreads(); + + auto k_window = k_window_in; + + for(index_t kb = 0; kb < N_k; ++kb) + { + const index_t bs_k = min(static_cast(kN0), seqlen_k - kb * kN0); + const float inv_bs_k = (bs_k > 0) ? (1.0f / static_cast(bs_k)) : 0.f; + + auto k_tile = load_tile(k_window); + + float k_data[NPerThread * KPerThread]; + tile_to_float(k_tile, k_data); + + // K mean + float pooled_k_mean[KPerThread]; + column_reduce_thread_and_warp(k_data, pooled_k_mean); + column_reduce_cross_warp(pooled_k_mean, smem_float); + for(index_t k = 0; k < KPerThread; ++k) + pooled_k_mean[k] *= inv_bs_k; + + // dot(pooled_q_mean, pooled_k_mean) + float dot = 0.f; + for(index_t k = 0; k < KPerThread; ++k) + dot += pooled_q_mean[k] * pooled_k_mean[k]; + dot = reduce_across_k(dot); + + // K L2 norms + normalised sum_hat + float k_psq[NPerThread]; + row_reduce_sq_norm(k_data, k_psq, bs_k); + + float k_sum_hat[KPerThread]; + column_reduce_normalised(k_data, k_psq, k_sum_hat, bs_k); + column_reduce_cross_warp(k_sum_hat, smem_float); + + // sim_k + float ksh_sq = 0.f; + for(index_t k = 0; k < KPerThread; ++k) + ksh_sq += k_sum_hat[k] * k_sum_hat[k]; + ksh_sq = reduce_across_k(ksh_sq); + const float denom_k = static_cast(bs_k) * static_cast(bs_k); + const bool sim_k = (denom_k > 0.f) && ((ksh_sq / denom_k) > simthreshd1); + + if(tid == 0) + { + if(!sim_k) + { + smem_bmap[kb] = 1; + smem_scores[kb] = -numeric::infinity(); + } + else + { + smem_scores[kb] = dot * scale; + } + } + __syncthreads(); + + move_tile_window(k_window, {kN0, 0}); + } + + // ================================================================== + // Phase 3: Softmax + Selection + // ================================================================== + + // max + float lmax = -numeric::infinity(); + for(index_t i = tid; i < N_k; i += kBlockSize) + lmax = max(lmax, smem_scores[i]); + const float max_score = block_reduce_max(lmax, smem_small); + + // exp + sum + float lsum = 0.f; + for(index_t i = tid; i < N_k; i += kBlockSize) + { + float e = (smem_scores[i] > -numeric::infinity()) + ? __builtin_expf(smem_scores[i] - max_score) + : 0.f; + smem_scores[i] = e; + lsum += e; + } + const float sum_exp = block_reduce_sum(lsum, smem_small); + + // normalise + const float inv_sum = (sum_exp > 0.f) ? (1.0f / sum_exp) : 0.f; + for(index_t i = tid; i < N_k; i += kBlockSize) + smem_scores[i] *= inv_sum; + __syncthreads(); + + // Selection: iterative argmax + index_t num_to_select = + (topk > 0.f) + ? max(static_cast(1), static_cast(topk * static_cast(N_k))) + : N_k; + + float cumulative_prob = 0.f; + for(index_t round = 0; round < num_to_select; ++round) + { + // thread-local argmax + float best_val = -1.f; + index_t best_idx = 0; + for(index_t i = tid; i < N_k; i += kBlockSize) + { + if(smem_scores[i] > best_val || (smem_scores[i] == best_val && i < best_idx)) + { + best_val = smem_scores[i]; + best_idx = i; + } + } + + // warp argmax + for(index_t stride = 1; stride < WarpSize; stride *= 2) + { + float rv = warp_shuffle(best_val, __lane_id() ^ stride); + index_t ri = warp_shuffle(best_idx, __lane_id() ^ stride); + if(rv > best_val || (rv == best_val && ri < best_idx)) + { + best_val = rv; + best_idx = ri; + } + } + + // cross-warp argmax via LDS + const index_t lane_id = tid % WarpSize; + const index_t warp_id = tid / WarpSize; + if(lane_id == 0) + { + smem_small[warp_id] = best_val; + smem_small[NumWarps + warp_id] = bit_cast(static_cast(best_idx)); + } + __syncthreads(); + + if(tid == 0) + { + float bv = smem_small[0]; + index_t bi = bit_cast(smem_small[NumWarps]); + for(index_t w = 1; w < NumWarps; ++w) + { + float wv = smem_small[w]; + index_t wi = bit_cast(smem_small[NumWarps + w]); + if(wv > bv || (wv == bv && wi < bi)) + { + bv = wv; + bi = wi; + } + } + smem_small[0] = bv; + smem_small[1] = bit_cast(static_cast(bi)); + } + __syncthreads(); + + float g_val = smem_small[0]; + index_t g_idx = bit_cast(smem_small[1]); + + if(g_val <= 0.f) + break; + + if(tid == 0) + { + smem_bmap[g_idx] = 1; + smem_scores[g_idx] = -1.f; + } + __syncthreads(); + + if(topk > 0.f) + { + if(round + 1 >= num_to_select) + break; + } + else + { + cumulative_prob += g_val; + if(cumulative_prob >= cdfthreshd) + break; + } + } + + // ================================================================== + // Write outputs to global memory + // ================================================================== + for(index_t i = tid; i < N_k; i += kBlockSize) + block_map_ptr[i] = smem_bmap[i]; + + if(lut_ptr != nullptr && tid == 0) + { + int32_t valid = 0, prev = 0; + for(index_t kb = 0; kb < N_k; ++kb) + { + if(smem_bmap[kb] != 0) + { + lut_ptr[valid] = static_cast(kb) - prev; + prev = static_cast(kb); + ++valid; + } + } + for(index_t i = valid; i < N_k; ++i) + lut_ptr[i] = 0; + *valid_block_num_ptr = valid; + } + } +}; + +} // namespace ck_tile