diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_fixed_nk_common.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_fixed_nk_common.hpp new file mode 100644 index 0000000000..b2a642e768 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_fixed_nk_common.hpp @@ -0,0 +1,167 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/utility/common_header.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +struct DeviceGroupedGemm_Fixed_NK_Common +{ + template + struct OffsettedBlockToCTileMapMLoops + { + using underlying_type = UnderlyingBlockToCTileMap; + + __host__ __device__ OffsettedBlockToCTileMapMLoops( + UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off = 0) + { + block_to_ctile_map_ = block_to_ctile_map; + block_start_ = block_start; + id_off_ = id_off; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto idx_bot = block_to_ctile_map_.CalculateBottomIndex( + make_multi_index(idx_top[Number<0>{}] - block_start_ + id_off_)); + + // Workarounds the fact that gridwise gemm implementations not supporting splitk require + // different index mapping. + if constexpr(HasSplitKSupport) + { + return make_tuple(idx_bot[Number<0>{}], idx_bot[Number<1>{}], idx_bot[Number<2>{}]); + } + else + { + return make_tuple(idx_bot[Number<1>{}], idx_bot[Number<2>{}]); + } + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); + } + + template + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n); + } + + UnderlyingBlockToCTileMap block_to_ctile_map_; + index_t block_start_; + index_t id_off_; + }; + + template + struct BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops + { + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops() = default; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& + operator=(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& + operator=(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(index_t M, + index_t N, + index_t KBatch, + index_t M01 = 8) + : M_(M), N_(N), KBatch_(KBatch), M01_(M01) + { + } + + template + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + const CGridDesc_M_N& c_grid_desc_m_n, index_t KBatch, index_t M01 = 8) + : BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), KBatch, M01) + { + } + + __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const + { + const auto M0 = math::integer_divide_ceil(M, MPerBlock); + const auto N0 = math::integer_divide_ceil(N, NPerBlock); + + return M0 * N0 * KBatch_; + } + + template + __host__ __device__ constexpr index_t + CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + { + return true; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto block_1d_id = idx_top[I0]; + + const auto M0 = math::integer_divide_ceil(M_, MPerBlock); + const auto N0 = math::integer_divide_ceil(N_, NPerBlock); + + block_1d_id = block_1d_id % (M0 * N0 * KBatch_); // hide groups + + const index_t idx_ksplit = block_1d_id / (M0 * N0); + block_1d_id = block_1d_id % (M0 * N0); + + index_t idx_N0 = block_1d_id % N0; + index_t idx_M0 = block_1d_id / N0; + + const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; + + index_t idx_M00 = idx_M0 / M01_; + index_t idx_M01 = idx_M0 % M01_; + index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; + + return make_tuple(idx_ksplit, + idx_N0_M01_local % M01_adapt + idx_M00 * M01_, + idx_N0_M01_local / M01_adapt); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const + { + return true; // always valid provided that user gets grid size from CalculateGridSize() + } + + private: + index_t M_; + index_t N_; + index_t KBatch_; + index_t M01_; + }; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp index ebe942b4c8..9532f7e76a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp @@ -21,6 +21,7 @@ #include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_fixed_nk_common.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -302,149 +303,11 @@ struct DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK false, false>; - // TODO: Block to tile mappings could potentially moved out to avoid code duplications between - // different device implementations. - - template - struct OffsettedBlockToCTileMapMLoops - { - using underlying_type = UnderlyingBlockToCTileMap; - - __host__ __device__ OffsettedBlockToCTileMapMLoops( - UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off = 0) - { - block_to_ctile_map_ = block_to_ctile_map; - block_start_ = block_start; - id_off_ = id_off; - } - - template - __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const - { - auto idx_bot = block_to_ctile_map_.CalculateBottomIndex( - make_multi_index(idx_top[Number<0>{}] - block_start_ + id_off_)); - - return make_tuple(idx_bot[Number<0>{}], idx_bot[Number<1>{}], idx_bot[Number<2>{}]); - } - - template - __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, - const CTileDim& c_tile_dim) const - { - return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); - } - - template - __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const - { - return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); - } - - template - __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const - { - return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n); - } - - UnderlyingBlockToCTileMap block_to_ctile_map_; - index_t block_start_; - index_t id_off_; - }; - - template - struct BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops - { - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops() = default; - - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& - operator=(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& - operator=(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; - - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(index_t M, - index_t N, - index_t KBatch, - index_t M01 = 8) - : M_(M), N_(N), KBatch_(KBatch), M01_(M01) - { - } - - template - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - const CGridDesc_M_N& c_grid_desc_m_n, index_t KBatch, index_t M01 = 8) - : BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), KBatch, M01) - { - } - - __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const - { - const auto M0 = math::integer_divide_ceil(M, MPerBlock); - const auto N0 = math::integer_divide_ceil(N, NPerBlock); - - return M0 * N0 * KBatch_; - } - - template - __host__ __device__ constexpr index_t - CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const - { - return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); - } - - template - __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const - { - return true; - } - - template - __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const - { - auto block_1d_id = idx_top[I0]; - - const auto M0 = math::integer_divide_ceil(M_, MPerBlock_); - const auto N0 = math::integer_divide_ceil(N_, NPerBlock_); - - block_1d_id = block_1d_id % (M0 * N0 * KBatch_); // hide groups - - const index_t idx_ksplit = block_1d_id / (M0 * N0); - block_1d_id = block_1d_id % (M0 * N0); - - index_t idx_N0 = block_1d_id % N0; - index_t idx_M0 = block_1d_id / N0; - - const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; - - index_t idx_M00 = idx_M0 / M01_; - index_t idx_M01 = idx_M0 % M01_; - index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; - - return make_tuple(idx_ksplit, - idx_N0_M01_local % M01_adapt + idx_M00 * M01_, - idx_N0_M01_local / M01_adapt); - } - - template - __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, - const CTileDim& /* c_tile_dim */) const - { - return true; // always valid provided that user gets grid size from CalculateGridSize() - } - - private: - index_t M_; - index_t N_; - index_t KBatch_; - index_t M01_; - }; - - using Block2ETileMap = BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops; - using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops; + using Block2ETileMap = + DeviceGroupedGemm_Fixed_NK_Common::BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops; + using GroupedGemmBlock2ETileMap = + DeviceGroupedGemm_Fixed_NK_Common::OffsettedBlockToCTileMapMLoops; static constexpr index_t DefaultKBatch = 1; // implementation only supports KBatch == 1 using KernelArgument = typename GridwiseGemm::Argument; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp index 36e66017c6..9978b62b17 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp @@ -12,6 +12,7 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_fixed_nk_common.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -268,167 +269,14 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK LoopSched>; using GridwiseGemm64 = GridwiseGemmBase; using GridwiseGemm32 = GridwiseGemmBase; - template - struct OffsettedBlockToCTileMapMLoops - { - using underlying_type = UnderlyingBlockToCTileMap; - - __host__ __device__ OffsettedBlockToCTileMapMLoops( - UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off = 0) - { - block_to_ctile_map_ = block_to_ctile_map; - block_start_ = block_start; - id_off_ = id_off; - } - - template - __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const - { - auto idx_bot = block_to_ctile_map_.CalculateBottomIndex( - make_multi_index(idx_top[Number<0>{}] - block_start_ + id_off_)); - - return make_tuple( - // idx_bot[Number<0>{}], - idx_bot[Number<1>{}], - idx_bot[Number<2>{}]); - } - - template - __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, - const CTileDim& c_tile_dim) const - { - return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); - } - - template - __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const - { - return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); - } - template - __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const - { - return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n); - } - - UnderlyingBlockToCTileMap block_to_ctile_map_; - index_t block_start_; - index_t id_off_; - }; - - template - struct BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops - { - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops() = default; - - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& - operator=(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& - operator=(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; - - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(index_t M, - index_t N, - index_t KBatch, - index_t M01 = 8) - : M_(M), N_(N), KBatch_(KBatch), M01_(M01) - { - } - - template - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - const CGridDesc_M_N& c_grid_desc_m_n, index_t KBatch, index_t M01 = 8) - : BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), KBatch, M01) - { - } - - __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const - { - const auto M0 = math::integer_divide_ceil(M, MPerBlock); - const auto N0 = math::integer_divide_ceil(N, NPerBlock); + using Block2ETileMap = + DeviceGroupedGemm_Fixed_NK_Common::BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops; + using GroupedGemmBlock2ETileMap = + DeviceGroupedGemm_Fixed_NK_Common::OffsettedBlockToCTileMapMLoops; - return M0 * N0 * KBatch_; - } - - template - __host__ __device__ constexpr index_t - CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const - { - return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); - } - - template - __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const - { - return true; - } - - template - __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const - { - auto block_1d_id = idx_top[I0]; - - const auto M0 = math::integer_divide_ceil(M_, MPerBlock_); - const auto N0 = math::integer_divide_ceil(N_, NPerBlock_); - - block_1d_id = block_1d_id % (M0 * N0 * KBatch_); // hide groups - - const index_t idx_ksplit = block_1d_id / (M0 * N0); - block_1d_id = block_1d_id % (M0 * N0); - - index_t idx_N0 = block_1d_id % N0; - index_t idx_M0 = block_1d_id / N0; - - const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; - - index_t idx_M00 = idx_M0 / M01_; - index_t idx_M01 = idx_M0 % M01_; - index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; - - return make_tuple(idx_ksplit, - idx_N0_M01_local % M01_adapt + idx_M00 * M01_, - idx_N0_M01_local / M01_adapt); - } - - template - __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, - const CTileDim& /* c_tile_dim */) const - { - return true; // always valid provided that user gets grid size from CalculateGridSize() - } - - private: - index_t M_; - index_t N_; - index_t KBatch_; - index_t M01_; - }; - - using Block2ETileMap = BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops; - using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops; - - struct GemmBiasTransKernelArg - { - // pointers - std::array as_ptr_; - std::array bs_ptr_; - std::array ds_ptr_; - void* e_ptr_; - - index_t M_, N_, K_; - std::array StrideAs_; - std::array StrideBs_; - std::array StrideDs_; - index_t StrideE_; - }; + using KernelArgument = GroupedGemmMultiABDKernelArgument; // Argument struct Argument : public BaseArgument @@ -537,7 +385,7 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK throw std::runtime_error("wrong! block_2_etile_map validation failed"); } - gemm_desc_kernel_arg_.push_back(GemmBiasTransKernelArg{ + gemm_desc_kernel_arg_.push_back(KernelArgument{ p_as_grid, p_bs_grid, p_ds_grid, @@ -556,7 +404,7 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK const auto e_grid_desc_sum_m_n = GridwiseGemm64::template MakeEGridDescriptor_M_N( - sum_of_m, gemm_desc_kernel_arg_[0].N_, gemm_desc_kernel_arg_[0].StrideE_); + sum_of_m, gemm_desc_kernel_arg_[0].N, gemm_desc_kernel_arg_[0].StrideE); const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_sum_m_n, 1}; @@ -570,7 +418,7 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK BElementwiseOperation b_element_op_; CDEElementwiseOperation c_element_op_; - std::vector gemm_desc_kernel_arg_; + std::vector gemm_desc_kernel_arg_; std::vector> a_mtx_mraw_kraw_; std::vector> b_mtx_nraw_kraw_; @@ -596,7 +444,7 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++) { - if(GridwiseGemm::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].K_) != + if(GridwiseGemm::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].K) != has_main_k_block_loop) { throw std::runtime_error("wrong! not all gemm has_main_k_block_loop"); @@ -729,7 +577,7 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK { if(get_warp_size() == 64) { - if(GridwiseGemm64::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].K_) != + if(GridwiseGemm64::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].K) != true) { supported = false; @@ -737,7 +585,7 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK } else { - if(GridwiseGemm32::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].K_) != + if(GridwiseGemm32::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].K) != true) { supported = false; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp index 8a9afc1733..b652b7d4a0 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp @@ -20,6 +20,7 @@ #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_fixed_nk_common.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" namespace ck { @@ -328,152 +329,11 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK( 1, 1, 1, 1, 1))>; - template - struct OffsettedBlockToCTileMapMLoops - { - using underlying_type = UnderlyingBlockToCTileMap; - - __host__ __device__ OffsettedBlockToCTileMapMLoops( - UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off = 0) - { - block_to_ctile_map_ = block_to_ctile_map; - block_start_ = block_start; - id_off_ = id_off; - } - - template - __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const - { - auto idx_bot = block_to_ctile_map_.CalculateBottomIndex( - make_multi_index(idx_top[Number<0>{}] - block_start_ + id_off_)); - - return make_tuple(idx_bot[Number<0>{}], idx_bot[Number<1>{}], idx_bot[Number<2>{}]); - } - - template - __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, - const CTileDim& c_tile_dim) const - { - return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); - } - - template - __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const - { - return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); - } - - template - __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const - { - return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n); - } - - UnderlyingBlockToCTileMap block_to_ctile_map_; - index_t block_start_; - index_t id_off_; - }; - - template - struct BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops - { - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops() = default; - - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& - operator=(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& - operator=(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; - - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(index_t M, - index_t N, - index_t KBatch, - index_t M01 = 8) - : M_(M), N_(N), KBatch_(KBatch), M01_(M01) - { - } - - template - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - const CGridDesc_M_N& c_grid_desc_m_n, index_t KBatch, index_t M01 = 8) - : BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), KBatch, M01) - { - } - - __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const - { - const auto M0 = math::integer_divide_ceil(M, MPerBlock); - const auto N0 = math::integer_divide_ceil(N, NPerBlock); - - return M0 * N0 * KBatch_; - } - - template - __host__ __device__ constexpr index_t - CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const - { - return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); - } - - template - __host__ bool CheckValidity(const CGridDesc_M_N&) const - { - return true; - } - - template - __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const - { - auto block_1d_id = idx_top[I0]; - - const auto M0 = math::integer_divide_ceil(M_, MPerBlock_); - const auto N0 = math::integer_divide_ceil(N_, NPerBlock_); - - const auto total_tiles_per_group = M0 * N0 * KBatch_; - - // wrap block id into this group - block_1d_id = block_1d_id % total_tiles_per_group; - - const index_t idx_ksplit = block_1d_id / (M0 * N0); - block_1d_id = block_1d_id % (M0 * N0); - - index_t idx_N0 = block_1d_id % N0; - index_t idx_M0 = block_1d_id / N0; - - const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; - - index_t idx_M00 = idx_M0 / M01_; - index_t idx_M01 = idx_M0 % M01_; - index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; - - return make_tuple(idx_ksplit, - idx_N0_M01_local % M01_adapt + idx_M00 * M01_, - idx_N0_M01_local / M01_adapt); - } - - template - __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, - const CTileDim& /* c_tile_dim */) const - { - return true; // always valid provided that user gets grid size from CalculateGridSize() - } - - private: - index_t M_; - index_t N_; - index_t KBatch_; - index_t M01_; - }; - - using Block2ETileMap = BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops; - using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops; + using Block2ETileMap = + DeviceGroupedGemm_Fixed_NK_Common::BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops; + using GroupedGemmBlock2ETileMap = + DeviceGroupedGemm_Fixed_NK_Common::OffsettedBlockToCTileMapMLoops; static constexpr index_t DefaultKBatch = 1; using KernelArgument = typename GridwiseGemm::Argument; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp index 311a1c0bf4..1e61b5f8cb 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp @@ -12,6 +12,7 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_fixed_nk_common.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -309,164 +310,13 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK; using GridwiseGemm32 = GridwiseGemmBase; - template - struct OffsettedBlockToCTileMapMLoops - { - using underlying_type = UnderlyingBlockToCTileMap; - - __host__ __device__ OffsettedBlockToCTileMapMLoops( - UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off = 0) - { - block_to_ctile_map_ = block_to_ctile_map; - block_start_ = block_start; - id_off_ = id_off; - } - - template - __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const - { - auto idx_bot = block_to_ctile_map_.CalculateBottomIndex( - make_multi_index(idx_top[Number<0>{}] - block_start_ + id_off_)); - - return make_tuple(idx_bot[Number<0>{}], idx_bot[Number<1>{}], idx_bot[Number<2>{}]); - } - - template - __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, - const CTileDim& c_tile_dim) const - { - return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); - } - - template - __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const - { - return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); - } - - template - __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const - { - return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n); - } - - UnderlyingBlockToCTileMap block_to_ctile_map_; - index_t block_start_; - index_t id_off_; - }; - - template - struct BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops - { - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops() = default; - - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& - operator=(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& - operator=(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; - - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(index_t M, - index_t N, - index_t KBatch, - index_t M01 = 8) - : M_(M), N_(N), KBatch_(KBatch), M01_(M01) - { - } - - template - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - const CGridDesc_M_N& c_grid_desc_m_n, index_t KBatch, index_t M01 = 8) - : BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), KBatch, M01) - { - } - - __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const - { - const auto M0 = math::integer_divide_ceil(M, MPerBlock); - const auto N0 = math::integer_divide_ceil(N, NPerBlock); - - return M0 * N0 * KBatch_; - } - - template - __host__ __device__ constexpr index_t - CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const - { - return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); - } - - template - __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const - { - return true; - } - - template - __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const - { - auto block_1d_id = idx_top[I0]; - - const auto M0 = math::integer_divide_ceil(M_, MPerBlock_); - const auto N0 = math::integer_divide_ceil(N_, NPerBlock_); - - block_1d_id = block_1d_id % (M0 * N0 * KBatch_); // hide groups - - const index_t idx_ksplit = block_1d_id / (M0 * N0); - block_1d_id = block_1d_id % (M0 * N0); + using Block2ETileMap = + DeviceGroupedGemm_Fixed_NK_Common::BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops; + using GroupedGemmBlock2ETileMap = + DeviceGroupedGemm_Fixed_NK_Common::OffsettedBlockToCTileMapMLoops; - index_t idx_N0 = block_1d_id % N0; - index_t idx_M0 = block_1d_id / N0; - - const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; - - index_t idx_M00 = idx_M0 / M01_; - index_t idx_M01 = idx_M0 % M01_; - index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; - - return make_tuple(idx_ksplit, - idx_N0_M01_local % M01_adapt + idx_M00 * M01_, - idx_N0_M01_local / M01_adapt); - } - - template - __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, - const CTileDim& /* c_tile_dim */) const - { - return true; // always valid provided that user gets grid size from CalculateGridSize() - } - - private: - index_t M_; - index_t N_; - index_t KBatch_; - index_t M01_; - }; - - using Block2ETileMap = BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops; - using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops; - - // TODO: replace with GroupedGemmKernelArgument - struct GemmBiasTransKernelArg - { - // pointers - const void* a_ptr_; - const void* b_ptr_; - std::array ds_ptr_; - void* e_ptr_; - - index_t M_, N_, K_; - index_t StrideA_, StrideB_; - std::array StrideDs_; - index_t StrideE_; - }; + using KernelArgument = GroupedGemmKernelArgument; // Argument struct Argument : public BaseArgument @@ -484,8 +334,8 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK( @@ -626,7 +476,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK( - sum_of_m, gemm_desc_kernel_arg_[0].N_, gemm_desc_kernel_arg_[0].StrideE_); + sum_of_m, gemm_desc_kernel_arg_[0].N, gemm_desc_kernel_arg_[0].StrideE); const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_sum_m_n, 1}; @@ -659,7 +509,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK gemm_desc_kernel_arg_; + std::vector gemm_desc_kernel_arg_; std::vector> a_mtx_mraw_kraw_; std::vector> b_mtx_nraw_kraw_; @@ -686,7 +536,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK -CK_TILE_HOST_DEVICE constexpr auto calculate_element_space_size_impl(const Lengths& lengths, - const Strides& strides, - number i, - AccOld acc_old) +CK_TILE_HOST_DEVICE constexpr long_index_t calculate_element_space_size_impl(const Lengths& lengths, + const Strides& strides, + number i, + AccOld acc_old) { - auto acc_new = acc_old + (lengths[i] - number<1>{}) * strides[i]; + long_index_t acc_new = acc_old + static_cast(lengths[i] - number<1>{}) * + static_cast(strides[i]); if constexpr(i.value < Lengths::size() - 1) { @@ -287,8 +288,12 @@ make_naive_tensor_descriptor(const tuple& lengths, constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{}; - const auto element_space_size = + const long_index_t element_space_size_long = detail::calculate_element_space_size_impl(lengths, strides, number<0>{}, long_number<1>{}); + constexpr long_index_t element_space_size_clamp_value = + static_cast(std::numeric_limits::max()); + const index_t element_space_size = + static_cast(std::min(element_space_size_long, element_space_size_clamp_value)); using GuaranteedVectorLengths = typename sequence_merge::type, @@ -323,8 +328,12 @@ make_naive_tensor_descriptor_with_offset(const tuple& lengths, number = number<-1>{}) { const auto desc_0 = [&]() { - const auto element_space_size = detail::calculate_element_space_size_impl( + const auto element_space_size_long = detail::calculate_element_space_size_impl( lengths, strides, number<0>{}, long_number<1>{}); + constexpr long_index_t element_space_size_clamp_value = + static_cast(std::numeric_limits::max()); + const index_t element_space_size = + static_cast(std::min(element_space_size_long, element_space_size_clamp_value)); const auto transforms = make_tuple(make_offset_transform(element_space_size, os));