From eed42a9dfa2d81358344689f04489517ee8c0510 Mon Sep 17 00:00:00 2001 From: Gino Lu Date: Thu, 19 Mar 2026 23:28:36 -0400 Subject: [PATCH 1/4] Add host-side Sparge block-map pipeline for sparse attention examples - Add sparge_tool.hpp: host-side Sparge block-map builder (mean-sim scoring, CDF/topk selection) and VSA delta-LUT converter. - Add test_sparge_jenga_sparse_attn.cpp and test_sparge_vsa_sparse_attn.cpp as end-to-end demos. - Update CMakeLists.txt to register both new executables. Note: block size is currently fixed at 128; flexible block size support is not yet addressed. --- example/ck_tile/50_sparse_attn/CMakeLists.txt | 22 + .../ck_tile/50_sparse_attn/sparge_tool.hpp | 408 +++++++++++++++++ .../test_sparge_jenga_sparse_attn.cpp | 422 +++++++++++++++++ .../test_sparge_vsa_sparse_attn.cpp | 429 ++++++++++++++++++ 4 files changed, 1281 insertions(+) create mode 100644 example/ck_tile/50_sparse_attn/sparge_tool.hpp create mode 100644 example/ck_tile/50_sparse_attn/test_sparge_jenga_sparse_attn.cpp create mode 100644 example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp diff --git a/example/ck_tile/50_sparse_attn/CMakeLists.txt b/example/ck_tile/50_sparse_attn/CMakeLists.txt index 65bb2077642..c916f642ebb 100644 --- a/example/ck_tile/50_sparse_attn/CMakeLists.txt +++ b/example/ck_tile/50_sparse_attn/CMakeLists.txt @@ -88,6 +88,17 @@ target_compile_options(${EXAMPLE_JENGA_SPARSE_ATTN} PRIVATE -Wno-float-equal ) +# Sparge + Jenga Example executable +set(EXAMPLE_SPARGE_JENGA_SPARSE_ATTN "tile_example_sparge_jenga_sparse_attn") +message(DEBUG "adding example ${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN}") +add_executable(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_sparge_jenga_sparse_attn.cpp) +target_link_libraries(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} ${SPARSE_ATTN_JENGA_INSTANCES}) +target_include_directories(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_compile_options(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} PRIVATE + -Wno-undefined-func-template + -Wno-float-equal +) + # ============================================================================ # VSA Sparse Attention # ============================================================================ @@ -153,4 +164,15 @@ target_compile_options(${EXAMPLE_VSA_SPARSE_ATTN} PRIVATE -Wno-float-equal ) +# Sparge + VSA Example executable +set(EXAMPLE_SPARGE_VSA_SPARSE_ATTN "tile_example_sparge_vsa_sparse_attn") +message(DEBUG "adding example ${EXAMPLE_SPARGE_VSA_SPARSE_ATTN}") +add_executable(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_sparge_vsa_sparse_attn.cpp) +target_link_libraries(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} ${SPARSE_ATTN_VSA_INSTANCES}) +target_include_directories(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_compile_options(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} PRIVATE + -Wno-undefined-func-template + -Wno-float-equal +) + set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) diff --git a/example/ck_tile/50_sparse_attn/sparge_tool.hpp b/example/ck_tile/50_sparse_attn/sparge_tool.hpp new file mode 100644 index 00000000000..49c69cc6f74 --- /dev/null +++ b/example/ck_tile/50_sparse_attn/sparge_tool.hpp @@ -0,0 +1,408 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +namespace sparge { + +struct SpargeParams +{ + int BLKQ = 128; + int BLKK = 128; + + // Similarity gate threshold (TODO: per-head support). + float simthreshd1 = 0.6f; + + // Exactly one of the following should be used: + // - Use CDF threshold if topk < 0 + // - Both should be in [0, 1] <-- NEED TO CHECK THIS + float cdfthreshd = 0.98f; + float topk = -1.0f; + + // If true, treat Q/K as BHSD; otherwise BSHD (same convention as CK examples). + bool i_perm = true; +}; + +// Output format CK VSA expects. +struct VSALut +{ + ck_tile::HostTensor lut; // [B, Hq, Q_blk, K_blk] delta-encoded + ck_tile::HostTensor valid_block_num; // [B, Hq, Q_blk] +}; + +namespace detail { + +template +inline float to_f32(const T& x) +{ + return ck_tile::type_convert(x); +} + +// Read element from HostTensor with either BHSD or BSHD layout. +// Q: [B, Hq, Sq, D] if i_perm else [B, Sq, Hq, D] +// K: [B, Hk, Sk, D] if i_perm else [B, Sk, Hk, D] +template +inline float load(const ck_tile::HostTensor& X, bool i_perm, int b, int h, int s, int d) +{ + return i_perm ? to_f32(X(b, h, s, d)) : to_f32(X(b, s, h, d)); +} + +// Compute pooled mean vector of one block: mean over tokens in [s0, s1). +template +std::vector +pooled_mean_block(const ck_tile::HostTensor& X, bool i_perm, int b, int h, int s0, int s1, int d) +{ + std::vector mean(d, 0.0f); + const int bs = std::max(0, s1 - s0); + if(bs == 0) + return mean; + + for(int s = s0; s < s1; ++s) + { + for(int d_ = 0; d_ < d; ++d_) + { + mean[d_] += load(X, i_perm, b, h, s, d_); + } + } + const float inv = 1.0f / static_cast(bs); + for(int d_ = 0; d_ < d; ++d_) + mean[d_] *= inv; + return mean; +} + +// Compute "sim" flag of one block following SpargeAttn's intent: +// mean_sim = sum(Gram(x_hat)) / (BS_*BS_), where x_hat are token vectors normalized along D. +// +// Important: sum(Gram) = ||sum_i x_hat_i||^2, so we can compute it in O(BS_*D) exactly +// instead of O(BS_^2 * D). +template +bool sim_block_flag(const ck_tile::HostTensor& X, + bool i_perm, + int b, + int h, + int s0, + int s1, + int d, + float simthreshd1) +{ + const int bs = std::max(0, s1 - s0); + if(bs == 0) + return false; + + std::vector sum_hat(d, 0.0f); + + for(int s = s0; s < s1; ++s) + { + // Compute L2 norm over D. + float norm2 = 0.0f; + for(int d_ = 0; d_ < d; ++d_) + { + const float v = load(X, i_perm, b, h, s, d_); + norm2 += v * v; + } + float inv_norm = 1.0f; + // spargeAttn use eps to prevent division by zero + if(norm2 > 0.0f) + inv_norm = 1.0f / std::sqrt(norm2); + + // Accumulate normalized vector. + for(int d_ = 0; d_ < d; ++d_) + { + sum_hat[d_] += load(X, i_perm, b, h, s, d_) * inv_norm; + } + } + + float sum_gram = 0.0f; + for(int d_ = 0; d_ < d; ++d_) + sum_gram += sum_hat[d_] * sum_hat[d_]; + + const float denom = static_cast(bs) * static_cast(bs); + const float mean_sim = sum_gram / denom; + + return mean_sim > simthreshd1; +} + +inline int select_count_from_cdf(const std::vector& sorted_probs, float cdfthreshd) +{ + // Choose the smallest n such that cdf[n-1] >= cdfthreshd. + // Ensure at least 1. + if(sorted_probs.empty()) + return 0; + if(cdfthreshd <= 0.0f) + return 1; + + float c = 0.0f; + for(int i = 0; i < static_cast(sorted_probs.size()); ++i) + { + c += sorted_probs[i]; + if(c >= cdfthreshd) + return i + 1; + } + return static_cast(sorted_probs.size()); +} + +inline int select_count_from_topk(int K_blk, float topk) +{ + if(K_blk <= 0) + return 0; + int n = static_cast(std::floor(topk * static_cast(K_blk))); + n = std::max(1, n); + return n; +} + +} // namespace detail + +// Build one-hot block_map[b,hq,qb,kb] in {0,1}. +// - No causal mask +// - No attention sink +// - Logic matches SpargeAttn's structure: +// - score softmax is only over sim_kblocks; ~sim_kblocks are forced ON later +// - if a Q-block is not "similar", force the whole row ON +template +ck_tile::HostTensor build_block_map_meansim(const ck_tile::HostTensor& Q, + const ck_tile::HostTensor& K, + const SpargeParams& p) +{ + const auto qlens = Q.get_lengths(); + const auto klens = K.get_lengths(); + + const int B = static_cast(qlens[0]); + const int Hq = p.i_perm ? static_cast(qlens[1]) : static_cast(qlens[2]); + const int Sq = p.i_perm ? static_cast(qlens[2]) : static_cast(qlens[1]); + const int D = static_cast(qlens[3]); + + [[maybe_unused]] const int Bk = static_cast(klens[0]); + const int Hk = p.i_perm ? static_cast(klens[1]) : static_cast(klens[2]); + const int Sk = p.i_perm ? static_cast(klens[2]) : static_cast(klens[1]); + [[maybe_unused]] const int Dk = static_cast(klens[3]); + + assert(B == Bk && D == Dk && Hq % Hk == 0); + assert(p.BLKQ > 0 && p.BLKK > 0); + + const int nhead_ratio_qk = Hq / Hk; + const int Q_blk = ck_tile::integer_divide_ceil(Sq, p.BLKQ); + const int K_blk = ck_tile::integer_divide_ceil(Sk, p.BLKK); + + ck_tile::HostTensor block_map({B, Hq, Q_blk, K_blk}); + + // pooled_q: [B,Hq,Q_blk,D], pooled_k: [B,Hk,K_blk,D] + // sim_q: [B,Hq,Q_blk], sim_k: [B,Hk,K_blk] + std::vector pooled_q(static_cast(B) * Hq * Q_blk * D, 0.0f); + std::vector pooled_k(static_cast(B) * Hk * K_blk * D, 0.0f); + std::vector sim_q(static_cast(B) * Hq * Q_blk, 0); + std::vector sim_k(static_cast(B) * Hk * K_blk, 0); + + auto idx_pq = [&](int b, int hq, int qb, int d) { + return (((b * Hq + hq) * Q_blk + qb) * D + d); + }; + auto idx_pk = [&](int b, int hk, int kb, int d) { + return (((b * Hk + hk) * K_blk + kb) * D + d); + }; + auto idx_sq = [&](int b, int hq, int qb) { return ((b * Hq + hq) * Q_blk + qb); }; + auto idx_sk = [&](int b, int hk, int kb) { return ((b * Hk + hk) * K_blk + kb); }; + + for(int b = 0; b < B; ++b) + { + for(int hq = 0; hq < Hq; ++hq) + { + // Q blocks + for(int qb = 0; qb < Q_blk; ++qb) + { + const int s0 = qb * p.BLKQ; + const int s1 = std::min(Sq, (qb + 1) * p.BLKQ); + + // pooled mean + auto mean = detail::pooled_mean_block(Q, p.i_perm, b, hq, s0, s1, D); + for(int d = 0; d < D; ++d) + pooled_q[idx_pq(b, hq, qb, d)] = mean[d]; + + // sim flag + sim_q[idx_sq(b, hq, qb)] = + detail::sim_block_flag(Q, p.i_perm, b, hq, s0, s1, D, p.simthreshd1) ? 1 : 0; + } + } + + for(int hk = 0; hk < Hk; ++hk) + { + // K blocks + for(int kb = 0; kb < K_blk; ++kb) + { + const int s0 = kb * p.BLKK; + const int s1 = std::min(Sk, (kb + 1) * p.BLKK); + + auto mean = detail::pooled_mean_block(K, p.i_perm, b, hk, s0, s1, D); + for(int d = 0; d < D; ++d) + pooled_k[idx_pk(b, hk, kb, d)] = mean[d]; + + sim_k[idx_sk(b, hk, kb)] = + detail::sim_block_flag(K, p.i_perm, b, hk, s0, s1, D, p.simthreshd1) ? 1 : 0; + } + } + } + + const float scale = 1.0f / std::sqrt(static_cast(D)); + + // Main loop + for(int b = 0; b < B; ++b) + { + for(int hq = 0; hq < Hq; ++hq) + { + const int hk = hq / nhead_ratio_qk; + + for(int qb = 0; qb < Q_blk; ++qb) + { + const bool q_is_sim = (sim_q[idx_sq(b, hq, qb)] != 0); + + // If Q-block is not "similar", force dense row. + if(!q_is_sim) + { + for(int kb = 0; kb < K_blk; ++kb) + block_map(b, hq, qb, kb) = 1; + continue; + } + + // Compute scores over K blocks (only sim_kblocks participate in softmax; others set + // to -inf). + std::vector score(K_blk, -std::numeric_limits::infinity()); + for(int kb = 0; kb < K_blk; ++kb) + { + const bool k_is_sim = (sim_k[idx_sk(b, hk, kb)] != 0); + if(!k_is_sim) + { + block_map(b, hq, qb, kb) = 1; + continue; + } + + float dot = 0.0f; + for(int d = 0; d < D; ++d) + { + dot += pooled_q[idx_pq(b, hq, qb, d)] * pooled_k[idx_pk(b, hk, kb, d)]; + } + score[kb] = dot * scale; + } + + // Softmax over K_blk (numerically stable). If all -inf, probs become all zeros. + float maxv = -std::numeric_limits::infinity(); + for(int kb = 0; kb < K_blk; ++kb) + maxv = std::max(maxv, score[kb]); + + std::vector prob(K_blk, 0.0f); + if(std::isfinite(maxv)) + { + float sumexp = 0.0f; + for(int kb = 0; kb < K_blk; ++kb) + { + if(!std::isfinite(score[kb])) + continue; + const float e = std::exp(score[kb] - maxv); + prob[kb] = e; + sumexp += e; + } + if(sumexp > 0.0f) + { + const float inv = 1.0f / sumexp; + for(int kb = 0; kb < K_blk; ++kb) + prob[kb] *= inv; + } + else + { + // All exponentials underflowed: keep zeros. + std::fill(prob.begin(), prob.end(), 0.0f); + } + } + + // Sort indices by prob descending. + std::vector order(K_blk); + std::iota(order.begin(), order.end(), 0); + std::sort(order.begin(), order.end(), [&](int a, int c) { + if(prob[a] != prob[c]) + return prob[a] > prob[c]; + return a < c; // tie-breaker for determinism + }); + + // Determine how many to select. + int num_to_select = 0; + if(p.topk > 0.0f) + { + num_to_select = detail::select_count_from_topk(K_blk, p.topk); + } + else + { + // Use CDF threshold selection (smallest n s.t. cumulative prob >= cdfthreshd). + std::vector sorted_probs(K_blk); + for(int i = 0; i < K_blk; ++i) + sorted_probs[i] = prob[order[i]]; + num_to_select = detail::select_count_from_cdf(sorted_probs, p.cdfthreshd); + num_to_select = std::max(1, num_to_select); + } + + // Select top-kb blocks by order[0..num_to_select-1]. + for(int i = 0; i < num_to_select; ++i) + { + const int kb = order[i]; + block_map(b, hq, qb, kb) = 1; + } + } + } + } + + return block_map; +} + +// Convert one-hot block_map -> delta-encoded LUT + valid_block_num (CK VSA format). +template +VSALut block_map_to_vsa_lut_delta(const ck_tile::HostTensor& block_map) +{ + const auto lens = block_map.get_lengths(); + const int B = static_cast(lens[0]); + const int H = static_cast(lens[1]); + const int Q = static_cast(lens[2]); + const int K = static_cast(lens[3]); + + VSALut out{ + ck_tile::HostTensor({B, H, Q, K}), + ck_tile::HostTensor({B, H, Q}), + }; + + for(int b = 0; b < B; ++b) + { + for(int h = 0; h < H; ++h) + { + for(int q = 0; q < Q; ++q) + { + int32_t valid = 0; + int32_t prev = 0; + + for(int k = 0; k < K; ++k) + { + const bool on = static_cast(block_map(b, h, q, k)) != 0; + if(on) + { + out.lut(b, h, q, valid) = static_cast(k - prev); + prev = static_cast(k); + ++valid; + } + } + + out.valid_block_num(b, h, q) = valid; + + // Optional: zero-fill the unused tail for determinism. + for(int i = valid; i < K; ++i) + out.lut(b, h, q, i) = 0; + } + } + } + + return out; +} + +} // namespace sparge diff --git a/example/ck_tile/50_sparse_attn/test_sparge_jenga_sparse_attn.cpp b/example/ck_tile/50_sparse_attn/test_sparge_jenga_sparse_attn.cpp new file mode 100644 index 00000000000..0bd664adf68 --- /dev/null +++ b/example/ck_tile/50_sparse_attn/test_sparge_jenga_sparse_attn.cpp @@ -0,0 +1,422 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// Demo: Sparge block-map -> Jenga sparse attention + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host/reference/reference_blocked_attention.hpp" +#include "ck_tile/core/utility/bit_cast.hpp" + +#include "jenga_sparse_attention.h" +#include "sparge_tool.hpp" + +// ============================================================================ +// Helper Functions +// ============================================================================ + +template +ck_tile::HostTensor make_qkv_tensor(ck_tile::index_t batch, + ck_tile::index_t nhead, + ck_tile::index_t seqlen, + ck_tile::index_t hdim, + bool i_perm) +{ + if(i_perm) + { + return ck_tile::HostTensor({batch, nhead, seqlen, hdim}); + } + return ck_tile::HostTensor({batch, seqlen, nhead, hdim}); +} + +template +ck_tile::HostTensor to_bhsd(const ck_tile::HostTensor& tensor, bool is_bhsd) +{ + auto lens = tensor.get_lengths(); + ck_tile::index_t batch = lens[0]; + ck_tile::index_t seqlen = is_bhsd ? lens[2] : lens[1]; + ck_tile::index_t nhead = is_bhsd ? lens[1] : lens[2]; + ck_tile::index_t hdim = lens[3]; + + ck_tile::HostTensor out({batch, nhead, seqlen, hdim}); + for(ck_tile::index_t b = 0; b < batch; ++b) + { + for(ck_tile::index_t h = 0; h < nhead; ++h) + { + for(ck_tile::index_t s = 0; s < seqlen; ++s) + { + for(ck_tile::index_t d = 0; d < hdim; ++d) + { + out(b, h, s, d) = is_bhsd ? tensor(b, h, s, d) : tensor(b, s, h, d); + } + } + } + } + return out; +} + +template +auto get_error_tolerance() +{ + double rtol = 1e-2; + double atol = 4e-2; + if constexpr(std::is_same_v) + { + atol = 2e-1; + rtol = 2e-1; + } + return ck_tile::make_tuple(rtol, atol); +} + +template +float to_float_for_compare(T value) +{ + return static_cast(value); +} + +template <> +float to_float_for_compare(ck_tile::bf16_t value) +{ +#if CK_TILE_USE_CUSTOM_DATA_TYPE + return static_cast(value); +#else + return ck_tile::bf16_to_float_raw(ck_tile::bit_cast(value)); +#endif +} + +// ============================================================================ +// Command line argument parser +// ============================================================================ + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("v", "1", "0:no validation, 1:cpu validation") + .insert("b", "1", "batch size") + .insert("h", "4", "num of head for q") + .insert("h_k", "-1", "num of head for k/v, -1 means equal to h") + .insert("s", "4096", "seqlen_q") + .insert("s_k", "-1", "seqlen_k, -1 means equal to s") + .insert("d", "128", "head dim for q, k") + .insert("d_v", "-1", "head dim for v, -1 means equal to d") + .insert("prec", "fp16", "data type: fp16/bf16") + .insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d") + .insert("operm", "1", "permute output") + .insert("seed", "42", "random seed") + .insert("warmup", "5", "warmup iterations") + .insert("repeat", "20", "benchmark iterations") + .insert("kname", "0", "print kernel name") + // Sparge-specific + .insert("blkq", "128", "Sparge BLKQ") + .insert("blkk", "128", "Sparge BLKK") + .insert("simthreshd1", "0.6", "Sparge sim threshold") + .insert("cdfthreshd", "0.98", "Sparge CDF threshold (used when topk < 0)") + .insert("topk", "-1.0", "Sparge topk ratio in (0,1]; if > 0, overrides cdfthreshd"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// ============================================================================ +// Main Test Function +// ============================================================================ + +template +bool run_test(const ck_tile::ArgParser& arg_parser) +{ + int do_validation = arg_parser.get_int("v"); + ck_tile::index_t batch = arg_parser.get_int("b"); + ck_tile::index_t nhead = arg_parser.get_int("h"); + ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); + ck_tile::index_t seqlen_q = arg_parser.get_int("s"); + ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); + ck_tile::index_t hdim_q = arg_parser.get_int("d"); + ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); + bool i_perm = arg_parser.get_bool("iperm"); + bool o_perm = arg_parser.get_bool("operm"); + uint32_t seed = arg_parser.get_uint32("seed"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + int kname = arg_parser.get_int("kname"); + + // Sparge params + ck_tile::index_t blkq = arg_parser.get_int("blkq"); + ck_tile::index_t blkk = arg_parser.get_int("blkk"); + float simthreshd1 = arg_parser.get_float("simthreshd1"); + float cdfthreshd = arg_parser.get_float("cdfthreshd"); + float topk = arg_parser.get_float("topk"); + + if(nhead_k < 0) + nhead_k = nhead; + if(seqlen_k < 0) + seqlen_k = seqlen_q; + if(hdim_v < 0) + hdim_v = hdim_q; + + if(blkq != 128 || blkk != 128 || hdim_q != 128 || hdim_v != 128) + { + std::cout << "\n>>> TEST SKIPPED <<<" << std::endl; + std::cout << "Jenga/VSA kernel instances are generated for BLKQ=BLKK=128, " + "hdim_q=128, hdim_v=128 only." + << std::endl; + std::cout << "TEST SKIPPED" << std::endl; + return true; + } + + ck_tile::index_t BLKQ = blkq; + ck_tile::index_t BLKK = blkk; + + ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ; + ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK; + + std::cout << "============================================================" << std::endl; + std::cout << "[Sparge -> Jenga Sparse Attention Demo]" << std::endl; + std::cout << "============================================================" << std::endl; + std::cout << " Batch: " << batch << ", nhead_q: " << nhead << ", nhead_k: " << nhead_k + << std::endl; + std::cout << " seqlen_q: " << seqlen_q << ", seqlen_k: " << seqlen_k << std::endl; + std::cout << " hdim_q: " << hdim_q << ", hdim_v: " << hdim_v << std::endl; + std::cout << " BLKQ=" << BLKQ << ", BLKK=" << BLKK << std::endl; + std::cout << " num_q_blocks: " << num_q_blocks << ", num_k_blocks: " << num_k_blocks + << std::endl; + std::cout << " Sparge(simthreshd1=" << simthreshd1 << ", cdfthreshd=" << cdfthreshd + << ", topk=" << topk << ")" << std::endl; + std::cout << " i_perm: " << i_perm << ", o_perm: " << o_perm << std::endl; + + // Create host tensors + ck_tile::HostTensor q_host = make_qkv_tensor(batch, nhead, seqlen_q, hdim_q, i_perm); + ck_tile::HostTensor k_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_q, i_perm); + ck_tile::HostTensor v_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_v, i_perm); + ck_tile::HostTensor output_host = + o_perm ? ck_tile::HostTensor({batch, nhead, seqlen_q, hdim_v}) + : ck_tile::HostTensor({batch, seqlen_q, nhead, hdim_v}); + ck_tile::HostTensor output_ref({batch, nhead, seqlen_q, hdim_v}); + + std::cout << "\nInitializing tensors..." << std::endl; + ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed}(q_host); + ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 1}(k_host); + ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 2}(v_host); + + // Build block map using Sparge tool + std::cout << "Building Sparge block map..." << std::endl; + sparge::SpargeParams p; + p.BLKQ = static_cast(BLKQ); + p.BLKK = static_cast(BLKK); + p.simthreshd1 = simthreshd1; + p.cdfthreshd = cdfthreshd; + p.topk = topk; + p.i_perm = i_perm; + + ck_tile::HostTensor block_relation_onehot = + sparge::build_block_map_meansim(q_host, k_host, p); + + // Print actual sparsity + std::size_t total_blocks = 0; + std::size_t active_blocks = 0; + for(ck_tile::index_t b = 0; b < batch; ++b) + { + for(ck_tile::index_t h = 0; h < nhead; ++h) + { + for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb) + { + for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) + { + total_blocks++; + if(block_relation_onehot(b, h, qb, kb) != 0) + active_blocks++; + } + } + } + } + float actual_sparsity = + 1.0f - static_cast(active_blocks) / static_cast(total_blocks); + std::cout << " Actual sparsity: " << actual_sparsity << " (" << active_blocks << "/" + << total_blocks << " blocks active)" << std::endl; + + std::cout << "\n--- Running Jenga sparse attention kernel ---" << std::endl; + + try + { + if(kname) + { + jenga_sparse_attention(q_host, + k_host, + v_host, + block_relation_onehot, + output_host, + batch, + nhead, + nhead_k, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + i_perm, + o_perm, + seqlen_q, + seqlen_k, + 1); + } + + for(int i = 0; i < warmup; ++i) + { + jenga_sparse_attention(q_host, + k_host, + v_host, + block_relation_onehot, + output_host, + batch, + nhead, + nhead_k, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + i_perm, + o_perm, + seqlen_q, + seqlen_k, + 0); + } + + [[maybe_unused]] auto sync_status1 = hipDeviceSynchronize(); + auto start = std::chrono::high_resolution_clock::now(); + + for(int i = 0; i < repeat; ++i) + { + jenga_sparse_attention(q_host, + k_host, + v_host, + block_relation_onehot, + output_host, + batch, + nhead, + nhead_k, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + i_perm, + o_perm, + seqlen_q, + seqlen_k, + 0); + } + + [[maybe_unused]] auto sync_status2 = hipDeviceSynchronize(); + auto end = std::chrono::high_resolution_clock::now(); + double avg_time_ms = + std::chrono::duration(end - start).count() / repeat; + + std::cout << "\n>>>> Jenga sparse attention average time: " << avg_time_ms << " ms <<<<" + << std::endl; + } + catch(const std::exception& e) + { + std::cerr << "Error during kernel execution: " << e.what() << std::endl; + return false; + } + + bool pass = true; + if(do_validation) + { + std::cout << "\n--- Performing CPU validation ---" << std::endl; + float scale = 1.0f / std::sqrt(static_cast(hdim_q)); + + std::cout << "Computing reference output..." << std::endl; + auto q_ref = to_bhsd(q_host, i_perm); + auto k_ref = to_bhsd(k_host, i_perm); + auto v_ref = to_bhsd(v_host, i_perm); + + ck_tile::reference_blocked_attention( + q_ref, k_ref, v_ref, block_relation_onehot, output_ref, BLKQ, BLKK, scale); + + auto [rtol, atol] = get_error_tolerance(); + + float max_diff = 0.0f; + float max_rel_diff = 0.0f; + std::size_t num_errors = 0; + + auto output_host_bhsd = to_bhsd(output_host, o_perm); + for(std::size_t i = 0; i < output_host_bhsd.mData.size(); ++i) + { + float gpu_val = to_float_for_compare(output_host_bhsd.mData[i]); + float ref_val = to_float_for_compare(output_ref.mData[i]); + float diff = std::abs(gpu_val - ref_val); + float rel_diff = (std::abs(ref_val) > 1e-6f) ? diff / std::abs(ref_val) : diff; + + max_diff = std::max(max_diff, diff); + max_rel_diff = std::max(max_rel_diff, rel_diff); + + if(diff > atol && rel_diff > rtol) + { + num_errors++; + if(num_errors <= 5) + { + std::cout << " Mismatch at index " << i << ": GPU=" << gpu_val + << ", Ref=" << ref_val << ", Diff=" << diff << std::endl; + } + } + } + + std::cout << "\nValidation results:" << std::endl; + std::cout << " Max absolute difference: " << max_diff << std::endl; + std::cout << " Max relative difference: " << max_rel_diff << std::endl; + std::cout << " Number of mismatches: " << num_errors << " / " + << output_host_bhsd.mData.size() << std::endl; + + if(num_errors == 0) + { + std::cout << "\n>>> VALIDATION PASSED <<<" << std::endl; + } + else + { + std::cout << "\n>>> VALIDATION FAILED <<<" << std::endl; + pass = false; + } + } + + std::cout << "\n" << (pass ? "TEST PASSED" : "TEST FAILED") << std::endl; + return pass; +} + +// ============================================================================ +// Main +// ============================================================================ + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + { + std::cerr << "Failed to parse arguments" << std::endl; + return -1; + } + + std::string prec = arg_parser.get_str("prec"); + + bool test_result = false; + if(prec == "fp16") + { + test_result = run_test(arg_parser); + } + else if(prec == "bf16") + { + test_result = run_test(arg_parser); + } + else + { + std::cerr << "Unsupported precision: " << prec << std::endl; + return -1; + } + + return test_result ? 0 : -1; +} diff --git a/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp b/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp new file mode 100644 index 00000000000..dd1d3e60bee --- /dev/null +++ b/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp @@ -0,0 +1,429 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// Demo: Sparge block-map -> (delta LUT) -> VSA sparse attention + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host/reference/reference_blocked_attention.hpp" +#include "ck_tile/core/utility/bit_cast.hpp" + +#include "jenga_sparse_attention.h" +#include "sparge_tool.hpp" + +// ============================================================================ +// Helper Functions +// ============================================================================ + +template +ck_tile::HostTensor make_qkv_tensor(ck_tile::index_t batch, + ck_tile::index_t nhead, + ck_tile::index_t seqlen, + ck_tile::index_t hdim, + bool i_perm) +{ + if(i_perm) + { + return ck_tile::HostTensor({batch, nhead, seqlen, hdim}); + } + return ck_tile::HostTensor({batch, seqlen, nhead, hdim}); +} + +template +ck_tile::HostTensor to_bhsd(const ck_tile::HostTensor& tensor, bool is_bhsd) +{ + auto lens = tensor.get_lengths(); + ck_tile::index_t batch = lens[0]; + ck_tile::index_t seqlen = is_bhsd ? lens[2] : lens[1]; + ck_tile::index_t nhead = is_bhsd ? lens[1] : lens[2]; + ck_tile::index_t hdim = lens[3]; + + ck_tile::HostTensor out({batch, nhead, seqlen, hdim}); + for(ck_tile::index_t b = 0; b < batch; ++b) + { + for(ck_tile::index_t h = 0; h < nhead; ++h) + { + for(ck_tile::index_t s = 0; s < seqlen; ++s) + { + for(ck_tile::index_t d = 0; d < hdim; ++d) + { + out(b, h, s, d) = is_bhsd ? tensor(b, h, s, d) : tensor(b, s, h, d); + } + } + } + } + return out; +} + +template +auto get_error_tolerance() +{ + double rtol = 1e-2; + double atol = 4e-2; + if constexpr(std::is_same_v) + { + atol = 2e-1; + rtol = 2e-1; + } + return ck_tile::make_tuple(rtol, atol); +} + +template +float to_float_for_compare(T value) +{ + return static_cast(value); +} + +template <> +float to_float_for_compare(ck_tile::bf16_t value) +{ +#if CK_TILE_USE_CUSTOM_DATA_TYPE + return static_cast(value); +#else + return ck_tile::bf16_to_float_raw(ck_tile::bit_cast(value)); +#endif +} + +// ============================================================================ +// Command line argument parser +// ============================================================================ + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("v", "1", "0:no validation, 1:cpu validation") + .insert("b", "1", "batch size") + .insert("h", "4", "num of head for q") + .insert("h_k", "-1", "num of head for k/v, -1 means equal to h") + .insert("s", "4096", "seqlen_q") + .insert("s_k", "-1", "seqlen_k, -1 means equal to s") + .insert("d", "128", "head dim for q, k") + .insert("d_v", "-1", "head dim for v, -1 means equal to d") + .insert("prec", "fp16", "data type: fp16/bf16") + .insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d") + .insert("operm", "1", "permute output") + .insert("seed", "42", "random seed") + .insert("warmup", "5", "warmup iterations") + .insert("repeat", "20", "benchmark iterations") + .insert("kname", "0", "print kernel name") + // Sparge-specific + .insert("blkq", "128", "Sparge BLKQ") + .insert("blkk", "128", "Sparge BLKK") + .insert("simthreshd1", "0.6", "Sparge sim threshold") + .insert("cdfthreshd", "0.98", "Sparge CDF threshold (used when topk < 0)") + .insert("topk", "-1.0", "Sparge topk ratio in (0,1]; if > 0, overrides cdfthreshd"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// ============================================================================ +// Main Test Function +// ============================================================================ + +template +bool run_test(const ck_tile::ArgParser& arg_parser) +{ + int do_validation = arg_parser.get_int("v"); + ck_tile::index_t batch = arg_parser.get_int("b"); + ck_tile::index_t nhead = arg_parser.get_int("h"); + ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); + ck_tile::index_t seqlen_q = arg_parser.get_int("s"); + ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); + ck_tile::index_t hdim_q = arg_parser.get_int("d"); + ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); + bool i_perm = arg_parser.get_bool("iperm"); + bool o_perm = arg_parser.get_bool("operm"); + uint32_t seed = arg_parser.get_uint32("seed"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + int kname = arg_parser.get_int("kname"); + + // Sparge params + ck_tile::index_t blkq = arg_parser.get_int("blkq"); + ck_tile::index_t blkk = arg_parser.get_int("blkk"); + float simthreshd1 = arg_parser.get_float("simthreshd1"); + float cdfthreshd = arg_parser.get_float("cdfthreshd"); + float topk = arg_parser.get_float("topk"); + + if(nhead_k < 0) + nhead_k = nhead; + if(seqlen_k < 0) + seqlen_k = seqlen_q; + if(hdim_v < 0) + hdim_v = hdim_q; + + if(blkq != 128 || blkk != 128 || hdim_q != 128 || hdim_v != 128) + { + std::cout << "\n>>> TEST SKIPPED <<<" << std::endl; + std::cout << "VSA kernel instances are generated for BLKQ=BLKK=128, " + "hdim_q=128, hdim_v=128 only." + << std::endl; + std::cout << "TEST SKIPPED" << std::endl; + return true; + } + + ck_tile::index_t BLKQ = blkq; + ck_tile::index_t BLKK = blkk; + + ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ; + ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK; + + std::cout << "============================================================" << std::endl; + std::cout << "[Sparge -> VSA Sparse Attention Demo]" << std::endl; + std::cout << "============================================================" << std::endl; + std::cout << " Batch: " << batch << ", nhead_q: " << nhead << ", nhead_k: " << nhead_k + << std::endl; + std::cout << " seqlen_q: " << seqlen_q << ", seqlen_k: " << seqlen_k << std::endl; + std::cout << " hdim_q: " << hdim_q << ", hdim_v: " << hdim_v << std::endl; + std::cout << " BLKQ=" << BLKQ << ", BLKK=" << BLKK << std::endl; + std::cout << " num_q_blocks: " << num_q_blocks << ", num_k_blocks: " << num_k_blocks + << std::endl; + std::cout << " Sparge(simthreshd1=" << simthreshd1 << ", cdfthreshd=" << cdfthreshd + << ", topk=" << topk << ")" << std::endl; + std::cout << " i_perm: " << i_perm << ", o_perm: " << o_perm << std::endl; + + // Create host tensors + ck_tile::HostTensor q_host = make_qkv_tensor(batch, nhead, seqlen_q, hdim_q, i_perm); + ck_tile::HostTensor k_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_q, i_perm); + ck_tile::HostTensor v_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_v, i_perm); + ck_tile::HostTensor output_host = + o_perm ? ck_tile::HostTensor({batch, nhead, seqlen_q, hdim_v}) + : ck_tile::HostTensor({batch, seqlen_q, nhead, hdim_v}); + ck_tile::HostTensor output_ref({batch, nhead, seqlen_q, hdim_v}); + + std::cout << "\nInitializing tensors..." << std::endl; + ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed}(q_host); + ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 1}(k_host); + ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 2}(v_host); + + // Build block map using Sparge tool + std::cout << "Building Sparge block map..." << std::endl; + sparge::SpargeParams p; + p.BLKQ = static_cast(BLKQ); + p.BLKK = static_cast(BLKK); + p.simthreshd1 = simthreshd1; + p.cdfthreshd = cdfthreshd; + p.topk = topk; + p.i_perm = i_perm; + + ck_tile::HostTensor block_relation_onehot = + sparge::build_block_map_meansim(q_host, k_host, p); + + // Convert to VSA LUT (delta-encoded) + valid_block_num + std::cout << "Converting block map to VSA LUT (delta)..." << std::endl; + auto vsa_lut = sparge::block_map_to_vsa_lut_delta(block_relation_onehot); + + // Print actual sparsity (based on one-hot) + std::size_t total_blocks = 0; + std::size_t active_blocks = 0; + for(ck_tile::index_t b = 0; b < batch; ++b) + { + for(ck_tile::index_t h = 0; h < nhead; ++h) + { + for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb) + { + for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) + { + total_blocks++; + if(block_relation_onehot(b, h, qb, kb) != 0) + active_blocks++; + } + } + } + } + float actual_sparsity = + 1.0f - static_cast(active_blocks) / static_cast(total_blocks); + std::cout << " Actual sparsity: " << actual_sparsity << " (" << active_blocks << "/" + << total_blocks << " blocks active)" << std::endl; + + std::cout << "\n--- Running VSA sparse attention kernel ---" << std::endl; + + try + { + if(kname) + { + vsa_sparse_attention(q_host, + k_host, + v_host, + vsa_lut.lut, + vsa_lut.valid_block_num, + output_host, + batch, + nhead, + nhead_k, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + i_perm, + o_perm, + seqlen_q, + seqlen_k, + 1); + } + + for(int i = 0; i < warmup; ++i) + { + vsa_sparse_attention(q_host, + k_host, + v_host, + vsa_lut.lut, + vsa_lut.valid_block_num, + output_host, + batch, + nhead, + nhead_k, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + i_perm, + o_perm, + seqlen_q, + seqlen_k, + 0); + } + + [[maybe_unused]] auto sync_status1 = hipDeviceSynchronize(); + auto start = std::chrono::high_resolution_clock::now(); + + for(int i = 0; i < repeat; ++i) + { + vsa_sparse_attention(q_host, + k_host, + v_host, + vsa_lut.lut, + vsa_lut.valid_block_num, + output_host, + batch, + nhead, + nhead_k, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + i_perm, + o_perm, + seqlen_q, + seqlen_k, + 0); + } + + [[maybe_unused]] auto sync_status2 = hipDeviceSynchronize(); + auto end = std::chrono::high_resolution_clock::now(); + double avg_time_ms = + std::chrono::duration(end - start).count() / repeat; + + std::cout << "\n>>>> VSA sparse attention average time: " << avg_time_ms << " ms <<<<" + << std::endl; + } + catch(const std::exception& e) + { + std::cerr << "Error during kernel execution: " << e.what() << std::endl; + return false; + } + + bool pass = true; + if(do_validation) + { + std::cout << "\n--- Performing CPU validation ---" << std::endl; + float scale = 1.0f / std::sqrt(static_cast(hdim_q)); + + std::cout << "Computing reference output..." << std::endl; + auto q_ref = to_bhsd(q_host, i_perm); + auto k_ref = to_bhsd(k_host, i_perm); + auto v_ref = to_bhsd(v_host, i_perm); + + ck_tile::reference_blocked_attention( + q_ref, k_ref, v_ref, block_relation_onehot, output_ref, BLKQ, BLKK, scale); + + auto [rtol, atol] = get_error_tolerance(); + + float max_diff = 0.0f; + float max_rel_diff = 0.0f; + std::size_t num_errors = 0; + + auto output_host_bhsd = to_bhsd(output_host, o_perm); + for(std::size_t i = 0; i < output_host_bhsd.mData.size(); ++i) + { + float gpu_val = to_float_for_compare(output_host_bhsd.mData[i]); + float ref_val = to_float_for_compare(output_ref.mData[i]); + float diff = std::abs(gpu_val - ref_val); + float rel_diff = (std::abs(ref_val) > 1e-6f) ? diff / std::abs(ref_val) : diff; + + max_diff = std::max(max_diff, diff); + max_rel_diff = std::max(max_rel_diff, rel_diff); + + if(diff > atol && rel_diff > rtol) + { + num_errors++; + if(num_errors <= 5) + { + std::cout << " Mismatch at index " << i << ": GPU=" << gpu_val + << ", Ref=" << ref_val << ", Diff=" << diff << std::endl; + } + } + } + + std::cout << "\nValidation results:" << std::endl; + std::cout << " Max absolute difference: " << max_diff << std::endl; + std::cout << " Max relative difference: " << max_rel_diff << std::endl; + std::cout << " Number of mismatches: " << num_errors << " / " + << output_host_bhsd.mData.size() << std::endl; + + if(num_errors == 0) + { + std::cout << "\n>>> VALIDATION PASSED <<<" << std::endl; + } + else + { + std::cout << "\n>>> VALIDATION FAILED <<<" << std::endl; + pass = false; + } + } + + std::cout << "\n" << (pass ? "TEST PASSED" : "TEST FAILED") << std::endl; + return pass; +} + +// ============================================================================ +// Main +// ============================================================================ + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + { + std::cerr << "Failed to parse arguments" << std::endl; + return -1; + } + + std::string prec = arg_parser.get_str("prec"); + + bool test_result = false; + if(prec == "fp16") + { + test_result = run_test(arg_parser); + } + else if(prec == "bf16") + { + test_result = run_test(arg_parser); + } + else + { + std::cerr << "Unsupported precision: " << prec << std::endl; + return -1; + } + + return test_result ? 0 : -1; +} From 9317fc4a8508cb53ec9bd829781d4781d84ce428 Mon Sep 17 00:00:00 2001 From: Gino Lu Date: Tue, 24 Mar 2026 05:57:54 -0400 Subject: [PATCH 2/4] Support 64x128 tile size in sparge fwd for Jenga and VSA paths --- example/ck_tile/50_sparse_attn/CMakeLists.txt | 116 ++- .../codegen/ops/sparge_fwd_jenga.py | 799 ++++++++++++++++++ .../codegen/ops/sparge_fwd_vsa.py | 799 ++++++++++++++++++ .../ck_tile/50_sparse_attn/fmha_fwd_trek.hpp | 6 + .../50_sparse_attn/jenga_sparge_attention.cpp | 189 +++++ .../50_sparse_attn/jenga_sparge_attention.h | 27 + .../test_sparge_jenga_sparse_attn.cpp | 14 +- .../test_sparge_vsa_sparse_attn.cpp | 14 +- .../50_sparse_attn/vsa_sparge_attention.cpp | 195 +++++ .../50_sparse_attn/vsa_sparge_attention.h | 28 + ...block_fmha_pipeline_qr_ks_vs_async_vsa.hpp | 2 +- 11 files changed, 2167 insertions(+), 22 deletions(-) create mode 100644 example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_jenga.py create mode 100644 example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_vsa.py create mode 100644 example/ck_tile/50_sparse_attn/jenga_sparge_attention.cpp create mode 100644 example/ck_tile/50_sparse_attn/jenga_sparge_attention.h create mode 100644 example/ck_tile/50_sparse_attn/vsa_sparge_attention.cpp create mode 100644 example/ck_tile/50_sparse_attn/vsa_sparge_attention.h diff --git a/example/ck_tile/50_sparse_attn/CMakeLists.txt b/example/ck_tile/50_sparse_attn/CMakeLists.txt index c916f642ebb..0ac86f6affa 100644 --- a/example/ck_tile/50_sparse_attn/CMakeLists.txt +++ b/example/ck_tile/50_sparse_attn/CMakeLists.txt @@ -1,8 +1,8 @@ -# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -# SPDX-License-Identifier: MIT -# CMakeLists.txt for sparse attention (Jenga and VSA) +#Copyright(c) Advanced Micro Devices, Inc., or its affiliates. +#SPDX - License - Identifier : MIT +#CMakeLists.txt for sparse attention(Jenga and VSA) -# Use SUPPORTED_GPU_TARGETS directly +#Use SUPPORTED_GPU_TARGETS directly set(INST_TARGETS ${SUPPORTED_GPU_TARGETS}) set(GPU_TARGETS ${SUPPORTED_GPU_TARGETS}) @@ -16,7 +16,7 @@ endif() message(STATUS "Building Sparse Attention (Jenga & VSA) for targets: ${INST_TARGETS}") -# Code generation scripts +#Code generation scripts file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS ${CMAKE_CURRENT_LIST_DIR}/generate.py ${CMAKE_CURRENT_LIST_DIR}/codegen/*.py @@ -88,11 +88,62 @@ target_compile_options(${EXAMPLE_JENGA_SPARSE_ATTN} PRIVATE -Wno-float-equal ) +# ============================================================================ +# Sparge Jenga (64x128 tile) +# ============================================================================ +set(SPARGE_JENGA_CODE_GEN_ARGS + ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api sparge_fwd_jenga + --receipt 600 +) + +execute_process( + COMMAND ${Python3_EXECUTABLE} ${SPARGE_JENGA_CODE_GEN_ARGS} + --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/sparge_jenga_blob_list.txt + RESULT_VARIABLE ret +) +if(ret AND NOT ret EQUAL 0) + message(FATAL_ERROR "Failed to generate Sparge Jenga kernel list") +endif() + +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/sparge_jenga_blob_list.txt SPARGE_JENGA_GEN_BLOBS) + +add_custom_command( + OUTPUT ${SPARGE_JENGA_GEN_BLOBS} + COMMAND ${Python3_EXECUTABLE} ${SPARGE_JENGA_CODE_GEN_ARGS} + --output_dir ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${CODE_GEN_SCRIPTS} + COMMENT "Generate CK Tile Sparge Jenga kernels" +) + +message(STATUS "Sparge Jenga kernel files to be generated: ${SPARGE_JENGA_GEN_BLOBS}") + +set(SPARGE_JENGA_INSTANCES "tile_sparge_jenga_instances") + +add_library(${SPARGE_JENGA_INSTANCES} OBJECT EXCLUDE_FROM_ALL + ${SPARGE_JENGA_GEN_BLOBS} + ${CMAKE_CURRENT_LIST_DIR}/jenga_sparge_attention.cpp +) +target_include_directories(${SPARGE_JENGA_INSTANCES} PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn +) +set_source_files_properties(${SPARGE_JENGA_GEN_BLOBS} PROPERTIES LANGUAGE HIP) +set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/jenga_sparge_attention.cpp PROPERTIES LANGUAGE HIP) +set_property(TARGET ${SPARGE_JENGA_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) + +target_compile_options(${SPARGE_JENGA_INSTANCES} PRIVATE + -DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN + -DCK_TILE_FMHA_FWD_FAST_EXP2 + -Wno-undefined-func-template + -Wno-float-equal +) + # Sparge + Jenga Example executable set(EXAMPLE_SPARGE_JENGA_SPARSE_ATTN "tile_example_sparge_jenga_sparse_attn") message(DEBUG "adding example ${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN}") add_executable(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_sparge_jenga_sparse_attn.cpp) -target_link_libraries(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} ${SPARSE_ATTN_JENGA_INSTANCES}) +target_link_libraries(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} ${SPARGE_JENGA_INSTANCES}) target_include_directories(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_compile_options(${EXAMPLE_SPARGE_JENGA_SPARSE_ATTN} PRIVATE -Wno-undefined-func-template @@ -164,11 +215,62 @@ target_compile_options(${EXAMPLE_VSA_SPARSE_ATTN} PRIVATE -Wno-float-equal ) +# ============================================================================ +# Sparge VSA (64x128 tile) +# ============================================================================ +set(SPARGE_VSA_CODE_GEN_ARGS + ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api sparge_fwd_vsa + --receipt 600 +) + +execute_process( + COMMAND ${Python3_EXECUTABLE} ${SPARGE_VSA_CODE_GEN_ARGS} + --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/sparge_vsa_blob_list.txt + RESULT_VARIABLE ret +) +if(ret AND NOT ret EQUAL 0) + message(FATAL_ERROR "Failed to generate Sparge VSA kernel list") +endif() + +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/sparge_vsa_blob_list.txt SPARGE_VSA_GEN_BLOBS) + +add_custom_command( + OUTPUT ${SPARGE_VSA_GEN_BLOBS} + COMMAND ${Python3_EXECUTABLE} ${SPARGE_VSA_CODE_GEN_ARGS} + --output_dir ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${CODE_GEN_SCRIPTS} + COMMENT "Generate CK Tile Sparge VSA kernels" +) + +message(STATUS "Sparge VSA kernel files to be generated: ${SPARGE_VSA_GEN_BLOBS}") + +set(SPARGE_VSA_INSTANCES "tile_sparge_vsa_instances") + +add_library(${SPARGE_VSA_INSTANCES} OBJECT EXCLUDE_FROM_ALL + ${SPARGE_VSA_GEN_BLOBS} + ${CMAKE_CURRENT_LIST_DIR}/vsa_sparge_attention.cpp +) +target_include_directories(${SPARGE_VSA_INSTANCES} PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn +) +set_source_files_properties(${SPARGE_VSA_GEN_BLOBS} PROPERTIES LANGUAGE HIP) +set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/vsa_sparge_attention.cpp PROPERTIES LANGUAGE HIP) +set_property(TARGET ${SPARGE_VSA_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) + +target_compile_options(${SPARGE_VSA_INSTANCES} PRIVATE + -DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN + -DCK_TILE_FMHA_FWD_FAST_EXP2 + -Wno-undefined-func-template + -Wno-float-equal +) + # Sparge + VSA Example executable set(EXAMPLE_SPARGE_VSA_SPARSE_ATTN "tile_example_sparge_vsa_sparse_attn") message(DEBUG "adding example ${EXAMPLE_SPARGE_VSA_SPARSE_ATTN}") add_executable(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_sparge_vsa_sparse_attn.cpp) -target_link_libraries(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} ${SPARSE_ATTN_VSA_INSTANCES}) +target_link_libraries(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} ${SPARGE_VSA_INSTANCES}) target_include_directories(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_compile_options(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} PRIVATE -Wno-undefined-func-template diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_jenga.py b/example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_jenga.py new file mode 100644 index 00000000000..872da2326ea --- /dev/null +++ b/example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_jenga.py @@ -0,0 +1,799 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT +# generate kernel instances to speed up compilation + +import copy +from dataclasses import dataclass, field +import fnmatch +import itertools +import os +import os.path as path +from pathlib import Path +from typing import List, Optional, Tuple + +from codegen.cpp_symbol_map import ( + BOOL_MAP, + FWD_DTYPE_MAP, + LAYOUT_MAP, + MODE_MAP, + PIPELINE_ENUM_MAP, + PIPELINE_MAP, + get_mask_check_map, + get_mask_map, +) + +GEN_DIR = "" + + +def update_file(file_path, content): + """Update the file at file_path with the given content if it differs from the existing content. + + It avoids unnecessary touching of the file which triggers rebuilds + """ + + existing_content = "" + if path.exists(file_path): + with open(file_path, "r") as file: + existing_content = file.read() + if existing_content == content: + return + with open(file_path, "w") as file: + file.write(content) + + +DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16} + +K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256} + +FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.\n +// auto generated by generate.py +#include "ck_tile/ops/fmha/block/variants.hpp" +#include "fmha_fwd_trek.hpp" +#include "pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp" +#include "kernel/fmha_fwd_jenga_kernel.hpp" + +""" + +# NOTE: Jenga sparse attention kernel has the following restrictions enforced by static_assert: +# - Group mode: NOT supported (batch mode only) +# - Bias: NOT supported (NO_BIAS only) +# - LSE output: NOT supported (false only) +# - Dropout: NOT supported (false only) +# - Logits soft-cap: NOT supported (false only) +# - FP8 static quantization: NOT supported (NO_SCALE only) +# The template below hardcodes these unsupported features accordingly. + +FMHA_FWD_KERNEL_BODY = """ +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; + +using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, + ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, + ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, + ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, + {F_vlayout}>; + +// TileFmhaTraits: spad, skpad, dpad, dvpad, has_logits_soft_cap, bias_enum, +// store_lse, has_dropout, has_randval, quant_scale_enum, occupancy, is_v_rowmajor_skip +using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + false, // has_logits_soft_cap - NOT supported + ck_tile::BlockAttentionBiasEnum::NO_BIAS, // bias - NOT supported + false, // store_lse - NOT supported + false, // has_dropout - NOT supported + false, // has_randval - NOT supported + ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE, // FP8 quant - NOT supported + {F_occupancy}, + false>; + +using fmha_variant_{F_idx} = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>; // logits_soft_cap=0 (NOT supported) + +using fmha_mask_{F_idx} = {F_mask}; + +using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< + typename FmhaSparseFwdTypeConfig::QDataType, + typename FmhaSparseFwdTypeConfig::KDataType, + typename FmhaSparseFwdTypeConfig::VDataType, + typename FmhaSparseFwdTypeConfig::SaccDataType, + typename FmhaSparseFwdTypeConfig::SMPLComputeDataType, + typename FmhaSparseFwdTypeConfig::BiasDataType, + typename FmhaSparseFwdTypeConfig::RandValOutputDataType, + typename FmhaSparseFwdTypeConfig::LSEDataType, + typename FmhaSparseFwdTypeConfig::PDataType, + typename FmhaSparseFwdTypeConfig::OaccDataType, + typename FmhaSparseFwdTypeConfig::ODataType, + fmha_shape_{F_idx}, + {F_mode}, + fmha_variant_{F_idx}, + fmha_mask_{F_idx}, + {F_trload}, + fmha_trait_{F_idx}>; + +using fmha_pipeline_{F_idx} = {F_pipeline}< + fmha_pipeline_problem_{F_idx}>; + +using fmha_epilogue_{F_idx} = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaSparseFwdTypeConfig<{F_dtype}>::ODataType, + {F_spad}, {F_dvpad}>>; + +using fmha_kernel_{F_idx} = + ck_tile::FmhaFwdJengaKernel; + +using trait_{F_idx} = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, + {F_pipeline_enum}, false/*logits*/, fmha_mask_{F_idx}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; + +#include + +template<> +float fmha_jenga_fwd_(const ck_tile::stream_config& s, fmha_jenga_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << "{F_kernel_name}" << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} +""" + +FMHA_FWD_API_FILENAME = "sparge_jenga_fwd_api.cpp" +FMHA_FWD_API = """ +#include + +#include + +namespace {{ +bool get_num_cus(unsigned& num_cus) {{ + int device; + auto status = hipGetDevice(&device); + if(status != hipSuccess) {{ + fprintf(stderr, "failed to get device"); + return false; + }} + + hipDeviceProp_t props{{}}; + status = hipGetDeviceProperties(&props, device); + if(status != hipSuccess) {{ + fprintf(stderr, "failed to get device properties"); + return false; + }} + + num_cus = props.multiProcessorCount; + return true; +}} + +unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{ + const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0; + const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1 + + return batch * nheads * num_m_blocks * num_n_blocks; +}} +}} // namespace + +float sparge_jenga_fwd(fmha_jenga_fwd_traits t, fmha_jenga_fwd_args a, const ck_tile::stream_config& s){{ + float r = -1; + + [[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate + + unsigned num_cus; + if (!get_num_cus(num_cus)) {{ + return r; + }} + + [[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{ + return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0); + }}; + + const bool has_load_tr = ck_tile::is_load_tr_supported(); + +{F_dispatch} + return r; +}} +""" + +FMHA_FWD_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{ +{F_dtype_case} + }} +""" + +FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && + ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ + using trait_ = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; + return fmha_jenga_fwd_(s, a); + }} +""" + + +@dataclass +class CppConstraint: + bool_expr: str = None + + def __str__(self): + if self.bool_expr is None: + return "true" + else: + return f"{self.bool_expr}" + + def __and__(self, other): + return CppConstraint(f"({str(self)}) && ({str(other)})") + + +@dataclass +class FmhaFwdApiTrait: + pipeline_tag: str + # sync with fmha_fwd_traits<>, to generate fallback calls + hdim: str + dtype: str # data type + mode: str # value from MODE_MAP + bm0: int # tile size along q seqlen (block size) + bn0: int # tile size along qk seqlen + bk0: int # tile size along qk gemm unroll + bn1: int # tile size along v head_dim + bk1: int # tile size along kv gemm unroll + bk0max: int + vlayout: str + logits: str + mask: str + spad: str + skpad: str + dpad: str + dvpad: str + tr_load: str + constraint: CppConstraint + + @property + def name(self) -> str: + return ( + f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" + + f"{self.vlayout}-{self.logits}-{self.mask}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}" + ) + + @property + def scheck(self) -> str: + if self.mode == "group": + return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true + if self.spad == "t": + return "true" # always support + return "true" + + @property + def seqtune(self) -> str: + return "true" + + @property + def skcheck(self) -> str: + if self.mode == "group": + return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true + if self.skpad == "t": + return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0" + return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0" + + @property + def dcheck(self) -> str: + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dpad == "t": + return f"a.hdim_q % {vec} == 0" + assert False + + @property + def dvcheck(self) -> str: + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dvpad == "t": + return f"a.hdim_v % {vec} == 0" + assert False + + +@dataclass +class FmhaFwdPipeline: + tag: str + + F_vlayout: str # row/col + F_spad: str # true/false + F_skpad: str # + F_dpad: str # + F_dvpad: str # + F_logits: str # t/f + F_mask: str # value from MASK_MAP + F_trload: str # true/false + F_constraint: CppConstraint = field(default_factory=CppConstraint) + + @property + def name(self) -> str: + def pad_name() -> str: + n = "" + if self.F_spad == "t": + n += "s" + if self.F_skpad == "t": + n += "sk" + if self.F_dpad == "t": + n += "d" + if self.F_dvpad == "t": + n += "dv" + if n != "": + n = "p" + n + return n + + pn = pad_name() + n = f"{self.tag}_v{self.F_vlayout[0]}" + if pn != "": + n += f"_{pn}" + else: + n += "_npad" + + if self.F_logits == "t": + n += "_logits" + else: + n += "_nlogits" + + n += "_nbias" + + if self.F_mask[0:2] == "s_": + if self.F_mask == "s_mask": + n += "_mask" + else: + n += "_nmask" + else: + if self.F_mask != "no": + n += f"_m{self.F_mask[0]}" + else: + n += "_nmask" + + n += "_nskip" + + n += "_nsquant" + + if self.F_trload == "t": + n += "_trload" + else: + n += "_ntrload" + + return n + + +class FmhaFwdApiPool: + def __init__(self, mask_impl): + self.pool = dict() + self.mask_impl = mask_impl + + def register_traits(self, trait: FmhaFwdApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.pool.keys(): + self.pool[trait.dtype] = dict() + hdim = trait.hdim, trait.bn1 + if hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][hdim] = list() + + self.pool[trait.dtype][hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + tr_load_cond_map = {"t": "has_load_tr", "f": "true"} + + per_tr_load = str() + for tr_load in ["t", "f"]: + per_dtypes = str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case = str() + for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): + traits = [ + t + for t in self.pool[dtype][(hdim, hdim_v)] + if tr_load == t.tr_load + ] + inners = str() + for k, trait in enumerate(traits): + if_k = "if" if k == 0 else "else if" + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format( + F_if=if_k, + F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], + # F_logits removed - hardcoded to false (NOT supported) + F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], + F_trload=BOOL_MAP[trait.tr_load], + F_scheck=trait.scheck, + F_seqtune=trait.seqtune, + F_skcheck=trait.skcheck, + F_dcheck=trait.dcheck, + F_dvcheck=trait.dvcheck, + F_constraint=trait.constraint, + F_spad=BOOL_MAP[trait.spad], + F_skpad=BOOL_MAP[trait.skpad], + F_dpad=BOOL_MAP[trait.dpad], + F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, + F_bn0=trait.bn0, + F_bk0=trait.bk0, + F_bn1=trait.bn1, + F_bk1=trait.bk1, + F_bk0max=trait.bk0max, + F_hdim=hdim, + F_dtype=FWD_DTYPE_MAP[dtype], + ) + if_j = "if" if j == 0 else "else if" + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format( + F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners + ) + if_i = "if" if i == 0 else "else if" + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format( + F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case + ) + per_tr_load += FMHA_FWD_API_PER_TRLOAD.format( + F_if="if", + F_trload_cond=tr_load_cond_map[tr_load], + F_dtype_case=per_dtypes, + ) + if not per_tr_load: + # empty string we add some ignore to suppress warning in api + per_tr_load += " (void)t ; (void)s ; (void)a;" + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load) + + +@dataclass +class FmhaFwdTileSize: + F_bm0: int # tile size along q seqlen (block size) + F_bn0: int # tile size along k seqlen + F_bk0: int # tile size along qk gemm unroll + F_bn1: int # tile size along v head_dim + F_bk1: int # tile size along kv gemm unroll + F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) + F_rm0: int # number of warps for gemm0 along q seqlen + F_rn0: int # number of warps for gemm0 along k seqlen + F_rk0: int # number of warps for gemm0 along head dim q (not used) + F_rm1: int # number of warps for gemm1 along q seqlen + F_rn1: int # number of warps for gemm1 along head dim v + F_rk1: int # number of warps for gemm1 along k seqlen (not used) + F_wm0: int # gemm0 warp size along m + F_wn0: int # gemm0 warp size along n + F_wk0: int # gemm0 warp size along k + F_wm1: int # gemm1 warp size along m + F_wn1: int # gemm1 warp size along n + F_wk1: int # gemm1 warp size along k + F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + F_constraint: CppConstraint = field(default_factory=CppConstraint) + + @property + def name(self) -> str: + return ( + f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" + + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" + + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" + + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + ) + + +@dataclass +class FmhaFwdKernel: + F_idx: int # this is not a tunable, but a counter to differentiate symbol + F_hdim: int # hdim + F_dtype: str # data type + F_mode: str # value from MODE_MAP + F_tile: FmhaFwdTileSize + F_pipeline: FmhaFwdPipeline + mask_impl: str + + @property + def template(self) -> str: + # kernel_body removed - unused + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format( + F_idx=self.F_idx, + F_hdim=self.F_hdim, + F_dtype=FWD_DTYPE_MAP[self.F_dtype], + F_bm0=self.F_tile.F_bm0, + F_bn0=self.F_tile.F_bn0, + F_bk0=self.F_tile.F_bk0, + F_bn1=self.F_tile.F_bn1, + F_bk1=self.F_tile.F_bk1, + F_bk0max=self.F_tile.F_bk0max, + F_rm0=self.F_tile.F_rm0, + F_rn0=self.F_tile.F_rn0, + F_rk0=self.F_tile.F_rk0, + F_rm1=self.F_tile.F_rm1, + F_rn1=self.F_tile.F_rn1, + F_rk1=self.F_tile.F_rk1, + F_wm0=self.F_tile.F_wm0, + F_wn0=self.F_tile.F_wn0, + F_wk0=self.F_tile.F_wk0, + F_wm1=self.F_tile.F_wm1, + F_wn1=self.F_tile.F_wn1, + F_wk1=self.F_tile.F_wk1, + F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad=BOOL_MAP[self.F_pipeline.F_spad], + F_skpad=BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad=BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad], + # F_logits removed - hardcoded to false in template (NOT supported) + F_occupancy=self.F_tile.F_occupancy, + F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mode=MODE_MAP[self.F_mode], + F_pipeline=PIPELINE_MAP[self.F_pipeline.tag], + F_trload=BOOL_MAP[self.F_pipeline.F_trload], + F_kernel_name=self.name, + ) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return ( + f"fmha_jenga_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + + self.F_tile.name + + "_" + + self.F_pipeline.name + ) + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdApiTrait: + return FmhaFwdApiTrait( + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0max=self.F_tile.F_bk0max, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + logits=self.F_pipeline.F_logits, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad, + tr_load=self.F_pipeline.F_trload, + constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint, + ) + + +class KernelComponentFactory: + # TODO: design a more practical way to do it + # this is current supported tile size per hdim + @staticmethod + def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: + if dtype == "fp16" or dtype == "bf16": + return { + # (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + # (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), + # FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + # (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (128, 128): [ + FmhaFwdTileSize( + 64, + 128, + 64, + 128, + 64, + 128, + 4, + 1, + 1, + 4, + 1, + 1, + 16, + 16, + 16, + 16, + 16, + 16, + -1, + ), + ], + # (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], + # (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + # (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], + # (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + } + else: + return None + + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + @staticmethod + def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # NOTE: logits soft-cap is NOT supported by Jenga sparse attention (enforced by static_assert) + pipelines = [] + if dtype in ["fp16", "bf16"]: + for logits, mask in itertools.product( + ["f"], # logits soft-cap NOT supported, always false + get_mask_map(mask_impl).keys(), + ): + if hdim == 256 and hdim_v == 256: + # jenga fmha only supports dim <= 192 for now. + continue + pipelines.append( + FmhaFwdPipeline( # fmt: skip + "qr_async", + "row", + "t", + "f", + "t", + "t", + logits, + mask, + "f", + ) + ) + pipelines.append( + FmhaFwdPipeline( # fmt: skip + "qr_async", + "row", + "t", + "t", + "t", + "t", + logits, + mask, + "f", + ) + ) + else: + assert False + return pipelines + + +class CustomFactory(KernelComponentFactory): + @staticmethod + def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: + result = KernelComponentFactory.get_hdim_tile_size_dict(dtype) + if dtype == "fp16" or dtype == "bf16": + if (128, 128) in result.keys(): + result[(128, 128)].insert( + 0, + FmhaFwdTileSize( + 64, + 128, + 64, + 128, + 64, + 128, + 4, + 1, + 1, + 4, + 1, + 1, + 16, + 16, + 16, + 16, + 16, + 16, + -1, + CppConstraint( + "get_num_blocks(128) < num_cus * min_cu_util_rate" + ), + ), + ) + return result + + +def get_fwd_blobs( + kernel_filter: Optional[str], receipt, optdim_list, mask_impl +) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: + gen = list() + api_pool = FmhaFwdApiPool(mask_impl) + + factory = ( + CustomFactory + if os.environ.get("CK_TILE_FMHA_FWD_CUSTOM_FACTORY", "0") == "1" + else KernelComponentFactory + ) + + # Only generate fp16/bf16 kernels for now. + # NOTE: Jenga sparse attention only supports batch mode (group mode NOT supported, enforced by static_assert) + for dtype in ["fp16", "bf16"]: + d = factory.get_hdim_tile_size_dict(dtype) + if d is None: + continue + for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), ["batch"]): + for tile, pipeline in itertools.product( + tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) + ): + if pipeline.tag != "qr_async": + continue + k = FmhaFwdKernel( + F_idx=2, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl, + ) + if kernel_filter != "": + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue + # 2 - Flash attention integration + if receipt in (2, 3): + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + if not cond: + continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + cond &= mode == "batch" + cond &= pipeline.F_logits == "f" + if not cond: + continue + # Aiter(mha_fwd) integration + elif receipt == 100: + cond = dtype in ["fp16", "bf16"] + cond &= mode == "batch" + cond &= pipeline.F_vlayout == "row" + if not cond: + continue + # Aiter(mha_varlen_fwd) integration + elif receipt == 200: + cond = dtype in ["fp16", "bf16"] + cond &= mode == "group" + cond &= pipeline.F_vlayout == "row" + if not cond: + continue + # aiter::mha_fwd C++ api integration + elif receipt == 600: + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + if not cond: + continue + + api_pool.register_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + + +def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: + update_file(autogen_dir / kernel.filename, kernel.template) + + +def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None: + update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api) + + +def write_blobs( + output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl +) -> None: + api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + write_single_fwd_kernel(kernel, output_dir) + write_fwd_api(api_pool, output_dir) + + +def list_blobs( + file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl +) -> None: + with file_path.open("a") as f: + _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n") + f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n") diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_vsa.py b/example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_vsa.py new file mode 100644 index 00000000000..c9a389df3fa --- /dev/null +++ b/example/ck_tile/50_sparse_attn/codegen/ops/sparge_fwd_vsa.py @@ -0,0 +1,799 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT +# generate kernel instances to speed up compilation + +import copy +from dataclasses import dataclass, field +import fnmatch +import itertools +import os +import os.path as path +from pathlib import Path +from typing import List, Optional, Tuple + +from codegen.cpp_symbol_map import ( + BOOL_MAP, + FWD_DTYPE_MAP, + LAYOUT_MAP, + MODE_MAP, + PIPELINE_ENUM_MAP, + PIPELINE_MAP, + get_mask_check_map, + get_mask_map, +) + +GEN_DIR = "" + + +def update_file(file_path, content): + """Update the file at file_path with the given content if it differs from the existing content. + + It avoids unnecessary touching of the file which triggers rebuilds + """ + + existing_content = "" + if path.exists(file_path): + with open(file_path, "r") as file: + existing_content = file.read() + if existing_content == content: + return + with open(file_path, "w") as file: + file.write(content) + + +DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16} + +K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256} + +FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.\n +// auto generated by generate.py +#include "ck_tile/ops/fmha/block/variants.hpp" +#include "fmha_fwd_trek.hpp" +#include "pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp" +#include "kernel/fmha_fwd_vsa_kernel.hpp" + +""" + +# NOTE: VSA sparse attention kernel has the following restrictions enforced by static_assert: +# - Group mode: NOT supported (batch mode only) +# - Bias: NOT supported (NO_BIAS only) +# - LSE output: NOT supported (false only) +# - Dropout: NOT supported (false only) +# - Logits soft-cap: NOT supported (false only) +# - FP8 static quantization: NOT supported (NO_SCALE only) +# The template below hardcodes these unsupported features accordingly. + +FMHA_FWD_KERNEL_BODY = """ +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; + +using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, + ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, + ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, + ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, + {F_vlayout}>; + +// TileFmhaTraits: spad, skpad, dpad, dvpad, has_logits_soft_cap, bias_enum, +// store_lse, has_dropout, has_randval, quant_scale_enum, occupancy, is_v_rowmajor_skip +using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + false, // has_logits_soft_cap - NOT supported + ck_tile::BlockAttentionBiasEnum::NO_BIAS, // bias - NOT supported + false, // store_lse - NOT supported + false, // has_dropout - NOT supported + false, // has_randval - NOT supported + ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE, // FP8 quant - NOT supported + {F_occupancy}, + false>; + +using fmha_variant_{F_idx} = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>; // logits_soft_cap=0 (NOT supported) + +using fmha_mask_{F_idx} = {F_mask}; + +using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< + typename FmhaSparseFwdTypeConfig::QDataType, + typename FmhaSparseFwdTypeConfig::KDataType, + typename FmhaSparseFwdTypeConfig::VDataType, + typename FmhaSparseFwdTypeConfig::SaccDataType, + typename FmhaSparseFwdTypeConfig::SMPLComputeDataType, + typename FmhaSparseFwdTypeConfig::BiasDataType, + typename FmhaSparseFwdTypeConfig::RandValOutputDataType, + typename FmhaSparseFwdTypeConfig::LSEDataType, + typename FmhaSparseFwdTypeConfig::PDataType, + typename FmhaSparseFwdTypeConfig::OaccDataType, + typename FmhaSparseFwdTypeConfig::ODataType, + fmha_shape_{F_idx}, + {F_mode}, + fmha_variant_{F_idx}, + fmha_mask_{F_idx}, + {F_trload}, + fmha_trait_{F_idx}>; + +using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaPipelineQRKSVSAsyncVSA< + fmha_pipeline_problem_{F_idx}>; + +using fmha_epilogue_{F_idx} = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaSparseFwdTypeConfig<{F_dtype}>::ODataType, + {F_spad}, {F_dvpad}>>; + +using fmha_kernel_{F_idx} = + ck_tile::FmhaFwdVSAKernel; + +using trait_{F_idx} = fmha_vsa_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, + {F_pipeline_enum}, false/*logits*/, fmha_mask_{F_idx}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; + +#include + +template<> +float fmha_vsa_fwd_(const ck_tile::stream_config& s, fmha_vsa_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << "{F_kernel_name}" << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} +""" + +FMHA_FWD_API_FILENAME = "sparge_vsa_fwd_api.cpp" +FMHA_FWD_API = """ +#include + +#include + +namespace {{ +bool get_num_cus(unsigned& num_cus) {{ + int device; + auto status = hipGetDevice(&device); + if(status != hipSuccess) {{ + fprintf(stderr, "failed to get device"); + return false; + }} + + hipDeviceProp_t props{{}}; + status = hipGetDeviceProperties(&props, device); + if(status != hipSuccess) {{ + fprintf(stderr, "failed to get device properties"); + return false; + }} + + num_cus = props.multiProcessorCount; + return true; +}} + +unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{ + const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0; + const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1 + + return batch * nheads * num_m_blocks * num_n_blocks; +}} +}} // namespace + +float sparge_vsa_fwd(fmha_vsa_fwd_traits t, fmha_vsa_fwd_args a, const ck_tile::stream_config& s){{ + float r = -1; + + [[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate + + unsigned num_cus; + if (!get_num_cus(num_cus)) {{ + return r; + }} + + [[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{ + return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0); + }}; + + const bool has_load_tr = ck_tile::is_load_tr_supported(); + +{F_dispatch} + return r; +}} +""" + +FMHA_FWD_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{ +{F_dtype_case} + }} +""" + +FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && + ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ + using trait_ = fmha_vsa_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; + return fmha_vsa_fwd_(s, a); + }} +""" + + +@dataclass +class CppConstraint: + bool_expr: str = None + + def __str__(self): + if self.bool_expr is None: + return "true" + else: + return f"{self.bool_expr}" + + def __and__(self, other): + return CppConstraint(f"({str(self)}) && ({str(other)})") + + +@dataclass +class FmhaFwdApiTrait: + pipeline_tag: str + # sync with fmha_fwd_traits<>, to generate fallback calls + hdim: str + dtype: str # data type + mode: str # value from MODE_MAP + bm0: int # tile size along q seqlen (block size) + bn0: int # tile size along qk seqlen + bk0: int # tile size along qk gemm unroll + bn1: int # tile size along v head_dim + bk1: int # tile size along kv gemm unroll + bk0max: int + vlayout: str + logits: str + mask: str + spad: str + skpad: str + dpad: str + dvpad: str + tr_load: str + constraint: CppConstraint + + @property + def name(self) -> str: + return ( + f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" + + f"{self.vlayout}-{self.logits}-{self.mask}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}" + ) + + @property + def scheck(self) -> str: + if self.mode == "group": + return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true + if self.spad == "t": + return "true" # always support + return "true" + + @property + def seqtune(self) -> str: + return "true" + + @property + def skcheck(self) -> str: + if self.mode == "group": + return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true + if self.skpad == "t": + return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0" + return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0" + + @property + def dcheck(self) -> str: + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dpad == "t": + return f"a.hdim_q % {vec} == 0" + assert False + + @property + def dvcheck(self) -> str: + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dvpad == "t": + return f"a.hdim_v % {vec} == 0" + assert False + + +@dataclass +class FmhaFwdPipeline: + tag: str + + F_vlayout: str # row/col + F_spad: str # true/false + F_skpad: str # + F_dpad: str # + F_dvpad: str # + F_logits: str # t/f + F_mask: str # value from MASK_MAP + F_trload: str # true/false + F_constraint: CppConstraint = field(default_factory=CppConstraint) + + @property + def name(self) -> str: + def pad_name() -> str: + n = "" + if self.F_spad == "t": + n += "s" + if self.F_skpad == "t": + n += "sk" + if self.F_dpad == "t": + n += "d" + if self.F_dvpad == "t": + n += "dv" + if n != "": + n = "p" + n + return n + + pn = pad_name() + n = f"{self.tag}_v{self.F_vlayout[0]}" + if pn != "": + n += f"_{pn}" + else: + n += "_npad" + + if self.F_logits == "t": + n += "_logits" + else: + n += "_nlogits" + + n += "_nbias" + + if self.F_mask[0:2] == "s_": + if self.F_mask == "s_mask": + n += "_mask" + else: + n += "_nmask" + else: + if self.F_mask != "no": + n += f"_m{self.F_mask[0]}" + else: + n += "_nmask" + + n += "_nskip" + + n += "_nsquant" + + if self.F_trload == "t": + n += "_trload" + else: + n += "_ntrload" + + return n + + +class FmhaFwdApiPool: + def __init__(self, mask_impl): + self.pool = dict() + self.mask_impl = mask_impl + + def register_traits(self, trait: FmhaFwdApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.pool.keys(): + self.pool[trait.dtype] = dict() + hdim = trait.hdim, trait.bn1 + if hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][hdim] = list() + + self.pool[trait.dtype][hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + tr_load_cond_map = {"t": "has_load_tr", "f": "true"} + + per_tr_load = str() + for tr_load in ["t", "f"]: + per_dtypes = str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case = str() + for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): + traits = [ + t + for t in self.pool[dtype][(hdim, hdim_v)] + if tr_load == t.tr_load + ] + inners = str() + for k, trait in enumerate(traits): + if_k = "if" if k == 0 else "else if" + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format( + F_if=if_k, + F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], + # F_logits removed - hardcoded to false (NOT supported) + F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], + F_trload=BOOL_MAP[trait.tr_load], + F_scheck=trait.scheck, + F_seqtune=trait.seqtune, + F_skcheck=trait.skcheck, + F_dcheck=trait.dcheck, + F_dvcheck=trait.dvcheck, + F_constraint=trait.constraint, + F_spad=BOOL_MAP[trait.spad], + F_skpad=BOOL_MAP[trait.skpad], + F_dpad=BOOL_MAP[trait.dpad], + F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, + F_bn0=trait.bn0, + F_bk0=trait.bk0, + F_bn1=trait.bn1, + F_bk1=trait.bk1, + F_bk0max=trait.bk0max, + F_hdim=hdim, + F_dtype=FWD_DTYPE_MAP[dtype], + ) + if_j = "if" if j == 0 else "else if" + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format( + F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners + ) + if_i = "if" if i == 0 else "else if" + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format( + F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case + ) + per_tr_load += FMHA_FWD_API_PER_TRLOAD.format( + F_if="if", + F_trload_cond=tr_load_cond_map[tr_load], + F_dtype_case=per_dtypes, + ) + if not per_tr_load: + # empty string we add some ignore to suppress warning in api + per_tr_load += " (void)t ; (void)s ; (void)a;" + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load) + + +@dataclass +class FmhaFwdTileSize: + F_bm0: int # tile size along q seqlen (block size) + F_bn0: int # tile size along k seqlen + F_bk0: int # tile size along qk gemm unroll + F_bn1: int # tile size along v head_dim + F_bk1: int # tile size along kv gemm unroll + F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) + F_rm0: int # number of warps for gemm0 along q seqlen + F_rn0: int # number of warps for gemm0 along k seqlen + F_rk0: int # number of warps for gemm0 along head dim q (not used) + F_rm1: int # number of warps for gemm1 along q seqlen + F_rn1: int # number of warps for gemm1 along head dim v + F_rk1: int # number of warps for gemm1 along k seqlen (not used) + F_wm0: int # gemm0 warp size along m + F_wn0: int # gemm0 warp size along n + F_wk0: int # gemm0 warp size along k + F_wm1: int # gemm1 warp size along m + F_wn1: int # gemm1 warp size along n + F_wk1: int # gemm1 warp size along k + F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + F_constraint: CppConstraint = field(default_factory=CppConstraint) + + @property + def name(self) -> str: + return ( + f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" + + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" + + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" + + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + ) + + +@dataclass +class FmhaFwdKernel: + F_idx: int # this is not a tunable, but a counter to differentiate symbol + F_hdim: int # hdim + F_dtype: str # data type + F_mode: str # value from MODE_MAP + F_tile: FmhaFwdTileSize + F_pipeline: FmhaFwdPipeline + mask_impl: str + + @property + def template(self) -> str: + # kernel_body removed - unused + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format( + F_idx=self.F_idx, + F_hdim=self.F_hdim, + F_dtype=FWD_DTYPE_MAP[self.F_dtype], + F_bm0=self.F_tile.F_bm0, + F_bn0=self.F_tile.F_bn0, + F_bk0=self.F_tile.F_bk0, + F_bn1=self.F_tile.F_bn1, + F_bk1=self.F_tile.F_bk1, + F_bk0max=self.F_tile.F_bk0max, + F_rm0=self.F_tile.F_rm0, + F_rn0=self.F_tile.F_rn0, + F_rk0=self.F_tile.F_rk0, + F_rm1=self.F_tile.F_rm1, + F_rn1=self.F_tile.F_rn1, + F_rk1=self.F_tile.F_rk1, + F_wm0=self.F_tile.F_wm0, + F_wn0=self.F_tile.F_wn0, + F_wk0=self.F_tile.F_wk0, + F_wm1=self.F_tile.F_wm1, + F_wn1=self.F_tile.F_wn1, + F_wk1=self.F_tile.F_wk1, + F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad=BOOL_MAP[self.F_pipeline.F_spad], + F_skpad=BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad=BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad], + # F_logits removed - hardcoded to false in template (NOT supported) + F_occupancy=self.F_tile.F_occupancy, + F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mode=MODE_MAP[self.F_mode], + F_pipeline=PIPELINE_MAP[self.F_pipeline.tag], + F_trload=BOOL_MAP[self.F_pipeline.F_trload], + F_kernel_name=self.name, + ) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return ( + f"fmha_vsa_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + + self.F_tile.name + + "_" + + self.F_pipeline.name + ) + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdApiTrait: + return FmhaFwdApiTrait( + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0max=self.F_tile.F_bk0max, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + logits=self.F_pipeline.F_logits, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad, + tr_load=self.F_pipeline.F_trload, + constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint, + ) + + +class KernelComponentFactory: + # TODO: design a more practical way to do it + # this is current supported tile size per hdim + @staticmethod + def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: + if dtype == "fp16" or dtype == "bf16": + return { + # (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + # (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), + # FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + # (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (128, 128): [ + FmhaFwdTileSize( + 64, + 128, + 64, + 128, + 64, + 128, + 4, + 1, + 1, + 4, + 1, + 1, + 16, + 16, + 16, + 16, + 16, + 16, + -1, + ), + ], + # (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], + # (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + # (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], + # (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + } + else: + return None + + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + @staticmethod + def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # NOTE: logits soft-cap is NOT supported by VSA sparse attention (enforced by static_assert) + pipelines = [] + if dtype in ["fp16", "bf16"]: + for logits, mask in itertools.product( + ["f"], # logits soft-cap NOT supported, always false + get_mask_map(mask_impl).keys(), + ): + if hdim == 256 and hdim_v == 256: + # vsa fmha only supports dim <= 192 for now. + continue + pipelines.append( + FmhaFwdPipeline( + "qr_async_vsa", + "row", + "t", + "f", + "t", + "t", + logits, + mask, + "f", + ) + ) + pipelines.append( + FmhaFwdPipeline( + "qr_async_vsa", + "row", + "t", + "t", + "t", + "t", + logits, + mask, + "f", + ) + ) + else: + assert False + return pipelines + + +class CustomFactory(KernelComponentFactory): + @staticmethod + def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: + result = KernelComponentFactory.get_hdim_tile_size_dict(dtype) + if dtype == "fp16" or dtype == "bf16": + if (128, 128) in result.keys(): + result[(128, 128)].insert( + 0, + FmhaFwdTileSize( + 64, + 128, + 64, + 128, + 64, + 128, + 4, + 1, + 1, + 4, + 1, + 1, + 16, + 16, + 16, + 16, + 16, + 16, + -1, + CppConstraint( + "get_num_blocks(128) < num_cus * min_cu_util_rate" + ), + ), + ) + return result + + +def get_fwd_blobs( + kernel_filter: Optional[str], receipt, optdim_list, mask_impl +) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: + gen = list() + api_pool = FmhaFwdApiPool(mask_impl) + + factory = ( + CustomFactory + if os.environ.get("CK_TILE_FMHA_FWD_CUSTOM_FACTORY", "0") == "1" + else KernelComponentFactory + ) + + # Only generate fp16/bf16 kernels for now. + # NOTE: VSA sparse attention only supports batch mode (group mode NOT supported, enforced by static_assert) + for dtype in ["fp16", "bf16"]: + d = factory.get_hdim_tile_size_dict(dtype) + if d is None: + continue + for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), ["batch"]): + for tile, pipeline in itertools.product( + tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) + ): + if pipeline.tag != "qr_async_vsa": + continue + k = FmhaFwdKernel( + F_idx=1, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl, + ) + if kernel_filter != "": + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue + # 2 - Flash attention integration + if receipt in (2, 3): + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + if not cond: + continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + cond &= mode == "batch" + cond &= pipeline.F_logits == "f" + if not cond: + continue + # Aiter(mha_fwd) integration + elif receipt == 100: + cond = dtype in ["fp16", "bf16"] + cond &= mode == "batch" + cond &= pipeline.F_vlayout == "row" + if not cond: + continue + # Aiter(mha_varlen_fwd) integration + elif receipt == 200: + cond = dtype in ["fp16", "bf16"] + cond &= mode == "group" + cond &= pipeline.F_vlayout == "row" + if not cond: + continue + # aiter::mha_fwd C++ api integration + elif receipt == 600: + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + if not cond: + continue + + api_pool.register_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + + +def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: + update_file(autogen_dir / kernel.filename, kernel.template) + + +def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None: + update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api) + + +def write_blobs( + output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl +) -> None: + api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + write_single_fwd_kernel(kernel, output_dir) + write_fwd_api(api_pool, output_dir) + + +def list_blobs( + file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl +) -> None: + with file_path.open("a") as f: + _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n") + f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n") diff --git a/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp index 7349c3576e8..25e3513d2fa 100644 --- a/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp +++ b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp @@ -277,6 +277,9 @@ struct fmha_jenga_fwd_traits float fmha_jenga_fwd(fmha_jenga_fwd_traits, fmha_jenga_fwd_args, const ck_tile::stream_config&); +// sparge jenga +float sparge_jenga_fwd(fmha_jenga_fwd_traits, fmha_jenga_fwd_args, const ck_tile::stream_config&); + template float fmha_jenga_fwd_(const ck_tile::stream_config&, fmha_jenga_fwd_args); @@ -322,6 +325,9 @@ using fmha_vsa_fwd_traits = fmha_jenga_fwd_traits; float fmha_vsa_fwd(fmha_vsa_fwd_traits, fmha_vsa_fwd_args, const ck_tile::stream_config&); +// sparge vsa +float sparge_vsa_fwd(fmha_vsa_fwd_traits, fmha_vsa_fwd_args, const ck_tile::stream_config&); + template float fmha_vsa_fwd_(const ck_tile::stream_config&, fmha_vsa_fwd_args); diff --git a/example/ck_tile/50_sparse_attn/jenga_sparge_attention.cpp b/example/ck_tile/50_sparse_attn/jenga_sparge_attention.cpp new file mode 100644 index 00000000000..88f3e08204e --- /dev/null +++ b/example/ck_tile/50_sparse_attn/jenga_sparge_attention.cpp @@ -0,0 +1,189 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#include "jenga_sparge_attention.h" +#include "fmha_fwd_trek.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" +#include "ck_tile/host/device_memory.hpp" +#include + +template +ck_tile::HostTensor +jenga_sparge_attention(const ck_tile::HostTensor& TQ, + const ck_tile::HostTensor& TK, + const ck_tile::HostTensor& TV, + const ck_tile::HostTensor& Tblock_relation_onehot, + ck_tile::HostTensor& Y, + int batch, + int nhead, + int nhead_k, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + bool i_perm, + bool o_perm, + int max_seqlen_q, + int max_seqlen_k, + int log_level) +{ + static_assert(std::is_same_v || + std::is_same_v, + "Jenga sparse attention supports fp16/bf16 only."); + std::string data_type = "fp16"; + if constexpr(std::is_same_v) + { + data_type = "bf16"; + } + + if(max_seqlen_q == 0) + max_seqlen_q = seqlen_q; + if(max_seqlen_k == 0) + max_seqlen_k = seqlen_k; + bool is_v_rowmajor = true; + float scale_s = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); + std::string msk_str = "0"; + mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k); + + const ck_tile::index_t shape_seqlen_q = seqlen_q; + const ck_tile::index_t shape_seqlen_k = seqlen_k; + + ck_tile::stream_config stream_config{nullptr, + false, // time_kernel + log_level, + 0, + 1, + false}; + + ck_tile::DeviceMem q_buf(TQ.get_element_space_size_in_bytes()); + ck_tile::DeviceMem k_buf(TK.get_element_space_size_in_bytes()); + ck_tile::DeviceMem v_buf(TV.get_element_space_size_in_bytes()); + ck_tile::DeviceMem block_relation_buf(Tblock_relation_onehot.get_element_space_size_in_bytes()); + ck_tile::DeviceMem o_buf(Y.get_element_space_size_in_bytes()); + + q_buf.ToDevice(TQ.data()); + k_buf.ToDevice(TK.data()); + v_buf.ToDevice(TV.data()); + block_relation_buf.ToDevice(Tblock_relation_onehot.data()); + + const auto init_args = [&](auto& args) { + assert(nhead % nhead_k == 0); + const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); + const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); + const ck_tile::index_t stride_v = [&]() { + if(is_v_rowmajor) + return i_perm ? hdim_v : nhead_k * hdim_v; + else + return (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k); + }(); + const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); + const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_k = i_perm ? shape_seqlen_k * hdim_q : hdim_q; + const ck_tile::index_t nhead_stride_v = [&]() { + if(is_v_rowmajor) + return i_perm ? shape_seqlen_k * hdim_v : hdim_v; + else + return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k; + }(); + const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); + const ck_tile::index_t batch_stride_k = nhead_k * shape_seqlen_k * hdim_q; + const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * shape_seqlen_k; + const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); + + args.q_ptr = q_buf.GetDeviceBuffer(); + args.k_ptr = k_buf.GetDeviceBuffer(); + args.v_ptr = v_buf.GetDeviceBuffer(); + args.block_relation_onehot_ptr = block_relation_buf.GetDeviceBuffer(); + + args.batch = batch; + args.seqlen_q = shape_seqlen_q; + args.hdim_q = hdim_q; + args.hdim_v = hdim_v; + args.nhead_q = nhead; + args.nhead_k = nhead_k; + + args.stride_q = stride_q; + args.stride_k = stride_k; + args.stride_v = stride_v; + args.nhead_stride_q = nhead_stride_q; + args.nhead_stride_k = nhead_stride_k; + args.nhead_stride_v = nhead_stride_v; + args.batch_stride_q = batch_stride_q; + args.batch_stride_k = batch_stride_k; + args.batch_stride_v = batch_stride_v; + + args.o_ptr = o_buf.GetDeviceBuffer(); + + args.seqlen_k = shape_seqlen_k; + args.max_seqlen_q = max_seqlen_q; + + args.scale_s = scale_s; + + args.stride_o = stride_o; + args.nhead_stride_o = nhead_stride_o; + args.batch_stride_o = batch_stride_o; + + args.window_size_left = mask.left; + args.window_size_right = mask.right; + args.mask_type = static_cast(mask.type); + }; + + const auto init_traits = [&](auto& traits) { + traits.hdim_q = hdim_q; + traits.hdim_v = hdim_v; + traits.data_type = data_type; + traits.is_v_rowmajor = is_v_rowmajor; + traits.mask_type = mask.type; + }; + + fmha_jenga_fwd_traits fmha_traits; + init_traits(fmha_traits); + + fmha_jenga_fwd_args args; + init_args(args); + + sparge_jenga_fwd(fmha_traits, args, stream_config); + + o_buf.FromDevice(Y.data(), Y.get_element_space_size_in_bytes()); + + return Y; +} + +template ck_tile::HostTensor +jenga_sparge_attention(const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + ck_tile::HostTensor&, + int, + int, + int, + int, + int, + int, + int, + bool, + bool, + int, + int, + int); + +template ck_tile::HostTensor +jenga_sparge_attention(const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + ck_tile::HostTensor&, + int, + int, + int, + int, + int, + int, + int, + bool, + bool, + int, + int, + int); diff --git a/example/ck_tile/50_sparse_attn/jenga_sparge_attention.h b/example/ck_tile/50_sparse_attn/jenga_sparge_attention.h new file mode 100644 index 00000000000..6259fcc73cf --- /dev/null +++ b/example/ck_tile/50_sparse_attn/jenga_sparge_attention.h @@ -0,0 +1,27 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once +#include +#include +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +template +ck_tile::HostTensor +jenga_sparge_attention(const ck_tile::HostTensor& TQ, + const ck_tile::HostTensor& TK, + const ck_tile::HostTensor& TV, + const ck_tile::HostTensor& Tblock_relation_onehot, + ck_tile::HostTensor& Y, + int batch, + int nhead, + int nhead_k, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + bool i_perm, + bool o_perm, + int max_seqlen_q, + int max_seqlen_k, + int log_level = 0); diff --git a/example/ck_tile/50_sparse_attn/test_sparge_jenga_sparse_attn.cpp b/example/ck_tile/50_sparse_attn/test_sparge_jenga_sparse_attn.cpp index 0bd664adf68..590e51db144 100644 --- a/example/ck_tile/50_sparse_attn/test_sparge_jenga_sparse_attn.cpp +++ b/example/ck_tile/50_sparse_attn/test_sparge_jenga_sparse_attn.cpp @@ -16,7 +16,7 @@ #include "ck_tile/host/reference/reference_blocked_attention.hpp" #include "ck_tile/core/utility/bit_cast.hpp" -#include "jenga_sparse_attention.h" +#include "jenga_sparge_attention.h" #include "sparge_tool.hpp" // ============================================================================ @@ -115,7 +115,7 @@ auto create_args(int argc, char* argv[]) .insert("repeat", "20", "benchmark iterations") .insert("kname", "0", "print kernel name") // Sparge-specific - .insert("blkq", "128", "Sparge BLKQ") + .insert("blkq", "64", "Sparge BLKQ") .insert("blkk", "128", "Sparge BLKK") .insert("simthreshd1", "0.6", "Sparge sim threshold") .insert("cdfthreshd", "0.98", "Sparge CDF threshold (used when topk < 0)") @@ -161,10 +161,10 @@ bool run_test(const ck_tile::ArgParser& arg_parser) if(hdim_v < 0) hdim_v = hdim_q; - if(blkq != 128 || blkk != 128 || hdim_q != 128 || hdim_v != 128) + if(blkq != 64 || blkk != 128 || hdim_q != 128 || hdim_v != 128) { std::cout << "\n>>> TEST SKIPPED <<<" << std::endl; - std::cout << "Jenga/VSA kernel instances are generated for BLKQ=BLKK=128, " + std::cout << "Sparge Jenga kernel instances are generated for BLKQ=64, BLKK=128, " "hdim_q=128, hdim_v=128 only." << std::endl; std::cout << "TEST SKIPPED" << std::endl; @@ -247,7 +247,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser) { if(kname) { - jenga_sparse_attention(q_host, + jenga_sparge_attention(q_host, k_host, v_host, block_relation_onehot, @@ -268,7 +268,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser) for(int i = 0; i < warmup; ++i) { - jenga_sparse_attention(q_host, + jenga_sparge_attention(q_host, k_host, v_host, block_relation_onehot, @@ -292,7 +292,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser) for(int i = 0; i < repeat; ++i) { - jenga_sparse_attention(q_host, + jenga_sparge_attention(q_host, k_host, v_host, block_relation_onehot, diff --git a/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp b/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp index dd1d3e60bee..c0feb23e581 100644 --- a/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp +++ b/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp @@ -16,7 +16,7 @@ #include "ck_tile/host/reference/reference_blocked_attention.hpp" #include "ck_tile/core/utility/bit_cast.hpp" -#include "jenga_sparse_attention.h" +#include "vsa_sparge_attention.h" #include "sparge_tool.hpp" // ============================================================================ @@ -115,7 +115,7 @@ auto create_args(int argc, char* argv[]) .insert("repeat", "20", "benchmark iterations") .insert("kname", "0", "print kernel name") // Sparge-specific - .insert("blkq", "128", "Sparge BLKQ") + .insert("blkq", "64", "Sparge BLKQ") .insert("blkk", "128", "Sparge BLKK") .insert("simthreshd1", "0.6", "Sparge sim threshold") .insert("cdfthreshd", "0.98", "Sparge CDF threshold (used when topk < 0)") @@ -161,10 +161,10 @@ bool run_test(const ck_tile::ArgParser& arg_parser) if(hdim_v < 0) hdim_v = hdim_q; - if(blkq != 128 || blkk != 128 || hdim_q != 128 || hdim_v != 128) + if(blkq != 64 || blkk != 128 || hdim_q != 128 || hdim_v != 128) { std::cout << "\n>>> TEST SKIPPED <<<" << std::endl; - std::cout << "VSA kernel instances are generated for BLKQ=BLKK=128, " + std::cout << "Sparge VSA kernel instances are generated for BLKQ=64, BLKK=128, " "hdim_q=128, hdim_v=128 only." << std::endl; std::cout << "TEST SKIPPED" << std::endl; @@ -251,7 +251,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser) { if(kname) { - vsa_sparse_attention(q_host, + vsa_sparge_attention(q_host, k_host, v_host, vsa_lut.lut, @@ -273,7 +273,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser) for(int i = 0; i < warmup; ++i) { - vsa_sparse_attention(q_host, + vsa_sparge_attention(q_host, k_host, v_host, vsa_lut.lut, @@ -298,7 +298,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser) for(int i = 0; i < repeat; ++i) { - vsa_sparse_attention(q_host, + vsa_sparge_attention(q_host, k_host, v_host, vsa_lut.lut, diff --git a/example/ck_tile/50_sparse_attn/vsa_sparge_attention.cpp b/example/ck_tile/50_sparse_attn/vsa_sparge_attention.cpp new file mode 100644 index 00000000000..5f9c2676ddb --- /dev/null +++ b/example/ck_tile/50_sparse_attn/vsa_sparge_attention.cpp @@ -0,0 +1,195 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#include "vsa_sparge_attention.h" +#include "fmha_fwd_trek.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" +#include "ck_tile/host/device_memory.hpp" +#include + +template +ck_tile::HostTensor +vsa_sparge_attention(const ck_tile::HostTensor& TQ, + const ck_tile::HostTensor& TK, + const ck_tile::HostTensor& TV, + const ck_tile::HostTensor& TKV_block_idx, + const ck_tile::HostTensor& TKV_blocks, + ck_tile::HostTensor& Y, + int batch, + int nhead, + int nhead_k, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + bool i_perm, + bool o_perm, + int max_seqlen_q, + int max_seqlen_k, + int log_level) +{ + static_assert(std::is_same_v || + std::is_same_v, + "VSA sparse attention supports fp16/bf16 only."); + std::string data_type = "fp16"; + if constexpr(std::is_same_v) + { + data_type = "bf16"; + } + + if(max_seqlen_q == 0) + max_seqlen_q = seqlen_q; + if(max_seqlen_k == 0) + max_seqlen_k = seqlen_k; + bool is_v_rowmajor = true; + float scale_s = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); + std::string msk_str = "0"; + mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k); + + const ck_tile::index_t shape_seqlen_q = seqlen_q; + const ck_tile::index_t shape_seqlen_k = seqlen_k; + + ck_tile::stream_config stream_config{nullptr, + false, // time_kernel + log_level, + 0, + 1, + false}; + + ck_tile::DeviceMem q_buf(TQ.get_element_space_size_in_bytes()); + ck_tile::DeviceMem k_buf(TK.get_element_space_size_in_bytes()); + ck_tile::DeviceMem v_buf(TV.get_element_space_size_in_bytes()); + ck_tile::DeviceMem lut_buf(TKV_block_idx.get_element_space_size_in_bytes()); + ck_tile::DeviceMem valid_block_num_buf(TKV_blocks.get_element_space_size_in_bytes()); + ck_tile::DeviceMem o_buf(Y.get_element_space_size_in_bytes()); + + q_buf.ToDevice(TQ.data()); + k_buf.ToDevice(TK.data()); + v_buf.ToDevice(TV.data()); + lut_buf.ToDevice(TKV_block_idx.data()); + valid_block_num_buf.ToDevice(TKV_blocks.data()); + + const auto init_args = [&](auto& args) { + assert(nhead % nhead_k == 0); + const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); + const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); + const ck_tile::index_t stride_v = [&]() { + if(is_v_rowmajor) + return i_perm ? hdim_v : nhead_k * hdim_v; + else + return (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k); + }(); + const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); + const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_k = i_perm ? shape_seqlen_k * hdim_q : hdim_q; + const ck_tile::index_t nhead_stride_v = [&]() { + if(is_v_rowmajor) + return i_perm ? shape_seqlen_k * hdim_v : hdim_v; + else + return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k; + }(); + const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); + const ck_tile::index_t batch_stride_k = nhead_k * shape_seqlen_k * hdim_q; + const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * shape_seqlen_k; + const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); + + args.q_ptr = q_buf.GetDeviceBuffer(); + args.k_ptr = k_buf.GetDeviceBuffer(); + args.v_ptr = v_buf.GetDeviceBuffer(); + args.lut_ptr = lut_buf.GetDeviceBuffer(); + args.valid_block_num_ptr = valid_block_num_buf.GetDeviceBuffer(); + + args.batch = batch; + args.seqlen_q = shape_seqlen_q; + args.hdim_q = hdim_q; + args.hdim_v = hdim_v; + args.nhead_q = nhead; + args.nhead_k = nhead_k; + + args.stride_q = stride_q; + args.stride_k = stride_k; + args.stride_v = stride_v; + args.nhead_stride_q = nhead_stride_q; + args.nhead_stride_k = nhead_stride_k; + args.nhead_stride_v = nhead_stride_v; + args.batch_stride_q = batch_stride_q; + args.batch_stride_k = batch_stride_k; + args.batch_stride_v = batch_stride_v; + + args.o_ptr = o_buf.GetDeviceBuffer(); + + args.seqlen_k = shape_seqlen_k; + args.max_seqlen_q = max_seqlen_q; + + args.scale_s = scale_s; + + args.stride_o = stride_o; + args.nhead_stride_o = nhead_stride_o; + args.batch_stride_o = batch_stride_o; + + args.window_size_left = mask.left; + args.window_size_right = mask.right; + args.mask_type = static_cast(mask.type); + }; + + const auto init_traits = [&](auto& traits) { + traits.hdim_q = hdim_q; + traits.hdim_v = hdim_v; + traits.data_type = data_type; + traits.is_v_rowmajor = is_v_rowmajor; + traits.mask_type = mask.type; + }; + + fmha_vsa_fwd_traits fmha_traits; + init_traits(fmha_traits); + + fmha_vsa_fwd_args args; + init_args(args); + + sparge_vsa_fwd(fmha_traits, args, stream_config); + + o_buf.FromDevice(Y.data(), Y.get_element_space_size_in_bytes()); + + return Y; +} + +template ck_tile::HostTensor +vsa_sparge_attention(const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + ck_tile::HostTensor&, + int, + int, + int, + int, + int, + int, + int, + bool, + bool, + int, + int, + int); + +template ck_tile::HostTensor +vsa_sparge_attention(const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + ck_tile::HostTensor&, + int, + int, + int, + int, + int, + int, + int, + bool, + bool, + int, + int, + int); diff --git a/example/ck_tile/50_sparse_attn/vsa_sparge_attention.h b/example/ck_tile/50_sparse_attn/vsa_sparge_attention.h new file mode 100644 index 00000000000..d51a7e8c00b --- /dev/null +++ b/example/ck_tile/50_sparse_attn/vsa_sparge_attention.h @@ -0,0 +1,28 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once +#include +#include +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +template +ck_tile::HostTensor +vsa_sparge_attention(const ck_tile::HostTensor& TQ, + const ck_tile::HostTensor& TK, + const ck_tile::HostTensor& TV, + const ck_tile::HostTensor& TKV_block_idx, + const ck_tile::HostTensor& TKV_blocks, + ck_tile::HostTensor& Y, + int batch, + int nhead, + int nhead_k, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + bool i_perm, + bool o_perm, + int max_seqlen_q, + int max_seqlen_k, + int log_level = 0); diff --git a/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp index 2b097ae5827..578ad7e6039 100644 --- a/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp +++ b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp @@ -200,7 +200,7 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); - int seqlen_k_start = kv_block_idx_ptr[0] * kM0; + int seqlen_k_start = kv_block_idx_ptr[0] * kN0; auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), q_dram_block_window_tmp.get_window_lengths(), q_dram_block_window_tmp.get_window_origin(), From d1d457b82a63fdf6e68461194fa2c1098ace5f93 Mon Sep 17 00:00:00 2001 From: Gino Lu Date: Mon, 13 Apr 2026 03:34:08 -0400 Subject: [PATCH 3/4] Add sparge gpu pipeline in tile_example_sparge_vsa_sparse_attn --- example/ck_tile/50_sparse_attn/CMakeLists.txt | 34 +- .../50_sparse_attn/sparge_blockmap.cpp | 156 ++++++ .../ck_tile/50_sparse_attn/sparge_blockmap.h | 26 + .../50_sparse_attn/sparge_blockmap_inst.cpp | 88 +++ .../50_sparse_attn/sparge_blockmap_trek.hpp | 93 ++++ .../test_sparge_vsa_sparse_attn.cpp | 234 ++++++-- .../kernel/sparge_blockmap_kernel.hpp | 195 +++++++ .../pipeline/sparge_blockmap_pipeline.hpp | 521 ++++++++++++++++++ 8 files changed, 1296 insertions(+), 51 deletions(-) create mode 100644 example/ck_tile/50_sparse_attn/sparge_blockmap.cpp create mode 100644 example/ck_tile/50_sparse_attn/sparge_blockmap.h create mode 100644 example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp create mode 100644 example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp create mode 100644 include/ck_tile/ops/sparse_attn/kernel/sparge_blockmap_kernel.hpp create mode 100644 include/ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp diff --git a/example/ck_tile/50_sparse_attn/CMakeLists.txt b/example/ck_tile/50_sparse_attn/CMakeLists.txt index 0ac86f6affa..169ed87ac3b 100644 --- a/example/ck_tile/50_sparse_attn/CMakeLists.txt +++ b/example/ck_tile/50_sparse_attn/CMakeLists.txt @@ -266,11 +266,41 @@ target_compile_options(${SPARGE_VSA_INSTANCES} PRIVATE -Wno-float-equal ) -# Sparge + VSA Example executable +# ============================================================================ +# Sparge BlockMap GPU Kernel (hand-written instantiation, no codegen) +# ============================================================================ +set(SPARGE_BLOCKMAP_INSTANCES "tile_sparge_blockmap_instances") + +add_library(${SPARGE_BLOCKMAP_INSTANCES} OBJECT EXCLUDE_FROM_ALL + ${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap_inst.cpp + ${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap.cpp +) +target_include_directories(${SPARGE_BLOCKMAP_INSTANCES} PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn +) +set_source_files_properties( + ${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap_inst.cpp + ${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap.cpp + PROPERTIES LANGUAGE HIP +) +set_property(TARGET ${SPARGE_BLOCKMAP_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) + +target_compile_options(${SPARGE_BLOCKMAP_INSTANCES} PRIVATE + -DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN + -DCK_TILE_FMHA_FWD_FAST_EXP2 + -Wno-undefined-func-template + -Wno-float-equal +) + +# Sparge + VSA Example executable (now links blockmap kernel too) set(EXAMPLE_SPARGE_VSA_SPARSE_ATTN "tile_example_sparge_vsa_sparse_attn") message(DEBUG "adding example ${EXAMPLE_SPARGE_VSA_SPARSE_ATTN}") add_executable(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_sparge_vsa_sparse_attn.cpp) -target_link_libraries(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} ${SPARGE_VSA_INSTANCES}) +target_link_libraries(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} + ${SPARGE_VSA_INSTANCES} + ${SPARGE_BLOCKMAP_INSTANCES} +) target_include_directories(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_compile_options(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} PRIVATE -Wno-undefined-func-template diff --git a/example/ck_tile/50_sparse_attn/sparge_blockmap.cpp b/example/ck_tile/50_sparse_attn/sparge_blockmap.cpp new file mode 100644 index 00000000000..b9ac56c533c --- /dev/null +++ b/example/ck_tile/50_sparse_attn/sparge_blockmap.cpp @@ -0,0 +1,156 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#include "sparge_blockmap.h" +#include "sparge_blockmap_trek.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" +#include "ck_tile/host/device_memory.hpp" +#include +#include + +template +sparge::VSALut sparge_blockmap_gpu(const ck_tile::HostTensor& TQ, + const ck_tile::HostTensor& TK, + ck_tile::HostTensor& block_map_out, + int batch, + int nhead_q, + int nhead_k, + int seqlen_q, + int seqlen_k, + int hdim_q, + bool i_perm, + float simthreshd1, + float cdfthreshd, + float topk, + int blkq, + int blkk, + int log_level) +{ + static_assert(std::is_same_v || + std::is_same_v, + "sparge_blockmap_gpu supports fp16/bf16 only."); + + std::string data_type = "fp16"; + if constexpr(std::is_same_v) + { + data_type = "bf16"; + } + + const ck_tile::index_t num_q_blocks = ck_tile::integer_divide_ceil(seqlen_q, blkq); + const ck_tile::index_t num_k_blocks = ck_tile::integer_divide_ceil(seqlen_k, blkk); + + const float scale = 1.0f / std::sqrt(static_cast(hdim_q)); + + // Allocate device memory + ck_tile::DeviceMem q_buf(TQ.get_element_space_size_in_bytes()); + ck_tile::DeviceMem k_buf(TK.get_element_space_size_in_bytes()); + + const std::size_t bmap_bytes = + static_cast(batch) * nhead_q * num_q_blocks * num_k_blocks * sizeof(uint8_t); + const std::size_t lut_bytes = + static_cast(batch) * nhead_q * num_q_blocks * num_k_blocks * sizeof(int32_t); + const std::size_t valid_bytes = + static_cast(batch) * nhead_q * num_q_blocks * sizeof(int32_t); + + ck_tile::DeviceMem bmap_buf(bmap_bytes); + ck_tile::DeviceMem lut_buf(lut_bytes); + ck_tile::DeviceMem valid_buf(valid_bytes); + + q_buf.ToDevice(TQ.data()); + k_buf.ToDevice(TK.data()); + bmap_buf.SetZero(); + lut_buf.SetZero(); + valid_buf.SetZero(); + + // Compute strides (assumes BHSD if i_perm, BSHD otherwise) + const ck_tile::index_t stride_q = i_perm ? hdim_q : nhead_q * hdim_q; + const ck_tile::index_t stride_k = i_perm ? hdim_q : nhead_k * hdim_q; + const ck_tile::index_t nhead_stride_q = + i_perm ? static_cast(seqlen_q) * hdim_q : hdim_q; + const ck_tile::index_t nhead_stride_k = + i_perm ? static_cast(seqlen_k) * hdim_q : hdim_q; + const ck_tile::index_t batch_stride_q = + static_cast(nhead_q) * seqlen_q * hdim_q; + const ck_tile::index_t batch_stride_k = + static_cast(nhead_k) * seqlen_k * hdim_q; + + ck_tile::stream_config stream_config{nullptr, false, log_level, 0, 1, false}; + + sparge_blockmap_args args; + args.q_ptr = q_buf.GetDeviceBuffer(); + args.k_ptr = k_buf.GetDeviceBuffer(); + args.batch = batch; + args.seqlen_q = seqlen_q; + args.seqlen_k = seqlen_k; + args.hdim_q = hdim_q; + args.nhead_q = nhead_q; + args.nhead_k = nhead_k; + args.stride_q = stride_q; + args.stride_k = stride_k; + args.nhead_stride_q = nhead_stride_q; + args.nhead_stride_k = nhead_stride_k; + args.batch_stride_q = batch_stride_q; + args.batch_stride_k = batch_stride_k; + args.simthreshd1 = simthreshd1; + args.cdfthreshd = cdfthreshd; + args.topk = topk; + args.scale = scale; + args.block_map_ptr = bmap_buf.GetDeviceBuffer(); + args.lut_ptr = lut_buf.GetDeviceBuffer(); + args.valid_block_num_ptr = valid_buf.GetDeviceBuffer(); + + sparge_blockmap_traits traits; + traits.data_type = data_type; + traits.hdim_q = hdim_q; + + sparge_blockmap_fwd(traits, args, stream_config); + + // Copy results back to host + bmap_buf.FromDevice(block_map_out.data(), bmap_bytes); + + sparge::VSALut vsa_lut{ + ck_tile::HostTensor({batch, nhead_q, num_q_blocks, num_k_blocks}), + ck_tile::HostTensor({batch, nhead_q, num_q_blocks}), + }; + lut_buf.FromDevice(vsa_lut.lut.data(), lut_bytes); + valid_buf.FromDevice(vsa_lut.valid_block_num.data(), valid_bytes); + + return vsa_lut; +} + +// Explicit template instantiations +template sparge::VSALut +sparge_blockmap_gpu(const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + ck_tile::HostTensor&, + int, + int, + int, + int, + int, + int, + bool, + float, + float, + float, + int, + int, + int); + +template sparge::VSALut +sparge_blockmap_gpu(const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + ck_tile::HostTensor&, + int, + int, + int, + int, + int, + int, + bool, + float, + float, + float, + int, + int, + int); diff --git a/example/ck_tile/50_sparse_attn/sparge_blockmap.h b/example/ck_tile/50_sparse_attn/sparge_blockmap.h new file mode 100644 index 00000000000..3057257ca14 --- /dev/null +++ b/example/ck_tile/50_sparse_attn/sparge_blockmap.h @@ -0,0 +1,26 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" +#include "sparge_tool.hpp" + +template +sparge::VSALut sparge_blockmap_gpu(const ck_tile::HostTensor& TQ, + const ck_tile::HostTensor& TK, + ck_tile::HostTensor& block_map_out, + int batch, + int nhead_q, + int nhead_k, + int seqlen_q, + int seqlen_k, + int hdim_q, + bool i_perm, + float simthreshd1, + float cdfthreshd, + float topk, + int blkq, + int blkk, + int log_level = 0); diff --git a/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp b/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp new file mode 100644 index 00000000000..fbd18b9ff24 --- /dev/null +++ b/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp @@ -0,0 +1,88 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// Hand-written template instantiation for SpargeBlockMapKernel (fp16, D=128). + +#include "sparge_blockmap_trek.hpp" +#include "ck_tile/ops/fmha/block/variants.hpp" + +#include + +// ============================================================================ +// Type configuration for block map kernel (reuses FmhaSparseFwdTypeConfig) +// ============================================================================ + +// fp16: D=128, kM0=64, kN0=128 +using bmap_fp16_block_tile = ck_tile::sequence<64, 128, 128, 128, 128, 128>; +// kM0 kN0 kK0 kN1 kK1 kQKHeaddim(D) + +using bmap_fp16_shape = + ck_tile::TileFmhaShape, // Gemm0BlockWarps + ck_tile::sequence<16, 16, 16>, // Gemm0WarpTile (unused by blockmap, but + // needed by shape) + ck_tile::sequence<4, 1, 1>, // Gemm1BlockWarps + ck_tile::sequence<16, 16, 16>, // Gemm1WarpTile + true>; // VLayout row-major + +using bmap_fp16_trait = ck_tile::TileFmhaTraits; // kIsVRowMajorSkip + +using bmap_fp16_variant = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>; +using bmap_fp16_mask = ck_tile::GenericAttentionMask; + +using bmap_fp16_problem = ck_tile::BlockFmhaPipelineProblem; + +using bmap_fp16_pipeline = ck_tile::SpargeBlockMapPipeline; +using bmap_fp16_kernel = ck_tile::SpargeBlockMapKernel; + +// ============================================================================ +// Dispatch +// ============================================================================ + +float sparge_blockmap_fwd(sparge_blockmap_traits traits, + sparge_blockmap_args args, + const ck_tile::stream_config& s) +{ + if(traits.data_type == "fp16" && traits.hdim_q == 128) + { + using k_ = bmap_fp16_kernel; + if(s.log_level_ > 0) + std::cout << ", sparge_blockmap_fp16_d128" << std::flush; + auto [kargs, grids] = sparge_blockmap_create_kargs_and_grids(args); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); + } + + if(s.log_level_ > 0) + std::cerr << "sparge_blockmap_fwd: unsupported config (data_type=" << traits.data_type + << ", hdim_q=" << traits.hdim_q << ")" << std::endl; + return -1.f; +} diff --git a/example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp b/example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp new file mode 100644 index 00000000000..1e7e33248a2 --- /dev/null +++ b/example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp @@ -0,0 +1,93 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp" +#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" +#include "ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp" +#include "ck_tile/ops/sparse_attn/kernel/sparge_blockmap_kernel.hpp" + +#include "fmha_fwd_trek.hpp" + +#include +#include + +// ============================================================================ +// Args and traits for sparge block map GPU kernel +// ============================================================================ +struct sparge_blockmap_args +{ + const void* q_ptr; + const void* k_ptr; + + ck_tile::index_t batch; + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t hdim_q; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + + float simthreshd1; + float cdfthreshd; + float topk; + float scale; + + void* block_map_ptr; + void* lut_ptr; + void* valid_block_num_ptr; +}; + +struct sparge_blockmap_traits +{ + std::string data_type; + int hdim_q; +}; + +// ============================================================================ +// Create kernel args and grid dimensions +// ============================================================================ +template +auto sparge_blockmap_create_kargs_and_grids(sparge_blockmap_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = BlockMapKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.stride_q, + args.stride_k, + args.nhead_stride_q, + args.nhead_stride_k, + args.batch_stride_q, + args.batch_stride_k, + args.simthreshd1, + args.cdfthreshd, + args.topk, + args.scale, + args.block_map_ptr, + args.lut_ptr, + args.valid_block_num_ptr); + + dim3 grids = BlockMapKernel::GridSize(args.batch, args.nhead_q, args.seqlen_q); + return ck_tile::make_tuple(kargs, grids); +} + +// ============================================================================ +// Hand-written template instantiation dispatch +// ============================================================================ +float sparge_blockmap_fwd(sparge_blockmap_traits traits, + sparge_blockmap_args args, + const ck_tile::stream_config& stream_config); diff --git a/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp b/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp index c0feb23e581..638a867b0f3 100644 --- a/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp +++ b/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp @@ -17,6 +17,7 @@ #include "ck_tile/core/utility/bit_cast.hpp" #include "vsa_sparge_attention.h" +#include "sparge_blockmap.h" #include "sparge_tool.hpp" // ============================================================================ @@ -198,53 +199,37 @@ bool run_test(const ck_tile::ArgParser& arg_parser) ck_tile::HostTensor output_host = o_perm ? ck_tile::HostTensor({batch, nhead, seqlen_q, hdim_v}) : ck_tile::HostTensor({batch, seqlen_q, nhead, hdim_v}); - ck_tile::HostTensor output_ref({batch, nhead, seqlen_q, hdim_v}); std::cout << "\nInitializing tensors..." << std::endl; ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed}(q_host); ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 1}(k_host); ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 2}(v_host); - // Build block map using Sparge tool - std::cout << "Building Sparge block map..." << std::endl; - sparge::SpargeParams p; - p.BLKQ = static_cast(BLKQ); - p.BLKK = static_cast(BLKK); - p.simthreshd1 = simthreshd1; - p.cdfthreshd = cdfthreshd; - p.topk = topk; - p.i_perm = i_perm; - - ck_tile::HostTensor block_relation_onehot = - sparge::build_block_map_meansim(q_host, k_host, p); - - // Convert to VSA LUT (delta-encoded) + valid_block_num - std::cout << "Converting block map to VSA LUT (delta)..." << std::endl; - auto vsa_lut = sparge::block_map_to_vsa_lut_delta(block_relation_onehot); - - // Print actual sparsity (based on one-hot) - std::size_t total_blocks = 0; - std::size_t active_blocks = 0; - for(ck_tile::index_t b = 0; b < batch; ++b) - { - for(ck_tile::index_t h = 0; h < nhead; ++h) - { - for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb) - { - for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) - { - total_blocks++; - if(block_relation_onehot(b, h, qb, kb) != 0) - active_blocks++; - } - } - } - } - float actual_sparsity = - 1.0f - static_cast(active_blocks) / static_cast(total_blocks); - std::cout << " Actual sparsity: " << actual_sparsity << " (" << active_blocks << "/" - << total_blocks << " blocks active)" << std::endl; - + // ================================================================== + // GPU: Build block map + VSA LUT in one kernel (always run) + // ================================================================== + std::cout << "Building Sparge block map + VSA LUT (GPU)..." << std::endl; + ck_tile::HostTensor block_map_gpu({batch, nhead, num_q_blocks, num_k_blocks}); + auto vsa_lut_gpu = sparge_blockmap_gpu(q_host, + k_host, + block_map_gpu, + batch, + nhead, + nhead_k, + seqlen_q, + seqlen_k, + hdim_q, + i_perm, + simthreshd1, + cdfthreshd, + topk, + static_cast(BLKQ), + static_cast(BLKK), + 0); + + // ================================================================== + // VSA sparse attention kernel (always run) + // ================================================================== std::cout << "\n--- Running VSA sparse attention kernel ---" << std::endl; try @@ -254,8 +239,8 @@ bool run_test(const ck_tile::ArgParser& arg_parser) vsa_sparge_attention(q_host, k_host, v_host, - vsa_lut.lut, - vsa_lut.valid_block_num, + vsa_lut_gpu.lut, + vsa_lut_gpu.valid_block_num, output_host, batch, nhead, @@ -276,8 +261,8 @@ bool run_test(const ck_tile::ArgParser& arg_parser) vsa_sparge_attention(q_host, k_host, v_host, - vsa_lut.lut, - vsa_lut.valid_block_num, + vsa_lut_gpu.lut, + vsa_lut_gpu.valid_block_num, output_host, batch, nhead, @@ -301,8 +286,8 @@ bool run_test(const ck_tile::ArgParser& arg_parser) vsa_sparge_attention(q_host, k_host, v_host, - vsa_lut.lut, - vsa_lut.valid_block_num, + vsa_lut_gpu.lut, + vsa_lut_gpu.valid_block_num, output_host, batch, nhead, @@ -332,17 +317,168 @@ bool run_test(const ck_tile::ArgParser& arg_parser) return false; } + // ================================================================== + // Sparsity statistics (always run, pure CPU read of HostTensor) + // ================================================================== + std::size_t total_blocks = 0; + std::size_t active_blocks = 0; + for(ck_tile::index_t b = 0; b < batch; ++b) + { + for(ck_tile::index_t h = 0; h < nhead; ++h) + { + for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb) + { + for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) + { + total_blocks++; + if(block_map_gpu(b, h, qb, kb) != 0) + active_blocks++; + } + } + } + } + float actual_sparsity = + 1.0f - static_cast(active_blocks) / static_cast(total_blocks); + std::cout << "\n Actual sparsity: " << actual_sparsity << " (" << active_blocks << "/" + << total_blocks << " blocks active)" << std::endl; + + // ================================================================== + // Validation (only when -v=1) + // ================================================================== bool pass = true; if(do_validation) { std::cout << "\n--- Performing CPU validation ---" << std::endl; + + // CPU golden: block map + VSA LUT + std::cout << "Building Sparge block map (CPU golden)..." << std::endl; + sparge::SpargeParams p; + p.BLKQ = static_cast(BLKQ); + p.BLKK = static_cast(BLKK); + p.simthreshd1 = simthreshd1; + p.cdfthreshd = cdfthreshd; + p.topk = topk; + p.i_perm = i_perm; + + ck_tile::HostTensor block_relation_onehot = + sparge::build_block_map_meansim(q_host, k_host, p); + + std::cout << "Converting block map to VSA LUT (delta, CPU)..." << std::endl; + auto vsa_lut_cpu = sparge::block_map_to_vsa_lut_delta(block_relation_onehot); + + // Validate block map + std::cout << "\n--- Validating GPU block map vs CPU golden ---" << std::endl; + { + std::size_t bmap_mismatches = 0; + for(ck_tile::index_t b = 0; b < batch; ++b) + { + for(ck_tile::index_t h = 0; h < nhead; ++h) + { + for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb) + { + for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) + { + if(block_map_gpu(b, h, qb, kb) != + block_relation_onehot(b, h, qb, kb)) + { + bmap_mismatches++; + if(bmap_mismatches <= 10) + { + std::cout + << " block_map mismatch at [" << b << "," << h << "," + << qb << "," << kb + << "]: GPU=" + << static_cast(block_map_gpu(b, h, qb, kb)) + << " CPU=" + << static_cast( + block_relation_onehot(b, h, qb, kb)) + << std::endl; + } + } + } + } + } + } + std::cout << " Block map mismatches: " << bmap_mismatches << " / " + << (batch * nhead * num_q_blocks * num_k_blocks) << std::endl; + if(bmap_mismatches > 0) + { + std::cout << ">>> GPU BLOCK MAP VALIDATION FAILED <<<" << std::endl; + pass = false; + } + else + { + std::cout << ">>> GPU BLOCK MAP VALIDATION PASSED <<<" << std::endl; + } + } + + // Validate VSA LUT + std::cout << "\n--- Validating GPU VSA LUT vs CPU golden ---" << std::endl; + { + std::size_t lut_mismatches = 0; + std::size_t valid_mismatches = 0; + for(ck_tile::index_t b = 0; b < batch; ++b) + { + for(ck_tile::index_t h = 0; h < nhead; ++h) + { + for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb) + { + if(vsa_lut_gpu.valid_block_num(b, h, qb) != + vsa_lut_cpu.valid_block_num(b, h, qb)) + { + valid_mismatches++; + if(valid_mismatches <= 5) + { + std::cout + << " valid_block_num mismatch at [" << b << "," << h + << "," << qb + << "]: GPU=" << vsa_lut_gpu.valid_block_num(b, h, qb) + << " CPU=" << vsa_lut_cpu.valid_block_num(b, h, qb) + << std::endl; + } + } + for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) + { + if(vsa_lut_gpu.lut(b, h, qb, kb) != + vsa_lut_cpu.lut(b, h, qb, kb)) + { + lut_mismatches++; + if(lut_mismatches <= 10) + { + std::cout + << " LUT mismatch at [" << b << "," << h << "," << qb + << "," << kb + << "]: GPU=" << vsa_lut_gpu.lut(b, h, qb, kb) + << " CPU=" << vsa_lut_cpu.lut(b, h, qb, kb) + << std::endl; + } + } + } + } + } + } + std::cout << " LUT mismatches: " << lut_mismatches << std::endl; + std::cout << " valid_block_num mismatches: " << valid_mismatches << std::endl; + if(lut_mismatches == 0 && valid_mismatches == 0) + { + std::cout << ">>> GPU VSA LUT VALIDATION PASSED <<<" << std::endl; + } + else + { + std::cout << ">>> GPU VSA LUT VALIDATION FAILED <<<" << std::endl; + pass = false; + } + } + + // Validate attention output float scale = 1.0f / std::sqrt(static_cast(hdim_q)); - std::cout << "Computing reference output..." << std::endl; + std::cout << "\nComputing reference attention output..." << std::endl; auto q_ref = to_bhsd(q_host, i_perm); auto k_ref = to_bhsd(k_host, i_perm); auto v_ref = to_bhsd(v_host, i_perm); + ck_tile::HostTensor output_ref({batch, nhead, seqlen_q, hdim_v}); ck_tile::reference_blocked_attention( q_ref, k_ref, v_ref, block_relation_onehot, output_ref, BLKQ, BLKK, scale); @@ -374,7 +510,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser) } } - std::cout << "\nValidation results:" << std::endl; + std::cout << "\nAttention validation results:" << std::endl; std::cout << " Max absolute difference: " << max_diff << std::endl; std::cout << " Max relative difference: " << max_rel_diff << std::endl; std::cout << " Number of mismatches: " << num_errors << " / " diff --git a/include/ck_tile/ops/sparse_attn/kernel/sparge_blockmap_kernel.hpp b/include/ck_tile/ops/sparse_attn/kernel/sparge_blockmap_kernel.hpp new file mode 100644 index 00000000000..ca177abf23a --- /dev/null +++ b/include/ck_tile/ops/sparse_attn/kernel/sparge_blockmap_kernel.hpp @@ -0,0 +1,195 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include "ck_tile/core.hpp" +#include + +namespace ck_tile { + +template +struct SpargeBlockMapKernel +{ + using Pipeline = remove_cvref_t; + + static constexpr index_t kBlockSize = Pipeline::kBlockSize; + static constexpr index_t kBlockPerCu = Pipeline::kBlockPerCu; + + using QDataType = typename Pipeline::QDataType; + using KDataType = typename Pipeline::KDataType; + + static constexpr index_t kM0 = Pipeline::kM0; + static constexpr index_t kN0 = Pipeline::kN0; + static constexpr index_t D = Pipeline::D; + + static constexpr index_t kAlignment = 16 / sizeof(QDataType); + + struct Kargs + { + const void* q_ptr; + const void* k_ptr; + + index_t seqlen_q; + index_t seqlen_k; + index_t hdim_q; + + index_t nhead_q; + index_t nhead_ratio_qk; + + index_t stride_q; + index_t stride_k; + index_t nhead_stride_q; + index_t nhead_stride_k; + index_t batch_stride_q; + index_t batch_stride_k; + + float simthreshd1; + float cdfthreshd; + float topk; + float scale; + + void* block_map_ptr; + void* lut_ptr; + void* valid_block_num_ptr; + + index_t N_k; + }; + + CK_TILE_HOST static constexpr auto MakeKargs(const void* q_ptr, + const void* k_ptr, + index_t seqlen_q, + index_t seqlen_k, + index_t hdim_q, + index_t nhead_q, + index_t nhead_ratio_qk, + index_t stride_q, + index_t stride_k, + index_t nhead_stride_q, + index_t nhead_stride_k, + index_t batch_stride_q, + index_t batch_stride_k, + float simthreshd1, + float cdfthreshd, + float topk, + float scale, + void* block_map_ptr, + void* lut_ptr, + void* valid_block_num_ptr) + { + const index_t N_k = integer_divide_ceil(seqlen_k, kN0); + return Kargs{q_ptr, + k_ptr, + seqlen_q, + seqlen_k, + hdim_q, + nhead_q, + nhead_ratio_qk, + stride_q, + stride_k, + nhead_stride_q, + nhead_stride_k, + batch_stride_q, + batch_stride_k, + simthreshd1, + cdfthreshd, + topk, + scale, + block_map_ptr, + lut_ptr, + valid_block_num_ptr, + N_k}; + } + + CK_TILE_HOST static constexpr auto GridSize(index_t batch, index_t nhead_q, index_t seqlen_q) + { + const index_t Q_blk = integer_divide_ceil(seqlen_q, kM0); + return dim3(Q_blk, nhead_q, batch); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + const index_t qb = static_cast(blockIdx.x); + const index_t hq = static_cast(blockIdx.y); + const index_t b = static_cast(blockIdx.z); + + const index_t hk = hq / kargs.nhead_ratio_qk; + + // Q pointer for this (batch, head, q_block) + const auto* q_base = reinterpret_cast(kargs.q_ptr) + + b * kargs.batch_stride_q + hq * kargs.nhead_stride_q + + qb * kM0 * kargs.stride_q; + + // K pointer for this (batch, head_k) + const auto* k_base = reinterpret_cast(kargs.k_ptr) + + b * kargs.batch_stride_k + hk * kargs.nhead_stride_k; + + // Q DRAM view with OOB padding + const auto q_dram_naive = make_naive_tensor_view( + q_base, + make_tuple(kargs.seqlen_q - qb * kM0, D), + make_tuple(kargs.stride_q, 1), + number{}, + number<1>{}); + const auto q_dram = pad_tensor_view( + q_dram_naive, make_tuple(number{}, number{}), sequence{}); + + auto q_window = make_tile_window(q_dram, + make_tuple(number{}, number{}), + {0, 0}, + Pipeline::MakeQBlockDistribution()); + + // K DRAM view with OOB padding + const auto k_dram_naive = + make_naive_tensor_view(k_base, + make_tuple(kargs.seqlen_k, D), + make_tuple(kargs.stride_k, 1), + number{}, + number<1>{}); + const auto k_dram = pad_tensor_view( + k_dram_naive, make_tuple(number{}, number{}), sequence{}); + + auto k_window = make_tile_window(k_dram, + make_tuple(number{}, number{}), + {0, 0}, + Pipeline::MakeKBlockDistribution()); + + // Output pointers for this (batch, head, q_block) + const index_t N_k = kargs.N_k; + const index_t bmap_offset = + (b * kargs.nhead_q + hq) * integer_divide_ceil(kargs.seqlen_q, kM0) * N_k + qb * N_k; + auto* bmap_ptr = reinterpret_cast(kargs.block_map_ptr) + bmap_offset; + + int32_t* lut_out = nullptr; + int32_t* valid_out = nullptr; + if(kargs.lut_ptr != nullptr) + { + lut_out = reinterpret_cast(kargs.lut_ptr) + bmap_offset; + const index_t valid_offset = + (b * kargs.nhead_q + hq) * integer_divide_ceil(kargs.seqlen_q, kM0) + qb; + valid_out = reinterpret_cast(kargs.valid_block_num_ptr) + valid_offset; + } + + // Shared memory + __shared__ char smem[Pipeline::GetSmemSize()]; + + Pipeline{}(q_window, + k_window, + kargs.seqlen_q, + kargs.seqlen_k, + qb, + N_k, + kargs.nhead_ratio_qk, + kargs.simthreshd1, + kargs.cdfthreshd, + kargs.topk, + kargs.scale, + bmap_ptr, + lut_out, + valid_out, + static_cast(smem)); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp b/include/ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp new file mode 100644 index 00000000000..222e73c60e2 --- /dev/null +++ b/include/ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp @@ -0,0 +1,521 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/reduce.hpp" + +namespace ck_tile { + +template +struct SpargeBlockMapPipeline +{ + using Problem = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using BlockFmhaShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t D = BlockFmhaShape::kQKHeaddim; + static constexpr index_t NumWarps = BlockFmhaShape::NumWarps; + static constexpr index_t WarpSize = get_warp_size(); + + static constexpr index_t KPerThread = 16 / sizeof(QDataType); + static constexpr index_t KThreads = D / KPerThread; + static constexpr index_t SeqThreadPerWarp = WarpSize / KThreads; + static constexpr index_t MPerThread = kM0 / (SeqThreadPerWarp * NumWarps); + static constexpr index_t NPerThread = kN0 / (SeqThreadPerWarp * NumWarps); + + static constexpr index_t kBlockPerCu = 1; + static constexpr index_t kMaxKBlocks = 1024; + + // LDS layout (non-overlapping, all used simultaneously in Phase 2): + // [0 .. kReduceBytes) cross-warp reduction scratch + // [kScoreOffset ..) scores[N_k] + // [kBmapOffset ..) block_map[N_k] + // [kSmallOffset ..) Phase 3 argmax scratch (2*NumWarps floats) + static constexpr index_t kReduceBytes = NumWarps * D * sizeof(float); + static constexpr index_t kScoreOffset = kReduceBytes; + static constexpr index_t kBmapOffset = kScoreOffset + kMaxKBlocks * sizeof(float); + static constexpr index_t kSmallOffset = kBmapOffset + kMaxKBlocks * sizeof(uint8_t); + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return kSmallOffset + 2 * NumWarps * sizeof(float); + } + + CK_TILE_HOST_DEVICE static constexpr auto MakeQBlockDistribution() + { + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + CK_TILE_HOST_DEVICE static constexpr auto MakeKBlockDistribution() + { + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + // Extract tile data into a local float array via static_for (compile-time indices). + template + CK_TILE_DEVICE static void tile_to_float(const Tile& tile, float (&out)[BufSize]) + { + static_assert(Tile::get_thread_buffer_size() == BufSize); + const auto& buf = tile.get_thread_buffer(); + static_for<0, BufSize, 1>{}([&](auto i) { out[i.value] = type_convert(buf[i]); }); + } + + // Column-wise (dim=0) sum: accumulate SeqPerThread rows into KPerThread partial sums, + // then xor-shuffle across m_idx within warp. + template + CK_TILE_DEVICE static void column_reduce_thread_and_warp(const float* __restrict__ data, + float (&col_acc)[KPerThread]) + { + for(index_t k = 0; k < KPerThread; ++k) + col_acc[k] = 0.f; + + for(index_t m = 0; m < SeqPerThread; ++m) + for(index_t k = 0; k < KPerThread; ++k) + col_acc[k] += data[m * KPerThread + k]; + + for(index_t stride = KThreads; stride < WarpSize; stride *= 2) + for(index_t k = 0; k < KPerThread; ++k) + col_acc[k] += warp_shuffle(col_acc[k], __lane_id() ^ stride); + } + + // Cross-warp LDS reduction for column sums. + CK_TILE_DEVICE static void column_reduce_cross_warp(float (&col_acc)[KPerThread], + float* __restrict__ smem_reduce) + { + const index_t tid = static_cast(threadIdx.x); + const index_t warp_id = tid / WarpSize; + const index_t lane_id = tid % WarpSize; + const index_t k_idx = lane_id % KThreads; + const index_t m_idx = lane_id / KThreads; + + if(m_idx == 0) + for(index_t k = 0; k < KPerThread; ++k) + smem_reduce[warp_id * D + k_idx * KPerThread + k] = col_acc[k]; + __syncthreads(); + + for(index_t k = 0; k < KPerThread; ++k) + col_acc[k] = 0.f; + for(index_t w = 0; w < NumWarps; ++w) + for(index_t k = 0; k < KPerThread; ++k) + col_acc[k] += smem_reduce[w * D + k_idx * KPerThread + k]; + __syncthreads(); + } + + // Compute ||v||^2 per row: sum along KPerThread then xor-shuffle across k_idx. + template + CK_TILE_DEVICE static void row_reduce_sq_norm(const float* __restrict__ data, + float (&row_norms)[SeqPerThread], + index_t actual_seq) + { + const index_t tid = static_cast(threadIdx.x); + const index_t warp_id = tid / WarpSize; + const index_t m_idx = (tid % WarpSize) / KThreads; + + for(index_t m = 0; m < SeqPerThread; ++m) + { + float sq = 0.f; + for(index_t k = 0; k < KPerThread; ++k) + { + float v = data[m * KPerThread + k]; + sq += v * v; + } + for(index_t stride = 1; stride < KThreads; stride *= 2) + sq += warp_shuffle(sq, __lane_id() ^ stride); + + index_t gsq = m * (SeqThreadPerWarp * NumWarps) + warp_id * SeqThreadPerWarp + m_idx; + row_norms[m] = (gsq < actual_seq) ? sq : 0.f; + } + } + + // Column reduce of normalised rows: sum_hat[d] = sum_i data[i,d] / ||data[i,:]||. + template + CK_TILE_DEVICE static void column_reduce_normalised(const float* __restrict__ data, + const float* __restrict__ row_norms, + float (&col_acc)[KPerThread], + index_t actual_seq) + { + const index_t tid = static_cast(threadIdx.x); + const index_t warp_id = tid / WarpSize; + const index_t m_idx = (tid % WarpSize) / KThreads; + + for(index_t k = 0; k < KPerThread; ++k) + col_acc[k] = 0.f; + + for(index_t m = 0; m < SeqPerThread; ++m) + { + float inv_norm = (row_norms[m] > 0.f) ? (1.0f / __builtin_sqrtf(row_norms[m])) : 0.f; + index_t gsq = m * (SeqThreadPerWarp * NumWarps) + warp_id * SeqThreadPerWarp + m_idx; + if(gsq < actual_seq) + for(index_t k = 0; k < KPerThread; ++k) + col_acc[k] += data[m * KPerThread + k] * inv_norm; + } + + for(index_t stride = KThreads; stride < WarpSize; stride *= 2) + for(index_t k = 0; k < KPerThread; ++k) + col_acc[k] += warp_shuffle(col_acc[k], __lane_id() ^ stride); + } + + // Scalar reduce across k_idx lanes (within warp). + CK_TILE_DEVICE static float reduce_across_k(float v) + { + for(index_t stride = 1; stride < KThreads; stride *= 2) + v += warp_shuffle(v, __lane_id() ^ stride); + return v; + } + + // Full-block scalar reduce (warp xor + cross-warp LDS). + CK_TILE_DEVICE static float block_reduce_sum(float v, float* smem_small) + { + const index_t tid = static_cast(threadIdx.x); + const index_t warp_id = tid / WarpSize; + const index_t lane_id = tid % WarpSize; + + for(index_t stride = 1; stride < WarpSize; stride *= 2) + v += warp_shuffle(v, __lane_id() ^ stride); + if(lane_id == 0) + smem_small[warp_id] = v; + __syncthreads(); + if(tid == 0) + { + float s = 0.f; + for(index_t w = 0; w < NumWarps; ++w) + s += smem_small[w]; + smem_small[0] = s; + } + __syncthreads(); + return smem_small[0]; + } + + CK_TILE_DEVICE static float block_reduce_max(float v, float* smem_small) + { + const index_t tid = static_cast(threadIdx.x); + const index_t warp_id = tid / WarpSize; + const index_t lane_id = tid % WarpSize; + + for(index_t stride = 1; stride < WarpSize; stride *= 2) + v = max(v, warp_shuffle(v, __lane_id() ^ stride)); + if(lane_id == 0) + smem_small[warp_id] = v; + __syncthreads(); + if(tid == 0) + { + float s = smem_small[0]; + for(index_t w = 1; w < NumWarps; ++w) + s = max(s, smem_small[w]); + smem_small[0] = s; + } + __syncthreads(); + return smem_small[0]; + } + + // ====================================================================== + template + CK_TILE_DEVICE void operator()(const QWindowType& q_window_in, + const KWindowType& k_window_in, + index_t seqlen_q, + index_t seqlen_k, + index_t qb, + index_t N_k, + index_t /*nhead_ratio_qk*/, + float simthreshd1, + float cdfthreshd, + float topk, + float scale, + uint8_t* block_map_ptr, + int32_t* lut_ptr, + int32_t* valid_block_num_ptr, + void* smem_ptr) const + { + const index_t tid = static_cast(threadIdx.x); + + auto* smem_float = reinterpret_cast(smem_ptr); + auto* smem_scores = + reinterpret_cast(reinterpret_cast(smem_ptr) + kScoreOffset); + auto* smem_bmap = + reinterpret_cast(reinterpret_cast(smem_ptr) + kBmapOffset); + auto* smem_small = + reinterpret_cast(reinterpret_cast(smem_ptr) + kSmallOffset); + + const index_t bs_q = min(static_cast(kM0), seqlen_q - qb * kM0); + const float inv_bs_q = (bs_q > 0) ? (1.0f / static_cast(bs_q)) : 0.f; + + // ================================================================== + // Phase 1: Q Block Statistics + // ================================================================== + auto q_tile = load_tile(q_window_in); + + float q_data[MPerThread * KPerThread]; + tile_to_float(q_tile, q_data); + + // 1a. L2 norm per token + float psq[MPerThread]; + row_reduce_sq_norm(q_data, psq, bs_q); + + // 1b. Column sum -> mean + float pooled_q_mean[KPerThread]; + column_reduce_thread_and_warp(q_data, pooled_q_mean); + column_reduce_cross_warp(pooled_q_mean, smem_float); + for(index_t k = 0; k < KPerThread; ++k) + pooled_q_mean[k] *= inv_bs_q; + + // 1c. Normalised sum_hat + float sum_hat[KPerThread]; + column_reduce_normalised(q_data, psq, sum_hat, bs_q); + column_reduce_cross_warp(sum_hat, smem_float); + + // 1d. sim_q = ||sum_hat||^2 / bs_q^2 + float sh_sq = 0.f; + for(index_t k = 0; k < KPerThread; ++k) + sh_sq += sum_hat[k] * sum_hat[k]; + sh_sq = reduce_across_k(sh_sq); + const float denom_q = static_cast(bs_q) * static_cast(bs_q); + const bool sim_q = (denom_q > 0.f) && ((sh_sq / denom_q) > simthreshd1); + + // Not similar → force all K blocks ON, early exit + if(!sim_q) + { + for(index_t i = tid; i < N_k; i += kBlockSize) + block_map_ptr[i] = 1; + + if(lut_ptr != nullptr && tid == 0) + { + int32_t valid = 0, prev = 0; + for(index_t kb = 0; kb < N_k; ++kb) + { + lut_ptr[valid] = static_cast(kb) - prev; + prev = static_cast(kb); + ++valid; + } + for(index_t i = valid; i < N_k; ++i) + lut_ptr[i] = 0; + *valid_block_num_ptr = valid; + } + return; + } + + // ================================================================== + // Phase 2: K Block Loop + // ================================================================== + for(index_t i = tid; i < N_k; i += kBlockSize) + smem_bmap[i] = 0; + __syncthreads(); + + auto k_window = k_window_in; + + for(index_t kb = 0; kb < N_k; ++kb) + { + const index_t bs_k = min(static_cast(kN0), seqlen_k - kb * kN0); + const float inv_bs_k = (bs_k > 0) ? (1.0f / static_cast(bs_k)) : 0.f; + + auto k_tile = load_tile(k_window); + + float k_data[NPerThread * KPerThread]; + tile_to_float(k_tile, k_data); + + // K mean + float pooled_k_mean[KPerThread]; + column_reduce_thread_and_warp(k_data, pooled_k_mean); + column_reduce_cross_warp(pooled_k_mean, smem_float); + for(index_t k = 0; k < KPerThread; ++k) + pooled_k_mean[k] *= inv_bs_k; + + // dot(pooled_q_mean, pooled_k_mean) + float dot = 0.f; + for(index_t k = 0; k < KPerThread; ++k) + dot += pooled_q_mean[k] * pooled_k_mean[k]; + dot = reduce_across_k(dot); + + // K L2 norms + normalised sum_hat + float k_psq[NPerThread]; + row_reduce_sq_norm(k_data, k_psq, bs_k); + + float k_sum_hat[KPerThread]; + column_reduce_normalised(k_data, k_psq, k_sum_hat, bs_k); + column_reduce_cross_warp(k_sum_hat, smem_float); + + // sim_k + float ksh_sq = 0.f; + for(index_t k = 0; k < KPerThread; ++k) + ksh_sq += k_sum_hat[k] * k_sum_hat[k]; + ksh_sq = reduce_across_k(ksh_sq); + const float denom_k = static_cast(bs_k) * static_cast(bs_k); + const bool sim_k = (denom_k > 0.f) && ((ksh_sq / denom_k) > simthreshd1); + + if(tid == 0) + { + if(!sim_k) + { + smem_bmap[kb] = 1; + smem_scores[kb] = -numeric::infinity(); + } + else + { + smem_scores[kb] = dot * scale; + } + } + __syncthreads(); + + move_tile_window(k_window, {kN0, 0}); + } + + // ================================================================== + // Phase 3: Softmax + Selection + // ================================================================== + + // max + float lmax = -numeric::infinity(); + for(index_t i = tid; i < N_k; i += kBlockSize) + lmax = max(lmax, smem_scores[i]); + const float max_score = block_reduce_max(lmax, smem_small); + + // exp + sum + float lsum = 0.f; + for(index_t i = tid; i < N_k; i += kBlockSize) + { + float e = (smem_scores[i] > -numeric::infinity()) + ? __builtin_expf(smem_scores[i] - max_score) + : 0.f; + smem_scores[i] = e; + lsum += e; + } + const float sum_exp = block_reduce_sum(lsum, smem_small); + + // normalise + const float inv_sum = (sum_exp > 0.f) ? (1.0f / sum_exp) : 0.f; + for(index_t i = tid; i < N_k; i += kBlockSize) + smem_scores[i] *= inv_sum; + __syncthreads(); + + // Selection: iterative argmax + index_t num_to_select = + (topk > 0.f) + ? max(static_cast(1), static_cast(topk * static_cast(N_k))) + : N_k; + + float cumulative_prob = 0.f; + for(index_t round = 0; round < num_to_select; ++round) + { + // thread-local argmax + float best_val = -1.f; + index_t best_idx = 0; + for(index_t i = tid; i < N_k; i += kBlockSize) + { + if(smem_scores[i] > best_val || (smem_scores[i] == best_val && i < best_idx)) + { + best_val = smem_scores[i]; + best_idx = i; + } + } + + // warp argmax + for(index_t stride = 1; stride < WarpSize; stride *= 2) + { + float rv = warp_shuffle(best_val, __lane_id() ^ stride); + index_t ri = warp_shuffle(best_idx, __lane_id() ^ stride); + if(rv > best_val || (rv == best_val && ri < best_idx)) + { + best_val = rv; + best_idx = ri; + } + } + + // cross-warp argmax via LDS + const index_t lane_id = tid % WarpSize; + const index_t warp_id = tid / WarpSize; + if(lane_id == 0) + { + smem_small[warp_id] = best_val; + smem_small[NumWarps + warp_id] = bit_cast(static_cast(best_idx)); + } + __syncthreads(); + + if(tid == 0) + { + float bv = smem_small[0]; + index_t bi = bit_cast(smem_small[NumWarps]); + for(index_t w = 1; w < NumWarps; ++w) + { + float wv = smem_small[w]; + index_t wi = bit_cast(smem_small[NumWarps + w]); + if(wv > bv || (wv == bv && wi < bi)) + { + bv = wv; + bi = wi; + } + } + smem_small[0] = bv; + smem_small[1] = bit_cast(static_cast(bi)); + } + __syncthreads(); + + float g_val = smem_small[0]; + index_t g_idx = bit_cast(smem_small[1]); + + if(g_val <= 0.f) + break; + + if(tid == 0) + { + smem_bmap[g_idx] = 1; + smem_scores[g_idx] = -1.f; + } + __syncthreads(); + + if(topk > 0.f) + { + if(round + 1 >= num_to_select) + break; + } + else + { + cumulative_prob += g_val; + if(cumulative_prob >= cdfthreshd) + break; + } + } + + // ================================================================== + // Write outputs to global memory + // ================================================================== + for(index_t i = tid; i < N_k; i += kBlockSize) + block_map_ptr[i] = smem_bmap[i]; + + if(lut_ptr != nullptr && tid == 0) + { + int32_t valid = 0, prev = 0; + for(index_t kb = 0; kb < N_k; ++kb) + { + if(smem_bmap[kb] != 0) + { + lut_ptr[valid] = static_cast(kb) - prev; + prev = static_cast(kb); + ++valid; + } + } + for(index_t i = valid; i < N_k; ++i) + lut_ptr[i] = 0; + *valid_block_num_ptr = valid; + } + } +}; + +} // namespace ck_tile From c7e6e4f616b483f2a9aafd3e8d00238f02de77e5 Mon Sep 17 00:00:00 2001 From: Gino Lu Date: Tue, 14 Apr 2026 10:11:00 -0400 Subject: [PATCH 4/4] fix extra host side operations. --- example/ck_tile/50_sparse_attn/CMakeLists.txt | 4 - .../50_sparse_attn/sparge_blockmap.cpp | 156 --------- .../ck_tile/50_sparse_attn/sparge_blockmap.h | 26 -- .../test_sparge_vsa_sparse_attn.cpp | 296 ++++++++++-------- .../50_sparse_attn/vsa_sparge_attention.cpp | 195 ------------ .../50_sparse_attn/vsa_sparge_attention.h | 28 -- 6 files changed, 164 insertions(+), 541 deletions(-) delete mode 100644 example/ck_tile/50_sparse_attn/sparge_blockmap.cpp delete mode 100644 example/ck_tile/50_sparse_attn/sparge_blockmap.h delete mode 100644 example/ck_tile/50_sparse_attn/vsa_sparge_attention.cpp delete mode 100644 example/ck_tile/50_sparse_attn/vsa_sparge_attention.h diff --git a/example/ck_tile/50_sparse_attn/CMakeLists.txt b/example/ck_tile/50_sparse_attn/CMakeLists.txt index 169ed87ac3b..f234f631b6b 100644 --- a/example/ck_tile/50_sparse_attn/CMakeLists.txt +++ b/example/ck_tile/50_sparse_attn/CMakeLists.txt @@ -249,14 +249,12 @@ set(SPARGE_VSA_INSTANCES "tile_sparge_vsa_instances") add_library(${SPARGE_VSA_INSTANCES} OBJECT EXCLUDE_FROM_ALL ${SPARGE_VSA_GEN_BLOBS} - ${CMAKE_CURRENT_LIST_DIR}/vsa_sparge_attention.cpp ) target_include_directories(${SPARGE_VSA_INSTANCES} PRIVATE ${CMAKE_CURRENT_LIST_DIR} ${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn ) set_source_files_properties(${SPARGE_VSA_GEN_BLOBS} PROPERTIES LANGUAGE HIP) -set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/vsa_sparge_attention.cpp PROPERTIES LANGUAGE HIP) set_property(TARGET ${SPARGE_VSA_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) target_compile_options(${SPARGE_VSA_INSTANCES} PRIVATE @@ -273,7 +271,6 @@ set(SPARGE_BLOCKMAP_INSTANCES "tile_sparge_blockmap_instances") add_library(${SPARGE_BLOCKMAP_INSTANCES} OBJECT EXCLUDE_FROM_ALL ${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap_inst.cpp - ${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap.cpp ) target_include_directories(${SPARGE_BLOCKMAP_INSTANCES} PRIVATE ${CMAKE_CURRENT_LIST_DIR} @@ -281,7 +278,6 @@ target_include_directories(${SPARGE_BLOCKMAP_INSTANCES} PRIVATE ) set_source_files_properties( ${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap_inst.cpp - ${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap.cpp PROPERTIES LANGUAGE HIP ) set_property(TARGET ${SPARGE_BLOCKMAP_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) diff --git a/example/ck_tile/50_sparse_attn/sparge_blockmap.cpp b/example/ck_tile/50_sparse_attn/sparge_blockmap.cpp deleted file mode 100644 index b9ac56c533c..00000000000 --- a/example/ck_tile/50_sparse_attn/sparge_blockmap.cpp +++ /dev/null @@ -1,156 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT -#include "sparge_blockmap.h" -#include "sparge_blockmap_trek.hpp" -#include "ck_tile/core.hpp" -#include "ck_tile/host/host_tensor.hpp" -#include "ck_tile/host/device_memory.hpp" -#include -#include - -template -sparge::VSALut sparge_blockmap_gpu(const ck_tile::HostTensor& TQ, - const ck_tile::HostTensor& TK, - ck_tile::HostTensor& block_map_out, - int batch, - int nhead_q, - int nhead_k, - int seqlen_q, - int seqlen_k, - int hdim_q, - bool i_perm, - float simthreshd1, - float cdfthreshd, - float topk, - int blkq, - int blkk, - int log_level) -{ - static_assert(std::is_same_v || - std::is_same_v, - "sparge_blockmap_gpu supports fp16/bf16 only."); - - std::string data_type = "fp16"; - if constexpr(std::is_same_v) - { - data_type = "bf16"; - } - - const ck_tile::index_t num_q_blocks = ck_tile::integer_divide_ceil(seqlen_q, blkq); - const ck_tile::index_t num_k_blocks = ck_tile::integer_divide_ceil(seqlen_k, blkk); - - const float scale = 1.0f / std::sqrt(static_cast(hdim_q)); - - // Allocate device memory - ck_tile::DeviceMem q_buf(TQ.get_element_space_size_in_bytes()); - ck_tile::DeviceMem k_buf(TK.get_element_space_size_in_bytes()); - - const std::size_t bmap_bytes = - static_cast(batch) * nhead_q * num_q_blocks * num_k_blocks * sizeof(uint8_t); - const std::size_t lut_bytes = - static_cast(batch) * nhead_q * num_q_blocks * num_k_blocks * sizeof(int32_t); - const std::size_t valid_bytes = - static_cast(batch) * nhead_q * num_q_blocks * sizeof(int32_t); - - ck_tile::DeviceMem bmap_buf(bmap_bytes); - ck_tile::DeviceMem lut_buf(lut_bytes); - ck_tile::DeviceMem valid_buf(valid_bytes); - - q_buf.ToDevice(TQ.data()); - k_buf.ToDevice(TK.data()); - bmap_buf.SetZero(); - lut_buf.SetZero(); - valid_buf.SetZero(); - - // Compute strides (assumes BHSD if i_perm, BSHD otherwise) - const ck_tile::index_t stride_q = i_perm ? hdim_q : nhead_q * hdim_q; - const ck_tile::index_t stride_k = i_perm ? hdim_q : nhead_k * hdim_q; - const ck_tile::index_t nhead_stride_q = - i_perm ? static_cast(seqlen_q) * hdim_q : hdim_q; - const ck_tile::index_t nhead_stride_k = - i_perm ? static_cast(seqlen_k) * hdim_q : hdim_q; - const ck_tile::index_t batch_stride_q = - static_cast(nhead_q) * seqlen_q * hdim_q; - const ck_tile::index_t batch_stride_k = - static_cast(nhead_k) * seqlen_k * hdim_q; - - ck_tile::stream_config stream_config{nullptr, false, log_level, 0, 1, false}; - - sparge_blockmap_args args; - args.q_ptr = q_buf.GetDeviceBuffer(); - args.k_ptr = k_buf.GetDeviceBuffer(); - args.batch = batch; - args.seqlen_q = seqlen_q; - args.seqlen_k = seqlen_k; - args.hdim_q = hdim_q; - args.nhead_q = nhead_q; - args.nhead_k = nhead_k; - args.stride_q = stride_q; - args.stride_k = stride_k; - args.nhead_stride_q = nhead_stride_q; - args.nhead_stride_k = nhead_stride_k; - args.batch_stride_q = batch_stride_q; - args.batch_stride_k = batch_stride_k; - args.simthreshd1 = simthreshd1; - args.cdfthreshd = cdfthreshd; - args.topk = topk; - args.scale = scale; - args.block_map_ptr = bmap_buf.GetDeviceBuffer(); - args.lut_ptr = lut_buf.GetDeviceBuffer(); - args.valid_block_num_ptr = valid_buf.GetDeviceBuffer(); - - sparge_blockmap_traits traits; - traits.data_type = data_type; - traits.hdim_q = hdim_q; - - sparge_blockmap_fwd(traits, args, stream_config); - - // Copy results back to host - bmap_buf.FromDevice(block_map_out.data(), bmap_bytes); - - sparge::VSALut vsa_lut{ - ck_tile::HostTensor({batch, nhead_q, num_q_blocks, num_k_blocks}), - ck_tile::HostTensor({batch, nhead_q, num_q_blocks}), - }; - lut_buf.FromDevice(vsa_lut.lut.data(), lut_bytes); - valid_buf.FromDevice(vsa_lut.valid_block_num.data(), valid_bytes); - - return vsa_lut; -} - -// Explicit template instantiations -template sparge::VSALut -sparge_blockmap_gpu(const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - ck_tile::HostTensor&, - int, - int, - int, - int, - int, - int, - bool, - float, - float, - float, - int, - int, - int); - -template sparge::VSALut -sparge_blockmap_gpu(const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - ck_tile::HostTensor&, - int, - int, - int, - int, - int, - int, - bool, - float, - float, - float, - int, - int, - int); diff --git a/example/ck_tile/50_sparse_attn/sparge_blockmap.h b/example/ck_tile/50_sparse_attn/sparge_blockmap.h deleted file mode 100644 index 3057257ca14..00000000000 --- a/example/ck_tile/50_sparse_attn/sparge_blockmap.h +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT -#pragma once - -#include -#include "ck_tile/core.hpp" -#include "ck_tile/host/host_tensor.hpp" -#include "sparge_tool.hpp" - -template -sparge::VSALut sparge_blockmap_gpu(const ck_tile::HostTensor& TQ, - const ck_tile::HostTensor& TK, - ck_tile::HostTensor& block_map_out, - int batch, - int nhead_q, - int nhead_k, - int seqlen_q, - int seqlen_k, - int hdim_q, - bool i_perm, - float simthreshd1, - float cdfthreshd, - float topk, - int blkq, - int blkk, - int log_level = 0); diff --git a/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp b/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp index 638a867b0f3..572b708f9ef 100644 --- a/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp +++ b/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp @@ -1,23 +1,17 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Demo: Sparge block-map -> (delta LUT) -> VSA sparse attention +// Demo: Sparge block-map -> (delta LUT) -> VSA sparse attention (all-in-device) #include -#include #include -#include #include -#include -#include -#include - #include "ck_tile/host.hpp" #include "ck_tile/core.hpp" #include "ck_tile/host/reference/reference_blocked_attention.hpp" #include "ck_tile/core/utility/bit_cast.hpp" -#include "vsa_sparge_attention.h" -#include "sparge_blockmap.h" +#include "sparge_blockmap_trek.hpp" +#include "fmha_fwd_trek.hpp" #include "sparge_tool.hpp" // ============================================================================ @@ -192,7 +186,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser) << ", topk=" << topk << ")" << std::endl; std::cout << " i_perm: " << i_perm << ", o_perm: " << o_perm << std::endl; - // Create host tensors + // Create host tensors and fill with random data ck_tile::HostTensor q_host = make_qkv_tensor(batch, nhead, seqlen_q, hdim_q, i_perm); ck_tile::HostTensor k_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_q, i_perm); ck_tile::HostTensor v_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_v, i_perm); @@ -206,119 +200,157 @@ bool run_test(const ck_tile::ArgParser& arg_parser) ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 2}(v_host); // ================================================================== - // GPU: Build block map + VSA LUT in one kernel (always run) + // Allocate device memory once, HtoD once // ================================================================== - std::cout << "Building Sparge block map + VSA LUT (GPU)..." << std::endl; - ck_tile::HostTensor block_map_gpu({batch, nhead, num_q_blocks, num_k_blocks}); - auto vsa_lut_gpu = sparge_blockmap_gpu(q_host, - k_host, - block_map_gpu, - batch, - nhead, - nhead_k, - seqlen_q, - seqlen_k, - hdim_q, - i_perm, - simthreshd1, - cdfthreshd, - topk, - static_cast(BLKQ), - static_cast(BLKK), - 0); + ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem o_buf(output_host.get_element_space_size_in_bytes()); + + q_buf.ToDevice(q_host.data()); + k_buf.ToDevice(k_host.data()); + v_buf.ToDevice(v_host.data()); + + const std::size_t bmap_bytes = + static_cast(batch) * nhead * num_q_blocks * num_k_blocks * sizeof(uint8_t); + const std::size_t lut_bytes = + static_cast(batch) * nhead * num_q_blocks * num_k_blocks * sizeof(int32_t); + const std::size_t valid_bytes = + static_cast(batch) * nhead * num_q_blocks * sizeof(int32_t); + + ck_tile::DeviceMem bmap_buf(bmap_bytes); + ck_tile::DeviceMem lut_buf(lut_bytes); + ck_tile::DeviceMem valid_buf(valid_bytes); + bmap_buf.SetZero(); + lut_buf.SetZero(); + valid_buf.SetZero(); // ================================================================== - // VSA sparse attention kernel (always run) + // Common stride calculations // ================================================================== - std::cout << "\n--- Running VSA sparse attention kernel ---" << std::endl; + assert(nhead % nhead_k == 0); + const float scale_s = 1.0f / std::sqrt(static_cast(hdim_q)); + + const ck_tile::index_t stride_q = i_perm ? hdim_q : nhead * hdim_q; + const ck_tile::index_t stride_k = i_perm ? hdim_q : nhead_k * hdim_q; + const ck_tile::index_t stride_v = i_perm ? hdim_v : nhead_k * hdim_v; + const ck_tile::index_t stride_o = o_perm ? hdim_v : nhead * hdim_v; + const ck_tile::index_t nhead_stride_q = i_perm ? seqlen_q * hdim_q : hdim_q; + const ck_tile::index_t nhead_stride_k = i_perm ? seqlen_k * hdim_q : hdim_q; + const ck_tile::index_t nhead_stride_v = i_perm ? seqlen_k * hdim_v : hdim_v; + const ck_tile::index_t nhead_stride_o = o_perm ? seqlen_q * hdim_v : hdim_v; + const ck_tile::index_t batch_stride_q = nhead * seqlen_q * hdim_q; + const ck_tile::index_t batch_stride_k = nhead_k * seqlen_k * hdim_q; + const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * seqlen_k; + const ck_tile::index_t batch_stride_o = nhead * seqlen_q * hdim_v; + + std::string data_type = "fp16"; + if constexpr(std::is_same_v) + data_type = "bf16"; - try - { - if(kname) - { - vsa_sparge_attention(q_host, - k_host, - v_host, - vsa_lut_gpu.lut, - vsa_lut_gpu.valid_block_num, - output_host, - batch, - nhead, - nhead_k, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - i_perm, - o_perm, - seqlen_q, - seqlen_k, - 1); - } + std::string msk_str = "0"; + mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k); - for(int i = 0; i < warmup; ++i) - { - vsa_sparge_attention(q_host, - k_host, - v_host, - vsa_lut_gpu.lut, - vsa_lut_gpu.valid_block_num, - output_host, - batch, - nhead, - nhead_k, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - i_perm, - o_perm, - seqlen_q, - seqlen_k, - 0); - } + // ================================================================== + // GPU: Build block map + VSA LUT (always run, device-only) + // ================================================================== + std::cout << "Building Sparge block map + VSA LUT (GPU)..." << std::endl; + { + sparge_blockmap_args args; + args.q_ptr = q_buf.GetDeviceBuffer(); + args.k_ptr = k_buf.GetDeviceBuffer(); + args.batch = batch; + args.seqlen_q = seqlen_q; + args.seqlen_k = seqlen_k; + args.hdim_q = hdim_q; + args.nhead_q = nhead; + args.nhead_k = nhead_k; + args.stride_q = stride_q; + args.stride_k = stride_k; + args.nhead_stride_q = nhead_stride_q; + args.nhead_stride_k = nhead_stride_k; + args.batch_stride_q = batch_stride_q; + args.batch_stride_k = batch_stride_k; + args.simthreshd1 = simthreshd1; + args.cdfthreshd = cdfthreshd; + args.topk = topk; + args.scale = scale_s; + args.block_map_ptr = bmap_buf.GetDeviceBuffer(); + args.lut_ptr = lut_buf.GetDeviceBuffer(); + args.valid_block_num_ptr = valid_buf.GetDeviceBuffer(); + + sparge_blockmap_traits traits; + traits.data_type = data_type; + traits.hdim_q = hdim_q; + + sparge_blockmap_fwd(traits, args, ck_tile::stream_config{}); + } - [[maybe_unused]] auto sync_status1 = hipDeviceSynchronize(); - auto start = std::chrono::high_resolution_clock::now(); + // ================================================================== + // VSA sparse attention kernel (always run, LUT stays on device) + // ================================================================== + std::cout << "\n--- Running VSA sparse attention kernel ---" << std::endl; - for(int i = 0; i < repeat; ++i) - { - vsa_sparge_attention(q_host, - k_host, - v_host, - vsa_lut_gpu.lut, - vsa_lut_gpu.valid_block_num, - output_host, - batch, - nhead, - nhead_k, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - i_perm, - o_perm, - seqlen_q, - seqlen_k, - 0); - } + fmha_vsa_fwd_args fmha_args; + fmha_args.q_ptr = q_buf.GetDeviceBuffer(); + fmha_args.k_ptr = k_buf.GetDeviceBuffer(); + fmha_args.v_ptr = v_buf.GetDeviceBuffer(); + fmha_args.lut_ptr = lut_buf.GetDeviceBuffer(); + fmha_args.valid_block_num_ptr = valid_buf.GetDeviceBuffer(); + fmha_args.o_ptr = o_buf.GetDeviceBuffer(); + fmha_args.batch = batch; + fmha_args.seqlen_q = seqlen_q; + fmha_args.seqlen_k = seqlen_k; + fmha_args.max_seqlen_q = seqlen_q; + fmha_args.hdim_q = hdim_q; + fmha_args.hdim_v = hdim_v; + fmha_args.nhead_q = nhead; + fmha_args.nhead_k = nhead_k; + fmha_args.scale_s = scale_s; + fmha_args.stride_q = stride_q; + fmha_args.stride_k = stride_k; + fmha_args.stride_v = stride_v; + fmha_args.stride_o = stride_o; + fmha_args.nhead_stride_q = nhead_stride_q; + fmha_args.nhead_stride_k = nhead_stride_k; + fmha_args.nhead_stride_v = nhead_stride_v; + fmha_args.nhead_stride_o = nhead_stride_o; + fmha_args.batch_stride_q = batch_stride_q; + fmha_args.batch_stride_k = batch_stride_k; + fmha_args.batch_stride_v = batch_stride_v; + fmha_args.batch_stride_o = batch_stride_o; + fmha_args.window_size_left = mask.left; + fmha_args.window_size_right = mask.right; + fmha_args.mask_type = static_cast(mask.type); + + fmha_vsa_fwd_traits fmha_traits; + fmha_traits.hdim_q = hdim_q; + fmha_traits.hdim_v = hdim_v; + fmha_traits.data_type = data_type; + fmha_traits.is_v_rowmajor = true; + fmha_traits.mask_type = mask.type; + + ck_tile::stream_config stream_config{nullptr, + true, + /* log_level = */ kname ? 1 : 0, + warmup, + repeat, + false}; + + float avg_time_ms = sparge_vsa_fwd(fmha_traits, fmha_args, stream_config); + + std::cout << "\n>>>> VSA sparse attention average time: " << avg_time_ms << " ms <<<<" + << std::endl; - [[maybe_unused]] auto sync_status2 = hipDeviceSynchronize(); - auto end = std::chrono::high_resolution_clock::now(); - double avg_time_ms = - std::chrono::duration(end - start).count() / repeat; + // DtoH: attention output (always needed) + o_buf.FromDevice(output_host.data(), output_host.get_element_space_size_in_bytes()); - std::cout << "\n>>>> VSA sparse attention average time: " << avg_time_ms << " ms <<<<" - << std::endl; - } - catch(const std::exception& e) - { - std::cerr << "Error during kernel execution: " << e.what() << std::endl; - return false; - } + // DtoH: block_map (needed for sparsity stats and validation) + ck_tile::HostTensor block_map_gpu({batch, nhead, num_q_blocks, num_k_blocks}); + bmap_buf.FromDevice(block_map_gpu.data(), bmap_bytes); // ================================================================== - // Sparsity statistics (always run, pure CPU read of HostTensor) + // Sparsity statistics (pure CPU, reads block_map HostTensor) // ================================================================== std::size_t total_blocks = 0; std::size_t active_blocks = 0; @@ -366,6 +398,14 @@ bool run_test(const ck_tile::ArgParser& arg_parser) std::cout << "Converting block map to VSA LUT (delta, CPU)..." << std::endl; auto vsa_lut_cpu = sparge::block_map_to_vsa_lut_delta(block_relation_onehot); + // DtoH: LUT + valid_block_num (only for validation) + sparge::VSALut vsa_lut_gpu{ + ck_tile::HostTensor({batch, nhead, num_q_blocks, num_k_blocks}), + ck_tile::HostTensor({batch, nhead, num_q_blocks}), + }; + lut_buf.FromDevice(vsa_lut_gpu.lut.data(), lut_bytes); + valid_buf.FromDevice(vsa_lut_gpu.valid_block_num.data(), valid_bytes); + // Validate block map std::cout << "\n--- Validating GPU block map vs CPU golden ---" << std::endl; { @@ -378,20 +418,16 @@ bool run_test(const ck_tile::ArgParser& arg_parser) { for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) { - if(block_map_gpu(b, h, qb, kb) != - block_relation_onehot(b, h, qb, kb)) + if(block_map_gpu(b, h, qb, kb) != block_relation_onehot(b, h, qb, kb)) { bmap_mismatches++; if(bmap_mismatches <= 10) { std::cout - << " block_map mismatch at [" << b << "," << h << "," - << qb << "," << kb - << "]: GPU=" - << static_cast(block_map_gpu(b, h, qb, kb)) - << " CPU=" - << static_cast( - block_relation_onehot(b, h, qb, kb)) + << " block_map mismatch at [" << b << "," << h << "," << qb + << "," << kb << "]: GPU=" + << static_cast(block_map_gpu(b, h, qb, kb)) << " CPU=" + << static_cast(block_relation_onehot(b, h, qb, kb)) << std::endl; } } @@ -429,28 +465,24 @@ bool run_test(const ck_tile::ArgParser& arg_parser) valid_mismatches++; if(valid_mismatches <= 5) { - std::cout - << " valid_block_num mismatch at [" << b << "," << h - << "," << qb - << "]: GPU=" << vsa_lut_gpu.valid_block_num(b, h, qb) - << " CPU=" << vsa_lut_cpu.valid_block_num(b, h, qb) - << std::endl; + std::cout << " valid_block_num mismatch at [" << b << "," << h + << "," << qb + << "]: GPU=" << vsa_lut_gpu.valid_block_num(b, h, qb) + << " CPU=" << vsa_lut_cpu.valid_block_num(b, h, qb) + << std::endl; } } for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) { - if(vsa_lut_gpu.lut(b, h, qb, kb) != - vsa_lut_cpu.lut(b, h, qb, kb)) + if(vsa_lut_gpu.lut(b, h, qb, kb) != vsa_lut_cpu.lut(b, h, qb, kb)) { lut_mismatches++; if(lut_mismatches <= 10) { std::cout << " LUT mismatch at [" << b << "," << h << "," << qb - << "," << kb - << "]: GPU=" << vsa_lut_gpu.lut(b, h, qb, kb) - << " CPU=" << vsa_lut_cpu.lut(b, h, qb, kb) - << std::endl; + << "," << kb << "]: GPU=" << vsa_lut_gpu.lut(b, h, qb, kb) + << " CPU=" << vsa_lut_cpu.lut(b, h, qb, kb) << std::endl; } } } diff --git a/example/ck_tile/50_sparse_attn/vsa_sparge_attention.cpp b/example/ck_tile/50_sparse_attn/vsa_sparge_attention.cpp deleted file mode 100644 index 5f9c2676ddb..00000000000 --- a/example/ck_tile/50_sparse_attn/vsa_sparge_attention.cpp +++ /dev/null @@ -1,195 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT -#include "vsa_sparge_attention.h" -#include "fmha_fwd_trek.hpp" -#include "ck_tile/core.hpp" -#include "ck_tile/host/host_tensor.hpp" -#include "ck_tile/host/device_memory.hpp" -#include - -template -ck_tile::HostTensor -vsa_sparge_attention(const ck_tile::HostTensor& TQ, - const ck_tile::HostTensor& TK, - const ck_tile::HostTensor& TV, - const ck_tile::HostTensor& TKV_block_idx, - const ck_tile::HostTensor& TKV_blocks, - ck_tile::HostTensor& Y, - int batch, - int nhead, - int nhead_k, - int seqlen_q, - int seqlen_k, - int hdim_q, - int hdim_v, - bool i_perm, - bool o_perm, - int max_seqlen_q, - int max_seqlen_k, - int log_level) -{ - static_assert(std::is_same_v || - std::is_same_v, - "VSA sparse attention supports fp16/bf16 only."); - std::string data_type = "fp16"; - if constexpr(std::is_same_v) - { - data_type = "bf16"; - } - - if(max_seqlen_q == 0) - max_seqlen_q = seqlen_q; - if(max_seqlen_k == 0) - max_seqlen_k = seqlen_k; - bool is_v_rowmajor = true; - float scale_s = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); - std::string msk_str = "0"; - mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k); - - const ck_tile::index_t shape_seqlen_q = seqlen_q; - const ck_tile::index_t shape_seqlen_k = seqlen_k; - - ck_tile::stream_config stream_config{nullptr, - false, // time_kernel - log_level, - 0, - 1, - false}; - - ck_tile::DeviceMem q_buf(TQ.get_element_space_size_in_bytes()); - ck_tile::DeviceMem k_buf(TK.get_element_space_size_in_bytes()); - ck_tile::DeviceMem v_buf(TV.get_element_space_size_in_bytes()); - ck_tile::DeviceMem lut_buf(TKV_block_idx.get_element_space_size_in_bytes()); - ck_tile::DeviceMem valid_block_num_buf(TKV_blocks.get_element_space_size_in_bytes()); - ck_tile::DeviceMem o_buf(Y.get_element_space_size_in_bytes()); - - q_buf.ToDevice(TQ.data()); - k_buf.ToDevice(TK.data()); - v_buf.ToDevice(TV.data()); - lut_buf.ToDevice(TKV_block_idx.data()); - valid_block_num_buf.ToDevice(TKV_blocks.data()); - - const auto init_args = [&](auto& args) { - assert(nhead % nhead_k == 0); - const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); - const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); - const ck_tile::index_t stride_v = [&]() { - if(is_v_rowmajor) - return i_perm ? hdim_v : nhead_k * hdim_v; - else - return (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k); - }(); - const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); - const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); - const ck_tile::index_t nhead_stride_k = i_perm ? shape_seqlen_k * hdim_q : hdim_q; - const ck_tile::index_t nhead_stride_v = [&]() { - if(is_v_rowmajor) - return i_perm ? shape_seqlen_k * hdim_v : hdim_v; - else - return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k; - }(); - const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); - const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); - const ck_tile::index_t batch_stride_k = nhead_k * shape_seqlen_k * hdim_q; - const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * shape_seqlen_k; - const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); - - args.q_ptr = q_buf.GetDeviceBuffer(); - args.k_ptr = k_buf.GetDeviceBuffer(); - args.v_ptr = v_buf.GetDeviceBuffer(); - args.lut_ptr = lut_buf.GetDeviceBuffer(); - args.valid_block_num_ptr = valid_block_num_buf.GetDeviceBuffer(); - - args.batch = batch; - args.seqlen_q = shape_seqlen_q; - args.hdim_q = hdim_q; - args.hdim_v = hdim_v; - args.nhead_q = nhead; - args.nhead_k = nhead_k; - - args.stride_q = stride_q; - args.stride_k = stride_k; - args.stride_v = stride_v; - args.nhead_stride_q = nhead_stride_q; - args.nhead_stride_k = nhead_stride_k; - args.nhead_stride_v = nhead_stride_v; - args.batch_stride_q = batch_stride_q; - args.batch_stride_k = batch_stride_k; - args.batch_stride_v = batch_stride_v; - - args.o_ptr = o_buf.GetDeviceBuffer(); - - args.seqlen_k = shape_seqlen_k; - args.max_seqlen_q = max_seqlen_q; - - args.scale_s = scale_s; - - args.stride_o = stride_o; - args.nhead_stride_o = nhead_stride_o; - args.batch_stride_o = batch_stride_o; - - args.window_size_left = mask.left; - args.window_size_right = mask.right; - args.mask_type = static_cast(mask.type); - }; - - const auto init_traits = [&](auto& traits) { - traits.hdim_q = hdim_q; - traits.hdim_v = hdim_v; - traits.data_type = data_type; - traits.is_v_rowmajor = is_v_rowmajor; - traits.mask_type = mask.type; - }; - - fmha_vsa_fwd_traits fmha_traits; - init_traits(fmha_traits); - - fmha_vsa_fwd_args args; - init_args(args); - - sparge_vsa_fwd(fmha_traits, args, stream_config); - - o_buf.FromDevice(Y.data(), Y.get_element_space_size_in_bytes()); - - return Y; -} - -template ck_tile::HostTensor -vsa_sparge_attention(const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - ck_tile::HostTensor&, - int, - int, - int, - int, - int, - int, - int, - bool, - bool, - int, - int, - int); - -template ck_tile::HostTensor -vsa_sparge_attention(const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - ck_tile::HostTensor&, - int, - int, - int, - int, - int, - int, - int, - bool, - bool, - int, - int, - int); diff --git a/example/ck_tile/50_sparse_attn/vsa_sparge_attention.h b/example/ck_tile/50_sparse_attn/vsa_sparge_attention.h deleted file mode 100644 index d51a7e8c00b..00000000000 --- a/example/ck_tile/50_sparse_attn/vsa_sparge_attention.h +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT -#pragma once -#include -#include -#include "ck_tile/core.hpp" -#include "ck_tile/host/host_tensor.hpp" - -template -ck_tile::HostTensor -vsa_sparge_attention(const ck_tile::HostTensor& TQ, - const ck_tile::HostTensor& TK, - const ck_tile::HostTensor& TV, - const ck_tile::HostTensor& TKV_block_idx, - const ck_tile::HostTensor& TKV_blocks, - ck_tile::HostTensor& Y, - int batch, - int nhead, - int nhead_k, - int seqlen_q, - int seqlen_k, - int hdim_q, - int hdim_v, - bool i_perm, - bool o_perm, - int max_seqlen_q, - int max_seqlen_k, - int log_level = 0);