Skip to content

Commit 941c7e6

Browse files
committed
Add fp8 QK block gemm config
1 parent 5876cd8 commit 941c7e6

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -239,10 +239,18 @@ struct BlockFmhaV3PipelineDefaultPolicy
239239
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
240240
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
241241

242-
constexpr auto warp_gemm = []() {
243-
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
244-
std::is_same_v<typename Problem::KDataType, half_t> &&
242+
constexpr auto warp_gemm = [] {
243+
if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
244+
std::is_same_v<typename Problem::KDataType, fp8_t> &&
245245
std::is_same_v<typename Problem::SaccDataType, float>)
246+
{
247+
constexpr index_t swizzle_factor = 4;
248+
return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<
249+
swizzle_factor>{};
250+
}
251+
else if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
252+
std::is_same_v<typename Problem::KDataType, half_t> &&
253+
std::is_same_v<typename Problem::SaccDataType, float>)
246254
{
247255
/// NOTICE: in order to use load_tile_transpose() later for V tile, we cannot use
248256
/// WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution here

0 commit comments

Comments
 (0)