From 863abad57d94b18648012e3cf8f67e8a90f563b6 Mon Sep 17 00:00:00 2001 From: Gino Lu Date: Mon, 20 Apr 2026 21:20:28 -0500 Subject: [PATCH] modify for fmha_fwd 192/128 --- example/ck_tile/01_fmha/CMakeLists.txt | 2 +- example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 13 ++++++++----- include/ck_tile/core/arch/arch.hpp | 12 ++++++++---- 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 35afb1181e0..1cd8a4d3fb1 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -47,7 +47,7 @@ set(FMHA_FWD_CODE_GEN_COMMON_ARGS ${CMAKE_CURRENT_LIST_DIR}/generate.py --targets ${FMHA_TARGETS_ARG} --api ${FMHA_FWD_APIS} - --optdim 32,64,80,128,256 + --optdim 32,64,80,128,192,256 # --filter fmha_fwd... ) set(FMHA_BWD_CODE_GEN_COMMON_ARGS diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index c64a19104e6..327fa1d3655 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -911,7 +911,8 @@ def check_tile_pipeline( and kernel_ctx.tile.F_bn0 == 128 ) or ( - (problem_ctx.hdim, problem_ctx.hdim_v) not in [(64, 64), (128, 128)] + (problem_ctx.hdim, problem_ctx.hdim_v) + not in [(64, 64), (128, 128), (192, 128)] ) ): return False @@ -1113,9 +1114,7 @@ def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: def get_pipelines( cls, dtype, hdim, hdim_v, receipt, mask_impl ) -> List[FmhaFwdPipeline]: - pipelines = KernelComponentFactoryGfx9.get_pipelines( - dtype, hdim, hdim_v, receipt, mask_impl - ) + pipelines = [] if dtype in cls._DT_FP16_BF16: qscale = "no" for logits, mask, bias, lse, dropout, skip, sink in itertools.product( @@ -1128,7 +1127,7 @@ def get_pipelines( ["t", "f"], ): if ( - (hdim, hdim_v) in [(64, 64), (128, 128)] + (hdim, hdim_v) in [(64, 64), (128, 128), (192, 128)] and logits == "f" and bias == "no" and dropout == "f" @@ -1164,6 +1163,10 @@ def get_pipelines( ): pipelines.append(FmhaFwdPipeline("qr", "col", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, "f", "f", sink)) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "col", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, "f", "f", sink)) # fmt: skip + + pipelines += KernelComponentFactoryGfx9.get_pipelines( + dtype, hdim, hdim_v, receipt, mask_impl + ) return pipelines diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index 417ec12c8ca..4874c51afcd 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -935,10 +935,14 @@ struct WaitcntLayoutGfx11 struct WaitcntLayoutLegacy { // FE'DC'BA98'7'654'3210 => VV'UU'LLLL'U'EEE'VVVV - CK_TILE_DEVICE static constexpr index_t VM_MASK = 0x3F; // split: low4 + hi2 + CK_TILE_DEVICE static constexpr index_t VM_MASK = 0x3F; // split: low4 + hi2 +#if defined(__gfx94__) + CK_TILE_DEVICE static constexpr index_t LGKM_MASK = 0x3F; // [13:8] gfx940+ extended to 6 bits +#else CK_TILE_DEVICE static constexpr index_t LGKM_MASK = 0x0F; // [11:8] - CK_TILE_DEVICE static constexpr index_t EXP_MASK = 0x07; // [6:4] - CK_TILE_DEVICE static constexpr bool HAS_EXP = true; +#endif + CK_TILE_DEVICE static constexpr index_t EXP_MASK = 0x07; // [6:4] + CK_TILE_DEVICE static constexpr bool HAS_EXP = true; CK_TILE_DEVICE static constexpr index_t pack_vm(index_t c) { @@ -968,7 +972,7 @@ struct waitcnt_arg CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0x3F; // 6 bits CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0x3F; // 6 bits CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0x0; // none -#elif defined(__gfx11__) +#elif defined(__gfx11__) || defined(__gfx94__) CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0x3F; // 6 bits CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0x3F; // 6 bits CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0x07; // 3 bits