From d4236de1ba6c254e8c074e2d5a94cea3fb1f3bd3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zolt=C3=A1n=20Lakatos?= <153429852+zsotakal@users.noreply.github.com> Date: Mon, 20 Apr 2026 12:25:45 +0000 Subject: [PATCH 1/2] [rocm-libraries] ROCm/rocm-libraries#4961 (commit 6c3969a) [CK] Remove code duplications in grouped gemm fixed nk implementations (#4961) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation Different flavours of grouped gemm fixed nk implemenations share the same block to tile mapping logic. Despite that the code responsible for it is duplicated in each device struct implementation. - Move `BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops` and `OffsettedBlockToCTileMapMLoops` from the device struct implementations to a common header file. - Use the generic Kernel Argument structures in xdl versions of the fixed nk. ## Technical Details ## Test Plan CI in general. Relevant test and examples are all fixed_nk versions of grouped gemm multiple D and ABD. ## Test Result ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- .../device_grouped_gemm_fixed_nk_common.hpp | 167 ++++++++++++++++ ...e_grouped_gemm_multi_abd_wmma_fixed_nk.hpp | 149 +-------------- ...ce_grouped_gemm_multi_abd_xdl_fixed_nk.hpp | 178 ++---------------- .../device_grouped_gemm_wmma_fixed_nk.hpp | 152 +-------------- .../impl/device_grouped_gemm_xdl_fixed_nk.hpp | 176 ++--------------- 5 files changed, 205 insertions(+), 617 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_fixed_nk_common.hpp diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_fixed_nk_common.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_fixed_nk_common.hpp new file mode 100644 index 00000000000..b2a642e768d --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_fixed_nk_common.hpp @@ -0,0 +1,167 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/utility/common_header.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +struct DeviceGroupedGemm_Fixed_NK_Common +{ + template + struct OffsettedBlockToCTileMapMLoops + { + using underlying_type = UnderlyingBlockToCTileMap; + + __host__ __device__ OffsettedBlockToCTileMapMLoops( + UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off = 0) + { + block_to_ctile_map_ = block_to_ctile_map; + block_start_ = block_start; + id_off_ = id_off; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto idx_bot = block_to_ctile_map_.CalculateBottomIndex( + make_multi_index(idx_top[Number<0>{}] - block_start_ + id_off_)); + + // Workarounds the fact that gridwise gemm implementations not supporting splitk require + // different index mapping. + if constexpr(HasSplitKSupport) + { + return make_tuple(idx_bot[Number<0>{}], idx_bot[Number<1>{}], idx_bot[Number<2>{}]); + } + else + { + return make_tuple(idx_bot[Number<1>{}], idx_bot[Number<2>{}]); + } + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); + } + + template + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n); + } + + UnderlyingBlockToCTileMap block_to_ctile_map_; + index_t block_start_; + index_t id_off_; + }; + + template + struct BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops + { + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops() = default; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& + operator=(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& + operator=(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(index_t M, + index_t N, + index_t KBatch, + index_t M01 = 8) + : M_(M), N_(N), KBatch_(KBatch), M01_(M01) + { + } + + template + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + const CGridDesc_M_N& c_grid_desc_m_n, index_t KBatch, index_t M01 = 8) + : BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), KBatch, M01) + { + } + + __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const + { + const auto M0 = math::integer_divide_ceil(M, MPerBlock); + const auto N0 = math::integer_divide_ceil(N, NPerBlock); + + return M0 * N0 * KBatch_; + } + + template + __host__ __device__ constexpr index_t + CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + { + return true; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto block_1d_id = idx_top[I0]; + + const auto M0 = math::integer_divide_ceil(M_, MPerBlock); + const auto N0 = math::integer_divide_ceil(N_, NPerBlock); + + block_1d_id = block_1d_id % (M0 * N0 * KBatch_); // hide groups + + const index_t idx_ksplit = block_1d_id / (M0 * N0); + block_1d_id = block_1d_id % (M0 * N0); + + index_t idx_N0 = block_1d_id % N0; + index_t idx_M0 = block_1d_id / N0; + + const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; + + index_t idx_M00 = idx_M0 / M01_; + index_t idx_M01 = idx_M0 % M01_; + index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; + + return make_tuple(idx_ksplit, + idx_N0_M01_local % M01_adapt + idx_M00 * M01_, + idx_N0_M01_local / M01_adapt); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const + { + return true; // always valid provided that user gets grid size from CalculateGridSize() + } + + private: + index_t M_; + index_t N_; + index_t KBatch_; + index_t M01_; + }; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp index ebe942b4c8b..9532f7e76a6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp @@ -21,6 +21,7 @@ #include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_fixed_nk_common.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -302,149 +303,11 @@ struct DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK false, false>; - // TODO: Block to tile mappings could potentially moved out to avoid code duplications between - // different device implementations. - - template - struct OffsettedBlockToCTileMapMLoops - { - using underlying_type = UnderlyingBlockToCTileMap; - - __host__ __device__ OffsettedBlockToCTileMapMLoops( - UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off = 0) - { - block_to_ctile_map_ = block_to_ctile_map; - block_start_ = block_start; - id_off_ = id_off; - } - - template - __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const - { - auto idx_bot = block_to_ctile_map_.CalculateBottomIndex( - make_multi_index(idx_top[Number<0>{}] - block_start_ + id_off_)); - - return make_tuple(idx_bot[Number<0>{}], idx_bot[Number<1>{}], idx_bot[Number<2>{}]); - } - - template - __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, - const CTileDim& c_tile_dim) const - { - return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); - } - - template - __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const - { - return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); - } - - template - __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const - { - return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n); - } - - UnderlyingBlockToCTileMap block_to_ctile_map_; - index_t block_start_; - index_t id_off_; - }; - - template - struct BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops - { - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops() = default; - - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& - operator=(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& - operator=(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; - - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(index_t M, - index_t N, - index_t KBatch, - index_t M01 = 8) - : M_(M), N_(N), KBatch_(KBatch), M01_(M01) - { - } - - template - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - const CGridDesc_M_N& c_grid_desc_m_n, index_t KBatch, index_t M01 = 8) - : BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), KBatch, M01) - { - } - - __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const - { - const auto M0 = math::integer_divide_ceil(M, MPerBlock); - const auto N0 = math::integer_divide_ceil(N, NPerBlock); - - return M0 * N0 * KBatch_; - } - - template - __host__ __device__ constexpr index_t - CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const - { - return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); - } - - template - __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const - { - return true; - } - - template - __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const - { - auto block_1d_id = idx_top[I0]; - - const auto M0 = math::integer_divide_ceil(M_, MPerBlock_); - const auto N0 = math::integer_divide_ceil(N_, NPerBlock_); - - block_1d_id = block_1d_id % (M0 * N0 * KBatch_); // hide groups - - const index_t idx_ksplit = block_1d_id / (M0 * N0); - block_1d_id = block_1d_id % (M0 * N0); - - index_t idx_N0 = block_1d_id % N0; - index_t idx_M0 = block_1d_id / N0; - - const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; - - index_t idx_M00 = idx_M0 / M01_; - index_t idx_M01 = idx_M0 % M01_; - index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; - - return make_tuple(idx_ksplit, - idx_N0_M01_local % M01_adapt + idx_M00 * M01_, - idx_N0_M01_local / M01_adapt); - } - - template - __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, - const CTileDim& /* c_tile_dim */) const - { - return true; // always valid provided that user gets grid size from CalculateGridSize() - } - - private: - index_t M_; - index_t N_; - index_t KBatch_; - index_t M01_; - }; - - using Block2ETileMap = BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops; - using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops; + using Block2ETileMap = + DeviceGroupedGemm_Fixed_NK_Common::BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops; + using GroupedGemmBlock2ETileMap = + DeviceGroupedGemm_Fixed_NK_Common::OffsettedBlockToCTileMapMLoops; static constexpr index_t DefaultKBatch = 1; // implementation only supports KBatch == 1 using KernelArgument = typename GridwiseGemm::Argument; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp index 36e66017c68..9978b62b173 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp @@ -12,6 +12,7 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_fixed_nk_common.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -268,167 +269,14 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK LoopSched>; using GridwiseGemm64 = GridwiseGemmBase; using GridwiseGemm32 = GridwiseGemmBase; - template - struct OffsettedBlockToCTileMapMLoops - { - using underlying_type = UnderlyingBlockToCTileMap; - - __host__ __device__ OffsettedBlockToCTileMapMLoops( - UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off = 0) - { - block_to_ctile_map_ = block_to_ctile_map; - block_start_ = block_start; - id_off_ = id_off; - } - - template - __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const - { - auto idx_bot = block_to_ctile_map_.CalculateBottomIndex( - make_multi_index(idx_top[Number<0>{}] - block_start_ + id_off_)); - - return make_tuple( - // idx_bot[Number<0>{}], - idx_bot[Number<1>{}], - idx_bot[Number<2>{}]); - } - - template - __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, - const CTileDim& c_tile_dim) const - { - return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); - } - - template - __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const - { - return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); - } - template - __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const - { - return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n); - } - - UnderlyingBlockToCTileMap block_to_ctile_map_; - index_t block_start_; - index_t id_off_; - }; - - template - struct BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops - { - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops() = default; - - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& - operator=(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& - operator=(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; - - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(index_t M, - index_t N, - index_t KBatch, - index_t M01 = 8) - : M_(M), N_(N), KBatch_(KBatch), M01_(M01) - { - } - - template - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - const CGridDesc_M_N& c_grid_desc_m_n, index_t KBatch, index_t M01 = 8) - : BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), KBatch, M01) - { - } - - __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const - { - const auto M0 = math::integer_divide_ceil(M, MPerBlock); - const auto N0 = math::integer_divide_ceil(N, NPerBlock); + using Block2ETileMap = + DeviceGroupedGemm_Fixed_NK_Common::BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops; + using GroupedGemmBlock2ETileMap = + DeviceGroupedGemm_Fixed_NK_Common::OffsettedBlockToCTileMapMLoops; - return M0 * N0 * KBatch_; - } - - template - __host__ __device__ constexpr index_t - CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const - { - return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); - } - - template - __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const - { - return true; - } - - template - __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const - { - auto block_1d_id = idx_top[I0]; - - const auto M0 = math::integer_divide_ceil(M_, MPerBlock_); - const auto N0 = math::integer_divide_ceil(N_, NPerBlock_); - - block_1d_id = block_1d_id % (M0 * N0 * KBatch_); // hide groups - - const index_t idx_ksplit = block_1d_id / (M0 * N0); - block_1d_id = block_1d_id % (M0 * N0); - - index_t idx_N0 = block_1d_id % N0; - index_t idx_M0 = block_1d_id / N0; - - const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; - - index_t idx_M00 = idx_M0 / M01_; - index_t idx_M01 = idx_M0 % M01_; - index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; - - return make_tuple(idx_ksplit, - idx_N0_M01_local % M01_adapt + idx_M00 * M01_, - idx_N0_M01_local / M01_adapt); - } - - template - __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, - const CTileDim& /* c_tile_dim */) const - { - return true; // always valid provided that user gets grid size from CalculateGridSize() - } - - private: - index_t M_; - index_t N_; - index_t KBatch_; - index_t M01_; - }; - - using Block2ETileMap = BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops; - using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops; - - struct GemmBiasTransKernelArg - { - // pointers - std::array as_ptr_; - std::array bs_ptr_; - std::array ds_ptr_; - void* e_ptr_; - - index_t M_, N_, K_; - std::array StrideAs_; - std::array StrideBs_; - std::array StrideDs_; - index_t StrideE_; - }; + using KernelArgument = GroupedGemmMultiABDKernelArgument; // Argument struct Argument : public BaseArgument @@ -537,7 +385,7 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK throw std::runtime_error("wrong! block_2_etile_map validation failed"); } - gemm_desc_kernel_arg_.push_back(GemmBiasTransKernelArg{ + gemm_desc_kernel_arg_.push_back(KernelArgument{ p_as_grid, p_bs_grid, p_ds_grid, @@ -556,7 +404,7 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK const auto e_grid_desc_sum_m_n = GridwiseGemm64::template MakeEGridDescriptor_M_N( - sum_of_m, gemm_desc_kernel_arg_[0].N_, gemm_desc_kernel_arg_[0].StrideE_); + sum_of_m, gemm_desc_kernel_arg_[0].N, gemm_desc_kernel_arg_[0].StrideE); const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_sum_m_n, 1}; @@ -570,7 +418,7 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK BElementwiseOperation b_element_op_; CDEElementwiseOperation c_element_op_; - std::vector gemm_desc_kernel_arg_; + std::vector gemm_desc_kernel_arg_; std::vector> a_mtx_mraw_kraw_; std::vector> b_mtx_nraw_kraw_; @@ -596,7 +444,7 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++) { - if(GridwiseGemm::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].K_) != + if(GridwiseGemm::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].K) != has_main_k_block_loop) { throw std::runtime_error("wrong! not all gemm has_main_k_block_loop"); @@ -729,7 +577,7 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK { if(get_warp_size() == 64) { - if(GridwiseGemm64::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].K_) != + if(GridwiseGemm64::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].K) != true) { supported = false; @@ -737,7 +585,7 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK } else { - if(GridwiseGemm32::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].K_) != + if(GridwiseGemm32::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].K) != true) { supported = false; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp index 8a9afc1733e..b652b7d4a0b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp @@ -20,6 +20,7 @@ #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_fixed_nk_common.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" namespace ck { @@ -328,152 +329,11 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK( 1, 1, 1, 1, 1))>; - template - struct OffsettedBlockToCTileMapMLoops - { - using underlying_type = UnderlyingBlockToCTileMap; - - __host__ __device__ OffsettedBlockToCTileMapMLoops( - UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off = 0) - { - block_to_ctile_map_ = block_to_ctile_map; - block_start_ = block_start; - id_off_ = id_off; - } - - template - __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const - { - auto idx_bot = block_to_ctile_map_.CalculateBottomIndex( - make_multi_index(idx_top[Number<0>{}] - block_start_ + id_off_)); - - return make_tuple(idx_bot[Number<0>{}], idx_bot[Number<1>{}], idx_bot[Number<2>{}]); - } - - template - __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, - const CTileDim& c_tile_dim) const - { - return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); - } - - template - __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const - { - return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); - } - - template - __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const - { - return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n); - } - - UnderlyingBlockToCTileMap block_to_ctile_map_; - index_t block_start_; - index_t id_off_; - }; - - template - struct BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops - { - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops() = default; - - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& - operator=(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& - operator=(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; - - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(index_t M, - index_t N, - index_t KBatch, - index_t M01 = 8) - : M_(M), N_(N), KBatch_(KBatch), M01_(M01) - { - } - - template - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - const CGridDesc_M_N& c_grid_desc_m_n, index_t KBatch, index_t M01 = 8) - : BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), KBatch, M01) - { - } - - __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const - { - const auto M0 = math::integer_divide_ceil(M, MPerBlock); - const auto N0 = math::integer_divide_ceil(N, NPerBlock); - - return M0 * N0 * KBatch_; - } - - template - __host__ __device__ constexpr index_t - CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const - { - return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); - } - - template - __host__ bool CheckValidity(const CGridDesc_M_N&) const - { - return true; - } - - template - __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const - { - auto block_1d_id = idx_top[I0]; - - const auto M0 = math::integer_divide_ceil(M_, MPerBlock_); - const auto N0 = math::integer_divide_ceil(N_, NPerBlock_); - - const auto total_tiles_per_group = M0 * N0 * KBatch_; - - // wrap block id into this group - block_1d_id = block_1d_id % total_tiles_per_group; - - const index_t idx_ksplit = block_1d_id / (M0 * N0); - block_1d_id = block_1d_id % (M0 * N0); - - index_t idx_N0 = block_1d_id % N0; - index_t idx_M0 = block_1d_id / N0; - - const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; - - index_t idx_M00 = idx_M0 / M01_; - index_t idx_M01 = idx_M0 % M01_; - index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; - - return make_tuple(idx_ksplit, - idx_N0_M01_local % M01_adapt + idx_M00 * M01_, - idx_N0_M01_local / M01_adapt); - } - - template - __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, - const CTileDim& /* c_tile_dim */) const - { - return true; // always valid provided that user gets grid size from CalculateGridSize() - } - - private: - index_t M_; - index_t N_; - index_t KBatch_; - index_t M01_; - }; - - using Block2ETileMap = BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops; - using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops; + using Block2ETileMap = + DeviceGroupedGemm_Fixed_NK_Common::BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops; + using GroupedGemmBlock2ETileMap = + DeviceGroupedGemm_Fixed_NK_Common::OffsettedBlockToCTileMapMLoops; static constexpr index_t DefaultKBatch = 1; using KernelArgument = typename GridwiseGemm::Argument; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp index 311a1c0bf46..1e61b5f8cbb 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp @@ -12,6 +12,7 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_fixed_nk_common.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -309,164 +310,13 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK; using GridwiseGemm32 = GridwiseGemmBase; - template - struct OffsettedBlockToCTileMapMLoops - { - using underlying_type = UnderlyingBlockToCTileMap; - - __host__ __device__ OffsettedBlockToCTileMapMLoops( - UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off = 0) - { - block_to_ctile_map_ = block_to_ctile_map; - block_start_ = block_start; - id_off_ = id_off; - } - - template - __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const - { - auto idx_bot = block_to_ctile_map_.CalculateBottomIndex( - make_multi_index(idx_top[Number<0>{}] - block_start_ + id_off_)); - - return make_tuple(idx_bot[Number<0>{}], idx_bot[Number<1>{}], idx_bot[Number<2>{}]); - } - - template - __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, - const CTileDim& c_tile_dim) const - { - return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); - } - - template - __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const - { - return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); - } - - template - __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const - { - return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n); - } - - UnderlyingBlockToCTileMap block_to_ctile_map_; - index_t block_start_; - index_t id_off_; - }; - - template - struct BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops - { - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops() = default; - - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& - operator=(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& - operator=(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; - - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(index_t M, - index_t N, - index_t KBatch, - index_t M01 = 8) - : M_(M), N_(N), KBatch_(KBatch), M01_(M01) - { - } - - template - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - const CGridDesc_M_N& c_grid_desc_m_n, index_t KBatch, index_t M01 = 8) - : BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), KBatch, M01) - { - } - - __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const - { - const auto M0 = math::integer_divide_ceil(M, MPerBlock); - const auto N0 = math::integer_divide_ceil(N, NPerBlock); - - return M0 * N0 * KBatch_; - } - - template - __host__ __device__ constexpr index_t - CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const - { - return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); - } - - template - __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const - { - return true; - } - - template - __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const - { - auto block_1d_id = idx_top[I0]; - - const auto M0 = math::integer_divide_ceil(M_, MPerBlock_); - const auto N0 = math::integer_divide_ceil(N_, NPerBlock_); - - block_1d_id = block_1d_id % (M0 * N0 * KBatch_); // hide groups - - const index_t idx_ksplit = block_1d_id / (M0 * N0); - block_1d_id = block_1d_id % (M0 * N0); + using Block2ETileMap = + DeviceGroupedGemm_Fixed_NK_Common::BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops; + using GroupedGemmBlock2ETileMap = + DeviceGroupedGemm_Fixed_NK_Common::OffsettedBlockToCTileMapMLoops; - index_t idx_N0 = block_1d_id % N0; - index_t idx_M0 = block_1d_id / N0; - - const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; - - index_t idx_M00 = idx_M0 / M01_; - index_t idx_M01 = idx_M0 % M01_; - index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; - - return make_tuple(idx_ksplit, - idx_N0_M01_local % M01_adapt + idx_M00 * M01_, - idx_N0_M01_local / M01_adapt); - } - - template - __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, - const CTileDim& /* c_tile_dim */) const - { - return true; // always valid provided that user gets grid size from CalculateGridSize() - } - - private: - index_t M_; - index_t N_; - index_t KBatch_; - index_t M01_; - }; - - using Block2ETileMap = BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops; - using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops; - - // TODO: replace with GroupedGemmKernelArgument - struct GemmBiasTransKernelArg - { - // pointers - const void* a_ptr_; - const void* b_ptr_; - std::array ds_ptr_; - void* e_ptr_; - - index_t M_, N_, K_; - index_t StrideA_, StrideB_; - std::array StrideDs_; - index_t StrideE_; - }; + using KernelArgument = GroupedGemmKernelArgument; // Argument struct Argument : public BaseArgument @@ -484,8 +334,8 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK( @@ -626,7 +476,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK( - sum_of_m, gemm_desc_kernel_arg_[0].N_, gemm_desc_kernel_arg_[0].StrideE_); + sum_of_m, gemm_desc_kernel_arg_[0].N, gemm_desc_kernel_arg_[0].StrideE); const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_sum_m_n, 1}; @@ -659,7 +509,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK gemm_desc_kernel_arg_; + std::vector gemm_desc_kernel_arg_; std::vector> a_mtx_mraw_kraw_; std::vector> b_mtx_nraw_kraw_; @@ -686,7 +536,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK Date: Mon, 20 Apr 2026 15:33:18 +0000 Subject: [PATCH 2/2] [rocm-libraries] ROCm/rocm-libraries#6168 (commit 2968835) [CK][CK Tile] Clamp element space size to max int32 value (#6168) ## Motivation Fix oob check by clamping element space size to avoid overflow when tensor is larger than 2GB. ## Technical Details - It is possible that tensor could be larger than 2GB but offsets no, so element space size must be clamped to 2GB if value is larger. ## Test Plan CI ## Test Result Pending ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. https://github.com/ROCm/composable_kernel/issues/3722 Co-authored-by: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> --- .../ck_tile/core/tensor/tensor_descriptor.hpp | 23 +++++++++++++------ 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/include/ck_tile/core/tensor/tensor_descriptor.hpp b/include/ck_tile/core/tensor/tensor_descriptor.hpp index cda2fb0bb52..0ec975441f4 100644 --- a/include/ck_tile/core/tensor/tensor_descriptor.hpp +++ b/include/ck_tile/core/tensor/tensor_descriptor.hpp @@ -236,12 +236,13 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, namespace detail { template -CK_TILE_HOST_DEVICE constexpr auto calculate_element_space_size_impl(const Lengths& lengths, - const Strides& strides, - number i, - AccOld acc_old) +CK_TILE_HOST_DEVICE constexpr long_index_t calculate_element_space_size_impl(const Lengths& lengths, + const Strides& strides, + number i, + AccOld acc_old) { - auto acc_new = acc_old + (lengths[i] - number<1>{}) * strides[i]; + long_index_t acc_new = acc_old + static_cast(lengths[i] - number<1>{}) * + static_cast(strides[i]); if constexpr(i.value < Lengths::size() - 1) { @@ -287,8 +288,12 @@ make_naive_tensor_descriptor(const tuple& lengths, constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{}; - const auto element_space_size = + const long_index_t element_space_size_long = detail::calculate_element_space_size_impl(lengths, strides, number<0>{}, long_number<1>{}); + constexpr long_index_t element_space_size_clamp_value = + static_cast(std::numeric_limits::max()); + const index_t element_space_size = + static_cast(std::min(element_space_size_long, element_space_size_clamp_value)); using GuaranteedVectorLengths = typename sequence_merge::type, @@ -323,8 +328,12 @@ make_naive_tensor_descriptor_with_offset(const tuple& lengths, number = number<-1>{}) { const auto desc_0 = [&]() { - const auto element_space_size = detail::calculate_element_space_size_impl( + const auto element_space_size_long = detail::calculate_element_space_size_impl( lengths, strides, number<0>{}, long_number<1>{}); + constexpr long_index_t element_space_size_clamp_value = + static_cast(std::numeric_limits::max()); + const index_t element_space_size = + static_cast(std::min(element_space_size_long, element_space_size_clamp_value)); const auto transforms = make_tuple(make_offset_transform(element_space_size, os));