Skip to content

Commit fa404ca

Browse files
committed
Add fp8 FMHA v3 instances
1 parent 941c7e6 commit fa404ca

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,6 +1048,10 @@ def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
10481048
if (128, 128) in result.keys():
10491049
result[(128, 128)].append(
10501050
FmhaFwdTileSize(256, 32, 128, 128, 32, 128, 8, 1, 1, 8, 1, 1, 32, 32, 16, 32, 32, 16, -1)) # fmt: skip
1051+
elif dtype in cls._DT_FP8BF16:
1052+
if (128, 128) in result.keys():
1053+
result[(128, 128)].append(
1054+
FmhaFwdTileSize(256, 64, 128, 128, 64, 128, 8, 1, 1, 8, 1, 1, 32, 32, 32, 32, 32, 32, -1)) # fmt: skip
10511055
return result
10521056

10531057
@classmethod
@@ -1085,6 +1089,15 @@ def get_pipelines(
10851089
pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f",
10861090
F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip
10871091

1092+
elif dtype in cls._DT_FP8BF16:
1093+
# no need lse/dropout kernels
1094+
# qr_async_trload_v3 only supports (generic) causal mask
1095+
for logits, qscale, mask in itertools.product(
1096+
["t", "f"],
1097+
["no", "pertensor"],
1098+
["no", "causal"],
1099+
):
1100+
pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f", F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip
10881101
return pipelines
10891102

10901103

include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ struct BlockFmhaFwdV3Pipeline
436436
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
437437
"wrong!");
438438

439-
static_assert(sizeof(SaccDataType) * kM0 * kN0 <= GetSmemSize());
439+
// static_assert(sizeof(SaccDataType) * kM0 * kN0 <= GetSmemSize());
440440
auto s_lds = make_tensor_view<address_space_enum::lds>(
441441
reinterpret_cast<SaccDataType*>(static_cast<char*>(smem_ptr)),
442442
MakeSimpleLdsDesc<kM0, kN0>());

0 commit comments

Comments
 (0)