Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion example/ck_tile/01_fmha/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ set(FMHA_FWD_CODE_GEN_COMMON_ARGS
${CMAKE_CURRENT_LIST_DIR}/generate.py
--targets ${FMHA_TARGETS_ARG}
--api ${FMHA_FWD_APIS}
--optdim 32,64,80,128,256
--optdim 32,64,80,128,192,256
# --filter fmha_fwd...
)
set(FMHA_BWD_CODE_GEN_COMMON_ARGS
Expand Down
13 changes: 8 additions & 5 deletions example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,7 +911,8 @@ def check_tile_pipeline(
and kernel_ctx.tile.F_bn0 == 128
)
or (
(problem_ctx.hdim, problem_ctx.hdim_v) not in [(64, 64), (128, 128)]
(problem_ctx.hdim, problem_ctx.hdim_v)
not in [(64, 64), (128, 128), (192, 128)]
)
):
return False
Expand Down Expand Up @@ -1113,9 +1114,7 @@ def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
def get_pipelines(
cls, dtype, hdim, hdim_v, receipt, mask_impl
) -> List[FmhaFwdPipeline]:
pipelines = KernelComponentFactoryGfx9.get_pipelines(
dtype, hdim, hdim_v, receipt, mask_impl
)
pipelines = []
if dtype in cls._DT_FP16_BF16:
qscale = "no"
for logits, mask, bias, lse, dropout, skip, sink in itertools.product(
Expand All @@ -1128,7 +1127,7 @@ def get_pipelines(
["t", "f"],
):
if (
(hdim, hdim_v) in [(64, 64), (128, 128)]
(hdim, hdim_v) in [(64, 64), (128, 128), (192, 128)]
and logits == "f"
and bias == "no"
and dropout == "f"
Expand Down Expand Up @@ -1164,6 +1163,10 @@ def get_pipelines(
):
pipelines.append(FmhaFwdPipeline("qr", "col", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, "f", "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "col", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, "f", "f", sink)) # fmt: skip

pipelines += KernelComponentFactoryGfx9.get_pipelines(
dtype, hdim, hdim_v, receipt, mask_impl
)
return pipelines


Expand Down
12 changes: 8 additions & 4 deletions include/ck_tile/core/arch/arch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -935,10 +935,14 @@ struct WaitcntLayoutGfx11

struct WaitcntLayoutLegacy
{ // FE'DC'BA98'7'654'3210 => VV'UU'LLLL'U'EEE'VVVV
CK_TILE_DEVICE static constexpr index_t VM_MASK = 0x3F; // split: low4 + hi2
CK_TILE_DEVICE static constexpr index_t VM_MASK = 0x3F; // split: low4 + hi2
#if defined(__gfx94__)
CK_TILE_DEVICE static constexpr index_t LGKM_MASK = 0x3F; // [13:8] gfx940+ extended to 6 bits
#else
CK_TILE_DEVICE static constexpr index_t LGKM_MASK = 0x0F; // [11:8]
CK_TILE_DEVICE static constexpr index_t EXP_MASK = 0x07; // [6:4]
CK_TILE_DEVICE static constexpr bool HAS_EXP = true;
#endif
CK_TILE_DEVICE static constexpr index_t EXP_MASK = 0x07; // [6:4]
CK_TILE_DEVICE static constexpr bool HAS_EXP = true;

CK_TILE_DEVICE static constexpr index_t pack_vm(index_t c)
{
Expand Down Expand Up @@ -968,7 +972,7 @@ struct waitcnt_arg
CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0x3F; // 6 bits
CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0x3F; // 6 bits
CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0x0; // none
#elif defined(__gfx11__)
#elif defined(__gfx11__) || defined(__gfx94__)
CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0x3F; // 6 bits
CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0x3F; // 6 bits
CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0x07; // 3 bits
Expand Down