diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_bf16.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_bf16.cpp index ea0168acccd..017a9ec2c0b 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_bf16.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_bf16.cpp @@ -47,14 +47,14 @@ using DeviceGroupedConvNDFwdInstance = 1, // 256, // BlockSize 128, // MPerBlock - 256, // NPerBlock + 128, // NPerBlock 32, // KPerBlock 8, // AK1 8, // BK1 16, // MPerXdl 16, // NPerXdl 4, // MXdlPerWave - 8, // NXdlPerWave + 4, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp index fab188fef1c..65d435bce12 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp @@ -47,14 +47,14 @@ using DeviceGroupedConvNDFwdInstance = 1, // 256, // BlockSize 128, // MPerBlock - 256, // NPerBlock + 128, // NPerBlock 32, // KPerBlock 8, // AK1 8, // BK1 16, // MPerXdl 16, // NPerXdl 4, // MXdlPerWave - 8, // NXdlPerWave + 4, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp16_comp_fp8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp16_comp_fp8.cpp index baf6c469878..04e203c799a 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp16_comp_fp8.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp16_comp_fp8.cpp @@ -48,14 +48,14 @@ using DeviceGroupedConvNDFwdInstance = 1, // 256, // BlockSize 128, // MPerBlock - 256, // NPerBlock + 128, // NPerBlock 32, // KPerBlock 8, // AK1 8, // BK1 16, // MPerXdl 16, // NPerXdl 4, // MXdlPerWave - 8, // NXdlPerWave + 4, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder diff --git a/example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp b/example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp index 5510c8e001d..b2b236f5a00 100644 --- a/example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp +++ b/example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp @@ -69,14 +69,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD CDEElementOp, GemmDefault, 256, // BlockSize - 256, // MPerBlock + 128, // MPerBlock 128, // NPerBlock 32, // KPerBlock 8, // AK1 8, // BK1 16, // MPerXDL 16, // NPerXDL - 8, // MXdlPerWave + 4, // MXdlPerWave 4, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder diff --git a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp index a80b1055318..2ac4cb7de9e 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp +++ b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp @@ -26,7 +26,7 @@ using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDat // ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < NDimSpatial, OutLayout, WeiLayout, BiasLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, ck::Tuple, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 1, 256, 128, 256, 32, 8, 2, 16, 16, 4, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>; + < NDimSpatial, OutLayout, WeiLayout, BiasLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, ck::Tuple, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 1, 256, 128, 128, 32, 8, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on #include "run_grouped_conv_bwd_data_bias_relu_example.inc" diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8.cpp index b043cd878c6..5dbcd83eab8 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8.cpp @@ -72,7 +72,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl ///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| ///######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| ///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v4>; + < AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1>; // clang-format on int main(int argc, char* argv[]) diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fastgelu_bf16_i8.cpp index 31641955b7d..d850469ca75 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fastgelu_bf16_i8.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fastgelu_bf16_i8.cpp @@ -72,7 +72,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl ///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| ///######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| ///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v4>; + < AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1>; // clang-format on int main(int argc, char* argv[]) diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp index b0b2d29d987..3d2faf973ee 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp @@ -71,7 +71,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl ///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| ///######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| ///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v4>; + < AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1>; // clang-format on int main(int argc, char* argv[]) diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp index e65eed9ffa6..d328889b35b 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp @@ -74,7 +74,14 @@ using AElementOp = PassThrough; using BElementOp = PassThrough; using CDEElementOp = MultiplyMultiply; -static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto PipelineVer = []() { +#if defined(CK_USE_WMMA) + return ck::BlockGemmPipelineVersion::v1; +#else + return ck::BlockGemmPipelineVersion::v3; +#endif +}(); using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 // clang-format off @@ -88,7 +95,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; + ck::BlockGemmPipelineScheduler::Intrawave, PipelineVer, FP8>; // clang-format on int main(int argc, char* argv[]) diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_int8.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_int8.cpp index 077a72b0801..46a2815dbf0 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_int8.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_int8.cpp @@ -114,7 +114,14 @@ using AElementOp = PassThrough; using BElementOp = PassThrough; using CDEElementOp = MultiplyMultiply; -static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto PipelineVer = []() { +#if defined(CK_USE_WMMA) + return ck::BlockGemmPipelineVersion::v1; +#else + return ck::BlockGemmPipelineVersion::v3; +#endif +}(); using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 // clang-format off @@ -135,7 +142,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, I8>; + ck::BlockGemmPipelineScheduler::Intrawave, PipelineVer, I8>; // clang-format on int main(int argc, char* argv[]) diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index 43e9350f8ff..90f915f0586 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -103,6 +103,20 @@ inline bool is_xdl_wmma_supported() } } +template +inline bool is_xdl_wmma_k_supported() +{ + if(is_gfx12_supported()) + { + return (KPerBlock % 16 == 0) && (KPack % 8 == 0); + } + else if(is_gfx11_supported()) + { + return (KPerBlock % 16 == 0) && (KPack % 16 == 0); + } + return true; +} + inline bool is_lds_direct_load_supported() { // Check if direct loads from global memory to LDS are supported. @@ -139,5 +153,22 @@ inline bool is_tf32_supported() return ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"; } +inline int __host__ get_lds_size() +{ + int device = 0; + int result = 0; + auto status = hipGetDevice(&device); + if(status == hipSuccess) + { + status = hipDeviceGetAttribute(&result, hipDeviceAttributeMaxSharedMemoryPerBlock, device); + if(status == hipSuccess) + { + return result; + } + } + + return 64 * 1024; +} + } // namespace ck #endif diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp index 512a019ec80..cd9d6368fa8 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp @@ -61,7 +61,7 @@ struct BlockwiseGemmXdlops_pipeline_base static constexpr index_t BMmaKStride = KPack; static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; - static constexpr index_t KRepeat = KPerThread / KPack; + static constexpr index_t KRepeat = math::max(KPerThread / KPack, 1); static constexpr index_t KPerInnerLoop = KPack; static constexpr index_t KGroup = []() { @@ -97,7 +97,7 @@ struct BlockwiseGemmXdlops_pipeline_base xdlops_gemm.KPerXdlops>; #if defined(__HIP_DEVICE_COMPILE__) - static_assert(KPerThread % KPack == 0, + static_assert(WaveSize != get_warp_size() || (KPerThread % KPack == 0), "Wrong KPack setting; try increasing KPerThread or decreasing KPack"); #endif diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp index 5604a31091e..fda4c8502ed 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp @@ -445,7 +445,7 @@ struct BlockwiseGemmXdlops_pipeline_v1 - using is_tuple = decltype(std::declval().IsTuple()); + using is_tuple = decltype(declval().IsTuple()); template __device__ void RunWrite(const DstDescs& dst_descs, diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp index c3cac6c1d82..f837e71cd54 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp @@ -144,7 +144,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter } template - using is_tuple = decltype(std::declval().IsTuple()); + using is_tuple = decltype(declval().IsTuple()); template __device__ void RunWrite(const DstDescs& dst_descs, diff --git a/include/ck/tensor_operation/gpu/device/device_base.hpp b/include/ck/tensor_operation/gpu/device/device_base.hpp index 361b1167827..759f4e32f3a 100644 --- a/include/ck/tensor_operation/gpu/device/device_base.hpp +++ b/include/ck/tensor_operation/gpu/device/device_base.hpp @@ -142,9 +142,10 @@ static constexpr auto GetWarpTileConfig() if constexpr(IsWave64 == false && NXdlPerWave != 0) { constexpr auto CShuffleNXdlPerWavePerShuffle32 = - NXdlPerWave >= CShuffleNXdlPerWavePerShuffle_ * NPerXDL_ / 16 + (NXdlPerWave >= CShuffleNXdlPerWavePerShuffle_ * NPerXDL_ / 16) && + (NXdlPerWave % (CShuffleNXdlPerWavePerShuffle_ * NPerXDL_ / 16) == 0) ? CShuffleNXdlPerWavePerShuffle_ * NPerXDL_ / 16 - : CShuffleNXdlPerWavePerShuffle_; + : NXdlPerWave; static_assert(CShuffleNXdlPerWavePerShuffle32 > 0); return Sequence<16, 16, diff --git a/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 5f60d8787dc..73c8f23bb32 100644 --- a/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -107,7 +107,7 @@ __device__ void device_grouped_conv_fwd_multiple_abd_xdl_cshuffle( static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); const auto& ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; DsPointer p_ds_grid_grp; @@ -356,13 +356,29 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle CDEElementwiseOperation, ComputeDataType> { - using DeviceOp = CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle; - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - static constexpr bool isMultiA = is_detected::value; - static constexpr bool isMultiB = is_detected::value; + using DeviceOp = CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle; + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr bool isMultiA = is_detected::value; + static constexpr bool isMultiB = is_detected::value; static constexpr index_t NumATensor = GetNumABTensors(); static constexpr index_t NumBTensor = GetNumABTensors(); @@ -446,46 +462,47 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle using GemmADataType = ck::conditional_t, ADataType>; using GemmBDataType = ck::conditional_t, BDataType>; -#define CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_TEMPLATE_PARAMETERS \ - GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ - EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ - InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \ - KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \ - ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \ - ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \ - ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \ - ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \ - BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \ - BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \ - BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \ - CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ - CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ +#define CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_TEMPLATE_PARAMETERS \ + GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ + EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ + InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \ + KPerBlock, AK1, BK1, WarpTileConfig::At(0), WarpTileConfig::At(1), WarpTileConfig::At(2), \ + WarpTileConfig::At(3), ABlockTransferThreadClusterLengths_AK0_M_AK1, \ + ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \ + ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \ + ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \ + BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \ + BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \ + BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \ + BBlockLdsExtraN, WarpTileConfig::At(4), WarpTileConfig::At(5), \ + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ CDEBlockTransferScalarPerVector_NPerBlock, LoopSched #define CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_TEMPLATE_PARAMETERS \ GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ - NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \ - NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, \ + NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, \ + WarpTileConfig::At(0), WarpTileConfig::At(1), WarpTileConfig::At(2), \ + WarpTileConfig::At(3), ABlockTransferThreadClusterLengths_AK0_M_AK1, \ ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \ ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \ ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \ BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \ BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \ BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \ - BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + BBlockLdsExtraN, WarpTileConfig::At(4), WarpTileConfig::At(5), \ CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ CDEBlockTransferScalarPerVector_NPerBlock, LoopSched // Use appropriate gridwise gemm - template + template using GridwiseGemmBase = ck::conditional_t< isMultiA || isMultiB, GridwiseGemmMultipleABD_xdl_cshuffle, GridwiseGemmMultipleD_xdl_cshuffle>; #undef CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_TEMPLATE_PARAMETERS #undef CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_TEMPLATE_PARAMETERS - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // If ADataTypes or BDataTypes is tuple, user has to pass ck::Array with pointers. using APointers = ck::conditional_t&, const void*>; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp index 0efed2aafe5..8e3e3ace982 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp @@ -36,7 +36,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) #endif kernel_contraction_multiple_d_xdl_cshuffle( const FloatAB* __restrict__ p_a_grid, @@ -59,7 +59,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -218,11 +218,27 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle { using DeviceOp = DeviceBatchedContractionMultipleD_Xdl_CShuffle; - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -554,7 +570,7 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle using ComputeDataType = ADataType; // GridwiseGemm - template + template using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle< ADataType, BDataType, @@ -573,10 +589,10 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle KPerBlock, AK1, BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave_, + WarpTileConfig::At(0), + WarpTileConfig::At(1), + WarpTileConfig::At(2), + WarpTileConfig::At(3), ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -593,13 +609,13 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, + WarpTileConfig::At(4), + WarpTileConfig::At(5), CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // desc for blockwise copy using AGridDesc_AK0_M_AK1 = @@ -879,7 +895,12 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_wmma_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp index c3747b195d9..aad583a1b07 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp @@ -179,13 +179,29 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; static constexpr auto matrix_padder = MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; @@ -345,7 +361,7 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute + template using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle< ADataType, BDataType, @@ -368,10 +384,10 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using AGridDesc_AK0_M_AK1 = remove_cvref_t()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp index b930c50e3af..e902ef87281 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp @@ -39,7 +39,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) #endif kernel_gemm_gemm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, @@ -189,10 +189,29 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm(); - static constexpr auto MXdlPerWave32 = - GetXdlPerWave2(); + // GEMM0 NXdlPerWave has limiations on GEMM_GEMM kernel, so we have swap M/N to update + // MXdlPerWave. the M/N order in WarpTileConfig64/32 is also permuted. + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto MXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto MXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -354,7 +373,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm + template using GridwiseGemmBase = GridwiseBatchedGemmGemm_Xdl_CShuffle< ADataType, // TODO: distinguish A/B datatype GemmAccDataType, @@ -380,11 +399,11 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument struct Argument : public BaseArgument @@ -605,7 +625,12 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp index 35e4029b854..4edae0d45d4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp @@ -63,7 +63,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) #endif kernel_batched_gemm_xdl(const ABDataType* __restrict__ p_a_grid, const ABDataType* __restrict__ p_b_grid, @@ -99,7 +99,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; DsPointer p_ds_grid_grp; @@ -200,11 +200,27 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -333,7 +349,7 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD + template using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle< ADataType, // TODO: distinguish A/B datatype BDataType, @@ -352,10 +368,10 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // desc for blockwise copy using AGridDesc_AK0_M_AK1 = @@ -618,7 +634,12 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp index 4410871ac11..4b54d4bc2dc 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp @@ -42,7 +42,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) #endif kernel_batched_gemm_gemm_xdl_cshuffle_v1( const A0B0B1DataType* __restrict__ p_a0_grid, @@ -235,20 +235,26 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle { using DeviceOp = DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle; - static constexpr auto Gemm0MXdlPerWave64 = GetXdlPerWave2(); - static constexpr auto Gemm0MXdlPerWave32 = GetXdlPerWave2(); + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto Gemm0MXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto Gemm0MXdlPerWave32 = WarpTileConfig32.At(3); static constexpr index_t NumD0Tensor = D0sDataType::Size(); static constexpr index_t NumD1Tensor = D1sDataType::Size(); @@ -462,7 +468,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle using E1GridDesc_M_N = decltype(MakeE1GridDescriptor_M_N(1, 1, 1)); // GridwiseGemm - template + template using GridwiseGemmBase = GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< A0DataType, // TODO: distinguish A/B datatype Acc0DataType, @@ -493,11 +499,11 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle A0K1, B0K1, B1K1, - Gemm0MPerXdl, - Gemm0NPerXdl, - Gemm0MXdlPerWave_, - Gemm0NXdlPerWave, - Gemm1NXdlPerWave, + WarpTileConfig::At(1), + WarpTileConfig::At(0), + WarpTileConfig::At(3), + WarpTileConfig::At(2), + Gemm1NXdlPerWave * Gemm0NPerXdl / WarpTileConfig::At(0), A0BlockTransferThreadClusterLengths_AK0_M_AK1, A0BlockTransferThreadClusterArrangeOrder, A0BlockTransferSrcAccessOrder, @@ -525,12 +531,13 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle false, B1BlockLdsExtraN, C1ShuffleMXdlPerWavePerShuffle, - C1ShuffleGemm0NXdlPerWavePerShuffle, + math::min(C1ShuffleGemm0NXdlPerWavePerShuffle* Gemm0NPerXdl / WarpTileConfig::At(0), + Gemm1NXdlPerWave* Gemm0NPerXdl / WarpTileConfig::At(0)), CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDE1ShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using A0GridDesc_AK0_M_AK1 = remove_cvref_t()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp index 5bebd40c1bd..49da389416d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp @@ -33,14 +33,14 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(GridwiseGemm::MaxBlockSize, MinimumOccupancy) #endif kernel_batched_gemm_xdl_cshuffle_v3_multi_d(BatchedGemmArg karg) { #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; const index_t g_idx = blockIdx.z % karg.Batch; const index_t k_idx = blockIdx.z / karg.Batch; @@ -82,7 +82,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(GridwiseGemm::MaxBlockSize, MinimumOccupancy) #endif kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds(BatchedGemmArg karg) { @@ -91,8 +91,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) { // Pass two lds pointer is the key to tell compiler that ds_read/write // operate on different lds chunk at same time without order dependecy - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; const index_t g_idx = blockIdx.z % karg.Batch; const index_t k_idx = blockIdx.z / karg.Batch; @@ -191,16 +191,32 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 BElementwiseOperation, CElementwiseOperation> { - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); static constexpr index_t NumDTensor = DsDataType::Size(); using CDEShuffleBlockTransferScalarPerVectors_ = CDEShuffleBlockTransferScalarPerVectors; using CDataType_ = CDataType; // GridwiseGemm - template + template using GridwiseGemmBase = GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, @@ -222,10 +238,10 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 KPerBlock, AK1, BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave_, + WarpTileConfig::At(0), + WarpTileConfig::At(1), + WarpTileConfig::At(2), + WarpTileConfig::At(3), ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -242,8 +258,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, + WarpTileConfig::At(4), + WarpTileConfig::At(5), CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, @@ -252,8 +268,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 ComputeTypeB, LDSTypeA, LDSTypeB>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using GridwiseGemm = GridwiseGemm64; struct ComputePtrOffsetOfStridedBatch @@ -999,7 +1015,12 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_wmma_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp index aa58276253a..9608b729c4d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp @@ -39,7 +39,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) #endif kernel_batched_gemm_reduce_xdl_cshuffle_v1( const FloatAB* __restrict__ p_a_grid, @@ -81,7 +81,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) p_reduces_grid(In) = p_reduces_grid(In) + d_batch_offset; }); - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; GridwiseGemm::template Run( p_a_grid + a_batch_offset, @@ -176,14 +176,30 @@ template struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperations::Size()> { - using DeviceOp = DeviceBatchedGemmReduce_Xdl_CShuffle; - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; + using DeviceOp = DeviceBatchedGemmReduce_Xdl_CShuffle; + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) { @@ -525,7 +541,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO }; // GridwiseGemm - template + template using GridwiseGemmBase = GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, // TODO: distinguish A/B datatype GemmAccDataType, @@ -552,10 +568,10 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO KPerBlock, AK1, BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave_, + WarpTileConfig::At(0), + WarpTileConfig::At(1), + WarpTileConfig::At(2), + WarpTileConfig::At(3), ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -572,16 +588,16 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, + WarpTileConfig::At(4), + WarpTileConfig::At(5), CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument struct Argument : public BaseArgument @@ -802,7 +818,12 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_wmma_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp index 9ece23985a2..1defcdb84e4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp @@ -43,7 +43,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) #endif kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1( const FloatAB* __restrict__ p_a_grid, @@ -222,10 +222,26 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle C1DEElementwiseOperation, MaskingSpec> { - static constexpr auto MXdlPerWave64 = - GetXdlPerWave2(); - static constexpr auto MXdlPerWave32 = - GetXdlPerWave2(); + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto MXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto MXdlPerWave32 = WarpTileConfig32.At(3); static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, "Number of dimension must be greater than 0"); @@ -386,7 +402,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle }; // GridwiseGemm - template + template using GridwiseGemmBase = GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, // TODO: distinguish A/B datatype GemmAccDataType, @@ -414,11 +430,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle AK1, BK1, B1K1, - MPerXDL, - NPerXDL, - MXdlPerWave_, - NXdlPerWave, - Gemm1NXdlPerWave, + WarpTileConfig::At(1), + WarpTileConfig::At(0), + WarpTileConfig::At(3), + WarpTileConfig::At(2), + Gemm1NXdlPerWave * NPerXDL / WarpTileConfig::At(0), ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -444,15 +460,16 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, + math::min(CShuffleNXdlPerWavePerShuffle* NPerXDL / WarpTileConfig::At(0), + Gemm1NXdlPerWave* NPerXDL / WarpTileConfig::At(0)), CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument // FIXME: constness @@ -738,7 +755,12 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle arg.Print(); } - if(!ck::is_xdl_wmma_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp index 35b2f54f589..41f900d769d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp @@ -44,7 +44,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) #endif kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1( const FloatAB* __restrict__ p_a_grid, @@ -209,11 +209,27 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle MaskOutUpperTriangle> { - using DeviceOp = DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle; - static constexpr auto MXdlPerWave64 = - GetXdlPerWave2(); - static constexpr auto MXdlPerWave32 = - GetXdlPerWave2(); + using DeviceOp = DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle; + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto MXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto MXdlPerWave32 = WarpTileConfig32.At(3); static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -380,7 +396,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle C0MatrixMask_impl>; // GridwiseGemm - template + template using GridwiseGemmBase = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, // TODO: distinguish A/B datatype GemmAccDataType, @@ -406,11 +422,11 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle AK1, BK1, B1K1, - MPerXDL, - NPerXDL, - MXdlPerWave_, - NXdlPerWave, - Gemm1NXdlPerWave, + WarpTileConfig::At(1), + WarpTileConfig::At(0), + WarpTileConfig::At(3), + WarpTileConfig::At(2), + Gemm1NXdlPerWave * NPerXDL / WarpTileConfig::At(0), ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -436,14 +452,15 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, + math::min(CShuffleNXdlPerWavePerShuffle* NPerXDL / WarpTileConfig::At(0), + Gemm1NXdlPerWave* NPerXDL / WarpTileConfig::At(0)), CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; #ifndef __HIPCC_RTC__ // Argument @@ -730,7 +747,12 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle #ifndef __HIPCC_RTC__ static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_wmma_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp index ef5413219b9..054eb249b7b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp @@ -48,7 +48,7 @@ namespace device { template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) #endif kernel_batched_gemm_xdlops_v2r3(const typename DeviceOp::Argument karg) { @@ -67,7 +67,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( static_cast(karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; const auto a_grid_desc_k0_m_k1 = amd_wave_read_first_lane(GridwiseGemm::MakeAGridDescriptor_K0_M_K1( @@ -139,13 +139,29 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm { - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; static constexpr auto K1Number = Number{}; @@ -180,7 +196,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm + template using GridwiseGemmBase = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext< BlockSize, ADataType, // TODO: distinguish A/B datatype @@ -197,11 +213,11 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using Problem = typename GridwiseGemm64::Problem; @@ -325,7 +341,12 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp index 315752d87d5..52dc5b600dd 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp @@ -33,14 +33,14 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(GridwiseGemm::MaxBlockSize, MinimumOccupancy) #endif kernel_batched_gemm_b_scale_xdl_cshuffle_v3(BatchedGemmArg karg) { #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; const index_t g_idx = blockIdx.z % karg.Batch; const index_t k_idx = blockIdx.z / karg.Batch; @@ -83,8 +83,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) { // Pass two lds pointer is the key to tell compiler that ds_read/write // operate on different lds chunk at same time without order dependecy - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; const index_t g_idx = blockIdx.z % karg.Batch; const index_t k_idx = blockIdx.z / karg.Batch; @@ -177,12 +177,27 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale BElementwiseOperation, CElementwiseOperation> { - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - // GridwiseGemm - template + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); // GridwiseGemm + template using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, @@ -204,10 +219,10 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale KPerBlock, AK1, BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave_, + WarpTileConfig::At(0), + WarpTileConfig::At(1), + WarpTileConfig::At(2), + WarpTileConfig::At(3), ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -224,8 +239,8 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, + WarpTileConfig::At(4), + WarpTileConfig::At(5), CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, @@ -234,8 +249,8 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale ComputeTypeB, PermuteA, PermuteB>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; static constexpr index_t APackedSize = []() { if constexpr(is_same_v, pk_i4_t>) @@ -877,7 +892,12 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_wmma_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp index bd25c56afcb..9d7e4e8ad23 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp @@ -74,14 +74,30 @@ template < struct DeviceCGemm_4Gemm_Xdl_CShuffle : public DeviceCGemm { - using DeviceOp = DeviceCGemm_4Gemm_Xdl_CShuffle; - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; + using DeviceOp = DeviceCGemm_4Gemm_Xdl_CShuffle; + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; static constexpr index_t MPerThread = MPerBlock / CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1); @@ -121,7 +137,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle } // GridwiseGemm - template + template using GridwiseGemmBase = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ALayout, BLayout, @@ -143,10 +159,10 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle KPerBlock, AK1, BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave_, + WarpTileConfig::At(0), + WarpTileConfig::At(1), + WarpTileConfig::At(2), + WarpTileConfig::At(3), ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -163,13 +179,13 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, + WarpTileConfig::At(4), + WarpTileConfig::At(5), CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using CGridDesc_M_N = decltype(MakeDescriptor_M_N({1, 1}, {1, 1})); @@ -488,7 +504,12 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_wmma_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp index 59e74dcd6cc..0536c914de0 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp @@ -37,7 +37,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) #endif kernel_contraction_multiple_abd_xdl_cshuffle( AsPointer p_as_grid, @@ -58,7 +58,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; GridwiseGemm::template Run( p_as_grid, @@ -164,13 +164,29 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle { using DeviceOp = DeviceContractionMultipleABD_Xdl_CShuffle; - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - static constexpr index_t NumATensor = AsDataType::Size(); - static constexpr index_t NumBTensor = BsDataType::Size(); - static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -180,7 +196,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle using ComputeDataType = EDataType; // GridwiseGemm - template + template using GridwiseGemmBase = GridwiseGemmMultipleABD_xdl_cshuffle< AsDataType, BsDataType, @@ -200,10 +216,10 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle KPerBlock, AK1, BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave_, + WarpTileConfig::At(0), + WarpTileConfig::At(1), + WarpTileConfig::At(2), + WarpTileConfig::At(3), ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -220,14 +236,14 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, + WarpTileConfig::At(4), + WarpTileConfig::At(5), CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVer>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; static constexpr auto matrix_padder = ck::tensor_operation::device::MatrixPadder{ @@ -633,7 +649,12 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle { using A0DataType = remove_cvref_t>; using B0DataType = remove_cvref_t>; - if(!ck::is_xdl_wmma_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp index fff435f1c25..6bac46b0623 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp @@ -35,7 +35,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) #endif kernel_contraction_multiple_d_xdl_cshuffle( const FloatAB* __restrict__ p_a_grid, @@ -56,7 +56,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; GridwiseGemm::template Run( p_a_grid, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp index ad0e01376d4..3e26b9b9e7e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -72,13 +72,29 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ using DeviceOp = DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - using ADataType = OutDataType; - using BDataType = InDataType; - using CDataType = WeiDataType; + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + using ADataType = OutDataType; + using BDataType = InDataType; + using CDataType = WeiDataType; using AElementwiseOperation = OutElementwiseOperation; using BElementwiseOperation = InElementwiseOperation; @@ -285,7 +301,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ using CGridDesc_M_N = remove_cvref_t; // GridwiseGemm - template + template using GridwiseGemmBase = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< BlockSize, ADataType, // TODO: distinguish A/B datatype @@ -301,11 +317,11 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ MPerBlock, NPerBlock, K0PerBlock, - MPerXdl, - NPerXdl, + WarpTileConfig::At(0), + WarpTileConfig::At(1), K1, - MXdlPerWave, - NXdlPerWave_, + WarpTileConfig::At(2), + WarpTileConfig::At(3), ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -328,16 +344,16 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ BBlockLdsN1PerBlock, BBlockLdsN0PerBlock, BBlockLdsN1Padding, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, + WarpTileConfig::At(4), + WarpTileConfig::At(5), CBlockTransferScalarPerVector_NWaveNPerXdl, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, true, true>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; - template + template using GridwiseGemmAtomicAddBase = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< BlockSize, ADataType, // TODO: distinguish A/B datatype @@ -353,11 +369,11 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ MPerBlock, NPerBlock, K0PerBlock, - MPerXdl, - NPerXdl, + WarpTileConfig::At(0), + WarpTileConfig::At(1), K1, - MXdlPerWave, - NXdlPerWave_, + WarpTileConfig::At(2), + WarpTileConfig::At(3), ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -380,14 +396,14 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ BBlockLdsN1PerBlock, BBlockLdsN0PerBlock, BBlockLdsN1Padding, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, + WarpTileConfig::At(4), + WarpTileConfig::At(5), CBlockTransferScalarPerVector_NWaveNPerXdl, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, true, true>; - using GridwiseGemmAtomicAdd64 = GridwiseGemmAtomicAddBase; - using GridwiseGemmAtomicAdd32 = GridwiseGemmAtomicAddBase; + using GridwiseGemmAtomicAdd64 = GridwiseGemmAtomicAddBase; + using GridwiseGemmAtomicAdd32 = GridwiseGemmAtomicAddBase; // Argument using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = @@ -665,7 +681,12 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_wmma_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp index 9b68c4de438..e64f1f49bfc 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp @@ -69,13 +69,29 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K { using DeviceOp = DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - using ADataType = OutDataType; - using BDataType = WeiDataType; - using CDataType = InDataType; + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + using ADataType = OutDataType; + using BDataType = WeiDataType; + using CDataType = InDataType; // TODO make A/B datatype different using ABDataType = InDataType; @@ -378,7 +394,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K using CGridDesc_M_N = remove_cvref_t; // GridwiseGemm - template + template using GridwiseGemmBase = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< BlockSize, ABDataType, // TODO: distinguish A/B datatype @@ -391,11 +407,11 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K MPerBlock, NPerBlock, K0PerBlock, - MPerXDL, - NPerXDL, + WarpTileConfig::At(0), + WarpTileConfig::At(1), K1, - MXdlPerWave, - NXdlPerWave_, + WarpTileConfig::At(2), + WarpTileConfig::At(3), ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -415,8 +431,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder, 7, // CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument struct Argument : public BaseArgument @@ -628,7 +644,12 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_wmma_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp index 13fa64d1052..6a70edf27a8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp @@ -68,13 +68,29 @@ struct using DeviceOp = DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - using ADataType = InDataType; - using BDataType = WeiDataType; - using CDataType = OutDataType; + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + using ADataType = InDataType; + using BDataType = WeiDataType; + using CDataType = OutDataType; // TODO make A/B datatype different using ABDataType = InDataType; @@ -471,7 +487,7 @@ struct using Block2CTileMap = BlockToCTileMap_M00_N0_M01; // GridwiseGemm - template + template using GridwiseGemmBase = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3< BlockSize, ABDataType, // TODO: distinguish A/B datatype @@ -489,11 +505,11 @@ struct MPerBlock, NPerBlock, K0PerBlock, - MPerXDL, - NPerXDL, + WarpTileConfig::At(0), + WarpTileConfig::At(1), K1, - MXdlPerWave, - NXdlPerWave_, + WarpTileConfig::At(2), + WarpTileConfig::At(3), ABlockTransferThreadClusterLengths_K0_M_K1, Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, @@ -510,12 +526,12 @@ struct BBlockTransferDstScalarPerVector_K1, false, // BThreadTransferSrcResetCoordinateAfterRun, BBlockLdsAddExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, + WarpTileConfig::At(4), + WarpTileConfig::At(5), CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, CBlockTransferScalarPerVector_NWaveNPerXdl>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument struct Argument : public BaseArgument @@ -799,7 +815,12 @@ struct static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_wmma_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp index 959c6e7bc7d..79e737ce1f9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp @@ -69,13 +69,29 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X using DeviceOp = DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - using ADataType = InDataType; - using BDataType = WeiDataType; - using CDataType = OutDataType; + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + using ADataType = InDataType; + using BDataType = WeiDataType; + using CDataType = OutDataType; // TODO make A/B datatype different using ABDataType = InDataType; @@ -452,7 +468,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X using C0GridDesc_M_N = remove_cvref_t; // GridwiseGemm - template + template using GridwiseGemmBase = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2< BlockSize, ABDataType, // TODO: distinguish A/B datatype @@ -469,11 +485,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X MPerBlock, NPerBlock, K0PerBlock, - MPerXDL, - NPerXDL, + WarpTileConfig::At(0), + WarpTileConfig::At(1), K1, - MXdlPerWave, - NXdlPerWave_, + WarpTileConfig::At(2), + WarpTileConfig::At(3), ABlockTransferThreadClusterLengths_K0_M_K1, Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, @@ -490,12 +506,12 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X BBlockTransferDstScalarPerVector_K1, false, // BThreadTransferSrcResetCoordinateAfterRun, BBlockLdsAddExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, + WarpTileConfig::At(4), + WarpTileConfig::At(5), CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, CBlockTransferScalarPerVector_NWaveNPerXdl>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument struct Argument : public BaseArgument @@ -759,7 +775,12 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_wmma_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp index ef62f6477a3..561a28bfecb 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -72,13 +72,29 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W { using DeviceOp = DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - using ADataType = InDataType; - using BDataType = WeiDataType; - using CDataType = OutDataType; + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + using ADataType = InDataType; + using BDataType = WeiDataType; + using CDataType = OutDataType; // TODO make A/B datatype different using ABDataType = InDataType; @@ -437,7 +453,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W using Block2CTileMap = BlockToCTileMap_M00_N0_M01; // GridwiseGemm - template + template using GridwiseGemmBase = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1< BlockSize, ABDataType, // TODO: distinguish A/B datatype @@ -456,10 +472,10 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W K0PerBlock * K1, K1, // AK1 K1, // BK1 - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave_, + WarpTileConfig::At(0), + WarpTileConfig::At(1), + WarpTileConfig::At(2), + WarpTileConfig::At(3), ABlockTransferThreadClusterLengths_K0_M_K1, Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, @@ -476,12 +492,12 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W BBlockTransferDstScalarPerVector_K1, false, // BThreadTransferSrcResetCoordinateAfterRun, BBlockLdsAddExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, + WarpTileConfig::At(4), + WarpTileConfig::At(5), CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, CBlockTransferScalarPerVector_NWaveNPerXdl>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument struct Argument : public BaseArgument @@ -734,7 +750,12 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_wmma_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp index 687e111b0ad..8d8fa1f9cf4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -69,13 +69,29 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K { using DeviceOp = DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - using ADataType = InDataType; - using BDataType = WeiDataType; - using CDataType = OutDataType; + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + using ADataType = InDataType; + using BDataType = WeiDataType; + using CDataType = OutDataType; // TODO make A/B datatype different using ABDataType = InDataType; @@ -328,7 +344,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K using CGridDesc_M_N = remove_cvref_t; // GridwiseGemm - template + template using GridwiseGemmBase = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< BlockSize, ABDataType, // TODO: distinguish A/B datatype @@ -341,11 +357,11 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K MPerBlock, NPerBlock, K0PerBlock, - MPerXDL, - NPerXDL, + WarpTileConfig::At(0), + WarpTileConfig::At(1), K1, - MXdlPerWave, - NXdlPerWave_, + WarpTileConfig::At(2), + WarpTileConfig::At(3), ABlockTransferThreadClusterLengths_K0_M_K1, Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, @@ -365,8 +381,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder, 7, // CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument struct Argument : public BaseArgument @@ -535,7 +551,12 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_wmma_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp index 9245d1fff36..2bb081b808b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp @@ -143,13 +143,29 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ { using DeviceOp = DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K; - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - using ADataType = InDataType; - using BDataType = WeiDataType; - using CDataType = OutDataType; + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + using ADataType = InDataType; + using BDataType = WeiDataType; + using CDataType = OutDataType; // TODO make A/B datatype different using ABDataType = InDataType; @@ -270,7 +286,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ using BGridDesc_K0_N_K1 = remove_cvref_t; using CGridDesc_M_N = remove_cvref_t; - template + template using GridwiseGemmBase = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< BlockSize, InDataType, @@ -286,11 +302,11 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ MPerBlock, NPerBlock, K0PerBlock, - MPerXDL, - NPerXDL, + WarpTileConfig::At(0), + WarpTileConfig::At(1), K1, - MXdlPerWave, - NXdlPerWave_, + WarpTileConfig::At(2), + WarpTileConfig::At(3), ABlockTransferThreadClusterLengths_K0_M_K1, Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, @@ -310,8 +326,8 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ Sequence<2, 3, 0, 1, 7, 5, 4, 6>, 7, CThreadTransferDstScalarPerVector>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); @@ -537,7 +553,12 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_wmma_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp index 022bda3ed04..87a7fcd0792 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp @@ -80,13 +80,29 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl { using DeviceOp = DeviceConvNdBwdDataNwcKxcNwk_Xdl; - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - using ADataType = OutDataType; - using BDataType = WeiDataType; - using CDataType = InDataType; + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + using ADataType = OutDataType; + using BDataType = WeiDataType; + using CDataType = InDataType; // TODO make A/B datatype different using ABDataType = InDataType; @@ -979,7 +995,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl using CGridDesc_M_N = remove_cvref_t; // GridwiseGemm - template + template using GridwiseGemmBase = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< BlockSize, ABDataType, // TODO: distinguish A/B datatype @@ -992,11 +1008,11 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl MPerBlock, NPerBlock, K0PerBlock, - MPerXDL, - NPerXDL, + WarpTileConfig::At(0), + WarpTileConfig::At(1), K1, - MXdlPerWave, - NXdlPerWave_, + WarpTileConfig::At(2), + WarpTileConfig::At(3), ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -1016,8 +1032,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder, 7, // CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument struct Argument : public BaseArgument @@ -1330,7 +1346,12 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_wmma_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_xdl_cshuffle.hpp index 518f8781a46..fd919efedb7 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_xdl_cshuffle.hpp @@ -81,13 +81,29 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO { using DeviceOp = DeviceGemmBiasAddReduce_Xdl_CShuffle; - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) { @@ -387,7 +403,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO using ReduceGridDesc_M = decltype(MakeReduceGridDescriptor_M(1)); // GridwiseGemm - template + template using GridwiseGemmBase = GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, // TODO: distinguish A/B datatype GemmAccDataType, @@ -419,10 +435,10 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO KPerBlock, AK1, BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave_, + WarpTileConfig::At(0), + WarpTileConfig::At(1), + WarpTileConfig::At(2), + WarpTileConfig::At(3), ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -439,16 +455,16 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, + WarpTileConfig::At(4), + WarpTileConfig::At(5), CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument struct Argument : public BaseArgument @@ -674,7 +690,12 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_wmma_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp index 0542a3e78f2..c43a59a283d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp @@ -80,19 +80,35 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD { - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - static constexpr index_t NumATensor = AsDataType::Size(); - static constexpr index_t NumBTensor = BsDataType::Size(); - static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); using ALayout = remove_cvref_t>; using BLayout = remove_cvref_t>; // GridwiseGemm - template + template using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, @@ -113,10 +129,10 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using Argument = typename GridwiseGemm64::Argument; @@ -597,7 +613,12 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp index d4257c56d24..b46c3c5770f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -38,7 +38,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) #endif kernel_gemm_multiple_d_multiple_r_xdl_cshuffle( const FloatAB* __restrict__ p_a_grid, @@ -63,7 +63,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; GridwiseGemm::template Run( p_a_grid, @@ -189,12 +189,28 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle { using DeviceOp = DeviceGemmMultipleDMultipleR_Xdl_CShuffle; - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - static constexpr index_t NumDTensor = DsDataType::Size(); - static constexpr index_t NumRTensor = RsDataType::Size(); + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr index_t NumRTensor = RsDataType::Size(); static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -290,7 +306,7 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle using RGridDesc_M = decltype(MakeRGridDescriptor_M(1)); // GridwiseGemm - template + template using GridwiseGemmBase = GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, // TODO: distinguish A/B datatype GemmAccDataType, @@ -318,10 +334,10 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle KPerBlock, AK1, BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave_, + WarpTileConfig::At(0), + WarpTileConfig::At(1), + WarpTileConfig::At(2), + WarpTileConfig::At(3), ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -338,14 +354,14 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, + WarpTileConfig::At(4), + WarpTileConfig::At(5), CDRThreadTransferClusterLengths_MPerBlock_NPerBlock, CDEReduceThreadTransferScalarPerVector_NPerBlock, RThreadTransferDstScalarPerVector_MPerBlock, LoopSched>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using AGridDesc_AK0_M_AK1 = remove_cvref_t()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp index 714a0420bd6..835ad4e938e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp @@ -37,7 +37,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) #endif kernel_gemm_multiple_d_xdl_cshuffle(const ADataType* __restrict__ p_a_grid, const BDataType* __restrict__ p_b_grid, @@ -57,7 +57,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; GridwiseGemm::template Run( p_a_grid, @@ -163,12 +163,28 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD { - using DeviceOp = DeviceGemmMultipleD_Xdl_CShuffle; - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - static constexpr index_t NumDTensor = DsDataType::Size(); + using DeviceOp = DeviceGemmMultipleD_Xdl_CShuffle; + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -253,7 +269,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD(1, 1, 1)); // GridwiseGemm - template + template using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle< ADataType, BDataType, @@ -272,10 +288,10 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // desc for blockwise copy using AGridDesc_AK0_M_AK1 = @@ -583,7 +599,12 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD()) + if(!ck::is_xdl_wmma_supported()) { return false; } @@ -899,7 +920,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD; - __shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char + p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; if(desc.has_main_k_block_loop) { GridwiseGemm::template Run( diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp index 83e16d1c3de..d27a6fc011f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp @@ -74,13 +74,30 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad BElementwiseOperation, CDEElementwiseOperation> { - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - static constexpr auto I1 = Number<1>{}; - static constexpr index_t NumDTensor = DsDataType::Size(); - - template + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr auto I1 = Number<1>{}; + static constexpr index_t NumDTensor = DsDataType::Size(); + + template using GridwiseGemmBase = GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad< ALayout, BLayout, @@ -105,10 +122,10 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad KPerBlock, AK1, BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave_, + WarpTileConfig::At(0), + WarpTileConfig::At(1), + WarpTileConfig::At(2), + WarpTileConfig::At(3), ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, @@ -119,14 +136,14 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad BBlockTransferSrcVectorDim, BBlockTransferScalarPerVector, BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, + WarpTileConfig::At(4), + WarpTileConfig::At(5), CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVer>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using Argument = typename GridwiseGemm64::Argument; @@ -210,7 +227,12 @@ struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_wmma_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp index 34b4f1c80e0..45c0f5e96b1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp @@ -81,13 +81,28 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK { - static constexpr index_t NumDTensor = DsDataType::Size(); - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - // GridwiseGemm - template + static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); // GridwiseGemm + template using GridwiseGemmBase = GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, @@ -109,10 +124,10 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using Argument = typename GridwiseGemm64::Argument; // Invoker @@ -654,7 +669,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp index 7d53ca00357..cdfdb2abca9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp @@ -92,13 +92,30 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 BElementwiseOperation, CElementwiseOperation> { - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr index_t NumDTensor = DsDataType::Size(); // GridwiseGemm - template + template using GridwiseGemmBase = GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3< ALayout, BLayout, @@ -123,10 +140,10 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 KPerBlock, AK1, BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave_, + WarpTileConfig::At(0), + WarpTileConfig::At(1), + WarpTileConfig::At(2), + WarpTileConfig::At(3), ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -143,8 +160,8 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - math::min(CShuffleNXdlPerWavePerShuffle, NXdlPerWave_), + WarpTileConfig::At(4), + WarpTileConfig::At(5), CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, @@ -153,8 +170,8 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 ComputeTypeB, LDSTypeA, LDSTypeB>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using Argument = typename GridwiseGemm64::Argument; @@ -352,17 +369,23 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 arg.KBatch = KBatch; if(get_warp_size() == 64) { - arg.KRead = GridwiseGemm64::CalculateKRead(arg.K, KBatch); - arg.KPadded = GridwiseGemm64::CalculateKPadded(arg.K, KBatch); - arg.AK0 = GridwiseGemm64::CalculateAK0Padded(arg.K, KBatch); - arg.BK0 = GridwiseGemm64::CalculateBK0Padded(arg.K, KBatch); + if constexpr(NXdlPerWave64 > 0) + { + arg.KRead = GridwiseGemm64::CalculateKRead(arg.K, KBatch); + arg.KPadded = GridwiseGemm64::CalculateKPadded(arg.K, KBatch); + arg.AK0 = GridwiseGemm64::CalculateAK0Padded(arg.K, KBatch); + arg.BK0 = GridwiseGemm64::CalculateBK0Padded(arg.K, KBatch); + } } else { - arg.KRead = GridwiseGemm32::CalculateKRead(arg.K, KBatch); - arg.KPadded = GridwiseGemm32::CalculateKPadded(arg.K, KBatch); - arg.AK0 = GridwiseGemm32::CalculateAK0Padded(arg.K, KBatch); - arg.BK0 = GridwiseGemm32::CalculateBK0Padded(arg.K, KBatch); + if constexpr(NXdlPerWave32 > 0) + { + arg.KRead = GridwiseGemm32::CalculateKRead(arg.K, KBatch); + arg.KPadded = GridwiseGemm32::CalculateKPadded(arg.K, KBatch); + arg.AK0 = GridwiseGemm32::CalculateAK0Padded(arg.K, KBatch); + arg.BK0 = GridwiseGemm32::CalculateBK0Padded(arg.K, KBatch); + } } } @@ -381,7 +404,12 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 return false; } - if(!ck::is_xdl_wmma_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp index c2b713db627..16ca01b3f02 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp @@ -82,13 +82,30 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle BElementwiseOperation, CElementwiseOperation> { - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr index_t NumDTensor = DsDataType::Size(); // GridwiseGemm - template + template using GridwiseGemmBase = GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle< ALayout, BLayout, @@ -110,10 +127,10 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle KPerBlock, AK1, BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave_, + WarpTileConfig::At(0), + WarpTileConfig::At(1), + WarpTileConfig::At(2), + WarpTileConfig::At(3), ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -130,8 +147,8 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, + WarpTileConfig::At(4), + WarpTileConfig::At(5), CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, @@ -140,11 +157,14 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle ComputeTypeB, LDSTypeA, LDSTypeB>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using Argument = typename GridwiseGemm64::Argument; - int GetPreShuffleParameters() override { return NPerXDL; } + int GetPreShuffleParameters() override + { + return get_warp_size() == 64 ? WarpTileConfig64.At(1) : WarpTileConfig32.At(1); + } // Invoker struct Invoker : public BaseInvoker @@ -522,7 +542,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_wmma_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp index 8f1a0f3b3de..c2982dfa952 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp @@ -92,13 +92,28 @@ struct DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle BElementwiseOperation, CElementwiseOperation> { - static constexpr index_t NumDTensor = DsDataType::Size(); - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - // GridwiseGemm - template + static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); // GridwiseGemm + template using GridwiseGemmBase = GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle< ALayout, BLayout, @@ -123,10 +138,10 @@ struct DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle KPerBlock, AK1, BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave_, + WarpTileConfig::At(0), + WarpTileConfig::At(1), + WarpTileConfig::At(2), + WarpTileConfig::At(3), ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -143,8 +158,8 @@ struct DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - math::min(CShuffleNXdlPerWavePerShuffle, NXdlPerWave_), + WarpTileConfig::At(4), + WarpTileConfig::At(5), CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, @@ -153,12 +168,15 @@ struct DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle ComputeTypeB, LDSTypeA, LDSTypeB>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using Argument = typename GridwiseGemm64::Argument; - int GetPreShuffleParameters() override { return NPerXDL; } + int GetPreShuffleParameters() override + { + return get_warp_size() == 64 ? WarpTileConfig64.At(1) : WarpTileConfig32.At(1); + } // Invoker struct Invoker : public BaseInvoker @@ -348,7 +366,12 @@ struct DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_wmma_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_xdl_cshuffle.hpp index ddf6b8ae5ac..c5678c04681 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_xdl_cshuffle.hpp @@ -78,13 +78,29 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio { using DeviceOp = DeviceGemmReduce_Xdl_CShuffle; - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) { @@ -382,7 +398,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio using ReduceGridDesc_M = decltype(MakeReduceGridDescriptor_M(1)); // GridwiseGemm - template + template using GridwiseGemmBase = GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, // TODO: distinguish A/B datatype GemmAccDataType, @@ -409,10 +425,10 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio KPerBlock, AK1, BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave_, + WarpTileConfig::At(0), + WarpTileConfig::At(1), + WarpTileConfig::At(2), + WarpTileConfig::At(3), ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -429,16 +445,16 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, + WarpTileConfig::At(4), + WarpTileConfig::At(5), CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument struct Argument : public BaseArgument @@ -640,7 +656,12 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_wmma_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp index d9d06d8bd5c..59c1c4281a4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp @@ -69,18 +69,34 @@ struct DeviceGemmXdl : public DeviceGemm { - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; static constexpr auto K1Number = Number{}; // GridwiseGemm - template + template using GridwiseGemmBase = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext< BlockSize, ADataType, // TODO: distinguish A/B datatype @@ -97,11 +113,11 @@ struct DeviceGemmXdl : public DeviceGemm; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using Argument = typename GridwiseGemm64::Argument; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp index 921c477a5a9..86ba1443cfa 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp @@ -80,16 +80,32 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; // GridwiseGemm - template + template using GridwiseGemmBase = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ALayout, BLayout, @@ -111,10 +127,10 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using Argument = typename GridwiseGemm64::Argument; @@ -204,7 +220,12 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp index 7715183b66a..cfd0ca8e365 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp @@ -69,13 +69,29 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm { - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - static constexpr auto I1 = Number<1>{}; - - template + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr auto I1 = Number<1>{}; + + template using GridwiseGemmBase = GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad< ALayout, BLayout, @@ -100,10 +116,10 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using Argument = typename GridwiseGemm64::Argument; @@ -207,7 +223,12 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp index 20984c9f3fa..c07c626f188 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp @@ -75,12 +75,27 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2 { - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - // GridwiseGemm - template + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); // GridwiseGemm + template using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_streamk_v3< ALayout, BLayout, @@ -100,10 +115,10 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using Argument = typename GridwiseGemm64::Argument; // @@ -481,7 +496,12 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp index e6e1a0c03d8..1da1da3f715 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp @@ -78,17 +78,33 @@ struct DeviceGemm_Xdl_CShuffleV2 : public DeviceGemm { - using DeviceOp = DeviceGemm_Xdl_CShuffleV2; - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; + using DeviceOp = DeviceGemm_Xdl_CShuffleV2; + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; // GridwiseGemm - template + template using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_v2< ALayout, BLayout, @@ -110,10 +126,10 @@ struct DeviceGemm_Xdl_CShuffleV2 : public DeviceGemm; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using Argument = typename GridwiseGemm64::Argument; @@ -200,7 +216,12 @@ struct DeviceGemm_Xdl_CShuffleV2 : public DeviceGemm()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp index cc59ad741dd..06046226270 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp @@ -177,11 +177,27 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 { // GridwiseGemm - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - template + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + template using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, @@ -201,10 +217,10 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using Argument = typename GridwiseGemm64::Argument; @@ -759,7 +775,12 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp index b78c44cca87..0bd576a8dea 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp @@ -77,12 +77,27 @@ struct DeviceGemm_Xdl_CShuffleV3_BPreshuffle : public DeviceGemmV2BPreshuffle { - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - // GridwiseGemm - template + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); // GridwiseGemm + template using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_v3_b_preshuffle< ALayout, BLayout, @@ -102,10 +117,10 @@ struct DeviceGemm_Xdl_CShuffleV3_BPreshuffle : public DeviceGemmV2BPreshuffle; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using Argument = typename GridwiseGemm64::Argument; @@ -151,7 +166,10 @@ struct DeviceGemm_Xdl_CShuffleV3_BPreshuffle : public DeviceGemmV2BPreshuffle()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp index 359eb629805..6b1264ae630 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp @@ -84,12 +84,27 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale { - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - // GridwiseGemm - template + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); // GridwiseGemm + template using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, @@ -111,10 +126,10 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using Argument = typename GridwiseGemm64::Argument; @@ -648,7 +663,12 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp index f97fa05b641..ca0d55b0140 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp @@ -164,12 +164,27 @@ struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX { - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - // GridwiseGemm - template + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); // GridwiseGemm + template using GridwiseGemmMXBase = GridwiseGemmMX_xdl_cshuffle_v3< ALayout, BLayout, @@ -192,10 +207,10 @@ struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX; - template + template using GridwiseGemmMXBPreshuffleBase = GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< ALayout, BLayout, @@ -243,10 +258,10 @@ struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX, - GridwiseGemmMXBase, - GridwiseGemmMXBPreshuffleBase>; + GridwiseGemmMXBase, + GridwiseGemmMXBPreshuffleBase>; using GridwiseGemm32 = conditional_t< // !is_same_v, - GridwiseGemmMXBase, - GridwiseGemmMXBPreshuffleBase>; + GridwiseGemmMXBase, + GridwiseGemmMXBPreshuffleBase>; using Argument = typename GridwiseGemm64::Argument; @@ -403,14 +418,29 @@ struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX 1) && tail_num_choice.value == tail_num) { - const auto kernel = kernel_gemm_xdl_cshuffle_v3_mx< // - Use2LDS, - GridwiseGemm, - mainloop_choice.value, - CGlobalMemoryDataOperation, - minimum_occupancy, - tail_num_choice.value>; - Run(kernel); + if constexpr(is_same_v) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_mx_bpreshuffle< // + Use2LDS, + GridwiseGemm, + mainloop_choice.value, + CGlobalMemoryDataOperation, + minimum_occupancy, + tail_num_choice.value>; + Run(kernel); + return; + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_mx< // + Use2LDS, + GridwiseGemm, + mainloop_choice.value, + CGlobalMemoryDataOperation, + minimum_occupancy, + tail_num_choice.value>; + Run(kernel); + } } }); return ave_time; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3r1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3r1.hpp index 98f7a42e8f2..ad003101b64 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3r1.hpp @@ -85,16 +85,32 @@ struct DeviceGemm_Xdl_CShuffleV3R1 : public DeviceGemmV2R1 { - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr index_t NumDTensor = DsDataType::Size(); using PassThrough = ck::tensor_operation::element_wise::PassThrough; // GridwiseGemm - template + template using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, @@ -114,10 +130,10 @@ struct DeviceGemm_Xdl_CShuffleV3R1 : public DeviceGemmV2R1; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; struct Argument : public GridwiseGemm64::Argument { @@ -568,7 +584,12 @@ struct DeviceGemm_Xdl_CShuffleV3R1 : public DeviceGemmV2R1()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp index eb61c3cbd97..5dd1f1975c3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp @@ -81,13 +81,29 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator { using DeviceOp = DeviceGemmLayerNorm_Xdl_CShuffle; - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) { @@ -384,7 +400,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator using C0GridDesc_N = decltype(MakeGridDescriptor_N(1)); // GridwiseGemm - template + template using GridwiseGemmBase = GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, // TODO: distinguish A/B datatype GemmAccDataType, @@ -408,10 +424,10 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator KPerBlock, AK1, BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave_, + WarpTileConfig::At(0), + WarpTileConfig::At(1), + WarpTileConfig::At(2), + WarpTileConfig::At(3), ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -428,15 +444,15 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - NXdlPerWave_, + WarpTileConfig::At(4), + WarpTileConfig::At(5), CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadCopySrcDstScalarPerVector_NPerBlock, LoopSched>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using Block2CTileMap = typename GridwiseGemm64::DefaultBlock2CTileMap; @@ -643,7 +659,12 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_wmma_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp index c105b98c3c5..a492b54f99a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp @@ -63,13 +63,29 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm { - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; static constexpr auto K1Number = Number{}; static_assert(BBlockBufferSize >= 2); @@ -192,7 +208,7 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm + template using GridwiseGemmBase = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1< BlockSize, ADataType, // TODO: distinguish A/B datatype @@ -208,11 +224,11 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm, // CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument struct Argument : public BaseArgument @@ -405,7 +421,12 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp index d2b2cded7df..fa328e8a8dc 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp @@ -75,14 +75,30 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK { - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; // TODO: should be exposed as Tparams. static constexpr index_t NumGemmKPrefetchStage = 1; @@ -90,7 +106,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK + template using GridwiseGemmBase = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, ADataType, @@ -108,11 +124,11 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; struct Argument : public GridwiseGemm64::Argument { @@ -338,7 +354,12 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp index 366e7fe5c6c..802c6793cc0 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp @@ -70,16 +70,32 @@ struct DeviceGemmXdlSplitKCShuffle_LdsDirectLoad : public DeviceGemmSplitK { - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - - template + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + template using GridwiseGemmBase = GridwiseGemm_xdlops_splitk_lds_direct_load< BlockSize, ADataType, @@ -97,11 +113,11 @@ struct DeviceGemmXdlSplitKCShuffle_LdsDirectLoad : public DeviceGemmSplitK; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; struct Argument : public GridwiseGemm64::Argument { @@ -306,7 +322,12 @@ struct DeviceGemmXdlSplitKCShuffle_LdsDirectLoad : public DeviceGemmSplitK()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp index a9b319ed682..170c1e3678d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp @@ -68,16 +68,32 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK { - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - - template + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + template using GridwiseGemmBase = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk< BlockSize, BlockToCTileMap_GemmStreamK; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using Argument = typename GridwiseGemm64::Argument; @@ -268,7 +284,12 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK()) + if(!ck::is_xdl_wmma_supported()) { return false; } @@ -323,7 +344,7 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK()) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; GridwiseGemm::template Run(p_a_grid, p_b_grid, @@ -136,11 +136,27 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm { static constexpr auto BlockSize = math::max(TileLoadThreadGroupSize, TileMathThreadGroupSize); - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - using DeviceOp = DeviceGemm_Xdl_WaveletModel_CShuffle; + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + using DeviceOp = DeviceGemm_Xdl_WaveletModel_CShuffle; static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -209,7 +225,7 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm(1, 1, 1)); // GridwiseGemm - template + template using GridwiseGemmBase = GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle< ADataType, // TODO: distinguish A/B datatype GemmAcEDataType, @@ -230,10 +246,10 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using AGridDesc_AK0_M_AK1 = remove_cvref_t()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp index 95497ad359a..7ebfebd3445 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp @@ -28,7 +28,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) #endif kernel_grouped_contraction_multiple_d_xdl_cshuffle( const void CK_CONSTANT_ADDRESS_SPACE* contraction_args, @@ -40,7 +40,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; const index_t block_id = get_block_1d_id(); @@ -168,10 +168,27 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle { using DeviceOp = DeviceGroupedContractionMultipleD_Xdl_CShuffle; - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -363,7 +380,7 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle using ComputeDataType = ADataType; // GridwiseGemm - template + template using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle< ADataType, // TODO: distinguish A/B datatype BDataType, @@ -382,10 +399,10 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle KPerBlock, AK1, BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave_, + WarpTileConfig::At(0), + WarpTileConfig::At(1), + WarpTileConfig::At(2), + WarpTileConfig::At(3), ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -402,13 +419,13 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, + WarpTileConfig::At(4), + WarpTileConfig::At(5), CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // desc for blockwise copy using AGridDesc_AK0_M_AK1 = @@ -742,7 +759,12 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_wmma_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index d33e807828a..5e2dfc598ed 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -80,7 +80,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) #endif kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle( const ABDataType* __restrict__ p_a_grid, @@ -124,7 +124,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const long_index_t e_n_offset = amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; DsPointer p_ds_grid_grp; @@ -318,12 +318,28 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ? 1 : 32; - using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1; - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - static constexpr index_t NumDTensor = DsDataType::Size(); + using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1; + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr GemmSpecialization GemmSpec = GemmSpecialization::MNKPadding; static constexpr bool IsSplitKSupported = (CDEBlockTransferScalarPerVector_NPerBlock % 2 == 0 || sizeof(EDataType) % 4 == 0) && @@ -439,53 +455,55 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 } // GridwiseGemm -#define CK_GRIDWISE_GEMM_BWD_DATA_MULTIPLE_D_TEMPLATE_PARAMETERS \ - ABDataType, ABDataType, AComputeType, AccDataType, CShuffleDataType, DsDataType, EDataType, \ - AElementwiseOp, BElementwiseOp, CDEElementwiseOp, NumGemmKPrefetchStage, BlockSize, \ - MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, \ - ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \ - ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \ - ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \ - ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \ - BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \ - BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \ - BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \ - CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ - CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ +#define CK_GRIDWISE_GEMM_BWD_DATA_MULTIPLE_D_TEMPLATE_PARAMETERS \ + ABDataType, ABDataType, AComputeType, AccDataType, CShuffleDataType, DsDataType, EDataType, \ + AElementwiseOp, BElementwiseOp, CDEElementwiseOp, NumGemmKPrefetchStage, BlockSize, \ + MPerBlock, NPerBlock, KPerBlock, AK1, BK1, WarpTileConfig::At(0), WarpTileConfig::At(1), \ + WarpTileConfig::At(2), WarpTileConfig::At(3), \ + ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \ + ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \ + ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \ + ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \ + BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \ + BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \ + BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, WarpTileConfig::At(4), \ + WarpTileConfig::At(5), CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, BComputeType -#define CK_GRIDWISE_GEMM_BWD_DATA_CTRANSPOSE_TEMPLATE_PARAMETERS \ - ABDataType, ABDataType, AComputeType, AccDataType, CShuffleDataType, DsDataType, EDataType, \ - BElementwiseOp, AElementwiseOp, CDEElementwiseOp, NumGemmKPrefetchStage, BlockSize, \ - NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerXDL, MPerXDL, NXdlPerWave_, MXdlPerWave, \ - BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \ - BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \ - BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \ - BBlockLdsExtraN, ABlockTransferThreadClusterLengths_AK0_M_AK1, \ - ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \ - ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \ - ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \ - CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ - CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ +#define CK_GRIDWISE_GEMM_BWD_DATA_CTRANSPOSE_TEMPLATE_PARAMETERS \ + ABDataType, ABDataType, AComputeType, AccDataType, CShuffleDataType, DsDataType, EDataType, \ + BElementwiseOp, AElementwiseOp, CDEElementwiseOp, NumGemmKPrefetchStage, BlockSize, \ + NPerBlock, MPerBlock, KPerBlock, BK1, AK1, WarpTileConfig::At(1), WarpTileConfig::At(0), \ + WarpTileConfig::At(3), WarpTileConfig::At(2), \ + BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \ + BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \ + BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \ + BBlockLdsExtraN, ABlockTransferThreadClusterLengths_AK0_M_AK1, \ + ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \ + ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \ + ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, WarpTileConfig::At(5), \ + WarpTileConfig::At(4), CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, BComputeType - template + template using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle< CK_GRIDWISE_GEMM_BWD_DATA_MULTIPLE_D_TEMPLATE_PARAMETERS>; - template + template using GridwiseGemmCTransposeBase = GridwiseGemmMultipleD_xdl_cshuffle< CK_GRIDWISE_GEMM_BWD_DATA_CTRANSPOSE_TEMPLATE_PARAMETERS>; #undef CK_GRIDWISE_GEMM_BWD_DATA_MULTIPLE_D_TEMPLATE_PARAMETERS #undef CK_GRIDWISE_GEMM_BWD_DATA_CTRANSPOSE_TEMPLATE_PARAMETERS - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using GridwiseGemmCTranspose64 = std::conditional_t, + GridwiseGemmCTransposeBase, GridwiseGemm64>; using GridwiseGemmCTranspose32 = - std::conditional_t, GridwiseGemm32>; + std::conditional_t, + GridwiseGemm32>; template static auto @@ -1493,7 +1511,20 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 { return false; } - if(!ck::is_xdl_wmma_supported()) + if(!ck::is_xdl_wmma_supported()) + { + return false; + } + if(!is_xdl_wmma_k_supported()) + { + return false; + } + if(!is_xdl_wmma_k_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index 1807dc1d9fe..9357c62cfcf 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -50,7 +50,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) #endif kernel_batched_gemm_xdlops_bwd_weight(const FloatA* __restrict__ p_a_grid, const FloatB* __restrict__ p_b_grid, @@ -174,14 +174,30 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle ComputeTypeA, ComputeTypeB> { - using DeviceOp = DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle; - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - using ADataType = OutDataType; - using BDataType = InDataType; - using EDataType = WeiDataType; + using DeviceOp = DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle; + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + using ADataType = OutDataType; + using BDataType = InDataType; + using EDataType = WeiDataType; static constexpr index_t NumDTensor = DsLayout::Size(); @@ -310,7 +326,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle using BGridDesc_K0_N_K1 = remove_cvref_t; using CGridDesc_M_N = remove_cvref_t; - template + template using GridwiseGemmBase = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< BlockSize, ADataType, @@ -327,11 +343,11 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle MPerBlock, NPerBlock, K0PerBlock, - MPerXDL, - NPerXDL, + WarpTileConfig::At(0), + WarpTileConfig::At(1), K1, - MXdlPerWave, - NXdlPerWave_, + WarpTileConfig::At(2), + WarpTileConfig::At(3), ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -354,8 +370,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle BBlockLdsN1PerBlock, BBlockLdsN0PerBlock, BBlockLdsN1Padding, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, + WarpTileConfig::At(4), + WarpTileConfig::At(5), WorkspaceInOutScalarPerVector, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, true, @@ -364,8 +380,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle PipelineVersion::v1, ComputeTypeA, ComputeTypeB>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; static constexpr auto MakeElementwiseInputSequence() { @@ -952,7 +968,12 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle return false; } #endif - if(!ck::is_xdl_wmma_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index 211496b3ff0..9b1712ccd24 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -82,7 +82,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const long_index_t e_batch_offset = amd_wave_read_first_lane( static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; DispatchSplitKHack); static_assert(is_same_v); - using DeviceOp = DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle; - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - using ADataType = OutDataType; - using BDataType = InDataType; - using EDataType = WeiDataType; + using DeviceOp = DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle; + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + using ADataType = OutDataType; + using BDataType = InDataType; + using EDataType = WeiDataType; // If NGCHW then ADataType must be equal to BDataType static_assert(!(is_NGCHW_NGKHW() || @@ -432,7 +448,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle using CElementwiseGridDesc_M_N = remove_cvref_t())>; - template + template using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_conv_v3< tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, @@ -452,10 +468,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle KPerBlock, K1, K1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave_, + WarpTileConfig::At(0), + WarpTileConfig::At(1), + WarpTileConfig::At(2), + WarpTileConfig::At(3), ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -472,16 +488,16 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, + WarpTileConfig::At(4), + WarpTileConfig::At(5), CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CBlockTransferScalarPerVector_NWaveNPerXdl, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt; @@ -1888,7 +1904,12 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle } return false; } - if(!ck::is_xdl_wmma_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 976b6f1ef87..c7e7636be1d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -117,7 +117,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) #endif kernel_batched_gemm_xdlops_bwd_weight(const FloatA* __restrict__ p_a_grid, const FloatB* __restrict__ p_b_grid, @@ -262,14 +262,30 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ComputeTypeA, ComputeTypeB> { - using DeviceOp = DeviceGroupedConvBwdWeight_Xdl_CShuffle; - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - using ADataType = OutDataType; - using BDataType = InDataType; - using CDataType = WeiDataType; + using DeviceOp = DeviceGroupedConvBwdWeight_Xdl_CShuffle; + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + using ADataType = OutDataType; + using BDataType = InDataType; + using CDataType = WeiDataType; // If NGCHW then ADataType must be equal to BDataType static_assert(!(is_NGCHW_NGKHW() || @@ -466,7 +482,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle I1, I0>; - template + template using GridwiseGemmBase = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< BlockSize, ADataType, @@ -483,11 +499,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle MPerBlock, NPerBlock, K0PerBlock, - MPerXDL, - NPerXDL, + WarpTileConfig::At(0), + WarpTileConfig::At(1), K1, - MXdlPerWave, - NXdlPerWave_, + WarpTileConfig::At(2), + WarpTileConfig::At(3), ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -510,8 +526,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle BBlockLdsN1PerBlock, BBlockLdsN0PerBlock, BBlockLdsN1Padding, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, + WarpTileConfig::At(4), + WarpTileConfig::At(5), CBlockTransferScalarPerVector_NWaveNPerXdl, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, true, @@ -520,8 +536,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle PipelineVersion::v1, ComputeTypeA, ComputeTypeB>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = @@ -1089,7 +1105,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle return false; } #endif - if(!ck::is_xdl_wmma_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index 2121be00d17..d45c4489d1a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -80,7 +80,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const long_index_t e_batch_offset = amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; DispatchSplitKHack); static_assert(is_same_v); - using DeviceOp = DeviceGroupedConvBwdWeight_Xdl_CShuffleV3; - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - using ADataType = OutDataType; - using BDataType = InDataType; - using CDataType = WeiDataType; + using DeviceOp = DeviceGroupedConvBwdWeight_Xdl_CShuffleV3; + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + using ADataType = OutDataType; + using BDataType = InDataType; + using CDataType = WeiDataType; using AElementwiseOperation = OutElementwiseOperation; using BElementwiseOperation = InElementwiseOperation; @@ -371,7 +387,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 using BGridDesc_K0_N_K1 = remove_cvref_t; using CGridDesc_M_N = remove_cvref_t; - template + template using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_conv_v3< tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, @@ -391,10 +407,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 K0PerBlock, K1, K1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave_, + WarpTileConfig::At(0), + WarpTileConfig::At(1), + WarpTileConfig::At(2), + WarpTileConfig::At(3), ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -411,16 +427,16 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, + WarpTileConfig::At(4), + WarpTileConfig::At(5), CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CBlockTransferScalarPerVector_NWaveNPerXdl, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Argument using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = @@ -1432,7 +1448,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 } } - if(!ck::is_xdl_wmma_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 7cb0ae20c39..91d201725b9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -85,7 +85,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) #endif kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle( AsPointer p_as_grid, @@ -120,7 +120,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const long_index_t e_n_offset = amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; DsPointer p_ds_grid_grp; @@ -1459,7 +1459,14 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle { return false; } - + if(!is_xdl_wmma_k_supported()) + { + return false; + } + if(!is_xdl_wmma_k_supported()) + { + return false; + } // check ConvolutionForwardSpecialization if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index f94311bc353..d8cb5f4a8cf 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -116,7 +116,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const long_index_t e_n_offset = amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; using Block2CTileMap = typename GridwiseGemm::Block2CTileMapDefault; const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4}; @@ -238,8 +238,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // Pass two lds pointer is the key to tell compiler that ds_read/write // operate on different lds chunk at same time without order dependecy - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; using Block2CTileMap = typename GridwiseGemm::Block2CTileMapDefault; const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4}; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp index a98738909e8..2597f28b905 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -173,7 +173,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto rs_batch_offset = compute_ptr_offset_of_batch.GetRsPtrOffset(g_idx); - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; DsPointer p_ds_grid_grp; @@ -303,13 +303,29 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle RsElementwiseOperation, QsElementwiseOperation> { - using DeviceOp = DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle; - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - static constexpr index_t NumDTensor = DsDataType::Size(); - static constexpr index_t NumRTensor = RsDataType::Size(); + using DeviceOp = DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle; + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr index_t NumRTensor = RsDataType::Size(); static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -436,7 +452,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle using RGridDesc_M = remove_cvref_t({}, {}))>; // GridwiseGemm - template + template using GridwiseGemmBase = GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, // TODO: distinguish A/B datatype AccDataType, @@ -464,10 +480,10 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle KPerBlock, AK1, BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave_, + WarpTileConfig::At(0), + WarpTileConfig::At(1), + WarpTileConfig::At(2), + WarpTileConfig::At(3), ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -484,14 +500,14 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, + WarpTileConfig::At(4), + WarpTileConfig::At(5), CDRThreadTransferClusterLengths_MPerBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, RThreadTransferDstScalarPerVector_MPerBlock, LoopSched>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using AGridDesc_AK0_M_AK1 = remove_cvref_t()) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; const index_t block_id_x = __builtin_amdgcn_readfirstlane(blockIdx.x); const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp index 897773f768e..9912a05617a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp @@ -36,7 +36,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) #endif kernel_grouped_gemm_xdl_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, const index_t group_count, @@ -48,7 +48,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; const index_t KBatch = 1; @@ -205,14 +205,30 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK BElementwiseOperation, CDEElementwiseOperation> { - using DeviceOp = DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK; - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - static constexpr index_t NumATensor = AsDataType::Size(); - static constexpr index_t NumBTensor = BsDataType::Size(); - static constexpr index_t NumDTensor = DsDataType::Size(); + using DeviceOp = DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK; + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -221,7 +237,7 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK static constexpr index_t NumGemmKPrefetchStage = 1; // GridwiseGemm - template + template using GridwiseGemmBase = GridwiseGemmMultipleABD_xdl_cshuffle< AsDataType, BsDataType, @@ -241,10 +257,10 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK KPerBlock, AK1, BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave_, + WarpTileConfig::At(0), + WarpTileConfig::At(1), + WarpTileConfig::At(2), + WarpTileConfig::At(3), ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -261,13 +277,13 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, + WarpTileConfig::At(4), + WarpTileConfig::At(5), CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; template struct OffsettedBlockToCTileMapMLoops { @@ -688,6 +704,15 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK static bool IsSupportedArgument(const Argument& arg) { + if(!ck::is_xdl_wmma_supported()) + { + return false; + } // Split-K autodeduction is not supported if(arg.k_batch_ < 1) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp index aeb6cd6d3fa..22dafcc5085 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp @@ -91,12 +91,28 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage BElementwiseOperation, CDEElementwiseOperation> { - using DeviceOp = DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage; - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - static constexpr index_t NumDTensor = DsDataType::Size(); + using DeviceOp = DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage; + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -109,7 +125,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage using WorkspaceDataType = float; // First stage GridwiseGEMM kernel. - template + template using GridwiseGemmBase = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, ADataType, @@ -127,11 +143,11 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage MPerBlock, NPerBlock, K0PerBlock, - MPerXDL, - NPerXDL, + WarpTileConfig::At(0), + WarpTileConfig::At(1), AK1, - MXdlPerWave, - NXdlPerWave_, + WarpTileConfig::At(2), + WarpTileConfig::At(3), ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -148,15 +164,15 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage BBlockTransferDstScalarPerVector_BK1, false, // BThreadTransferSrcResetCoordinateAfterRun, BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, + WarpTileConfig::At(4), + WarpTileConfig::At(5), CDEShuffleBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, LoopSched, PipelineVer, ComputeDataType>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // Use gemm_padder for consistent descriptor creation static constexpr auto gemm_padder = @@ -254,14 +270,15 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap; using GemmKernelArgument = typename GridwiseGemm64::Argument; + template struct GemmTransKernelArg { - GemmKernelArgument karg_; + typename GridwiseGemm::Argument karg_; GroupedGemmBlock2ETileMap block_2_ctile_map_; index_t block_start_, block_end_; GemmTransKernelArg() = default; - GemmTransKernelArg(GemmKernelArgument&& karg, + GemmTransKernelArg(typename GridwiseGemm::Argument&& karg, GroupedGemmBlock2ETileMap&& b2c_map, index_t block_start, index_t block_end) @@ -546,7 +563,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage std::vector>& p_Ds_; std::vector> stride_Ds_; - std::vector gemm_kernel_args_; + std::vector> gemm_kernel_args_; std::vector group_grid_size_; std::vector elementwise_c_grid_descs_m_n_; @@ -742,7 +759,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage { const auto gemm_kernel = kernel_grouped_gemm_xdl_splitk, HasMainKBlockLoop, InMemoryDataOperationEnum::AtomicAdd, AElementwiseOperation, @@ -756,15 +773,15 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage ck::Tuple, Block2TileMap, CDEElementwiseOperation>; - return LaunchKernel(gemm_kernel, - elementwise_kernel, - arg, - dev_gemm_kargs, - dev_gemm_workspace, - stream_config); + return LaunchKernel(gemm_kernel, + elementwise_kernel, + arg, + dev_gemm_kargs, + dev_gemm_workspace, + stream_config); } - template + template float LaunchKernel(const KernelFunction& gemm_kernel, const KernelFunction2& elementwise_kernel, const Argument& arg, @@ -774,12 +791,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage { float time{0.f}; - hip_check_error( - hipMemcpyAsync(dev_gemm_kargs, - arg.gemm_kernel_args_.data(), - arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg), - hipMemcpyHostToDevice, - stream_config.stream_id_)); + hip_check_error(hipMemcpyAsync(dev_gemm_kargs, + arg.gemm_kernel_args_.data(), + arg.gemm_kernel_args_.size() * + sizeof(GemmTransKernelArg), + hipMemcpyHostToDevice, + stream_config.stream_id_)); auto preprocess = [&]() { hip_check_error(hipMemsetAsync( @@ -831,7 +848,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_wmma_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } @@ -1026,7 +1048,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage auto arg = dynamic_cast(p_arg); if(arg) { - return arg->gemm_kernel_args_.size() * sizeof(GemmTransKernelArg); + return arg->gemm_kernel_args_.size() * sizeof(GemmTransKernelArg); } else throw std::runtime_error( diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp index 4492e6474fc..7b86ae7cf0b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp @@ -60,7 +60,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) #endif kernel_grouped_gemm_multiple_d_xdl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, const index_t group_count, @@ -71,7 +71,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) { - constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); + constexpr index_t shared_size = + GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch()); __shared__ uint8_t p_shared[shared_size]; __shared__ uint8_t p_shared1[shared_size]; @@ -469,13 +470,30 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop BElementwiseOperation, CDEElementwiseOperation> { - using DeviceOp = DeviceGroupedGemmMultipleDXdlCShuffleTileLoop; - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - static constexpr index_t NumDTensor = DsDataType::Size(); - - template + using DeviceOp = DeviceGroupedGemmMultipleDXdlCShuffleTileLoop; + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr index_t NumDTensor = DsDataType::Size(); + + template using GridwiseGemmBase = GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, @@ -497,10 +515,10 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop KPerBlock, AK1, BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave_, + WarpTileConfig::At(0), + WarpTileConfig::At(1), + WarpTileConfig::At(2), + WarpTileConfig::At(3), ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -517,16 +535,16 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop BBlockTransferDstScalarPerVector_BK1, false, // BThreadTransferSrcResetCoordinateAfterRun, BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, + WarpTileConfig::At(4), + WarpTileConfig::At(5), CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using KernelArguments = GroupedGemmKernelArgument; using Block2ETileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; @@ -748,7 +766,12 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_wmma_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp index b6c2030dee3..abc89fbd62c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp @@ -205,10 +205,26 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle CElementwiseOperation, MaskingSpec> { - static constexpr auto MXdlPerWave64 = - GetXdlPerWave2(); - static constexpr auto MXdlPerWave32 = - GetXdlPerWave2(); + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto MXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto MXdlPerWave32 = WarpTileConfig32.At(3); static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, "Number of dimension must be greater than 0"); @@ -350,7 +366,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle }; // GridwiseGemm - template + template using GridwiseGemmBase = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, // TODO: distinguish A/B datatype GemmAccDataType, @@ -376,11 +392,11 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle AK1, BK1, B1K1, - MPerXDL, - NPerXDL, - MXdlPerWave_, - NXdlPerWave, - Gemm1NXdlPerWave, + WarpTileConfig::At(1), + WarpTileConfig::At(0), + WarpTileConfig::At(3), + WarpTileConfig::At(2), + Gemm1NXdlPerWave * NPerXDL / WarpTileConfig::At(0), ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -406,14 +422,15 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, + math::min(CShuffleNXdlPerWavePerShuffle* NPerXDL / WarpTileConfig::At(0), + Gemm1NXdlPerWave* NPerXDL / WarpTileConfig::At(0)), CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using Block2CTileMap = OffsettedBlockToCTileMap; @@ -715,7 +732,12 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_wmma_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp index 1083da3e7f8..64544ca5433 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp @@ -30,7 +30,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) #endif kernel_grouped_gemm_xdl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, const index_t group_count, @@ -41,7 +41,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; const index_t block_id = get_block_1d_id(); @@ -148,11 +148,28 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm { - using DeviceOp = DeviceGroupedGemm_Xdl; - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - static constexpr index_t NumDTensor = DsDataType::Size(); + using DeviceOp = DeviceGroupedGemm_Xdl; + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -235,7 +252,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm(1, 1, 1)); // GridwiseGemm - template + template using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle< ADataType, // TODO: distinguish A/B datatype BDataType, @@ -254,10 +271,10 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using AGridDesc_AK0_M_AK1 = remove_cvref_t()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp index 71f2d737e62..30cfc7756c3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp @@ -38,7 +38,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) #endif kernel_grouped_gemm_xdl_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, uint32_t* barrier_count, @@ -53,7 +53,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; const index_t block_id = get_block_1d_id(); @@ -243,12 +243,28 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK { - using DeviceOp = DeviceGroupedGemm_Xdl_Fixed_NK; - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - static constexpr index_t NumDTensor = DsDataType::Size(); + using DeviceOp = DeviceGroupedGemm_Xdl_Fixed_NK; + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -258,7 +274,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK + template using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_splitk_cshuffle< ADataType, // TODO: distinguish A/B datatype BDataType, @@ -278,10 +294,10 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; template struct OffsettedBlockToCTileMapMLoops @@ -843,6 +859,15 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK()) + { + return false; + } if(ck::type_convert(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_) { return false; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp index 1db9fd45b8c..737071addb4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp @@ -33,7 +33,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) #endif kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, const index_t group_count, @@ -44,7 +44,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { - constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); + constexpr index_t shared_size = + GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch()); __shared__ uint8_t p_shared[shared_size]; const index_t block_id = get_block_1d_id(); @@ -147,10 +148,27 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK { - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -159,7 +177,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK + template using GridwiseGemmBase = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, ADataType, @@ -177,11 +195,11 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using CGridDesc_M_N = typename GridwiseGemm64::CGridDesc_M_N; using Block2ETileMapKSplit = @@ -591,7 +609,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp index a53cfb3ed0d..aa09f61b03a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp @@ -87,11 +87,28 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle { - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - static constexpr index_t NumDTensor = DsDataType::Size(); - template + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr index_t NumDTensor = DsDataType::Size(); + template using GridwiseGemmBase = GridwiseMoeGemm; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using Argument = typename GridwiseGemm64::Argument; @@ -168,7 +185,10 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp index df7179efe50..d4bec0832dc 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp @@ -99,11 +99,28 @@ struct DeviceMoeGemmBlockScale BElementwiseOperation, CElementwiseOperation> { - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - static constexpr index_t NumDTensor = DsDataType::Size(); - template + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr index_t NumDTensor = DsDataType::Size(); + template using GridwiseGemmBase = GridwiseMoeGemmBlockScale< ALayout, BLayout, @@ -128,10 +145,10 @@ struct DeviceMoeGemmBlockScale KPerBlock, AK1, BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave_, + WarpTileConfig::At(0), + WarpTileConfig::At(1), + WarpTileConfig::At(2), + WarpTileConfig::At(3), ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -148,8 +165,8 @@ struct DeviceMoeGemmBlockScale BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - math::min(CShuffleNXdlPerWavePerShuffle, NXdlPerWave_), + WarpTileConfig::At(4), + WarpTileConfig::At(5), CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, @@ -164,8 +181,8 @@ struct DeviceMoeGemmBlockScale ComputeTypeB, LDSTypeA, LDSTypeB>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using Argument = typename GridwiseGemm64::Argument; @@ -183,7 +200,10 @@ struct DeviceMoeGemmBlockScale return 1; }(); - int GetPreShuffleParameters() override { return NPerXDL; } + int GetPreShuffleParameters() override + { + return get_warp_size() == 64 ? WarpTileConfig64.At(1) : WarpTileConfig32.At(1); + } // Invoker struct Invoker : public BaseInvoker @@ -426,7 +446,12 @@ struct DeviceMoeGemmBlockScale { return false; } - if(!ck::is_xdl_wmma_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm.hpp index 7f22fc72cfc..314914c10d5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm.hpp @@ -90,11 +90,28 @@ struct DeviceMoeGemmMX : public DeviceMoEGemmMXBPreShuffle { - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - static constexpr index_t NumDTensor = DsDataType::Size(); - template + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr index_t NumDTensor = DsDataType::Size(); + template using GridwiseGemmBase = GridwiseMoeGemmMX; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using Argument = typename GridwiseGemm64::Argument; static constexpr index_t APackedSize = packed_size_v; static constexpr index_t BPackedSize = packed_size_v; - int GetPreShuffleParameters() override { return NPerXDL; } + int GetPreShuffleParameters() override + { + return get_warp_size() == 64 ? WarpTileConfig64.At(1) : WarpTileConfig32.At(1); + } // Invoker struct Invoker : public BaseInvoker @@ -368,7 +388,12 @@ struct DeviceMoeGemmMX : public DeviceMoEGemmMXBPreShuffle()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bns.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bns.hpp index d81e73630b2..8d6ff3807af 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bns.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bns.hpp @@ -90,11 +90,28 @@ struct DeviceMoeGemmMXBNS : public DeviceMoEGemmMXBPreShuffle { - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - static constexpr index_t NumDTensor = DsDataType::Size(); - template + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr index_t NumDTensor = DsDataType::Size(); + template using GridwiseGemmBase = GridwiseMoeGemmMXBNS; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using Argument = typename GridwiseGemm64::Argument; static constexpr index_t APackedSize = packed_size_v; static constexpr index_t BPackedSize = packed_size_v; - int GetPreShuffleParameters() override { return NPerXDL; } + int GetPreShuffleParameters() override + { + return get_warp_size() == 64 ? WarpTileConfig64.At(1) : WarpTileConfig32.At(1); + } // Invoker struct Invoker : public BaseInvoker @@ -367,7 +387,12 @@ struct DeviceMoeGemmMXBNS : public DeviceMoEGemmMXBPreShuffle()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp index e64970145bd..c4e6a60484f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp @@ -90,11 +90,28 @@ struct DeviceMoeGemmMXBPreShuffle : public DeviceMoEGemmMXBPreShuffle { - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - static constexpr index_t NumDTensor = DsDataType::Size(); - template + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr index_t NumDTensor = DsDataType::Size(); + template using GridwiseGemmBase = GridwiseMoeGemmMX_BPreshuffle< ALayout, BLayout, @@ -119,10 +136,10 @@ struct DeviceMoeGemmMXBPreShuffle : public DeviceMoEGemmMXBPreShuffle; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; using Argument = typename GridwiseGemm64::Argument; static constexpr index_t APackedSize = packed_size_v; static constexpr index_t BPackedSize = packed_size_v; - int GetPreShuffleParameters() override { return NPerXDL; } + int GetPreShuffleParameters() override + { + return get_warp_size() == 64 ? WarpTileConfig64.At(1) : WarpTileConfig32.At(1); + } // Invoker struct Invoker : public BaseInvoker @@ -394,7 +414,12 @@ struct DeviceMoeGemmMXBPreShuffle : public DeviceMoEGemmMXBPreShuffle()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp index c413befd80d..a7b7974f9ac 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp @@ -36,7 +36,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) #endif kernel_contraction_multiple_d_xdl_cshuffle( const FloatAB* __restrict__ p_a_grid, @@ -59,7 +59,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -194,11 +194,28 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle BElementwiseOperation, CDEElementwiseOperation> { - using DeviceOp = DeviceSplitKContractionMultipleD_Xdl_CShuffle; - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - static constexpr index_t NumDTensor = DsDataType::Size(); + using DeviceOp = DeviceSplitKContractionMultipleD_Xdl_CShuffle; + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -528,7 +545,7 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle }; // GridwiseGemm - template + template using GridwiseGemmBase = GridwiseGemmSplitKMultipleD_xdl_cshuffle< ADataType, // TODO: distinguish A/B datatype AccDataType, @@ -550,10 +567,10 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle KPerBlock, AK1, BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave_, + WarpTileConfig::At(0), + WarpTileConfig::At(1), + WarpTileConfig::At(2), + WarpTileConfig::At(3), ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -570,16 +587,16 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, + WarpTileConfig::At(4), + WarpTileConfig::At(5), CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // GridwiseGemm - template + template using GridwiseGemmAtomicAddBase = GridwiseGemmSplitKMultipleD_xdl_cshuffle< ADataType, // TODO: distinguish A/B datatype AccDataType, @@ -601,10 +618,10 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle KPerBlock, AK1, BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave_, + WarpTileConfig::At(0), + WarpTileConfig::At(1), + WarpTileConfig::At(2), + WarpTileConfig::At(3), ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -621,13 +638,13 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, + WarpTileConfig::At(4), + WarpTileConfig::At(5), CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched>; - using GridwiseGemmAtomicAdd64 = GridwiseGemmAtomicAddBase; - using GridwiseGemmAtomicAdd32 = GridwiseGemmAtomicAddBase; + using GridwiseGemmAtomicAdd64 = GridwiseGemmAtomicAddBase; + using GridwiseGemmAtomicAdd32 = GridwiseGemmAtomicAddBase; using AGridDesc_AKB_AK0_M_AK1 = remove_cvref_t()) + if(!ck::is_xdl_wmma_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp index da731ead2f2..dff40da080e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp @@ -14,6 +14,7 @@ #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp" namespace ck { @@ -77,30 +78,167 @@ template struct GridwiseBatchedGemmGemm_Xdl_CShuffle + : public GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + FloatAB, + FloatAB, + FloatGemmAcc, + FloatCShuffle, + Tuple<>, + FloatC, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + Gemm1NPerBlock, + Gemm1KPerBlock, + AK1Value, + B1K1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + Gemm1NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + B1BlockTransferThreadClusterLengths_BK0_N_BK1, + B1BlockTransferThreadClusterArrangeOrder, + B1BlockTransferSrcAccessOrder, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_BK1, + B1ThreadTransferSrcResetCoordinateAfterRun, + B1BlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + FloatAB, + FloatAB, + true> // ForceNaiveLayout { + using Base = GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + FloatAB, + FloatAB, + FloatGemmAcc, + FloatCShuffle, + Tuple<>, + FloatC, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + Gemm1NPerBlock, + Gemm1KPerBlock, + AK1Value, + B1K1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + Gemm1NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + B1BlockTransferThreadClusterLengths_BK0_N_BK1, + B1BlockTransferThreadClusterArrangeOrder, + B1BlockTransferSrcAccessOrder, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_BK1, + B1ThreadTransferSrcResetCoordinateAfterRun, + B1BlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + FloatAB, + FloatAB, + true>; // ForceNaiveLayout + + using Base0 = GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + FloatAB, + FloatAB, + FloatGemmAcc, + FloatCShuffle, + Tuple<>, + FloatC, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, // ignored + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + FloatAB, + FloatAB, + true>; // ForceNaiveLayout + + using Base::I0; + using Base::I1; + using Base::I2; + using Base::I3; + using Base::I4; + using Base::I5; + using Base::I6; + using Base::I7; + + using ThisThreadBlock = typename Base::ThisThreadBlock; + static_assert(LoopSched == LoopScheduler::Default, "Non-default loop scheduler is currently not supported"); - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; - // K1 should be Number<...> // Gemm0 - static constexpr auto AK0 = Number{}; - static constexpr auto BK0 = Number{}; - static constexpr auto AK1 = Number{}; - static constexpr auto BK1 = Number{}; + static constexpr auto AK0 = Base0::AK0Number; + static constexpr auto BK0 = Base0::BK0Number; + static constexpr auto AK1 = Base0::AK1Number; + static constexpr auto BK1 = Base0::BK1Number; // Gemm1 - static constexpr auto B1K0 = Number{}; - static constexpr auto B1K1 = Number{}; - - using ThisThreadBlock = ThisThreadBlock; + static constexpr auto B1K0 = Base::BK0Number; + static constexpr auto B1K1 = Base::BK1Number; using GridwiseGemmPipe = remove_cvref_t< decltype(GridwiseGemmPipeline_Selector())>; @@ -141,46 +279,6 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle BBlockDesc_BK0_N_BK1{}); } - __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() - { - // A matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(AK0, Number{}, AK1), - make_tuple(Number{} * AK1, AK1, I1)); - } - - __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() - { - // B matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(BK0, Number{}, BK1), - make_tuple(Number{} * BK1, BK1, I1)); - } - - __host__ __device__ static constexpr auto GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1() - { - // B1 matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(B1K0, Number{}, B1K1), - make_tuple(Number{} * B1K1, B1K1, I1)); - } - - __host__ __device__ static constexpr auto - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - I1, - Number{})); - - return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; - } - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned + @@ -319,11 +417,11 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle { // LDS allocation for A and B: be careful of alignment static constexpr auto a_block_desc_ak0_m_ak1 = - GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + Base0::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); static constexpr auto b_block_desc_bk0_n_bk1 = - GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + Base0::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); static constexpr auto b1_block_desc_bk0_n_bk1 = - GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), B1K1); @@ -340,7 +438,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle // LDS allocation for C shuffle in LDS static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + Base::GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(get_device_arch()); static constexpr auto c_block_space_size = c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); }; @@ -369,8 +467,6 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); const auto b1_grid_buf = make_dynamic_buffer( p_b1_grid, b1_grid_desc_bk0_n_bk1.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); // divide block work by [M, N] const auto block_work_idx = @@ -392,10 +488,12 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle __builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock); // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + Base0::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto b_block_desc_bk0_n_bk1 = + Base0::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); // // set up Gemm0 @@ -608,7 +706,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle n4>{acc_element_op}; #endif // B1 matrix in LDS memory, dst of blockwise copy - constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto b1_block_desc_bk0_n_bk1 = + Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); // B1 matrix blockwise copy auto b1_blockwise_copy = @@ -789,201 +888,15 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle } while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop // shuffle C and write out - { - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - Gemm1NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl); - - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - gemm1_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - gemm1_blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - tensor_operation::element_wise::PassThrough{}}; - - // shuffle: blockwise copy C from LDS to global - auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< - ThisThreadBlock, // ThreadGroup - CElementwiseOperation, // ElementwiseOperation, - CGlobalMemoryDataOperation, // DstInMemOp, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - FloatCShuffle, // typename SrcData, - FloatC, // typename DstData, - decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, - true, // bool ThreadTransferSrcResetCoordinateAfterRun, - false> // bool ThreadTransferDstResetCoordinateAfterRun> - {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(0, 0, 0, 0), - c_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), - c_element_op}; - - // space filling curve for threadwise C in VGPR - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; - - // space filling curve for shuffled blockwise C in global mem - constexpr auto sfc_c_global = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); - - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - c_shuffle_block_copy_lds_to_global.Run( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - c_shuffle_block_buf, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_buf); - - if constexpr(access_id < num_access - 1) - { - constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); - - // move on C - c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); - } - }); - } + Base::template RunEpilogue( + gemm1_blockwise_gemm, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_work_idx[I0], + block_work_idx[I1], + p_shared, + p_c_grid, + c_element_op); } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp index a3c27c95557..9590c18fb76 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp @@ -14,6 +14,7 @@ #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp" namespace ck { @@ -83,22 +84,160 @@ template struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle + : public GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + A0B0B1DataType, + A0B0B1DataType, + Acc1DataType, + C1ShuffleDataType, + D1sDataType, + E1DataType, + A0ElementwiseOperation, + B1ElementwiseOperation, + BlockSize, + Gemm0MPerBlock, + Gemm1NPerBlock, + Gemm1KPerBlock, + A0K1Value, + B1K1Value, + Gemm0MPerXdl, + Gemm0NPerXdl, + Gemm0MXdlPerWave, + Gemm1NXdlPerWave, + A0BlockTransferThreadClusterLengths_AK0_M_AK1, + A0BlockTransferThreadClusterArrangeOrder, + A0BlockTransferSrcAccessOrder, + A0BlockTransferSrcVectorDim, + A0BlockTransferSrcScalarPerVector, + A0BlockTransferDstScalarPerVector_AK1, + A0ThreadTransferSrcResetCoordinateAfterRun, // ignored + A0BlockLdsExtraM, + B1BlockTransferThreadClusterLengths_BK0_N_BK1, + B1BlockTransferThreadClusterArrangeOrder, + B1BlockTransferSrcAccessOrder, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_BK1, + B1ThreadTransferSrcResetCoordinateAfterRun, + B1BlockLdsExtraN, + C1ShuffleGemm0MXdlPerWavePerShuffle, + C1ShuffleGemm0NXdlPerWavePerShuffle, + CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + A0B0B1DataType, + A0B0B1DataType, + true> // ForceNaiveLayout { + using Base = GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + A0B0B1DataType, + A0B0B1DataType, + Acc1DataType, + C1ShuffleDataType, + D1sDataType, + E1DataType, + A0ElementwiseOperation, + B1ElementwiseOperation, + BlockSize, + Gemm0MPerBlock, + Gemm1NPerBlock, + Gemm1KPerBlock, + A0K1Value, + B1K1Value, + Gemm0MPerXdl, + Gemm0NPerXdl, + Gemm0MXdlPerWave, + Gemm1NXdlPerWave, + A0BlockTransferThreadClusterLengths_AK0_M_AK1, + A0BlockTransferThreadClusterArrangeOrder, + A0BlockTransferSrcAccessOrder, + A0BlockTransferSrcVectorDim, + A0BlockTransferSrcScalarPerVector, + A0BlockTransferDstScalarPerVector_AK1, + A0ThreadTransferSrcResetCoordinateAfterRun, // ignored + A0BlockLdsExtraM, + B1BlockTransferThreadClusterLengths_BK0_N_BK1, + B1BlockTransferThreadClusterArrangeOrder, + B1BlockTransferSrcAccessOrder, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_BK1, + B1ThreadTransferSrcResetCoordinateAfterRun, + B1BlockLdsExtraN, + C1ShuffleGemm0MXdlPerWavePerShuffle, + C1ShuffleGemm0NXdlPerWavePerShuffle, + CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + A0B0B1DataType, + A0B0B1DataType, + true>; // ForceNaiveLayout + + using Base0 = GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + A0B0B1DataType, + A0B0B1DataType, + Acc0DataType, + C1ShuffleDataType, + Tuple<>, + E1DataType, + A0ElementwiseOperation, + B0ElementwiseOperation, + BlockSize, + Gemm0MPerBlock, + Gemm0NPerBlock, + Gemm0KPerBlock, + A0K1Value, + B0K1Value, + Gemm0MPerXdl, + Gemm0NPerXdl, + Gemm0MXdlPerWave, + Gemm0NXdlPerWave, + A0BlockTransferThreadClusterLengths_AK0_M_AK1, + A0BlockTransferThreadClusterArrangeOrder, + A0BlockTransferSrcAccessOrder, + A0BlockTransferSrcVectorDim, + A0BlockTransferSrcScalarPerVector, + A0BlockTransferDstScalarPerVector_AK1, + A0ThreadTransferSrcResetCoordinateAfterRun, // ignored + A0BlockLdsExtraM, + B0BlockTransferThreadClusterLengths_BK0_N_BK1, + B0BlockTransferThreadClusterArrangeOrder, + B0BlockTransferSrcAccessOrder, + B0BlockTransferSrcVectorDim, + B0BlockTransferSrcScalarPerVector, + B0BlockTransferDstScalarPerVector_BK1, + B0ThreadTransferSrcResetCoordinateAfterRun, // ignored + B0BlockLdsExtraN, + C1ShuffleGemm0MXdlPerWavePerShuffle, + C1ShuffleGemm0NXdlPerWavePerShuffle, + CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + A0B0B1DataType, + A0B0B1DataType, + true>; // ForceNaiveLayout + + using Base::I0; + using Base::I1; + using Base::I2; + using Base::I3; + using Base::I4; + using Base::I5; + using Base::I6; + using Base::I7; + using ThisThreadBlock = typename Base::ThisThreadBlock; + static_assert(LoopSched == LoopScheduler::Default, "Non-default loop scheduler is currently not supported"); static constexpr index_t NumD0Tensor = D0sDataType::Size(); static constexpr index_t NumD1Tensor = D1sDataType::Size(); - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; - // K1 should be Number<...> // Gemm0 static constexpr auto A0K1 = Number{}; @@ -114,8 +253,6 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle static constexpr auto B1K1 = Number{}; static constexpr auto B1K0PerBlock = Number{}; - using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = GridwiseGemmPipeline_v1; // ck::Tuple @@ -201,46 +338,6 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle BBlockDesc_BK0_N_BK1{}); } - __host__ __device__ static constexpr auto GetA0BlockDescriptor_AK0PerBlock_MPerBlock_AK1() - { - // A0 matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(A0K0PerBlock, Number{}, A0K1), - make_tuple(Number{} * A0K1, A0K1, I1)); - } - - __host__ __device__ static constexpr auto GetB0BlockDescriptor_BK0PerBlock_NPerBlock_BK1() - { - // B0 matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(B0K0PerBlock, Number{}, B0K1), - make_tuple(Number{} * B0K1, B0K1, I1)); - } - - __host__ __device__ static constexpr auto GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1() - { - // B1 matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(B1K0PerBlock, Number{}, B1K1), - make_tuple(Number{} * B1K1, B1K1, I1)); - } - - __host__ __device__ static constexpr auto - GetC1ShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() - { - constexpr index_t MWave = Gemm0MPerBlock / (Gemm0MXdlPerWave * Gemm0MPerXdl); - constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * Gemm0NPerXdl); - - constexpr auto c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - I1, - Number{})); - - return c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock; - } - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { const index_t gemm0_bytes_end = (SharedMemTrait::a0_block_space_size_aligned + @@ -524,11 +621,11 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle { // LDS allocation for A0 and B0: be careful of alignment static constexpr auto a0_block_desc_ak0_m_ak1 = - GetA0BlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + Base0::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); static constexpr auto b0_block_desc_bk0_n_bk1 = - GetB0BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + Base0::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); static constexpr auto b1_block_desc_bk0_n_bk1 = - GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); static constexpr auto max_lds_align = math::lcm(math::lcm(A0K1, B0K1), B1K1); @@ -545,7 +642,7 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle // LDS allocation for C1 shuffle in LDS static constexpr auto c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetC1ShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + Base::GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(get_device_arch()); static constexpr auto c1_block_space_size = c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); }; @@ -588,8 +685,6 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle p_b0_grid, b0_grid_desc_bk0_n_bk1.GetElementSpaceSize()); const auto b1_grid_buf = make_dynamic_buffer( p_b1_grid, b1_grid_desc_bk0_n_bk1.GetElementSpaceSize()); - auto e1_grid_buf = make_dynamic_buffer( - p_e1_grid, e1_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); const auto d0s_grid_buf = generate_tuple( [&](auto i) { return make_dynamic_buffer( @@ -597,13 +692,6 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i].GetElementSpaceSize()); }, Number{}); - const auto d1s_grid_buf = generate_tuple( - [&](auto i) { - return make_dynamic_buffer( - p_d1s_grid[i], - d1s_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); - }, - Number{}); // divide block work by [M, N] const auto block_work_idx = @@ -625,10 +713,12 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle __builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock); // A0 matrix in LDS memory, dst of blockwise copy - constexpr auto a0_block_desc_ak0_m_ak1 = GetA0BlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a0_block_desc_ak0_m_ak1 = + Base0::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B0 matrix in LDS memory, dst of blockwise copy - constexpr auto b0_block_desc_bk0_n_bk1 = GetB0BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto b0_block_desc_bk0_n_bk1 = + Base0::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); // // set up Gemm0 @@ -908,7 +998,8 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle #endif // B1 matrix in LDS memory, dst of blockwise copy - constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto b1_block_desc_bk0_n_bk1 = + Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); // B1 matrix blockwise copy auto b1_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1( - static_cast(p_shared), - c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (Gemm0MXdlPerWave) per - // shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = Gemm0MPerXdl - M3, - M4)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (Gemm0NXdlPerWave) per - // shuffle - N1, // N1 = NWave - N2))), // N2 = Gemm0NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM C1 matrix starting index - const auto c1_thread_mtx_on_block = - blockwise_gemm1.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c1_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c1_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c1_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - tensor_operation::element_wise::PassThrough{}}; - - // tuple of reference to C/Ds tensor descriptors - const auto c1_d1s_desc_refs = concat_tuple_of_reference( - tie(c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return d1s_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); - - // tuple of reference to C/Ds tensor descriptors - const auto c1_d1s_buf_refs = concat_tuple_of_reference( - tie(c1_shuffle_block_buf), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return d1s_grid_buf[i]; }, - Number{})); - - // tuple of starting index of C/Ds blockwise copy - const auto idx_c1_d1s_block_begin = container_concat( - make_tuple(make_multi_index(0, 0, 0, 0)), - generate_tuple( - [&](auto) { - return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); - }, - Number{})); - - // shuffle: blockwise copy C from LDS to global - auto cde1_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v7< - ThisThreadBlock, - decltype(container_concat(make_tuple(C1ShuffleDataType{}), D1sDataType{})), - Tuple, - decltype(c1_d1s_desc_refs), - decltype(tie(e1_grid_desc_mblock_mperblock_nblock_nperblock)), - CDE1ElementwiseOperation, - Sequence(E1GlobalMemoryDataOperation)>, // FIXME: make Sequence - // support arbitray - // type - Sequence<1, - C1ShuffleGemm0MXdlPerWavePerShuffle * MWave * Gemm0MPerXdl, - 1, - C1ShuffleGemm0NXdlPerWavePerShuffle * NWave * - Gemm0NPerXdl>, // BlockSliceLengths, - CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CDE1ShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags - {c1_d1s_desc_refs, - idx_c1_d1s_block_begin, - tie(e1_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)), - cde1_element_op}; - - // space filling curve for threadwise C in VGPR - constexpr auto sfc_c1_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; - - // space filling curve for shuffled blockwise C in global mem - constexpr auto sfc_e1_global = SpaceFillingCurve< - Sequence<1, Gemm0MPerBlock, 1, Gemm1NPerBlock>, - Sequence<0, 2, 1, 3>, - Sequence<1, - C1ShuffleGemm0MXdlPerWavePerShuffle * MWave * Gemm0MPerXdl, - 1, - C1ShuffleGemm0NXdlPerWavePerShuffle * NWave * Gemm0NPerXdl>>{}; - - constexpr index_t num_access = sfc_c1_vgpr.GetNumOfAccess(); - - static_assert(num_access == sfc_e1_global.GetNumOfAccess(), "wrong!"); - - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c1_thread_copy_vgpr_to_lds.Run(c1_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c1_vgpr.GetIndexTupleOfNumber(access_id), - c1_thread_buf, - c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c1_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - cde1_shuffle_block_copy_lds_to_global.Run( - c1_d1s_desc_refs, - c1_d1s_buf_refs, - tie(e1_grid_desc_mblock_mperblock_nblock_nperblock), - tie(e1_grid_buf)); - - if constexpr(access_id < num_access - 1) - { - constexpr auto e1_global_step = sfc_e1_global.GetForwardStep(access_id); - - // move on D1s - static_for<0, NumD1Tensor, 1>{}([&](auto i) { - cde1_shuffle_block_copy_lds_to_global.MoveSrcSliceWindow( - c1_d1s_desc_refs, i + I1, e1_global_step); - }); - - // move on C - cde1_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( - tie(e1_grid_desc_mblock_mperblock_nblock_nperblock), I0, e1_global_step); - } - }); - } + Base::template RunMultiDEpilogue( + blockwise_gemm1, + d1s_grid_desc_mblock_mperblock_nblock_nperblock, + e1_grid_desc_mblock_mperblock_nblock_nperblock, + c1_thread_buf, + block_work_idx[I0], + block_work_idx[I1], + p_shared, + p_d1s_grid, + p_e1_grid, + cde1_element_op); } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp index fd0f77aad77..5722bbc1469 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp @@ -15,6 +15,7 @@ #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/block/blockwise_softmax.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp" namespace ck { @@ -83,37 +84,172 @@ template struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle + : public GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + FloatAB, + FloatAB, + FloatGemmAcc, + FloatCShuffle, + Tuple<>, + FloatC, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + Gemm1NPerBlock, + Gemm1KPerBlock, + AK1Value, + B1K1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + Gemm1NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + B1BlockTransferThreadClusterLengths_BK0_N_BK1, + B1BlockTransferThreadClusterArrangeOrder, + B1BlockTransferSrcAccessOrder, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_BK1, + B1ThreadTransferSrcResetCoordinateAfterRun, + B1BlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + FloatAB, + FloatAB, + true> // ForceNaiveLayout { + using Base = GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + FloatAB, + FloatAB, + FloatGemmAcc, + FloatCShuffle, + Tuple<>, + FloatC, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + Gemm1NPerBlock, + Gemm1KPerBlock, + AK1Value, + B1K1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + Gemm1NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + B1BlockTransferThreadClusterLengths_BK0_N_BK1, + B1BlockTransferThreadClusterArrangeOrder, + B1BlockTransferSrcAccessOrder, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_BK1, + B1ThreadTransferSrcResetCoordinateAfterRun, + B1BlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + FloatAB, + FloatAB, + true>; // ForceNaiveLayout + + using Base0 = GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + FloatAB, + FloatAB, + FloatGemmAcc, + FloatCShuffle, + Tuple<>, + FloatC, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, // ignored + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + FloatAB, + FloatAB, + true>; // ForceNaiveLayout + + using Base::I0; + using Base::I1; + using Base::I2; + using Base::I3; + using Base::I4; + using Base::I5; + using Base::I6; + using Base::I7; + using ThisThreadBlock = typename Base::ThisThreadBlock; + static_assert(LoopSched == LoopScheduler::Default, "Non-default loop scheduler is currently not supported"); static constexpr index_t NumD0Tensor = D0sDataType::Size(); - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; - // K1 should be Number<...> // Gemm0 - static constexpr auto AK0 = Number{}; - static constexpr auto BK0 = Number{}; - static constexpr auto AK1 = Number{}; - static constexpr auto BK1 = Number{}; + static constexpr auto AK0 = Base0::AK0Number; + static constexpr auto BK0 = Base0::BK0Number; + static constexpr auto AK1 = Base0::AK1Number; + static constexpr auto BK1 = Base0::BK1Number; + // Gemm1 + static constexpr auto B1K0 = Base::BK0Number; + static constexpr auto B1K1 = Base::BK1Number; static constexpr auto Gemm0MWaves = MPerBlock / (MPerXdl * MXdlPerWave); static constexpr auto Gemm0NWaves = NPerBlock / (NPerXdl * NXdlPerWave); - // Gemm1 - static constexpr auto B1K0 = Number{}; - static constexpr auto B1K1 = Number{}; - - using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t< decltype(GridwiseGemmPipeline_Selector())>; @@ -152,47 +288,6 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K( BBlockDesc_BK0_N_BK1{}); } - - __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() - { - // A matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(AK0, Number{}, AK1), - make_tuple(Number{} * AK1, AK1, I1)); - } - - __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() - { - // B matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(BK0, Number{}, BK1), - make_tuple(Number{} * BK1, BK1, I1)); - } - - __host__ __device__ static constexpr auto GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1() - { - // B1 matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(B1K0, Number{}, B1K1), - make_tuple(Number{} * B1K1, B1K1, I1)); - } - - __host__ __device__ static constexpr auto - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - I1, - Number{})); - - return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; - } - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned + @@ -412,11 +507,11 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle { // LDS allocation for A and B: be careful of alignment static constexpr auto a_block_desc_ak0_m_ak1 = - GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + Base0::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); static constexpr auto b_block_desc_bk0_n_bk1 = - GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + Base0::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); static constexpr auto b1_block_desc_bk0_n_bk1 = - GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), B1K1); @@ -439,7 +534,7 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle // LDS allocation for C shuffle in LDS static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + Base::GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(get_device_arch()); static constexpr auto c_block_space_size = c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); }; @@ -473,8 +568,6 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); const auto b1_grid_buf = make_dynamic_buffer( p_b1_grid, b1_grid_desc_bk0_n_bk1.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); const auto d0s_grid_buf = generate_tuple( [&](auto i) { return make_dynamic_buffer( @@ -503,10 +596,12 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle __builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock); // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + Base0::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto b_block_desc_bk0_n_bk1 = + Base0::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); // // set up Gemm0 @@ -781,7 +876,8 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle n4>{tensor_operation::element_wise::PassThrough{}}; #endif // B1 matrix in LDS memory, dst of blockwise copy - constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto b1_block_desc_bk0_n_bk1 = + Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); // B1 matrix blockwise copy auto b1_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1( - static_cast(p_shared), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2)), // M2 = MPerXdl - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2, // N2 * N3 * N4 = NPerXdl - N3, - N4))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4>{}, Sequence<>{}, Sequence<1, 3, 5, 6, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - gemm1_blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - n_thread_data_on_block_idx[I2], - n_thread_data_on_block_idx[I3], - n_thread_data_on_block_idx[I4]), - tensor_operation::element_wise::PassThrough{}}; - - // shuffle: blockwise copy C from LDS to global - auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< - ThisThreadBlock, // ThreadGroup - C1DEElementwiseOperation, // ElementwiseOperation, - CGlobalMemoryDataOperation, // DstInMemOp, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - FloatCShuffle, // typename SrcData, - FloatC, // typename DstData, - decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, - true, // bool ThreadTransferSrcResetCoordinateAfterRun, - false> // bool ThreadTransferDstResetCoordinateAfterRun> - {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(0, 0, 0, 0), - c_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), - c1de_element_op}; - - // space filling curve for threadwise C in VGPR - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; - - // space filling curve for shuffled blockwise C in global mem - constexpr auto sfc_c_global = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); - - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - c_shuffle_block_copy_lds_to_global.Run( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - c_shuffle_block_buf, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_buf); - - if constexpr(access_id < num_access - 1) - { - constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); - - // move on C - c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); - } - }); - } + Base::template RunEpilogue( + gemm1_blockwise_gemm, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_work_idx[I0], + block_work_idx[I1], + p_shared, + p_c_grid, + c1de_element_op); } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp index bc2c1978473..67e9b9b5456 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp @@ -11,10 +11,11 @@ #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" -#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/block/blockwise_softmax.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp" namespace ck { @@ -84,35 +85,170 @@ template struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle + : public GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + FloatAB, + FloatAB, + FloatGemmAcc, + FloatCShuffle, + Tuple<>, + FloatC, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + Gemm1NPerBlock, + Gemm1KPerBlock, + AK1Value, + B1K1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + Gemm1NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + B1BlockTransferThreadClusterLengths_BK0_N_BK1, + B1BlockTransferThreadClusterArrangeOrder, + B1BlockTransferSrcAccessOrder, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_BK1, + B1ThreadTransferSrcResetCoordinateAfterRun, + B1BlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + FloatAB, + FloatAB, + true> // ForceNaiveLayout { + using Base = GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + FloatAB, + FloatAB, + FloatGemmAcc, + FloatCShuffle, + Tuple<>, + FloatC, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + Gemm1NPerBlock, + Gemm1KPerBlock, + AK1Value, + B1K1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + Gemm1NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + B1BlockTransferThreadClusterLengths_BK0_N_BK1, + B1BlockTransferThreadClusterArrangeOrder, + B1BlockTransferSrcAccessOrder, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_BK1, + B1ThreadTransferSrcResetCoordinateAfterRun, + B1BlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + FloatAB, + FloatAB, + true>; // ForceNaiveLayout + + using Base0 = GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + FloatAB, + FloatAB, + FloatGemmAcc, + FloatCShuffle, + Tuple<>, + FloatC, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, // ignored + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + FloatAB, + FloatAB, + true>; // ForceNaiveLayout + + using Base::I0; + using Base::I1; + using Base::I2; + using Base::I3; + using Base::I4; + using Base::I5; + using Base::I6; + using Base::I7; + using ThisThreadBlock = typename Base::ThisThreadBlock; + static_assert(LoopSched == LoopScheduler::Default, "Non-default loop scheduler is currently not supported"); - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; - // K1 should be Number<...> // Gemm0 - static constexpr auto AK0 = Number{}; - static constexpr auto BK0 = Number{}; - static constexpr auto AK1 = Number{}; - static constexpr auto BK1 = Number{}; + static constexpr auto AK0 = Base0::AK0Number; + static constexpr auto BK0 = Base0::BK0Number; + static constexpr auto AK1 = Base0::AK1Number; + static constexpr auto BK1 = Base0::BK1Number; + // Gemm1 + static constexpr auto B1K0 = Base::BK0Number; + static constexpr auto B1K1 = Base::BK1Number; static constexpr auto Gemm0MWaves = MPerBlock / (MPerXdl * MXdlPerWave); static constexpr auto Gemm0NWaves = NPerBlock / (NPerXdl * NXdlPerWave); - // Gemm1 - static constexpr auto B1K0 = Number{}; - static constexpr auto B1K1 = Number{}; - - using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t< decltype(GridwiseGemmPipeline_Selector())>; @@ -152,46 +288,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle BBlockDesc_BK0_N_BK1{}); } - __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() - { - // A matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(AK0, Number{}, AK1), - make_tuple(Number{} * AK1, AK1, I1)); - } - - __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() - { - // B matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(BK0, Number{}, BK1), - make_tuple(Number{} * BK1, BK1, I1)); - } - - __host__ __device__ static constexpr auto GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1() - { - // B1 matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(B1K0, Number{}, B1K1), - make_tuple(Number{} * B1K1, B1K1, I1)); - } - - __host__ __device__ static constexpr auto - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - I1, - Number{})); - - return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; - } - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned + @@ -213,6 +309,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> __device__ static bool constexpr IsValidCompilationParameter() { + if constexpr(KPerBlock % + MfmaSelector::GetKPerXdlops() != + 0) + { + return false; + } return ck::tensor_operation::device::IsValidGemmCompilationParameter< BlockSize, MPerBlock, @@ -253,7 +355,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle { return false; } - +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) + if(!is_xdl_wmma_k_supported()) + { + return false; + } +#endif // check gemm0 gridwise gemm pipeline const auto num_gemm0_k_loop = K / KPerBlock; if(!GridwiseGemmPipe::IsSupported(num_gemm0_k_loop)) @@ -327,11 +434,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle { // LDS allocation for A and B: be careful of alignment static constexpr auto a_block_desc_ak0_m_ak1 = - GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + Base0::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); static constexpr auto b_block_desc_bk0_n_bk1 = - GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + Base0::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); static constexpr auto b1_block_desc_bk0_n_bk1 = - GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), B1K1); @@ -354,7 +461,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle // LDS allocation for C shuffle in LDS static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + Base::GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(get_device_arch()); static constexpr auto c_block_space_size = c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); }; @@ -384,8 +491,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); const auto b1_grid_buf = make_dynamic_buffer( p_b1_grid, b1_grid_desc_bk0_n_bk1.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); // divide block work by [M, N] const auto block_work_idx = @@ -407,10 +512,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle __builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock); // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + Base0::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto b_block_desc_bk0_n_bk1 = + Base0::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); // // set up Gemm0 @@ -625,7 +732,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle #endif // B1 matrix in LDS memory, dst of blockwise copy - constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto b1_block_desc_bk0_n_bk1 = + Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); // B1 matrix blockwise copy auto b1_blockwise_copy = @@ -980,201 +1088,15 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle } while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop // shuffle C and write out - { - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - Gemm1NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl); - - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = - gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp = - gemm1_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5); - constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6); - constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2)), // M2 = MPerXdl - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2, // N2 * N3 * N4 = NPerXdl - N3, - N4))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4>{}, Sequence<>{}, Sequence<1, 3, 5, 6, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - gemm1_blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - n_thread_data_on_block_idx[I2], - n_thread_data_on_block_idx[I3], - n_thread_data_on_block_idx[I4]), - tensor_operation::element_wise::PassThrough{}}; - - // shuffle: blockwise copy C from LDS to global - auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< - ThisThreadBlock, // ThreadGroup - CElementwiseOperation, // ElementwiseOperation, - CGlobalMemoryDataOperation, // DstInMemOp, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - FloatCShuffle, // typename SrcData, - FloatC, // typename DstData, - decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, - true, // bool ThreadTransferSrcResetCoordinateAfterRun, - false> // bool ThreadTransferDstResetCoordinateAfterRun> - {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(0, 0, 0, 0), - c_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), - c_element_op}; - - // space filling curve for threadwise C in VGPR - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; - - // space filling curve for shuffled blockwise C in global mem - constexpr auto sfc_c_global = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); - - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - c_shuffle_block_copy_lds_to_global.Run( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - c_shuffle_block_buf, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_buf); - - if constexpr(access_id < num_access - 1) - { - constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); - - // move on C - c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); - } - }); - } + Base::template RunEpilogue( + gemm1_blockwise_gemm, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_work_idx[I0], + block_work_idx[I1], + p_shared, + p_c_grid, + c_element_op); } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp index c168ca9d18f..ccf4b04a6c2 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp @@ -15,6 +15,7 @@ #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp" namespace ck { @@ -40,7 +41,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) #endif kernel_gemm_bias_add_reduce_xdl_cshuffle_v1( const FloatAB* __restrict__ p_a_grid, @@ -70,7 +71,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; GridwiseGemm::template Run( p_a_grid, @@ -177,15 +178,108 @@ template struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 + : public GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + FloatAB, + FloatAB, + FloatGemmAcc, + FloatCShuffle, + Tuple<>, + FloatC, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + FloatAB, + FloatAB, + true> // ForceNaiveLdsLayout { - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; + using Base = GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + FloatAB, + FloatAB, + FloatGemmAcc, + FloatCShuffle, + Tuple<>, + FloatC, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + FloatAB, + FloatAB, + true>; // ForceNaiveLdsLayout + + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; + using Base::I0; + using Base::I1; + using Base::I2; + using Base::I3; + using Base::I4; + using Base::I5; + using Base::I6; + using ThisThreadBlock = typename Base::ThisThreadBlock; // K1 should be Number<...> static constexpr auto AK0 = Number{}; @@ -193,70 +287,9 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 static constexpr auto AK1 = Number{}; static constexpr auto BK1 = Number{}; - using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t< decltype(GridwiseGemmPipeline_Selector())>; - __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() - { - // A matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(AK0, Number{}, AK1), - make_tuple(Number{} * AK1, AK1, I1)); - } - - __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() - { - // B matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(BK0, Number{}, BK1), - make_tuple(Number{} * BK1, BK1, I1)); - } - - __host__ __device__ static constexpr auto - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - I1, - Number{})); - - return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; - } - - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - - // lds max alignment - constexpr auto max_lds_align = math::lcm(AK1, BK1); - - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size_aligned = math::integer_least_multiple( - b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); - - // LDS allocation for C shuffle in LDS - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - constexpr auto c_block_size = - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - - return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * - sizeof(FloatAB), - c_block_size * sizeof(FloatCShuffle)); - } - template < InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> __device__ static bool constexpr IsValidCompilationParameter() @@ -444,10 +477,12 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 constexpr auto max_lds_align = math::lcm(AK1, BK1); // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); // A matrix blockwise copy auto a_blockwise_copy = @@ -592,115 +627,35 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, "wrong!"); - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - // TODO: hacky, fix it! constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + Base::template GetCThreadDescriptor(); constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + Base::GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + get_device_arch()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + Base::template GetCBlockThreadDescriptor< + false, + decltype(blockwise_gemm), + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock)>(); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = Base::template GetCThreadCopyVgprToLds( + blockwise_gemm, + c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + ck::tensor_operation::element_wise::PassThrough{}); auto c_shuffle_block_buf = make_dynamic_buffer( static_cast(p_shared), c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; - // space filling curve for threadwise C in VGPR + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6); constexpr auto sfc_c_vgpr = SpaceFillingCurve, Sequence<0, 1, 2, 3, 4, 5, 6, 7>, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp index c3858b967a5..6260d3b9980 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp @@ -10,12 +10,11 @@ #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" -#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp" -#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp" namespace ck { @@ -74,29 +73,120 @@ template struct GridwiseGemmMultipleABD_xdl_cshuffle + : public GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + AComputeDataType_, + BComputeDataType_, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + AComputeDataType_, + BComputeDataType_, + true> { + using Base = GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + AComputeDataType_, + BComputeDataType_, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + AComputeDataType_, + BComputeDataType_, + true>; + + using Base::AK0Number; + using Base::AK1Number; + using Base::BK0Number; + using Base::BK1Number; + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; + using Base::I0; + using Base::I1; + using Base::I2; + using ThisThreadBlock = typename Base::ThisThreadBlock; + static constexpr index_t NumATensor = AsDataType::Size(); static constexpr index_t NumBTensor = BsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size(); using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization; - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; - // K1 should be Number<...> - static constexpr auto AK1 = Number{}; - static constexpr auto BK1 = Number{}; - static constexpr auto AK0PerBlock = Number{}; - static constexpr auto BK0PerBlock = Number{}; - - using ThisThreadBlock = ThisThreadBlock; + static constexpr auto AK1 = Base::AK1Number; + static constexpr auto BK1 = Base::BK1Number; + static constexpr auto AK0PerBlock = Base::AK0Number; + static constexpr auto BK0PerBlock = Base::BK0Number; using GridwiseGemmPipe = remove_cvref_t< decltype(GridwiseGemmPipeline_Selector())>; @@ -114,38 +204,6 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle conditional_t, float, BComputeDataType_>; #endif - __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() - { - // A matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(AK0PerBlock, Number{}, AK1), - make_tuple(Number{} * AK1, AK1, I1)); - } - - __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() - { - // B matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(BK0PerBlock, Number{}, BK1), - make_tuple(Number{} * BK1, BK1, I1)); - } - - __host__ __device__ static constexpr auto - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - I1, - Number{})); - - return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; - } - static constexpr auto MakeAsGridPointer() { return generate_tuple( @@ -180,33 +238,6 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle Number{}); } - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - - // lds max alignment - constexpr auto max_lds_align = math::lcm(AK1, BK1); - - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size_aligned = math::integer_least_multiple( - b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); - - // LDS allocation for C shuffle in LDS - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - constexpr auto c_block_size = - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - - return math::max(a_block_space_size_aligned * sizeof(AElementDataType) + - b_block_space_size_aligned * sizeof(BElementDataType), - c_block_size * sizeof(CShuffleDataType)); - } - // A desc for source in blockwise copy template __host__ __device__ static constexpr auto @@ -579,17 +610,6 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle }, Number{}); - const auto ds_grid_buf = generate_tuple( - [&](auto i) { - return make_dynamic_buffer( - p_ds_grid[i], - ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); - }, - Number{}); - - auto e_grid_buf = make_dynamic_buffer( - p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - // divide block work by [M, N] const auto block_work_idx = block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); @@ -612,10 +632,12 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle constexpr auto max_lds_align = math::lcm(AK1, BK1); // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); const auto idx_as_block_begin = generate_tuple([&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); }, @@ -756,241 +778,17 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle c_thread_buf, num_k_block_main_loop); - // shuffle C and write out - { - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_desc_refs = concat_tuple_of_reference( - tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_buf_refs = concat_tuple_of_reference( - tie(c_shuffle_block_buf), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); - - // tuple of starting index of C/Ds blockwise copy - const auto idx_c_ds_block_begin = container_concat( - make_tuple(make_multi_index(0, 0, 0, 0)), - generate_tuple( - [&](auto) { - return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); - }, - Number{})); - - // blockwise copy C/D/E between LDS and global - auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r2< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CDEElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence - // support arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVector_NPerBlock, - CDEShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags - {c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)), - cde_element_op}; - - // space filling curve for threadwise C in VGPR before shuffle - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; - - // space filling curve for shuffled blockwise C/D/E - constexpr auto sfc_cde_block = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); - - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - cde_block_copy_lds_and_global.Run( - c_ds_desc_refs, - c_ds_buf_refs, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - tie(e_grid_buf)); - - if constexpr(access_id < num_access - 1) - { - constexpr auto cde_lds_and_global_step = - sfc_cde_block.GetForwardStep(access_id); - - // move on Ds - static_for<0, NumDTensor, 1>{}([&](auto i) { - cde_block_copy_lds_and_global.MoveSrcSliceWindow( - c_ds_desc_refs, i + I1, cde_lds_and_global_step); - }); - - // move on E - cde_block_copy_lds_and_global.MoveDstSliceWindow( - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - I0, - cde_lds_and_global_step); - } - }); - } + Base::template RunMultiDEpilogue( + blockwise_gemm, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_work_idx[I0], + block_work_idx[I1], + p_shared, + p_ds_grid, + p_e_grid, + cde_element_op); } template struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1 + : public GridwiseGemm_xdl_cshuffle_base, + FloatAB, + FloatAB, + true> // ForceNaiveLdsLayout { - static constexpr index_t NumDTensor = DsDataType::Size(); - static constexpr index_t NumRTensor = RsDataType::Size(); + using Base = + GridwiseGemm_xdl_cshuffle_base, + FloatAB, + FloatAB, + true>; // ForceNaiveLdsLayout + + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; + using Base::I0; + using Base::I1; + using Base::I2; + using Base::I3; + using Base::I4; + using Base::I5; + using Base::I6; + using Base::I7; + using ThisThreadBlock = typename Base::ThisThreadBlock; + using Base::NumDTensor; - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; + static constexpr index_t NumRTensor = RsDataType::Size(); // K1 should be Number<...> - static constexpr auto AK1 = Number{}; - static constexpr auto BK1 = Number{}; - static constexpr auto AK0PerBlock = Number{}; - static constexpr auto BK0PerBlock = Number{}; - - using ThisThreadBlock = ThisThreadBlock; + static constexpr auto AK1 = Base::AK1Number; + static constexpr auto BK1 = Base::BK1Number; + static constexpr auto AK0PerBlock = Base::AK0Number; + static constexpr auto BK0PerBlock = Base::BK0Number; using GridwiseGemmPipe = remove_cvref_t< decltype(GridwiseGemmPipeline_Selector())>; - __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() - { - // A matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(AK0PerBlock, Number{}, AK1), - make_tuple(Number{} * AK1, AK1, I1)); - } - - __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() - { - // B matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(BK0PerBlock, Number{}, BK1), - make_tuple(Number{} * BK1, BK1, I1)); - } - - __host__ __device__ static constexpr auto - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - I1, - Number{})); - - return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; - } - // ck::Tuple template static constexpr auto MakeTsGridPointer() @@ -142,33 +201,6 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1 Number{}); } - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - - // lds max alignment - constexpr auto max_lds_align = math::lcm(AK1, BK1); - - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size_aligned = math::integer_least_multiple( - b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); - - // LDS allocation for C shuffle in LDS - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - constexpr auto c_block_size = - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - - return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * - sizeof(FloatAB), - c_block_size * sizeof(FloatCShuffle)); - } - // A desc for source in blockwise copy __host__ __device__ static constexpr auto MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k) @@ -392,10 +424,12 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1 constexpr auto max_lds_align = math::lcm(AK1, BK1); // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); // A matrix blockwise copy auto a_blockwise_copy = @@ -544,109 +578,32 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - // TODO: hacky, fix it! constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + Base::template GetCThreadDescriptor(); constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + Base::GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + get_device_arch()); auto c_shuffle_block_buf = make_dynamic_buffer( static_cast(p_shared), c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + Base::template GetCBlockThreadDescriptor< + false, + decltype(blockwise_gemm), + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock)>(); + + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = Base::template GetCThreadCopyVgprToLds( + blockwise_gemm, + c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + ck::tensor_operation::element_wise::PassThrough{}); // space filling curve for threadwise C in VGPR constexpr auto sfc_c_vgpr = diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index 3523da6c461..d5b5236bacd 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -4,6 +4,8 @@ #pragma once #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/host_utility/device_prop.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" @@ -11,12 +13,11 @@ #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" -#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" -#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp" namespace ck { @@ -75,28 +76,115 @@ template struct GridwiseGemmMultipleD_xdl_cshuffle + : public GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + AComputeDataType_, + BComputeDataType_, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + AComputeDataType_, + BComputeDataType_, + true> // ForceNaiveLdsLayout { + using Base = GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + AComputeDataType_, + BComputeDataType_, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + AComputeDataType_, + BComputeDataType_, + true>; // ForceNaiveLdsLayout + + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; + using Base::I0; + using Base::I1; + using Base::I2; + using ThisThreadBlock = typename Base::ThisThreadBlock; + static constexpr index_t NumDTensor = DsDataType::Size(); static_assert(!DoElementwiseBeforeCShuffle || NumDTensor == 0); using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization; - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; - // K1 should be Number<...> - static constexpr auto AK1 = Number{}; - static constexpr auto BK1 = Number{}; - static constexpr auto AK0PerBlock = Number{}; - static constexpr auto BK0PerBlock = Number{}; - - using ThisThreadBlock = ThisThreadBlock; + static constexpr auto AK1 = Base::AK1Number; + static constexpr auto BK1 = Base::BK1Number; + static constexpr auto AK0PerBlock = Base::AK0Number; + static constexpr auto BK0PerBlock = Base::BK0Number; using GridwiseGemmPipe = remove_cvref_t< decltype(GridwiseGemmPipeline_Selector())>; @@ -113,38 +201,6 @@ struct GridwiseGemmMultipleD_xdl_cshuffle conditional_t, float, BComputeDataType_>; #endif - __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() - { - // A matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(AK0PerBlock, Number{}, AK1), - make_tuple(Number{} * AK1, AK1, I1)); - } - - __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() - { - // B matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(BK0PerBlock, Number{}, BK1), - make_tuple(Number{} * BK1, BK1, I1)); - } - - __host__ __device__ static constexpr auto - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - I1, - Number{})); - - return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; - } - // ck::Tuple static constexpr auto MakeDsGridPointer() { @@ -157,33 +213,6 @@ struct GridwiseGemmMultipleD_xdl_cshuffle Number{}); } - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - - // lds max alignment - constexpr auto max_lds_align = math::lcm(AK1, BK1); - - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size_aligned = math::integer_least_multiple( - b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); - - // LDS allocation for C shuffle in LDS - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - constexpr auto c_block_size = - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - - return math::max(a_block_space_size_aligned * sizeof(AComputeDataType) + - b_block_space_size_aligned * sizeof(BComputeDataType), - c_block_size * sizeof(CShuffleDataType)); - } - // A desc for source in blockwise copy template __host__ __device__ static constexpr auto @@ -324,7 +353,97 @@ struct GridwiseGemmMultipleD_xdl_cshuffle return true; } - IS_VALID_COMPILATION_PARAMETER_IMPL(EDataType) + template + static constexpr index_t GetEstimateVgprCount() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t WaveSize = BlockSize / (MWave * NWave); + + // VGPR used in LDS loading and WMMA + constexpr index_t BaseInputVgprCount = + MPerBlock * KPerBlock / MWave / WaveSize * sizeof(ADataType) / sizeof(uint32_t) + + NPerBlock * KPerBlock / NWave / WaveSize * sizeof(BDataType) / sizeof(uint32_t); + // WMMA input is duplicated in GFX11 + constexpr index_t InputVgprCount = IsGfx11 ? BaseInputVgprCount * 2 : BaseInputVgprCount; + // VGPR used in Accumulator + constexpr index_t AccVgprCount = + MPerBlock * NPerBlock / BlockSize * sizeof(AccDataType) / sizeof(uint32_t); + + if constexpr(PipelineVer == PipelineVersion::v1) + { + return InputVgprCount + AccVgprCount + BaseInputVgprCount * (NumGemmKPrefetchStage - 1); + } + else if constexpr(PipelineVer == PipelineVersion::v2) + { + return InputVgprCount + AccVgprCount + BaseInputVgprCount; + } + else if constexpr(PipelineVer == PipelineVersion::weight_only) + { + return InputVgprCount + AccVgprCount; + } + else if constexpr(PipelineVer == PipelineVersion::v4) + { + return InputVgprCount * 2 + AccVgprCount; + } + else + { + // invalid pipeline version + static_assert(0); + } + } + + __host__ static index_t GetSharedMemoryNumberOfByteOnHost() + { +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) + if(ck::get_device_name() == "gfx950") + { + return Base::GetSharedMemoryNumberOfByte(gfx950_t{}); + } + else +#endif + { + return Base::GetSharedMemoryNumberOfByte(gfx_invalid_t{}); + } + } + + template < + InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> + __device__ static bool constexpr IsValidCompilationParameter() + { +#if defined(__gfx11__) || defined(__gfx12__) + if constexpr(is_same_v) + { + + return false; + } +#endif + + if constexpr(Base::GetSharedMemoryNumberOfByte(get_device_arch()) > + get_lds_size(get_device_arch())) + { + return false; + } + + constexpr bool IsGfx11 = is_same_v; + constexpr auto EstimateVgprCount = GetEstimateVgprCount(); + constexpr auto AvailableVgprCount = get_max_vgpr_count(get_device_arch()); + if constexpr(EstimateVgprCount > (AvailableVgprCount + AvailableVgprCount / 4)) + { + return false; + } + + return ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + EDataType, + CGlobalMemoryDataOperation_>(); + } template ()) + { + return false; + } + + if(GetSharedMemoryNumberOfByteOnHost() > get_lds_size()) + { + return false; + } + const auto availableVgprCount = []() { + if(ck::is_gfx12_supported()) + { + return get_max_vgpr_count(gfx12_t{}); + } + else if(ck::is_gfx11_supported()) + { + return get_max_vgpr_count(gfx11_t{}); + } + else + { + return get_max_vgpr_count(gfx9_t{}); + } + }(); + const auto estimateVgprCount = + ck::is_gfx11_supported() ? GetEstimateVgprCount() : GetEstimateVgprCount(); + if(estimateVgprCount > (availableVgprCount + availableVgprCount / 4)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Estimated VGPR count (" << estimateVgprCount + << ") exceeds available VGPR count (" << availableVgprCount << ")! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } +#endif return true; } @@ -538,17 +695,6 @@ struct GridwiseGemmMultipleD_xdl_cshuffle const auto b_grid_buf = make_dynamic_buffer( p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); - const auto ds_grid_buf = generate_tuple( - [&](auto i) { - return make_dynamic_buffer( - p_ds_grid[i], - ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); - }, - Number{}); - - auto e_grid_buf = make_dynamic_buffer( - p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - // divide block work by [M, N] const auto block_work_idx = block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); @@ -576,10 +722,12 @@ struct GridwiseGemmMultipleD_xdl_cshuffle constexpr auto max_lds_align = math::lcm(AK1, BK1); // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); // A matrix blockwise copy auto a_blockwise_copy = @@ -724,263 +872,18 @@ struct GridwiseGemmMultipleD_xdl_cshuffle c_thread_buf, num_k_block_main_loop); - // shuffle C and write out - { - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - tensor_operation::element_wise::PassThrough pass_through{}; - const auto& vpgr_to_lds_element_op = [&] { - if constexpr(DoElementwiseBeforeCShuffle) - { - return cde_element_op; - } - else - { - return pass_through; - } - }; - const auto& lds_to_global_element_op = [&] { - if constexpr(!DoElementwiseBeforeCShuffle) - { - return cde_element_op; - } - else - { - return pass_through; - } - }; - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< - AccDataType, - CShuffleDataType, - decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), - decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), - conditional_t, - Sequence, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - vpgr_to_lds_element_op()}; - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_desc_refs = concat_tuple_of_reference( - tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_buf_refs = concat_tuple_of_reference( - tie(c_shuffle_block_buf), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); - - // tuple of starting index of C/Ds blockwise copy - const auto idx_c_ds_block_begin = container_concat( - make_tuple(make_multi_index(0, 0, 0, 0)), - generate_tuple( - [&](auto) { - return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); - }, - Number{})); - - // blockwise copy C/D/E between LDS and global - auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - conditional_t, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence - // support arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CDEShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags - {c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)), - lds_to_global_element_op()}; - - // space filling curve for threadwise C in VGPR before shuffle - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; - - // space filling curve for shuffled blockwise C/D/E - constexpr auto sfc_cde_block = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); - - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - cde_block_copy_lds_and_global.Run( - c_ds_desc_refs, - c_ds_buf_refs, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - tie(e_grid_buf)); - - if constexpr(access_id < num_access - 1) - { - constexpr auto cde_lds_and_global_step = - sfc_cde_block.GetForwardStep(access_id); - - // move on Ds - static_for<0, NumDTensor, 1>{}([&](auto i) { - cde_block_copy_lds_and_global.MoveSrcSliceWindow( - c_ds_desc_refs, i + I1, cde_lds_and_global_step); - }); - - // move on E - cde_block_copy_lds_and_global.MoveDstSliceWindow( - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - I0, - cde_lds_and_global_step); - } - }); - } + // Shuffle C and write out. + Base::template RunMultiDEpilogue( + blockwise_gemm, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_work_idx[I0], + block_work_idx[I1], + p_shared, + p_ds_grid, + p_e_grid, + cde_element_op); } template struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad + : public GridwiseGemm_xdl_cshuffle_base< + ALayout, + BLayout, + ELayout, + AComputeDataType_, + BComputeDataType_, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferScalarPerVector, + ABlockTransferScalarPerVector, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferScalarPerVector, + BBlockTransferScalarPerVector, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + AComputeDataType_, + BComputeDataType_, + false, // ForceNaiveLayout + true> // DirectLoad { - static constexpr index_t NumDTensor = DsDataType::Size(); - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; + using Base = GridwiseGemm_xdl_cshuffle_base< + ALayout, + BLayout, + ELayout, + AComputeDataType_, + BComputeDataType_, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferScalarPerVector, + ABlockTransferScalarPerVector, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferScalarPerVector, + BBlockTransferScalarPerVector, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + AComputeDataType_, + BComputeDataType_, + false, // ForceNaiveLayout + true>; // DirectLoad + + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; + using ThisThreadBlock = typename Base::ThisThreadBlock; + using Base::I0; + using Base::I1; + using Base::I2; - static constexpr auto AK1 = Number{}; - static constexpr auto BK1 = Number{}; - static constexpr auto AK0PerBlock = Number{}; - static constexpr auto BK0PerBlock = Number{}; + static constexpr index_t NumDTensor = DsDataType::Size(); - using ThisThreadBlock = ThisThreadBlock; + static constexpr auto AK1 = Base::AK1Number; + static constexpr auto BK1 = Base::BK1Number; + static constexpr auto AK0PerBlock = Base::AK0Number; + static constexpr auto BK0PerBlock = Base::BK0Number; using GridwiseGemmPipe = remove_cvref_t< decltype(GridwiseGemmPipeline_Selector())>; @@ -178,54 +266,6 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad conditional_t, float, BComputeDataType_>; #endif - __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() - { - if constexpr(is_same_v) - { - // FIXME: our support to non-K contiguous layout is limited, only work in some specific - // setting - return make_naive_tensor_descriptor_packed( - make_tuple(AK0PerBlock, Number{}, AK1)); - } - else - { - return make_naive_tensor_descriptor(make_tuple(AK0PerBlock, Number{}, AK1), - make_tuple(AK1, Number{}, I1)); - } - } - - __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() - { - if constexpr(is_same_v) - { - // FIXME: our support to non-K contiguous layout is limited, only work in some specific - // setting - return make_naive_tensor_descriptor_packed( - make_tuple(BK0PerBlock, Number{}, BK1)); - } - else - { - return make_naive_tensor_descriptor(make_tuple(BK0PerBlock, Number{}, BK1), - make_tuple(BK1, Number{}, I1)); - } - } - - __host__ __device__ static constexpr auto - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - I1, - Number{})); - - return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; - } - // ck::Tuple static constexpr auto MakeDsGridPointer() { @@ -240,29 +280,8 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { - // LDS allocation for A and B: be careful of alignment. - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - - constexpr auto max_lds_align = math::lcm(AK1, BK1); - - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size_aligned = math::integer_least_multiple( - b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); - - // LDS allocation for C shuffle. - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - constexpr auto c_block_size = - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - - return math::max( - NumGemmKPrefetchStage * a_block_space_size_aligned * sizeof(AComputeDataType) + - NumGemmKPrefetchStage * b_block_space_size_aligned * sizeof(BComputeDataType), - c_block_size * sizeof(CShuffleDataType)); + return Base::template GetSharedMemoryNumberOfByte( + get_device_arch()); } __host__ __device__ static auto @@ -550,17 +569,6 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad const auto b_grid_buf = make_dynamic_buffer( p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); - const auto ds_grid_buf = generate_tuple( - [&](auto i) { - return make_dynamic_buffer( - p_ds_grid[i], - ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); - }, - Number{}); - - auto e_grid_buf = make_dynamic_buffer( - p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - // Divide block work by [M, N]. const auto block_work_idx = block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); @@ -582,10 +590,12 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad constexpr auto max_lds_align = math::lcm(AK1, BK1); // A matrix in LDS memory, destination of blockwise copy. - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, destination of blockwise copy. - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_DirectLoad( - static_cast(p_shared), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); - - // Calculate the origin of thread output tensor on global memory. - const auto c_thread_mtx_on_block = - blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // Shuffle: threadwise copy C from VGPR to LDS. - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; - - // A tuple of reference to C/Ds tensor descriptors. - const auto c_ds_desc_refs = concat_tuple_of_reference( - tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); - - // A tuple of reference to C/Ds grid buffers. - const auto c_ds_buf_refs = concat_tuple_of_reference( - tie(c_shuffle_block_buf), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); - - // A tuple of starting index of C/Ds blockwise copy. - const auto idx_c_ds_block_begin = container_concat( - make_tuple(make_multi_index(0, 0, 0, 0)), - generate_tuple( - [&](auto) { - return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); - }, - Number{})); - - // Blockwise copy C/D/E between LDS and global. - auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CDEElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CDEShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags - {c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)), - cde_element_op}; - - // Space filling curve for threadwise C in VGPR before shuffle. - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; - - // Space filling curve for shuffled blockwise C/D/E. - constexpr auto sfc_cde_block = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); - - static_for<0, num_access, 1>{}([&](auto access_id) { - // Make sure it's safe to write to LDS. - block_sync_lds(); - - // Each thread write its data from VGPR to LDS. - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); - - // Make sure it's safe to read from LDS. - block_sync_lds(); - - // Each block copy its data from LDS to global. - cde_block_copy_lds_and_global.Run( - c_ds_desc_refs, - c_ds_buf_refs, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - tie(e_grid_buf)); - - if constexpr(access_id < num_access - 1) - { - constexpr auto cde_lds_and_global_step = - sfc_cde_block.GetForwardStep(access_id); - - // Move on Ds. - static_for<0, NumDTensor, 1>{}([&](auto i) { - cde_block_copy_lds_and_global.MoveSrcSliceWindow( - c_ds_desc_refs, i + I1, cde_lds_and_global_step); - }); - - // Move on E. - cde_block_copy_lds_and_global.MoveDstSliceWindow( - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - I0, - cde_lds_and_global_step); - } - }); - } + Base::template RunMultiDEpilogue( + blockwise_gemm, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_work_idx[I0], + block_work_idx[I1], + p_shared, + p_ds_grid, + p_e_grid, + cde_element_op); } struct Argument : public tensor_operation::device::BaseArgument diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp index 6307f649df6..5fb769f3555 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp @@ -10,13 +10,11 @@ #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" -#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" -#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" -#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp" namespace ck { @@ -76,83 +74,142 @@ template struct GridwiseGemmMultipleD_xdl_splitk_cshuffle + : public GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + ALDSType, + BLDSType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + AComputeType, + BComputeType, + true> // ForceNaiveLayout { + using Base = GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + ALDSType, + BLDSType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + AComputeType, + BComputeType, + true>; // ForceNaiveLayout + + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; + using Base::I0; + using Base::I1; + using Base::I2; + using Base::I3; + using ThisThreadBlock = typename Base::ThisThreadBlock; static constexpr index_t NumDTensor = DsDataType::Size(); using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization; - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; - // K1 should be Number<...> - static constexpr auto AK1 = Number{}; - static constexpr auto BK1 = Number{}; - static constexpr auto AK0PerBlock = Number{}; - static constexpr auto BK0PerBlock = Number{}; - - using ThisThreadBlock = ThisThreadBlock; + static constexpr auto AK1 = Base::AK1Number; + static constexpr auto BK1 = Base::BK1Number; + static constexpr auto AK0PerBlock = Base::AK0Number; + static constexpr auto BK0PerBlock = Base::BK0Number; using GridwiseGemmPipe = remove_cvref_t< decltype(GridwiseGemmPipeline_Selector())>; - __host__ __device__ static constexpr auto GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1() - { - // A matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(I1, AK0PerBlock, Number{}, AK1), - make_tuple(AK0PerBlock * Number{} * AK1, - Number{} * AK1, - AK1, - I1)); - } - - __host__ __device__ static constexpr auto GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1() - { - // B matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(I1, BK0PerBlock, Number{}, BK1), - make_tuple(BK0PerBlock * Number{} * BK1, - Number{} * BK1, - BK1, - I1)); - } - - __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + template + __host__ __device__ static constexpr auto GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1( + const ABlockDescriptor_AK0PerBlock_MPerBlock_AK1&) { - // A matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(AK0PerBlock, Number{}, AK1), - make_tuple(Number{} * AK1, AK1, I1)); + return transform_tensor_descriptor( + ABlockDescriptor_AK0PerBlock_MPerBlock_AK1{}, + make_tuple(make_unmerge_transform(make_tuple(I1, AK0PerBlock)), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{})); } - __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + template + __host__ __device__ static constexpr auto GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1( + const BBlockDescriptor_BK0PerBlock_NPerBlock_BK1&) { - // B matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(BK0PerBlock, Number{}, BK1), - make_tuple(Number{} * BK1, BK1, I1)); - } - - __host__ __device__ static constexpr auto - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - I1, - Number{})); - - return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + return transform_tensor_descriptor( + BBlockDescriptor_BK0PerBlock_NPerBlock_BK1{}, + make_tuple(make_unmerge_transform(make_tuple(I1, BK0PerBlock)), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{})); } // ck::Tuple @@ -167,33 +224,6 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle Number{}); } - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - - // lds max alignment - constexpr auto max_lds_align = math::lcm(AK1, BK1); - - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size_aligned = math::integer_least_multiple( - b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); - - // LDS allocation for C shuffle in LDS - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - constexpr auto c_block_size = - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - - return math::max(a_block_space_size_aligned * sizeof(ALDSType) + - b_block_space_size_aligned * sizeof(BLDSType), - c_block_size * sizeof(CShuffleDataType)); - } - __host__ __device__ static auto CalculateMPadded(index_t M) { return math::integer_least_multiple(M, MPerBlock); @@ -491,14 +521,6 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle const auto b_grid_buf = make_dynamic_buffer( p_b_grid, b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize()); - const auto ds_grid_buf = generate_tuple( - [&](auto i) { - return make_dynamic_buffer( - p_ds_grid[i], - ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); - }, - Number{}); - auto e_grid_buf = make_dynamic_buffer( p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); @@ -518,13 +540,21 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle // lds max alignment constexpr auto max_lds_align = math::lcm(AK1, BK1); + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); + // A matrix in LDS memory, dst of blockwise copy constexpr auto a_block_desc_kbatch_ak0_m_ak1 = - GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1(); + GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1(a_block_desc_ak0_m_ak1); // B matrix in LDS memory, dst of blockwise copy constexpr auto b_block_desc_kbatch_bk0_n_bk1 = - GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1(); + GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1(b_block_desc_bk0_n_bk1); // A matrix blockwise copy auto a_blockwise_copy = @@ -588,12 +618,6 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle make_multi_index(0, 0, 0, 0), ck::tensor_operation::element_wise::PassThrough{}); - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - // GEMM definition // c_mtx += transpose(a_mtx) * b_mtx // a_mtx[K0PerBlock, MPerBlock] is in LDS @@ -731,259 +755,42 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle blockwise_gemm, c_thread_buf, num_k_block_main_loop); - - // shuffle C and write out + if constexpr(Zeroing) { - if constexpr(Zeroing) + if(threadIdx.x == 0) { - if(threadIdx.x == 0) - { - while(__atomic_load_n(barrier_count_finished, __ATOMIC_RELAXED) == 0) {} - } - __builtin_amdgcn_s_barrier(); + while(__atomic_load_n(barrier_count_finished, __ATOMIC_RELAXED) == 0) {} } + __builtin_amdgcn_s_barrier(); + } - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_desc_refs = concat_tuple_of_reference( - tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_buf_refs = concat_tuple_of_reference( - tie(c_shuffle_block_buf), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); - - // tuple of starting index of C/Ds blockwise copy - const auto idx_c_ds_block_begin = container_concat( - make_tuple(make_multi_index(0, 0, 0, 0)), - generate_tuple( - [&](auto) { - return make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0); - }, - Number{})); - - // space filling curve for threadwise C in VGPR before shuffle - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; - - // space filling curve for shuffled blockwise C/D/E - constexpr auto sfc_cde_block = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); - - // blockwise copy C/D/E between LDS and global - auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType_{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CDEElementwiseOperation_, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make - // Sequence support - // arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CDEShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags - {c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0)), - cde_element_op}; - - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - cde_block_copy_lds_and_global.Run( - c_ds_desc_refs, - c_ds_buf_refs, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - tie(e_grid_buf)); - - if constexpr(access_id < num_access - 1) - { - constexpr auto cde_lds_and_global_step = - sfc_cde_block.GetForwardStep(access_id); - - // move on Ds - static_for<0, NumDTensor_, 1>{}([&](auto i) { - cde_block_copy_lds_and_global.MoveSrcSliceWindow( - c_ds_desc_refs, i + I1, cde_lds_and_global_step); - }); - - // move on E - cde_block_copy_lds_and_global.MoveDstSliceWindow( - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - I0, - cde_lds_and_global_step); - } - }); + // Shuffle C and write out. + Base::template RunMultiDEpilogue( + blockwise_gemm, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_work_idx[I0], + block_work_idx[I1], + p_shared, + p_ds_grid, + p_e_grid, + cde_element_op); - if constexpr(Zeroing) + if constexpr(Zeroing) + { + if(threadIdx.x == 0) { - if(threadIdx.x == 0) - { - index_t k_id_finished_t = atomicAdd(barrier_count_finished, 1); + index_t k_id_finished_t = atomicAdd(barrier_count_finished, 1); - if(k_id_finished_t == KBatch) - { - *barrier_count_finished = 0; - } + if(k_id_finished_t == KBatch) + { + *barrier_count_finished = 0; } } } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp index 6ce04e858b4..f1e97455b72 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp @@ -15,6 +15,7 @@ #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp" namespace ck { @@ -35,7 +36,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) #endif kernel_gemm_reduce_xdl_cshuffle_v1( const FloatAB* __restrict__ p_a_grid, @@ -58,7 +59,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; GridwiseGemm::template Run(p_a_grid, p_b_grid, @@ -149,15 +150,108 @@ template struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 + : public GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + FloatAB, + FloatAB, + FloatGemmAcc, + FloatCShuffle, + Tuple<>, + FloatC, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + FloatAB, + FloatAB, + true> // ForceNaiveLdsLayout { - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; + using Base = GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + FloatAB, + FloatAB, + FloatGemmAcc, + FloatCShuffle, + Tuple<>, + FloatC, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + FloatAB, + FloatAB, + true>; // ForceNaiveLdsLayout + + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; + using Base::I0; + using Base::I1; + using Base::I2; + using Base::I3; + using Base::I4; + using Base::I5; + using Base::I6; + using ThisThreadBlock = typename Base::ThisThreadBlock; // K1 should be Number<...> static constexpr auto AK0 = Number{}; @@ -165,70 +259,9 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 static constexpr auto AK1 = Number{}; static constexpr auto BK1 = Number{}; - using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t< decltype(GridwiseGemmPipeline_Selector())>; - __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() - { - // A matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(AK0, Number{}, AK1), - make_tuple(Number{} * AK1, AK1, I1)); - } - - __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() - { - // B matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(BK0, Number{}, BK1), - make_tuple(Number{} * BK1, BK1, I1)); - } - - __host__ __device__ static constexpr auto - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - I1, - Number{})); - - return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; - } - - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - - // lds max alignment - constexpr auto max_lds_align = math::lcm(AK1, BK1); - - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size_aligned = math::integer_least_multiple( - b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); - - // LDS allocation for C shuffle in LDS - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - constexpr auto c_block_size = - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - - return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * - sizeof(FloatAB), - c_block_size * sizeof(FloatCShuffle)); - } - template < InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> __device__ static bool constexpr IsValidCompilationParameter() @@ -396,10 +429,12 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 constexpr auto max_lds_align = math::lcm(AK1, BK1); // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); // A matrix blockwise copy auto a_blockwise_copy = @@ -541,116 +576,36 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 // shuffle C + reduction + write out { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - // TODO: hacky, fix it! constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + Base::template GetCThreadDescriptor(); constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + Base::GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + get_device_arch()); auto c_shuffle_block_buf = make_dynamic_buffer( static_cast(p_shared), c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + Base::template GetCBlockThreadDescriptor< + false, + decltype(blockwise_gemm), + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock)>(); + + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6); + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = Base::template GetCThreadCopyVgprToLds( + blockwise_gemm, + c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + ck::tensor_operation::element_wise::PassThrough{}); // shuffle: blockwise copy C from LDS to global auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp index 12d07ca23ae..d9b4a1d76c1 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp @@ -10,11 +10,9 @@ #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" -#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" -#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" -#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp" namespace ck { @@ -73,82 +71,118 @@ template struct GridwiseGemmSplitKMultipleD_xdl_cshuffle + : public GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + ABDataType, + ABDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + ABDataType, + ABDataType, + true> // ForceNaiveLayout { - static constexpr index_t NumDTensor = DsDataType::Size(); + using Base = GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + ABDataType, + ABDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + ABDataType, + ABDataType, + true>; // ForceNaiveLayout + + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::GetABlockDescriptor_AKB_AK0PerBlock_MPerBlock_AK1; + using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; + using Base::GetBBlockDescriptor_BKB_BK0PerBlock_NPerBlock_BK1; + using Base::I0; + using Base::I1; + using Base::I2; + using Base::I3; + using ThisThreadBlock = typename Base::ThisThreadBlock; - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; + static constexpr index_t NumDTensor = DsDataType::Size(); // K1 should be Number<...> - static constexpr auto AK1 = Number{}; - static constexpr auto BK1 = Number{}; - static constexpr auto AK0PerBlock = Number{}; - static constexpr auto BK0PerBlock = Number{}; - - using ThisThreadBlock = ThisThreadBlock; + static constexpr auto AK1 = Base::AK1Number; + static constexpr auto BK1 = Base::BK1Number; + static constexpr auto AK0PerBlock = Base::AK0Number; + static constexpr auto BK0PerBlock = Base::BK0Number; using GridwiseGemmPipe = GridwiseGemmPipeline_v1; - __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() - { - // A matrix in LDS memory, src of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(AK0PerBlock, Number{}, AK1), - make_tuple(Number{} * AK1, AK1, I1)); - } - - __host__ __device__ static constexpr auto GetABlockDescriptor_AKB_AK0PerBlock_MPerBlock_AK1() - { - // A matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(I1, AK0PerBlock, Number{}, AK1), - make_tuple(AK0PerBlock * Number{} * AK1, - Number{} * AK1, - AK1, - I1)); - } - - __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() - { - // B matrix in LDS memory, src of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(BK0PerBlock, Number{}, BK1), - make_tuple(Number{} * BK1, BK1, I1)); - } - - __host__ __device__ static constexpr auto GetBBlockDescriptor_BKB_BK0PerBlock_NPerBlock_BK1() - { - // B matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(I1, BK0PerBlock, Number{}, BK1), - make_tuple(BK0PerBlock * Number{} * BK1, - Number{} * BK1, - BK1, - I1)); - } - - __host__ __device__ static constexpr auto - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - I1, - Number{})); - - return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; - } - // ck::Tuple static constexpr auto MakeDsGridPointer() { @@ -161,33 +195,6 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle Number{}); } - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - - // lds max alignment - constexpr auto max_lds_align = math::lcm(AK1, BK1); - - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size_aligned = math::integer_least_multiple( - b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); - - // LDS allocation for C shuffle in LDS - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - constexpr auto c_block_size = - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - - return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * - sizeof(ABDataType), - c_block_size * sizeof(CShuffleDataType)); - } - // A desc for source in blockwise copy __host__ __device__ static constexpr auto MakeDefaultAGridDescriptor_AKB_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k, @@ -469,17 +476,6 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle const auto b_grid_buf = make_dynamic_buffer( p_b_grid, b_grid_desc_bkb_bk0_n_bk1.GetElementSpaceSize()); - const auto ds_grid_buf = generate_tuple( - [&](auto i) { - return make_dynamic_buffer( - p_ds_grid[i], - ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); - }, - Number{}); - - auto e_grid_buf = make_dynamic_buffer( - p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - // divide block work by [M, N] const auto block_work_idx = block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); @@ -505,14 +501,16 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle constexpr auto max_lds_align = math::lcm(AK1, BK1); // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); constexpr auto a_block_desc_akb_ak0_m_ak1 = - GetABlockDescriptor_AKB_AK0PerBlock_MPerBlock_AK1(); + GetABlockDescriptor_AKB_AK0PerBlock_MPerBlock_AK1(a_block_desc_ak0_m_ak1); // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); constexpr auto b_block_desc_bkb_bk0_n_bk1 = - GetBBlockDescriptor_BKB_BK0PerBlock_NPerBlock_BK1(); + GetBBlockDescriptor_BKB_BK0PerBlock_NPerBlock_BK1(b_block_desc_bk0_n_bk1); // A matrix blockwise copy auto a_blockwise_copy = @@ -656,239 +654,17 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle num_k_block_main_loop); // shuffle C and write out - { - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; - { - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_desc_refs = concat_tuple_of_reference( - tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_buf_refs = concat_tuple_of_reference( - tie(c_shuffle_block_buf), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); - - // tuple of starting index of C/Ds blockwise copy - const auto idx_c_ds_block_begin = container_concat( - make_tuple(make_multi_index(0, 0, 0, 0)), - generate_tuple( - [&](auto) { - return make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0); - }, - Number{})); - - // blockwise copy C/D/E between LDS and global - auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CDEElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make - // Sequence support - // arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CDEShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t, - uniform_sequence_gen_t< - NumDTensor, - false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags - {c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0)), - cde_element_op}; - - // space filling curve for threadwise C in VGPR before shuffle - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; - - // space filling curve for shuffled blockwise C/D/E - constexpr auto sfc_cde_block = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); - - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - cde_block_copy_lds_and_global.Run( - c_ds_desc_refs, - c_ds_buf_refs, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - tie(e_grid_buf)); - - if constexpr(access_id < num_access - 1) - { - constexpr auto cde_lds_and_global_step = - sfc_cde_block.GetForwardStep(access_id); - - // move on Ds - static_for<0, NumDTensor, 1>{}([&](auto i) { - cde_block_copy_lds_and_global.MoveSrcSliceWindow( - c_ds_desc_refs, i + I1, cde_lds_and_global_step); - }); - - // move on E - cde_block_copy_lds_and_global.MoveDstSliceWindow( - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - I0, - cde_lds_and_global_step); - } - }); - } - } + Base::template RunMultiDEpilogue( + blockwise_gemm, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_work_idx[I1], + block_work_idx[I2], + p_shared, + p_ds_grid, + p_e_grid, + cde_element_op); } template ( p_b_grid, b_grid_desc_bkb_bk0_n_bk1.GetElementSpaceSize()); - auto e_grid_buf = make_dynamic_buffer( - p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - // divide block work by [M, N] const auto block_work_idx = block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); @@ -942,14 +715,16 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle constexpr auto max_lds_align = math::lcm(AK1, BK1); // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); constexpr auto a_block_desc_akb_ak0_m_ak1 = - GetABlockDescriptor_AKB_AK0PerBlock_MPerBlock_AK1(); + Base::GetABlockDescriptor_AKB_AK0PerBlock_MPerBlock_AK1(a_block_desc_ak0_m_ak1); // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); constexpr auto b_block_desc_bkb_bk0_n_bk1 = - GetBBlockDescriptor_BKB_BK0PerBlock_NPerBlock_BK1(); + Base::GetBBlockDescriptor_BKB_BK0PerBlock_NPerBlock_BK1(b_block_desc_bk0_n_bk1); // A matrix blockwise copy auto a_blockwise_copy = @@ -1091,204 +866,16 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle blockwise_gemm, c_thread_buf, num_k_block_main_loop); - // shuffle C and write out - { - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; - { - // shuffle: blockwise copy C from LDS to global - auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< - ThisThreadBlock, // ThreadGroup - ck::tensor_operation::element_wise::PassThrough, // ElementwiseOperation, - EGlobalMemoryDataOperation, // DstInMemOp, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - CShuffleDataType, // typename SrcData, - EDataType, // typename DstData, - decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - decltype(e_grid_desc_mblock_mperblock_nblock_nperblock), - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CDEShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, - true, // bool ThreadTransferSrcResetCoordinateAfterRun, - false> // bool ThreadTransferDstResetCoordinateAfterRun> - {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(0, 0, 0, 0), - e_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0), - ck::tensor_operation::element_wise::PassThrough{}}; - - // space filling curve for threadwise C in VGPR - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; - - // space filling curve for shuffled blockwise C in global mem - constexpr auto sfc_c_global = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); - - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - c_shuffle_block_copy_lds_to_global.Run( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - c_shuffle_block_buf, - e_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_buf); - - if constexpr(access_id < num_access - 1) - { - constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); - - // move on C - c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( - e_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); - } - }); - } - } + Base::template RunEpilogue( + blockwise_gemm, + e_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_work_idx[I1], + block_work_idx[I2], + p_shared, + p_e_grid, + ck::tensor_operation::element_wise::PassThrough{}); } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp index 5af9d97c9fd..9c007e6a694 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp @@ -10,13 +10,11 @@ #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" -#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" -#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" -#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp" namespace ck { @@ -73,85 +71,119 @@ template struct GridwiseGemmMultipleD_xdl_splitk_cshuffle + : public GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + AComputeDataType_, + BComputeDataType_, + true> // ForceNaiveLayout { + using Base = GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + AComputeDataType_, + BComputeDataType_, + true>; // ForceNaiveLayout + + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::GetABlockDescriptor_AKB_AK0PerBlock_MPerBlock_AK1; + using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; + using Base::GetBBlockDescriptor_BKB_BK0PerBlock_NPerBlock_BK1; + using Base::I0; + using Base::I1; + using Base::I2; + using ThisThreadBlock = typename Base::ThisThreadBlock; static constexpr index_t NumDTensor = DsDataType::Size(); using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization; - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; - // K1 should be Number<...> - static constexpr auto AK1 = Number{}; - static constexpr auto BK1 = Number{}; - static constexpr auto AK0PerBlock = Number{}; - static constexpr auto BK0PerBlock = Number{}; - - using ThisThreadBlock = ThisThreadBlock; + static constexpr auto AK1 = Base::AK1Number; + static constexpr auto BK1 = Base::BK1Number; + static constexpr auto AK0PerBlock = Base::AK0Number; + static constexpr auto BK0PerBlock = Base::BK0Number; using GridwiseGemmPipe = remove_cvref_t< decltype(GridwiseGemmPipeline_Selector())>; - __host__ __device__ static constexpr auto GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1() - { - // A matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(I1, AK0PerBlock, Number{}, AK1), - make_tuple(AK0PerBlock * Number{} * AK1, - Number{} * AK1, - AK1, - I1)); - } - - __host__ __device__ static constexpr auto GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1() - { - // B matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(I1, BK0PerBlock, Number{}, BK1), - make_tuple(BK0PerBlock * Number{} * BK1, - Number{} * BK1, - BK1, - I1)); - } - - __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() - { - // A matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(AK0PerBlock, Number{}, AK1), - make_tuple(Number{} * AK1, AK1, I1)); - } - - __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() - { - // B matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(BK0PerBlock, Number{}, BK1), - make_tuple(Number{} * BK1, BK1, I1)); - } - - __host__ __device__ static constexpr auto - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - I1, - Number{})); - - return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; - } - // ck::Tuple static constexpr auto MakeDsGridPointer() { @@ -164,33 +196,6 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle Number{}); } - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - - // lds max alignment - constexpr auto max_lds_align = math::lcm(AK1, BK1); - - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size_aligned = math::integer_least_multiple( - b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); - - // LDS allocation for C shuffle in LDS - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - constexpr auto c_block_size = - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - - return math::max(a_block_space_size_aligned * sizeof(ADataType) + - b_block_space_size_aligned * sizeof(BDataType), - c_block_size * sizeof(CShuffleDataType)); - } - __host__ __device__ static auto CalculateMPadded(index_t M) { return math::integer_least_multiple(M, MPerBlock); @@ -514,13 +519,21 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle // lds max alignment constexpr auto max_lds_align = math::lcm(AK1, BK1); + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); + // A matrix in LDS memory, dst of blockwise copy constexpr auto a_block_desc_kbatch_ak0_m_ak1 = - GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1(); + GetABlockDescriptor_AKB_AK0PerBlock_MPerBlock_AK1(a_block_desc_ak0_m_ak1); // B matrix in LDS memory, dst of blockwise copy constexpr auto b_block_desc_kbatch_bk0_n_bk1 = - GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1(); + GetBBlockDescriptor_BKB_BK0PerBlock_NPerBlock_BK1(b_block_desc_bk0_n_bk1); // A matrix blockwise copy auto a_blockwise_copy = @@ -584,12 +597,6 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle make_multi_index(0, 0, 0, 0), ck::tensor_operation::element_wise::PassThrough{}); - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - // GEMM definition // c_mtx += transpose(a_mtx) * b_mtx // a_mtx[K0PerBlock, MPerBlock] is in LDS @@ -723,255 +730,32 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle blockwise_gemm, c_thread_buf, num_k_block_main_loop); - - // shuffle C and write out + if(threadIdx.x == 0) { - if(threadIdx.x == 0) - { - while(__atomic_load_n(barrier_count_finished, __ATOMIC_RELAXED) == 0) {} - } + while(__atomic_load_n(barrier_count_finished, __ATOMIC_RELAXED) == 0) {} + } - __syncthreads(); + __syncthreads(); - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_desc_refs = concat_tuple_of_reference( - tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_buf_refs = concat_tuple_of_reference( - tie(c_shuffle_block_buf), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); - - // tuple of starting index of C/Ds blockwise copy - const auto idx_c_ds_block_begin = container_concat( - make_tuple(make_multi_index(0, 0, 0, 0)), - generate_tuple( - [&](auto) { - return make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0); - }, - Number{})); - - // space filling curve for threadwise C in VGPR before shuffle - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; - - // space filling curve for shuffled blockwise C/D/E - constexpr auto sfc_cde_block = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); - - // blockwise copy C/D/E between LDS and global - auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType_{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CDEElementwiseOperation_, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make - // Sequence support - // arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CDEShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags - {c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0)), - cde_element_op}; - - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - cde_block_copy_lds_and_global.Run( - c_ds_desc_refs, - c_ds_buf_refs, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - tie(e_grid_buf)); - - if constexpr(access_id < num_access - 1) - { - constexpr auto cde_lds_and_global_step = - sfc_cde_block.GetForwardStep(access_id); - - // move on Ds - static_for<0, NumDTensor_, 1>{}([&](auto i) { - cde_block_copy_lds_and_global.MoveSrcSliceWindow( - c_ds_desc_refs, i + I1, cde_lds_and_global_step); - }); - - // move on E - cde_block_copy_lds_and_global.MoveDstSliceWindow( - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - I0, - cde_lds_and_global_step); - } - }); + // shuffle C and write out + Base::template RunMultiDEpilogue( + blockwise_gemm, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_work_idx[I1], + block_work_idx[I2], + p_shared, + p_ds_grid, + p_e_grid, + cde_element_op); + if(threadIdx.x == 0) + { + index_t k_id_finished_t = atomicAdd(barrier_count_finished, 1); - if(threadIdx.x == 0) + if(k_id_finished_t == KBatch) { - index_t k_id_finished_t = atomicAdd(barrier_count_finished, 1); - - if(k_id_finished_t == KBatch) - { - *barrier_count_finished = 0; - } + *barrier_count_finished = 0; } } } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp new file mode 100644 index 00000000000..6e047dd64a0 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp @@ -0,0 +1,1843 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) +#include +#include +#endif + +#include "ck/utility/env.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +enum Activation +{ + gelu_and_mul = 0, + silu_and_mul = 1 +}; + +template +struct GridwiseGemm_xdl_cshuffle_base +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + static constexpr auto I8 = Number<8>{}; + static constexpr auto I9 = Number<9>{}; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + static constexpr auto MaxBlockSize = BlockSize; + + static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock = + CDEShuffleBlockTransferScalarPerVectors{}[I0]; + + using ThisThreadBlock = ThisThreadBlock; + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr index_t APackedSize = []() { + if constexpr(IsMxGemm) + { + // KPerBlock is based on packed data type in MxGemm + return 1; + } + else if constexpr(is_same_v, pk_i4_t>) + { + return 2; + } + else + { + return packed_size_v; + } + }(); + + static constexpr index_t BPackedSize = []() { + if constexpr(IsMxGemm) + { + // KPerBlock is based on packed data type in MxGemm + return 1; + } + else if constexpr(is_same_v, pk_i4_t>) + { + return 2; + } + else + { + return packed_size_v; + } + }(); + template + __device__ __host__ static constexpr auto + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(DeviceArch) + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t WaveSize = BlockSize / (MWave * NWave); + constexpr index_t KPerBlockInByte = KPerBlock * sizeof(ADataType) / APackedSize; + + // A matrix in LDS memory, dst of blockwise copy + if constexpr(DirectLoad && + (is_same_v || is_same_v)) + { + if constexpr(is_same_v) + { + // FIXME: our support to non-K contiguous layout is limited, only work in some + // specific setting + return make_naive_tensor_descriptor_packed( + make_tuple(AK0Number, Number{}, AK1Number)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + } + else if constexpr(ABlockLdsExtraM || ForceNaiveLdsLayout) + { + // bank conflict when writting the data into LDS, but don't worry, we have whole entire + // loop to hide it in v4. it may give you some benefit from less valu in compute address + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(Number{} * AK1Number, AK1Number, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr index_t LdsSize = 32 * 4 / KPerBlockInByte; + constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize; + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + AK0Number * Number{}, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_ak0_mldslayer_m_ak1, + make_tuple(make_pass_through_transform(AK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = WaveSize / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(ADataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(ADataType))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + template + __device__ __host__ static constexpr auto + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(DeviceArch) + { + constexpr index_t KPerBlockInByte = KPerBlock * sizeof(BDataType) / BPackedSize; + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t WaveSize = BlockSize / (MWave * NWave); + // B matrix in LDS memory, dst of blockwise copy + if constexpr(DirectLoad && + (is_same_v || is_same_v)) + { + if constexpr(is_same_v) + { + // FIXME: our support to non-K contiguous layout is limited, only work in some + // specific setting + return make_naive_tensor_descriptor_packed( + make_tuple(BK0Number, Number{}, BK1Number)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + } + } + else if constexpr(BBlockLdsExtraN || ForceNaiveLdsLayout) + { + // bank conflict when writting the data into LDS, but don't worry, we have whole entire + // loop to hide it in v4. it may give you some benefit from less valu in compute address + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(Number{} * BK1Number, BK1Number, I1)); + } + else if constexpr(is_same::value) + { + // NLdsLayer * K0 as logical Bank + constexpr index_t LdsSize = 32 * 4 / KPerBlockInByte; + constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize; + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + BK0Number * Number{}, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_bk0_nldslayer_n_bk1, + make_tuple(make_pass_through_transform(BK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + else // RowMajor B + { + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; + + constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; + constexpr auto KThreadRead = WaveSize / NPerXdl; + constexpr auto K0PerThreadRead = BK0Number / KThreadRead; + + constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) + ? 1 + : 128 / (BK1Number * N0 * sizeof(BDataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128) + ? 1 + : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0 + ? N0 + : 128 / (BK1Number * NPerXdl * sizeof(BDataType))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + BK1Number)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + } + + template + __device__ __host__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DeviceArch) + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + template + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(DeviceArch) + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + Number{}, + I1, + Number{}, + Number{})); + + return c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + } + + template + __host__ __device__ static constexpr auto GetABlockDescriptor_AKB_AK0PerBlock_MPerBlock_AK1( + const ABlockDescriptor_AK0PerBlock_MPerBlock_AK1&) + { + return transform_tensor_descriptor( + ABlockDescriptor_AK0PerBlock_MPerBlock_AK1{}, + make_tuple(make_unmerge_transform(make_tuple(I1, AK0Number)), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{})); + } + + template + __host__ __device__ static constexpr auto GetBBlockDescriptor_BKB_BK0PerBlock_NPerBlock_BK1( + const BBlockDescriptor_BK0PerBlock_NPerBlock_BK1&) + { + return transform_tensor_descriptor( + BBlockDescriptor_BK0PerBlock_NPerBlock_BK1{}, + make_tuple(make_unmerge_transform(make_tuple(I1, BK0Number)), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{})); + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MShuffle_MPerShuffle_NShuffle_NPerShuffle() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})); + } + + template + __device__ __host__ static constexpr index_t GetSharedMemoryNumberOfByte(DeviceArch) + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(DeviceArch{}); + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(DeviceArch{}); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = + BPreshuffle ? 0 + : math::integer_least_multiple(b_block_desc_bk0_n_bk1.GetElementSpaceSize(), + max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_block_size = [&]() { + if constexpr(CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock:: + Size() == 0) + { + return 0; + } + else if constexpr( + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::Size() == 6) + { + constexpr auto + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + DeviceArch{}); + return c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize(); + } + else + { + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DeviceArch{}); + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + } + }(); + + return math::max((a_block_space_size_aligned * sizeof(ADataType) / APackedSize + + b_block_space_size_aligned * sizeof(BDataType) / BPackedSize) * + NumLdsBuffer, + c_block_size * sizeof(CShuffleDataType)); + } + + template + __device__ static constexpr auto GetCThreadDescriptor() + { + if constexpr(TransposeC) + { + // TODO: Support transposed MXGEMM + static_assert(IsMxGemm == false); + return BlockwiseGemmPipe::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + } + else + { + if constexpr(IsMxGemm) + { + return BlockwiseGemmPipe::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); + } + else + { + return BlockwiseGemmPipe::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + } + } + } + + template + __device__ static constexpr auto GetCBlockThreadDescriptor() + { + if constexpr(TransposeC) + { + static_assert(IsMxGemm == false); + // c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp = + BlockwiseGemmPipe::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6); + constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7); + + if constexpr(CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock:: + Size() == 6) + { + return transform_tensor_descriptor( + CBlockDescriptor{}, + make_tuple(make_freeze_transform(I0), + make_pass_through_transform( + Number{}), // M0 (MXdlPerWave) per + // shuffle + make_unmerge_transform(make_tuple(M1, // M1 = MWave + M2)), // M2 = MPerXdl + make_freeze_transform(I0), + make_pass_through_transform( + Number{}), // N0 (NXdlPerWave) per + // shuffle + make_unmerge_transform(make_tuple(N1, // N1 = NWave + N2, // N2 * N3 * N4 = NPerXdl + N3, + N4))), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<>{}, + Sequence<0>{}, + Sequence<2, 4>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<3, 5, 6, 7>{})); + } + else + { + return transform_tensor_descriptor( + CBlockDescriptor{}, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2)), // M2 = MPerXdl + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2, // N2 * N3 * N4 = NPerXdl + N3, + N4))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, + Sequence<0, 2, 4>{}, + Sequence<>{}, + Sequence<1, 3, 5, 6, 7>{})); + } + } + else + { + if constexpr(IsMxGemm) + { + constexpr auto MXdlPack = BlockwiseGemmPipe::MXdlPack; + constexpr auto NXdlPack = BlockwiseGemmPipe::NXdlPack; + static_assert( + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::Size() == + 4); + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + BlockwiseGemmPipe::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); + + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9); + return transform_tensor_descriptor( + CBlockDescriptor{}, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) + // per shuffle + M1, // M1 = MWave + M2, // M2 = MXdlPack + M3, // M3 * M4 * M5 = MPerXdl + M4, + M5)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) + // per shuffle + N1, // N1 = NWave + N2, // N2 = NXdlPack + N3))), // N3 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, + Sequence<0, 2, 4, 6, 7, 8>{}, + Sequence<>{}, + Sequence<1, 3, 5, 9>{})); + } + else + { + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + BlockwiseGemmPipe::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + if constexpr(CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock:: + Size() == 6) + { + return transform_tensor_descriptor( + CBlockDescriptor{}, + make_tuple(make_freeze_transform(I0), // freeze mblock + make_pass_through_transform( + Number{}), // M0 (MXdlPerWave) + // per shuffle + make_unmerge_transform(make_tuple( + M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_freeze_transform(I0), // freeze nblock + make_pass_through_transform( + Number{}), // N0 (NXdlPerWave) + // per shuffle + make_unmerge_transform( + make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<>{}, + Sequence<0>{}, + Sequence<2, 4, 5, 6>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<3, 7>{}) + + ); + } + else + { + return transform_tensor_descriptor( + CBlockDescriptor{}, + make_tuple(make_freeze_transform(I0), // freeze mblock + make_unmerge_transform( + make_tuple(CShuffleMXdlPerWavePerShuffle, + M1, + M2, + M3, + M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_freeze_transform(I0), // freeze nblock + make_unmerge_transform( + make_tuple(CShuffleNXdlPerWavePerShuffle, + N1, + N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, + Sequence<0, 2, 4, 5, 6>{}, + Sequence<>{}, + Sequence<1, 3, 7>{})); + } + } + } + } + + template + __device__ static auto + GetCThreadCopyVgprToLds(const BlockwiseGemmPipe& blockwise_gemm, + const CThreadDescriptor&, + const CBlockThreadDescriptor& c_block_thread_desc, + const CDEElementwiseOperation& cde_element_op) + { + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + if constexpr(TransposeC) + { + static_assert(IsMxGemm == false); + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp = + BlockwiseGemmPipe::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6); + constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto m_thread_data_on_block_to_m0_m1_m2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + return ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_thread_desc, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I3], + n_thread_data_on_block_idx[I4]), + cde_element_op}; + } + else + { + if constexpr(IsMxGemm) + { + constexpr auto MXdlPack = BlockwiseGemmPipe::MXdlPack; + constexpr auto NXdlPack = BlockwiseGemmPipe::NXdlPack; + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_tmp = + BlockwiseGemmPipe::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_tmp.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_tmp.GetLength(I5); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_tmp.GetLength(I6); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_tmp.GetLength(I7); + constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_tmp.GetLength(I8); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_tmp.GetLength(I9); + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))), + make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + return ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + CThreadDescriptor, + CBlockThreadDescriptor, + CDEElementwiseOperation, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + 9, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_thread_desc, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + m_thread_data_on_block_idx[I5], + n_thread_data_on_block_idx[I3]), + cde_element_op}; + } + else + { + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + BlockwiseGemmPipe::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + return ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_thread_desc, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + cde_element_op}; + } + } + } + + template + __device__ static constexpr auto GetCThreadWiseSpaceFillingCurve() + { + if constexpr(TransposeC) + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp = + BlockwiseGemmPipe::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5); + constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7); + + return SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + } + else + { + if constexpr(IsMxGemm) + { + constexpr auto MXdlPack = BlockwiseGemmPipe::MXdlPack; + constexpr auto NXdlPack = BlockwiseGemmPipe::NXdlPack; + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_tmp = + BlockwiseGemmPipe::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); + + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_tmp.GetLength(I4); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_tmp.GetLength(I7); + + return SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + Sequence>{}; + } + else + { + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + BlockwiseGemmPipe::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + return SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + } + } + } + template + __device__ static void RunEpilogueNoShuffle( + BlockwiseGemmPipe& blockwise_gemm, + const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + CThreadBuffer& c_thread_buf, + index_t block_m_id, + index_t block_n_id, + EDataType* p_c_grid, + const CDEElementwiseOperation& cde_element_op) + { + static_assert(IsMxGemm == false); + + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize()); + static_assert(TransposeC == false); + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + BlockwiseGemmPipe::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + BlockwiseGemmPipe::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_grid = m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; + + const index_t n_thread_data_on_grid = n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_grid_idx = + m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_grid)); + + const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_grid_idx = + n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_grid)); + + auto c_thread_copy = + ThreadwiseTensorSliceTransfer_v1r3, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, + CGlobalMemoryDataOperation, + 1, + true>{ + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(m_thread_data_on_grid_idx[I0], + n_thread_data_on_grid_idx[I0], + m_thread_data_on_grid_idx[I1], + n_thread_data_on_grid_idx[I1], + m_thread_data_on_grid_idx[I2], + m_thread_data_on_grid_idx[I3], + m_thread_data_on_grid_idx[I4], + n_thread_data_on_grid_idx[I2]), + cde_element_op}; + + c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_grid_buf); + } + + template + __device__ static void RunEpilogue(const BlockwiseGemmPipe& blockwise_gemm_pipeline, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + CThreadBuffer& c_thread_buf, + index_t block_m_id, + index_t block_n_id, + void* p_shared, + EDataType* p_c_grid, + const CDEElementwiseOperation& cde_element_op) + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + tensor_operation::element_wise::PassThrough pass_through{}; + const auto& vpgr_to_lds_element_op = [&] { + if constexpr(DoElementwiseBeforeCShuffle) + { + return cde_element_op; + } + else + { + return pass_through; + } + }; + const auto& lds_to_global_element_op = [&] { + if constexpr(!DoElementwiseBeforeCShuffle) + { + return cde_element_op; + } + else + { + return pass_through; + } + }; + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(get_device_arch()); + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_thread_desc = GetCThreadDescriptor(); + constexpr auto c_block_thread_desc = GetCBlockThreadDescriptor< + TransposeC, + BlockwiseGemmPipe, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock)>(); + + auto c_thread_copy_vgpr_to_lds = GetCThreadCopyVgprToLds( + blockwise_gemm_pipeline, c_thread_desc, c_block_thread_desc, vpgr_to_lds_element_op()); + + // const auto c_thread_mtx_on_block = + // blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + // const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + // const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + conditional_t, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + EDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + lds_to_global_element_op()}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + GetCThreadWiseSpaceFillingCurve(); + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_thread_desc, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + + template + __device__ static void + RunMultiDEpilogue(BlockwiseGemmPipe& blockwise_gemm_pipeline, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + CThreadBuffer& c_thread_buf, + index_t block_m_id, + index_t block_n_id, + void* p_shared, + DsGridPointer& p_ds_grid, + EDataType* p_c_grid, + const CDEElementwiseOperation& cde_element_op) + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + tensor_operation::element_wise::PassThrough pass_through{}; + const auto& vpgr_to_lds_element_op = [&] { + if constexpr(DoElementwiseBeforeCShuffle) + { + return cde_element_op; + } + else + { + return pass_through; + } + }; + const auto& lds_to_global_element_op = [&] { + if constexpr(!DoElementwiseBeforeCShuffle) + { + return cde_element_op; + } + else + { + return pass_through; + } + }; + + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(get_device_arch()); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_thread_desc = GetCThreadDescriptor(); + constexpr auto c_block_thread_desc = GetCBlockThreadDescriptor< + TransposeC, + BlockwiseGemmPipe, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock)>(); + + auto c_thread_copy_vgpr_to_lds = GetCThreadCopyVgprToLds( + blockwise_gemm_pipeline, c_thread_desc, c_block_thread_desc, vpgr_to_lds_element_op()); + + // const auto c_thread_mtx_on_block = + // blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + // const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + // const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], + ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple([&](auto) { return make_multi_index(block_m_id, 0, block_n_id, 0); }, + Number{})); + + // shuffle: blockwise copy C from LDS to global + auto cde_block_copy_lds_and_global = [&]() { + if constexpr(IsLegacy) + { + return ThreadGroupTensorSliceTransfer_v7< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType_{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(c_grid_desc_mblock_mperblock_nblock_nperblock)), + conditional_t, + Sequence(CGlobalMemoryDataOperation)>, // FIXME: make + // Sequence support + // arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t, + uniform_sequence_gen_t< + NumDTensor_, + false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(c_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), + lds_to_global_element_op()}; + } + else if constexpr(CDEShuffleBlockTransferScalarPerVectors::Size() == 1 && + NumDTensor_ != 0) + { + return ThreadGroupTensorSliceTransfer_v7r2< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType_{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(c_grid_desc_mblock_mperblock_nblock_nperblock)), + conditional_t, + Sequence(CGlobalMemoryDataOperation)>, // FIXME: make + // Sequence support + // arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t, + uniform_sequence_gen_t< + NumDTensor_, + false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(c_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), + lds_to_global_element_op()}; + } + else + { + return ThreadGroupTensorSliceTransfer_v7r3< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType_{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(c_grid_desc_mblock_mperblock_nblock_nperblock)), + conditional_t, + Sequence(CGlobalMemoryDataOperation)>, // FIXME: make + // Sequence support + // arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t, + uniform_sequence_gen_t< + NumDTensor_, + false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(c_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), + lds_to_global_element_op()}; + } + }(); + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + GetCThreadWiseSpaceFillingCurve(); + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_thread_desc, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run(c_ds_desc_refs, + c_ds_buf_refs, + tie(c_grid_desc_mblock_mperblock_nblock_nperblock), + tie(c_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor_, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(c_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + + template + __device__ static void RunMoeEpilogue(BlockwiseGemmPipe& blockwise_gemm_pipeline, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + ds_grid_desc_mblock_mperblock_nblock_nperblock, + CThreadBuffer& c_thread_buf, + index_t block_m_id, + index_t block_n_id, + void* p_shared, + const index_t* p_sorted_token_ids, + EDataType* p_c_grid, + DsGridPointer& p_ds_grid, + const CDEElementwiseOperation& cde_element_op, + index_t problemTopK, + index_t problemN) + { + + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr auto c_thread_desc = GetCThreadDescriptor(); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(get_device_arch()); + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_thread_desc = GetCBlockThreadDescriptor< + TransposeC, + BlockwiseGemmPipe, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock)>(); + + auto c_thread_copy_vgpr_to_lds = + GetCThreadCopyVgprToLds(blockwise_gemm_pipeline, + c_thread_desc, + c_block_thread_desc, + ck::tensor_operation::element_wise::PassThrough{}); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + const DDataType* ptr_ = p_ds_grid[i]; + // hack logic here to support different kind of strides. todo fix it. + // ascale t, 1; bscale E, N, 1, move ptr to E + return make_dynamic_buffer( + ptr_, ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = + container_concat(make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_m_id, 0, block_n_id, 0); + // return make_multi_index(block_work_idx[I0], 0, + // block_work_idx[I1], 0); + }, + Number{})); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + c_grid_desc_mblock_mperblock_nblock_nperblock; + + using CDEBlockTransferCluster = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + constexpr index_t scatter_weight_idx = + TransposeC ? 1 : 3; // IsInputGemm ? 1 : 1; // hack fix felix + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CDEElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferCluster, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + cde_element_op}; + + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + GetCThreadWiseSpaceFillingCurve(); + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + constexpr auto EMThreads = + CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1); + constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads; + constexpr auto ENThreads = + CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + StaticallyIndexedArray scatter_offsets; + + auto dstidx = sfc_cde_block.GetIndex(access_id); + const index_t c_token_pos = + block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1); + static_for<0, EMRepeats, 1>{}([&](auto m0) { + const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; + index_t token_offset = fused_token & 0xffffff; + if constexpr(IsInputGemm) + { + token_offset = token_offset * problemTopK + (fused_token >> 24); + } + scatter_offsets(m0) = token_offset * problemN; + }); + + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_thread_desc, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run(c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(c_grid_buf), + scatter_offsets); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp index 8188c42ca5e..cfbfaf3262a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp @@ -10,9 +10,9 @@ #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" -#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp" namespace ck { @@ -63,21 +63,106 @@ template struct GridwiseGemm_xdl_cshuffle_conv_v3 + : public GridwiseGemm_xdl_cshuffle_base< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + Tuple<>, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraMCustom, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraNCustom, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + ComputeTypeA, + ComputeTypeB, + false> // ForceNaiveLayout { - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; - - // K1 should be Number<...> - static constexpr auto AK0Number = Number{}; - static constexpr auto BK0Number = Number{}; - static constexpr auto AK1Number = Number{}; - static constexpr auto BK1Number = Number{}; + using Base = GridwiseGemm_xdl_cshuffle_base< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + Tuple<>, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraMCustom, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraNCustom, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + ComputeTypeA, + ComputeTypeB, + false>; // ForceNaiveLayout + + using Base::AK0Number; + using Base::AK1Number; + using Base::BK0Number; + using Base::BK1Number; + using Base::I0; + using Base::I1; + using Base::I2; + using ThisThreadBlock = typename Base::ThisThreadBlock; static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); static constexpr bool is_single_rate_mfma = @@ -98,8 +183,6 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 is_single_rate_mfma, is_scale_mfma>::selected_mfma.k_per_blk); - using ThisThreadBlock = ThisThreadBlock; - __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch, index_t Batch) { return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), KBatch, Batch); @@ -267,343 +350,75 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 CDataType* p_c_grid; }; - __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + template + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(DeviceArch) { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - constexpr index_t WaveSize = BlockSize / (MWave * NWave); -#if defined(__gfx950__) - // Force use padded layout on gfx950 to reduce bank conflicts - constexpr index_t ABlockLdsExtraM = 1; -#else - constexpr index_t ABlockLdsExtraM = ABlockLdsExtraMCustom; -#endif - // A matrix in LDS memory, dst of blockwise copy - if constexpr(ABlockLdsExtraM) + if constexpr(is_same_v) { + // Force use padded layout on gfx950 to reduce bank conflicts + constexpr index_t ABlockLdsExtraM = 1; return make_naive_tensor_descriptor( make_tuple(AK0Number, Number{}, AK1Number), make_tuple(Number{} * AK1Number, AK1Number, I1)); } - // xor tensor transformation request more unnecessary vgpr usage, would cause register spill - // in some cases. - else if constexpr(is_same::value) + else { - constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(ADataType) < 1 - ? 1 - : 32 * 4 / KPerBlock / sizeof(ADataType); - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( - make_tuple( - AK0Number * Number{}, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple(make_xor_with_modulo_transform(make_tuple( - Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<1, 0>{}, Sequence<2>{}), - make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - - constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); - - constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_ak0_mldslayer_m_ak1, - make_tuple(make_pass_through_transform(AK0Number), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return a_lds_block_desc_ak0_m_ak1; - } - else // ColumnMajor A - { - // kfold and mpair dimension is not always required. - // more dimension in merge_transform increase the difficulty of generating immarg offset - // for compiler. - constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); - constexpr auto M1 = MPerBlock / M0; - - constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); - constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; - constexpr auto KThreadRead = WaveSize / MPerXdl; - constexpr auto K0PerThreadRead = AK0Number / KThreadRead; - - constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) - ? 1 - : 128 / (AK1Number * M0 * sizeof(ADataType)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=mpair<=n0 - constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128) - ? 1 - : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0 - ? M0 - : 128 / (AK1Number * MPerXdl * sizeof(ADataType))); - - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - AK1Number)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); - - constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<1>{}, - Sequence<2>{}, - Sequence<0, 3>{}, - Sequence<4, 5>{}, - Sequence<6>{}, - Sequence<7>{})); - - constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(Number{}, - Number{}, - Number{}, - Number{})), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return a_lds_block_desc_ak0_m_ak1; + return Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(DeviceArch{}); } } - __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + template + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(DeviceArch) { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - constexpr index_t WaveSize = BlockSize / (MWave * NWave); -#if defined(__gfx950__) - // Force use padded layout on gfx950 to reduce bank conflicts - constexpr index_t BBlockLdsExtraN = 1; -#else - constexpr index_t BBlockLdsExtraN = BBlockLdsExtraNCustom; -#endif - // B matrix in LDS memory, dst of blockwise copy - if constexpr(BBlockLdsExtraN) + if constexpr(is_same_v) { + constexpr index_t BBlockLdsExtraN = 1; return make_naive_tensor_descriptor( make_tuple(BK0Number, Number{}, BK1Number), make_tuple(Number{} * BK1Number, BK1Number, I1)); } - else if constexpr(is_same::value) + else { - // NLdsLayer * K0 as logical Bank - constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(BDataType) < 1 - ? 1 - : 32 * 4 / KPerBlock / sizeof(BDataType); - ; - constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( - make_tuple( - BK0Number * Number{}, Number{}, BK1Number), - make_tuple(BK1Number, Number{}, I1)); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc, - make_tuple(make_xor_with_modulo_transform(make_tuple( - Number{}, Number{})), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<1, 0>{}, Sequence<2>{}), - make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - - constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); - - constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_bk0_nldslayer_n_bk1, - make_tuple(make_pass_through_transform(BK0Number), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{})), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return b_lds_block_desc_bk0_n_bk1; - } - else // RowMajor B - { - constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); - constexpr auto N1 = NPerBlock / N0; - - constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); - constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; - constexpr auto KThreadRead = WaveSize / NPerXdl; - constexpr auto K0PerThreadRead = BK0Number / KThreadRead; - - constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) - ? 1 - : 128 / (BK1Number * N0 * sizeof(BDataType)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=npair<=n0 - constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128) - ? 1 - : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0 - ? N0 - : 128 / (BK1Number * NPerXdl * sizeof(BDataType))); - - constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - BK1Number)); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(BK1Number)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); - - constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<1>{}, - Sequence<2>{}, - Sequence<0, 3>{}, - Sequence<4, 5>{}, - Sequence<6>{}, - Sequence<7>{})); - - constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(Number{}, - Number{}, - Number{}, - Number{})), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return b_lds_block_desc_bk0_n_bk1; + return Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(DeviceArch{}); } } - __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - I1, - Number{})); - - return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; - } - IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) - using BlockwiseGemmPipe = - remove_cvref_t())>; - - __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + using BlockwiseGemmPipe = remove_cvref_t< + decltype(BlockGemmPipeline_Selector< + BlkGemmPipelineVer, + BlkGemmPipeSched, + BlockSize, + ADataType, + BDataType, + ComputeTypeA, + AccDataType, + decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch())), + decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch())), + decltype(MakeAMmaTileDescriptor_M0_M1_M2_K( + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()))), + decltype(MakeBMmaTileDescriptor_N0_N1_N2_K( + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()))), + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack>())>; + + template + __device__ static constexpr index_t GetSharedMemoryNumberOfByte(DeviceArch) { // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(DeviceArch{}); + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(DeviceArch{}); // lds max alignment constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); @@ -616,7 +431,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 // LDS allocation for C shuffle in LDS constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + Base::GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DeviceArch{}); constexpr auto c_block_size = c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); @@ -684,8 +499,6 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize() / a_space_size_divisor); const auto b_grid_buf = make_dynamic_buffer( p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize() / b_space_size_divisor); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); const AElementwiseOperation a_element_op{}; const BElementwiseOperation b_element_op{}; @@ -719,10 +532,12 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); // A matrix blockwise copy auto a_blockwise_copy = @@ -827,201 +642,15 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 num_k_block_main_loop); // shuffle C and write out - { - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; - - // shuffle: blockwise copy C from LDS to global - auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< - ThisThreadBlock, // ThreadGroup - CElementwiseOperation, // ElementwiseOperation, - CGlobalMemoryDataOperation, // DstInMemOp, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - CShuffleDataType, // typename SrcData, - CDataType, // typename DstData, - decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, - true, // bool ThreadTransferSrcResetCoordinateAfterRun, - false> // bool ThreadTransferDstResetCoordinateAfterRun> - {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(0, 0, 0, 0), - c_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_m_id, 0, block_n_id, 0), - c_element_op}; - - // space filling curve for threadwise C in VGPR - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; - - // space filling curve for shuffled blockwise C in global mem - constexpr auto sfc_c_global = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); - - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - c_shuffle_block_copy_lds_to_global.Run( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - c_shuffle_block_buf, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_buf); - - if constexpr(access_id < num_access - 1) - { - constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); - - // move on C - c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); - } - }); - } + Base::template RunEpilogue( + blockwise_gemm_pipeline, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_m_id, + block_n_id, + p_shared, + p_c_grid, + c_element_op); } template ( - static_cast(p_shared_0), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; - - // shuffle: blockwise copy C from LDS to global - auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< - ThisThreadBlock, // ThreadGroup - CElementwiseOperation, // ElementwiseOperation, - CGlobalMemoryDataOperation, // DstInMemOp, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - CShuffleDataType, // typename SrcData, - CDataType, // typename DstData, - decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, - true, // bool ThreadTransferSrcResetCoordinateAfterRun, - false> // bool ThreadTransferDstResetCoordinateAfterRun> - {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(0, 0, 0, 0), - c_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_m_id, 0, block_n_id, 0), - c_element_op}; - - // space filling curve for threadwise C in VGPR - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; - - // space filling curve for shuffled blockwise C in global mem - constexpr auto sfc_c_global = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); - - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - c_shuffle_block_copy_lds_to_global.Run( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - c_shuffle_block_buf, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_buf); - - if constexpr(access_id < num_access - 1) - { - constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); - - // move on C - c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); - } - }); - } + Base::template RunEpilogue( + blockwise_gemm_pipeline, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_m_id, + block_n_id, + p_shared_0, + p_c_grid, + c_element_op); } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp index 3b5610865cb..4b679adc8db 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp @@ -11,12 +11,11 @@ #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" -#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" -#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/utility/workgroup_barrier.hpp" #include "ck/utility/reduction_functions_accumulate.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp" namespace ck { @@ -33,14 +32,14 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(GridwiseGemm::MaxBlockSize, MinimumOccupancy) #endif kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) { #if defined(__gfx9__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; GridwiseGemm::template Run( karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg, karg.p_workspace_); @@ -57,7 +56,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(GridwiseGemm::MaxBlockSize, MinimumOccupancy) #endif kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) { @@ -66,8 +65,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) { // Pass two lds pointer is the key to tell compiler that ds_read/write // operate on different lds chunk at same time without order dependecy - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; GridwiseGemm::template Run_2Lds( karg.p_a_grid, @@ -130,21 +129,113 @@ template struct GridwiseGemm_xdl_cshuffle_streamk_v3 + : public GridwiseGemm_xdl_cshuffle_base< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + Tuple<>, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + ComputeTypeA, + ComputeTypeB, + false> { - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; - - // K1 should be Number<...> - static constexpr auto AK0Number = Number{}; - static constexpr auto BK0Number = Number{}; - static constexpr auto AK1Number = Number{}; - static constexpr auto BK1Number = Number{}; + using Base = GridwiseGemm_xdl_cshuffle_base< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + Tuple<>, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + ComputeTypeA, + ComputeTypeB, + false>; + + using Base::AK0Number; + using Base::AK1Number; + using Base::BK0Number; + using Base::BK1Number; + using Base::I0; + using Base::I1; + using Base::I2; + using Base::I3; + using Base::I4; + using Base::I5; + using Base::I6; + using ThisThreadBlock = typename Base::ThisThreadBlock; + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; + using Base::GetSharedMemoryNumberOfByte; static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); static constexpr bool is_single_rate_mfma = @@ -165,7 +256,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 is_single_rate_mfma, is_scale_mfma>::selected_mfma.k_per_blk); - using ThisThreadBlock = ThisThreadBlock; __host__ static auto CalculateMPadded(index_t M) { return math::integer_least_multiple(M, MPerBlock); @@ -660,363 +750,31 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 index_t b_k_split_offset; }; - __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - constexpr index_t WaveSize = BlockSize / (MWave * NWave); - // A matrix in LDS memory, dst of blockwise copy - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); - } - // xor tensor transformation request more unnecessary vgpr usage, would cause register spill - // in some cases. - else if constexpr(is_same::value) - { - constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(ADataType) < 1 - ? 1 - : 32 * 4 / KPerBlock / sizeof(ADataType); - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( - make_tuple( - AK0Number * Number{}, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple(make_xor_with_modulo_transform(make_tuple( - Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<1, 0>{}, Sequence<2>{}), - make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - - constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); - - constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_ak0_mldslayer_m_ak1, - make_tuple(make_pass_through_transform(AK0Number), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return a_lds_block_desc_ak0_m_ak1; - } - else // ColumnMajor A - { - // kfold and mpair dimension is not always required. - // more dimension in merge_transform increase the difficulty of generating immarg offset - // for compiler. - constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); - constexpr auto M1 = MPerBlock / M0; - - constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); - constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; - constexpr auto KThreadRead = WaveSize / MPerXdl; - constexpr auto K0PerThreadRead = AK0Number / KThreadRead; - - constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) - ? 1 - : 128 / (AK1Number * M0 * sizeof(ADataType)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=mpair<=n0 - constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128) - ? 1 - : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0 - ? M0 - : 128 / (AK1Number * MPerXdl * sizeof(ADataType))); - - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - AK1Number)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); - - constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<1>{}, - Sequence<2>{}, - Sequence<0, 3>{}, - Sequence<4, 5>{}, - Sequence<6>{}, - Sequence<7>{})); - - constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(Number{}, - Number{}, - Number{}, - Number{})), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return a_lds_block_desc_ak0_m_ak1; - } - } - - __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - constexpr index_t WaveSize = BlockSize / (MWave * NWave); - // B matrix in LDS memory, dst of blockwise copy - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(BK0Number, Number{}, BK1Number), - make_tuple(BK1Number, Number{}, I1)); - } - else if constexpr(is_same::value) - { - // NLdsLayer * K0 as logical Bank - constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(BDataType) < 1 - ? 1 - : 32 * 4 / KPerBlock / sizeof(BDataType); - ; - constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( - make_tuple( - BK0Number * Number{}, Number{}, BK1Number), - make_tuple(BK1Number, Number{}, I1)); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc, - make_tuple(make_xor_with_modulo_transform(make_tuple( - Number{}, Number{})), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<1, 0>{}, Sequence<2>{}), - make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - - constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); - - constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_bk0_nldslayer_n_bk1, - make_tuple(make_pass_through_transform(BK0Number), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{})), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return b_lds_block_desc_bk0_n_bk1; - } - else // RowMajor B - { - constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); - constexpr auto N1 = NPerBlock / N0; - - constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); - constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; - constexpr auto KThreadRead = WaveSize / NPerXdl; - constexpr auto K0PerThreadRead = BK0Number / KThreadRead; - - constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) - ? 1 - : 128 / (BK1Number * N0 * sizeof(BDataType)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=npair<=n0 - constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128) - ? 1 - : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0 - ? N0 - : 128 / (BK1Number * NPerXdl * sizeof(BDataType))); - - constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - BK1Number)); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(BK1Number)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); - - constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<1>{}, - Sequence<2>{}, - Sequence<0, 3>{}, - Sequence<4, 5>{}, - Sequence<6>{}, - Sequence<7>{})); - - constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(Number{}, - Number{}, - Number{}, - Number{})), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return b_lds_block_desc_bk0_n_bk1; - } - } - - __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - I1, - Number{})); - - return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; - } - - __host__ __device__ static constexpr auto - GetCBlockDescriptor_MShuffle_MPerShuffle_NShuffle_NPerShuffle() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - return make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{}, - Number{})); - } - - using BlockwiseGemmPipe = - remove_cvref_t())>; - - __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - - // lds max alignment - constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); - - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size_aligned = math::integer_least_multiple( - b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); - - // LDS allocation for C shuffle in LDS - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - constexpr auto c_block_size = - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - - return math::max((a_block_space_size_aligned * sizeof(ADataType) + - b_block_space_size_aligned * sizeof(BDataType)), - c_block_size * sizeof(CShuffleDataType)); - } + using BlockwiseGemmPipe = remove_cvref_t< + decltype(BlockGemmPipeline_Selector< + BlkGemmPipelineVer, + BlkGemmPipeSched, + BlockSize, + ADataType, + BDataType, + ComputeTypeA, + AccDataType, + decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch())), + decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch())), + decltype(MakeAMmaTileDescriptor_M0_M1_M2_K( + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()))), + decltype(MakeBMmaTileDescriptor_N0_N1_N2_K( + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()))), + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack>())>; IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) @@ -1566,11 +1324,11 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 // A matrix in LDS memory, dst of blockwise copy constexpr auto a_block_desc_ak0_m_ak1 = - GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy constexpr auto b_block_desc_bk0_n_bk1 = - GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); // A matrix blockwise copy auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1< @@ -1687,29 +1445,15 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - // TODO: hacky, fix it! constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + Base::template GetCThreadDescriptor(); constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + Base::GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + get_device_arch()); constexpr auto c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle = - GetCBlockDescriptor_MShuffle_MPerShuffle_NShuffle_NPerShuffle(); + Base::GetCBlockDescriptor_MShuffle_MPerShuffle_NShuffle_NPerShuffle(); auto c_shuffle_block_buf = make_dynamic_buffer( static_cast(p_shared), @@ -1723,87 +1467,19 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 .GetElementSpaceSize()); constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per - // shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per - // shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<>{}, - Sequence<0, 2, 4, 5, 6>{}, - Sequence<>{}, - Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< - AccDataType, - CShuffleDataType, - decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), - decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), - ck::tensor_operation::element_wise::PassThrough, - Sequence, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; + Base::template GetCBlockThreadDescriptor< + false, + BlockwiseGemmPipe, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock)>(); + + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6); + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = Base::template GetCThreadCopyVgprToLds( + blockwise_gemm_pipeline, + c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + ck::tensor_operation::element_wise::PassThrough{}); // shuffle: blockwise copy C from LDS to global auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1r2< @@ -2305,11 +1981,11 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 // A matrix in LDS memory, dst of blockwise copy constexpr auto a_block_desc_ak0_m_ak1 = - GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy constexpr auto b_block_desc_bk0_n_bk1 = - GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); // A matrix blockwise copy auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1< @@ -2436,30 +2112,15 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - // TODO: hacky, fix it! constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + Base::template GetCThreadDescriptor(); constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + Base::GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + get_device_arch()); constexpr auto c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle = - GetCBlockDescriptor_MShuffle_MPerShuffle_NShuffle_NPerShuffle(); + Base::GetCBlockDescriptor_MShuffle_MPerShuffle_NShuffle_NPerShuffle(); auto c_shuffle_block_buf = make_dynamic_buffer( static_cast(p_shared_0), @@ -2473,87 +2134,20 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 .GetElementSpaceSize()); constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per - // shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per - // shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<>{}, - Sequence<0, 2, 4, 5, 6>{}, - Sequence<>{}, - Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< - AccDataType, - CShuffleDataType, - decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), - decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), - ck::tensor_operation::element_wise::PassThrough, - Sequence, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; + Base::template GetCBlockThreadDescriptor< + false, + BlockwiseGemmPipe, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock)>(); + + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6); + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = Base::template GetCThreadCopyVgprToLds( + blockwise_gemm_pipeline, + c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + ck::tensor_operation::element_wise::PassThrough{}); + // shuffle: blockwise copy C from LDS to global auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1r2< ThisThreadBlock, // ThreadGroup @@ -2711,6 +2305,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 } }); } + // exit condition iter_end -= current_iter_length; if(iter_end <= iter_start) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp index 716fe6f41d7..e26d80db5f3 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp @@ -4,6 +4,8 @@ #pragma once #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/host_utility/device_prop.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" @@ -11,16 +13,16 @@ #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" -#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp" namespace ck { template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) #endif kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg) { @@ -28,7 +30,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; GridwiseGemm::template Run( karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg); @@ -45,7 +47,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) #endif kernel_gemm_xdl_cshuffle_v1(const FloatA* __restrict__ p_a_grid, const FloatB* __restrict__ p_b_grid, @@ -56,7 +58,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; GridwiseGemm::template Run( p_a_grid, p_b_grid, p_c_grid, p_shared, problem); @@ -118,23 +120,109 @@ template struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 + : public GridwiseGemm_xdl_cshuffle_base< + ALayout, + BLayout, + CLayout, + ComputeTypeA, + ComputeTypeB, + FloatGemmAcc, + FloatCShuffle, + Tuple<>, + FloatC, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + ComputeTypeA, + ComputeTypeB, + true> // ForceNaiveLayout { - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; - - // K1 should be Number<...> - static constexpr auto AK0Number = Number{}; - static constexpr auto BK0Number = Number{}; - static constexpr auto AK1Number = Number{}; - static constexpr auto BK1Number = Number{}; - - using ThisThreadBlock = ThisThreadBlock; + using Base = GridwiseGemm_xdl_cshuffle_base< + ALayout, + BLayout, + CLayout, + ComputeTypeA, + ComputeTypeB, + FloatGemmAcc, + FloatCShuffle, + Tuple<>, + FloatC, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + ComputeTypeA, + ComputeTypeB, + true>; // ForceNaiveLayout + + using Base::AK0Number; + using Base::AK1Number; + using Base::BK0Number; + using Base::BK1Number; + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; + using Base::GetSharedMemoryNumberOfByte; + using Base::I0; + using Base::I1; + using Base::I2; + using ThisThreadBlock = typename Base::ThisThreadBlock; __host__ static auto CalculateGridSize(index_t M, index_t N) { @@ -493,69 +581,51 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 using GridwiseGemmPipe = remove_cvref_t< decltype(GridwiseGemmPipeline_Selector())>; - __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + template + static constexpr index_t GetEstimateVgprCount() { - // A matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(Number{} * AK1Number, AK1Number, I1)); - } - - __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() - { - // B matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(BK0Number, Number{}, BK1Number), - make_tuple(Number{} * BK1Number, BK1Number, I1)); - } - - __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - I1, - Number{})); - - return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; - } - - __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t WaveSize = BlockSize / (MWave * NWave); - // lds max alignment - constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); - - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size_aligned = math::integer_least_multiple( - b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); - - // LDS allocation for C shuffle in LDS - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + // VGPR used in LDS loading and WMMA + constexpr index_t BaseInputVgprCount = + MPerBlock * KPerBlock / MWave / WaveSize * sizeof(ComputeTypeA) / sizeof(uint32_t) + + NPerBlock * KPerBlock / NWave / WaveSize * sizeof(ComputeTypeB) / sizeof(uint32_t); + // WMMA input is duplicated in GFX11 + constexpr index_t InputVgprCount = IsGfx11 ? BaseInputVgprCount * 2 : BaseInputVgprCount; + // VGPR used in Accumulator + constexpr index_t AccVgprCount = + MPerBlock * NPerBlock / BlockSize * sizeof(FloatGemmAcc) / sizeof(uint32_t); - constexpr auto c_block_size = - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - - return math::max((a_block_space_size_aligned * sizeof(ComputeTypeA) + - b_block_space_size_aligned * sizeof(ComputeTypeB)), - c_block_size * sizeof(FloatCShuffle)); + if constexpr(PipelineVer == PipelineVersion::v1) + { + return InputVgprCount + AccVgprCount + BaseInputVgprCount * (NumGemmKPrefetchStage - 1); + } + else if constexpr(PipelineVer == PipelineVersion::v2) + { + return InputVgprCount + AccVgprCount + BaseInputVgprCount; + } + else if constexpr(PipelineVer == PipelineVersion::weight_only) + { + return InputVgprCount + AccVgprCount; + } + else if constexpr(PipelineVer == PipelineVersion::v4) + { + return InputVgprCount * 2 + AccVgprCount; + } + else + { + // invalid pipeline version + static_assert(0); + } } template < InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> __device__ static bool constexpr IsValidCompilationParameter() { - return ck::tensor_operation::device::IsValidGemmCompilationParameter< + constexpr bool valid = ck::tensor_operation::device::IsValidGemmCompilationParameter< BlockSize, MPerBlock, NPerBlock, @@ -565,6 +635,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 NXdlPerWave, FloatC, CGlobalMemoryDataOperation>(); + if constexpr(!valid) + { + return false; + } + + constexpr bool IsGfx11 = is_same_v; + constexpr auto EstimateVgprCount = GetEstimateVgprCount(); + constexpr auto AvailableVgprCount = get_max_vgpr_count(get_device_arch()); + if constexpr(EstimateVgprCount > (AvailableVgprCount + AvailableVgprCount / 4)) + { + return false; + } + return true; } // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} @@ -574,6 +657,34 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, "Invalid tuning param!"); + const auto availableVgprCount = []() { + if(ck::is_gfx12_supported()) + { + return get_max_vgpr_count(gfx12_t{}); + } + else if(ck::is_gfx11_supported()) + { + return get_max_vgpr_count(gfx11_t{}); + } + else + { + return get_max_vgpr_count(gfx9_t{}); + } + }(); + const auto estimateVgprCount = + ck::is_gfx11_supported() ? GetEstimateVgprCount() : GetEstimateVgprCount(); + if(estimateVgprCount > (availableVgprCount + availableVgprCount / 4)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Estimated VGPR count (" << estimateVgprCount + << ") exceeds available VGPR count (" << availableVgprCount << ")! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || @@ -718,8 +829,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); const AElementwiseOperation a_element_op{}; const BElementwiseOperation b_element_op{}; @@ -750,10 +859,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); // A matrix blockwise copy auto a_blockwise_copy = @@ -897,201 +1008,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 num_k_block_main_loop); // shuffle C and write out - { - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; - - // shuffle: blockwise copy C from LDS to global - auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< - ThisThreadBlock, // ThreadGroup - CElementwiseOperation, // ElementwiseOperation, - CGlobalMemoryDataOperation, // DstInMemOp, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - FloatCShuffle, // typename SrcData, - FloatC, // typename DstData, - decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, - true, // bool ThreadTransferSrcResetCoordinateAfterRun, - false> // bool ThreadTransferDstResetCoordinateAfterRun> - {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(0, 0, 0, 0), - c_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), - c_element_op}; - - // space filling curve for threadwise C in VGPR - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; - - // space filling curve for shuffled blockwise C in global mem - constexpr auto sfc_c_global = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); - - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - c_shuffle_block_copy_lds_to_global.Run( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - c_shuffle_block_buf, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_buf); - - if constexpr(access_id < num_access - 1) - { - constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); - - // move on C - c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); - } - }); - } + Base::template RunEpilogue( + blockwise_gemm, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_work_idx[I0], + block_work_idx[I1], + p_shared, + p_c_grid, + c_element_op); } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp index 34f0a975867..f7656629045 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp @@ -11,16 +11,16 @@ #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" -#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp" namespace ck { template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1) +__launch_bounds__(GridwiseGemm::MaxBlockSize, 1) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v2(typename GridwiseGemm::Argument karg) @@ -31,8 +31,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1) { // Pass two lds pointer is the key to tell compiler that ds_read/write // operate on different lds chunk at same time without order dependecy - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; GridwiseGemm::template Run( karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared_0, p_shared_1, karg); @@ -49,7 +49,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1) +__launch_bounds__(GridwiseGemm::MaxBlockSize, 1) #endif kernel_gemm_xdl_cshuffle_v2(const FloatA* p_a_grid, const FloatB* p_b_grid, @@ -60,8 +60,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1) defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) { - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; GridwiseGemm::template Run( p_a_grid, p_b_grid, p_c_grid, p_shared_0, p_shared_1, problem); @@ -123,23 +123,109 @@ template struct GridwiseGemm_xdl_cshuffle_v2 + : public GridwiseGemm_xdl_cshuffle_base< + ALayout, + BLayout, + CLayout, + ComputeTypeA, + ComputeTypeB, + FloatGemmAcc, + FloatCShuffle, + Tuple<>, + FloatC, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + ComputeTypeA, + ComputeTypeB, + true> // ForceNaiveLayout { - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; - - // K1 should be Number<...> - static constexpr auto AK0Number = Number{}; - static constexpr auto BK0Number = Number{}; - static constexpr auto AK1Number = Number{}; - static constexpr auto BK1Number = Number{}; - - using ThisThreadBlock = ThisThreadBlock; + using Base = GridwiseGemm_xdl_cshuffle_base< + ALayout, + BLayout, + CLayout, + ComputeTypeA, + ComputeTypeB, + FloatGemmAcc, + FloatCShuffle, + Tuple<>, + FloatC, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + ComputeTypeA, + ComputeTypeB, + true>; // ForceNaiveLayout + + using Base::AK0Number; + using Base::AK1Number; + using Base::BK0Number; + using Base::BK1Number; + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; + using Base::GetSharedMemoryNumberOfByte; + using Base::I0; + using Base::I1; + using Base::I2; + using ThisThreadBlock = typename Base::ThisThreadBlock; __host__ static auto CalculateGridSize(index_t M, index_t N) { @@ -531,64 +617,6 @@ struct GridwiseGemm_xdl_cshuffle_v2 using GridwiseGemmPipe = remove_cvref_t< decltype(GridwiseGemmPipeline_Selector())>; - __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() - { - // A matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(Number{} * AK1Number, AK1Number, I1)); - } - - __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() - { - // B matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(BK0Number, Number{}, BK1Number), - make_tuple(Number{} * BK1Number, BK1Number, I1)); - } - - __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - I1, - Number{})); - - return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; - } - - __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - - // lds max alignment - constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); - - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size_aligned = math::integer_least_multiple( - b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); - - // LDS allocation for C shuffle in LDS - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - constexpr auto c_block_size = - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - - return math::max((a_block_space_size_aligned * sizeof(ComputeTypeA) + - b_block_space_size_aligned * sizeof(ComputeTypeB)), - c_block_size * sizeof(FloatCShuffle)); - } - template < InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> __device__ static bool constexpr IsValidCompilationParameter() @@ -768,8 +796,6 @@ struct GridwiseGemm_xdl_cshuffle_v2 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); const AElementwiseOperation a_element_op{}; const BElementwiseOperation b_element_op{}; @@ -810,10 +836,12 @@ struct GridwiseGemm_xdl_cshuffle_v2 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); // A matrix blockwise copy auto a_blockwise_copy = @@ -978,202 +1006,15 @@ struct GridwiseGemm_xdl_cshuffle_v2 c_thread_buf, num_k_block_main_loop); - // shuffle C and write out - { - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared_0), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; - - // shuffle: blockwise copy C from LDS to global - auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< - ThisThreadBlock, // ThreadGroup - CElementwiseOperation, // ElementwiseOperation, - CGlobalMemoryDataOperation, // DstInMemOp, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - FloatCShuffle, // typename SrcData, - FloatC, // typename DstData, - decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, - true, // bool ThreadTransferSrcResetCoordinateAfterRun, - false> // bool ThreadTransferDstResetCoordinateAfterRun> - {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(0, 0, 0, 0), - c_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), - c_element_op}; - - // space filling curve for threadwise C in VGPR - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; - - // space filling curve for shuffled blockwise C in global mem - constexpr auto sfc_c_global = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); - - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - c_shuffle_block_copy_lds_to_global.Run( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - c_shuffle_block_buf, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_buf); - - if constexpr(access_id < num_access - 1) - { - constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); - - // move on C - c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); - } - }); - } + Base::template RunEpilogue( + blockwise_gemm_pipeline, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_work_idx[I0], + block_work_idx[I1], + p_shared_0, + p_c_grid, + c_element_op); } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp index 93f4059f0a0..9f1a14b09ed 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -14,6 +14,7 @@ #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp" namespace ck { @@ -38,7 +39,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #if defined(__gfx9__) || defined(__gfx12__) || defined(__gfx11__) if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); @@ -71,8 +72,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // operate on different lds chunk at same time without order dependecy if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); @@ -244,22 +245,123 @@ template struct GridwiseGemm_xdl_cshuffle_v3 + : public GridwiseGemm_xdl_cshuffle_base< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + Tuple<>, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + ComputeTypeA, + ComputeTypeB, + BlkGemmPipelineVer == BlockGemmPipelineVersion::v4> { - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; - - // K1 should be Number<...> - static constexpr auto AK0Number = Number{}; - static constexpr auto BK0Number = Number{}; - static constexpr auto AK1Number = Number{}; - static constexpr auto BK1Number = Number{}; + using Base = GridwiseGemm_xdl_cshuffle_base< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + Tuple<>, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + ComputeTypeA, + ComputeTypeB, + BlkGemmPipelineVer == BlockGemmPipelineVersion::v4>; + + using Base::AK0Number; + using Base::AK1Number; + using Base::BK0Number; + using Base::BK1Number; + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; + using Base::GetSharedMemoryNumberOfByte; + using Base::I0; + using Base::I1; + using Base::I2; + using ThisThreadBlock = typename Base::ThisThreadBlock; + static constexpr index_t TransposeC = false; + static constexpr index_t APackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); static constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && @@ -280,22 +382,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 is_single_rate_mfma, is_scale_mfma>::selected_mfma.k_per_blk); - using ThisThreadBlock = ThisThreadBlock; - - static constexpr index_t APackedSize = []() { - if constexpr(is_same_v, pk_i4_t>) - return 2; - else - return 1; - }(); - - static constexpr index_t BPackedSize = []() { - if constexpr(is_same_v, pk_i4_t>) - return 2; - else - return 1; - }(); - __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) { return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); @@ -816,299 +902,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 index_t c_reduce_offset; }; - __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - constexpr index_t WaveSize = BlockSize / (MWave * NWave); - // A matrix in LDS memory, dst of blockwise copy - if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) - { - // bank conflict when writting the data into LDS, but don't worry, we have whole entire - // loop to hide it in v4. it may give you some benefit from less valu in compute address - return make_naive_tensor_descriptor( - make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(Number{} * AK1Number, AK1Number, I1)); - } - // xor tensor transformation request more unnecessary vgpr usage, would cause register spill - // in some cases. - else if constexpr(is_same::value) - { - constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(ADataType) / APackedSize; - constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize; - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( - make_tuple( - AK0Number * Number{}, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple(make_xor_with_modulo_transform(make_tuple( - Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<1, 0>{}, Sequence<2>{}), - make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - - constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); - - constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_ak0_mldslayer_m_ak1, - make_tuple(make_pass_through_transform(AK0Number), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return a_lds_block_desc_ak0_m_ak1; - } - else // ColumnMajor A - { - // kfold and mpair dimension is not always required. - // more dimension in merge_transform increase the difficulty of generating immarg offset - // for compiler. - constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); - constexpr auto M1 = MPerBlock / M0; - - constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); - constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; - constexpr auto KThreadRead = WaveSize / MPerXdl; - constexpr auto K0PerThreadRead = AK0Number / KThreadRead; - - constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) - ? 1 - : 128 / (AK1Number * M0 * sizeof(ADataType)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=mpair<=n0 - constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128) - ? 1 - : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0 - ? M0 - : 128 / (AK1Number * MPerXdl * sizeof(ADataType))); - - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - AK1Number)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); - - constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<1>{}, - Sequence<2>{}, - Sequence<0, 3>{}, - Sequence<4, 5>{}, - Sequence<6>{}, - Sequence<7>{})); - - constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(Number{}, - Number{}, - Number{}, - Number{})), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return a_lds_block_desc_ak0_m_ak1; - } - } - - __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - constexpr index_t WaveSize = BlockSize / (MWave * NWave); - // B matrix in LDS memory, dst of blockwise copy - if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) - { - // bank conflict when writting the data into LDS, but don't worry, we have whole entire - // loop to hide it in v4. it may give you some benefit from less valu in compute address - return make_naive_tensor_descriptor( - make_tuple(BK0Number, Number{}, BK1Number), - make_tuple(Number{} * BK1Number, BK1Number, I1)); - } - else if constexpr(is_same::value) - { - // NLdsLayer * K0 as logical Bank - constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType) / BPackedSize; - constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize; - constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( - make_tuple( - BK0Number * Number{}, Number{}, BK1Number), - make_tuple(BK1Number, Number{}, I1)); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc, - make_tuple(make_xor_with_modulo_transform(make_tuple( - Number{}, Number{})), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<1, 0>{}, Sequence<2>{}), - make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - - constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); - - constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_bk0_nldslayer_n_bk1, - make_tuple(make_pass_through_transform(BK0Number), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{})), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return b_lds_block_desc_bk0_n_bk1; - } - else // RowMajor B - { - constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); - constexpr auto N1 = NPerBlock / N0; - - constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); - constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; - constexpr auto KThreadRead = WaveSize / NPerXdl; - constexpr auto K0PerThreadRead = BK0Number / KThreadRead; - - constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) - ? 1 - : 128 / (BK1Number * N0 * sizeof(BDataType)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=npair<=n0 - constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128) - ? 1 - : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0 - ? N0 - : 128 / (BK1Number * NPerXdl * sizeof(BDataType))); - - constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - BK1Number)); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(BK1Number)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); - - constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<1>{}, - Sequence<2>{}, - Sequence<0, 3>{}, - Sequence<4, 5>{}, - Sequence<6>{}, - Sequence<7>{})); - - constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(Number{}, - Number{}, - Number{}, - Number{})), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return b_lds_block_desc_bk0_n_bk1; - } - } - - __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - I1, - Number{})); - - return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; - } - using BlockwiseGemmPipe = remove_cvref_t())>; - __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - - // lds max alignment - constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); - - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size_aligned = math::integer_least_multiple( - b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); - - // LDS allocation for C shuffle in LDS - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - constexpr auto c_block_size = - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + template + static constexpr index_t GetEstimateVgprCount() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t WaveSize = BlockSize / (MWave * NWave); - return math::max((a_block_space_size_aligned * sizeof(ADataType) / APackedSize + - b_block_space_size_aligned * sizeof(BDataType) / BPackedSize), - c_block_size * sizeof(CShuffleDataType)); + // VGPR used in LDS loading and WMMA + constexpr index_t BaseInputVgprCount = + MPerBlock * KPerBlock / MWave / WaveSize * sizeof(ComputeTypeA) / sizeof(uint32_t) + + NPerBlock * KPerBlock / NWave / WaveSize * sizeof(ComputeTypeB) / sizeof(uint32_t); + // WMMA input is duplicated in GFX11 + constexpr index_t InputVgprCount = IsGfx11 ? BaseInputVgprCount * 2 : BaseInputVgprCount; + // VGPR used in Accumulator + constexpr index_t AccVgprCount = + MPerBlock * NPerBlock / BlockSize * sizeof(AccDataType) / sizeof(uint32_t); + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + return InputVgprCount + AccVgprCount; + } + else if constexpr((BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) || + (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) || + (BlkGemmPipelineVer == BlockGemmPipelineVersion::v5)) + { + return 2 * InputVgprCount + AccVgprCount; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + return 3 * InputVgprCount + AccVgprCount; + } + else + { + // invalid pipeline version + static_assert(0); + } } template @@ -1187,6 +991,34 @@ struct GridwiseGemm_xdl_cshuffle_v3 return false; } + constexpr bool IsGfx11 = is_same_v; + constexpr auto EstimateVgprCount = GetEstimateVgprCount(); + constexpr auto AvailableVgprCount = [&]() { + if constexpr(MinimumOccupancy != 0) + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t WaveSize = BlockSize / (MWave * NWave); + return get_max_vgpr_count(get_device_arch()) / MinimumOccupancy / + (math::integer_divide_ceil(BlockSize, WaveSize * 4)); + } + else + { + return get_max_vgpr_count(get_device_arch()); + } + }(); + if constexpr(EstimateVgprCount > (AvailableVgprCount + AvailableVgprCount / 4)) + { + return false; + } + constexpr index_t LdsSize = BlkGemmPipelineVer == BlockGemmPipelineVersion::v4 + ? GetSharedMemoryNumberOfByte(get_device_arch()) * 2 + : GetSharedMemoryNumberOfByte(get_device_arch()); + if constexpr(LdsSize > get_lds_size(get_device_arch())) + { + return false; + } + return ck::tensor_operation::device::IsValidGemmCompilationParameter< BlockSize, MPerBlock, @@ -1198,6 +1030,19 @@ struct GridwiseGemm_xdl_cshuffle_v3 CDataType, CGlobalMemoryDataOperation>(); } + __host__ static index_t GetSharedMemoryNumberOfByteOnHost() + { +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) + if(ck::get_device_name() == "gfx950") + { + return GetSharedMemoryNumberOfByte(gfx950_t{}); + } + else +#endif + { + return GetSharedMemoryNumberOfByte(gfx_invalid_t{}); + } + } // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ static constexpr bool CheckValidity(const Argument& karg) { @@ -1391,6 +1236,48 @@ struct GridwiseGemm_xdl_cshuffle_v3 } } + constexpr index_t ldsBufferCount = + BlkGemmPipelineVer == BlockGemmPipelineVersion::v4 ? 2 : 1; + if(GetSharedMemoryNumberOfByteOnHost() * ldsBufferCount > get_lds_size()) + { + return false; + } + + const auto maxVgprCount = []() { + if(ck::is_gfx12_supported()) + { + return get_max_vgpr_count(gfx12_t{}); + } + else if(ck::is_gfx11_supported()) + { + return get_max_vgpr_count(gfx11_t{}); + } + else + { + return get_max_vgpr_count(gfx9_t{}); + } + }(); + + const index_t availableVgprCount = [&]() { + if constexpr(MinimumOccupancy != 0) + { + return maxVgprCount / math::max(MinimumOccupancy, 1) / + (math::integer_divide_ceil(BlockSize, ck::get_warp_size() * 4)); + } + else + { + return maxVgprCount; + } + }(); + + const auto estimateVgprCount = + ck::is_gfx11_supported() ? GetEstimateVgprCount() : GetEstimateVgprCount(); + + if(estimateVgprCount > (availableVgprCount + availableVgprCount / 4)) + { + return false; + } + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) return true; } @@ -1448,8 +1335,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); // divide block work by [M, N] const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; @@ -1479,10 +1364,10 @@ struct GridwiseGemm_xdl_cshuffle_v3 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); // A matrix blockwise copy auto a_blockwise_copy = @@ -1588,226 +1473,16 @@ struct GridwiseGemm_xdl_cshuffle_v3 num_k_block_main_loop); // shuffle C and write out - { - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - tensor_operation::element_wise::PassThrough pass_through{}; - const auto& vpgr_to_lds_element_op = [&] { - if constexpr(DoElementwiseBeforeCShuffle) - { - return problem.c_element_op_; - } - else - { - return pass_through; - } - }; - const auto& lds_to_global_element_op = [&] { - if constexpr(!DoElementwiseBeforeCShuffle) - { - return problem.c_element_op_; - } - else - { - return pass_through; - } - }; - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< - AccDataType, - CShuffleDataType, - decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), - decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), - conditional_t, - Sequence, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - vpgr_to_lds_element_op()}; - - // shuffle: blockwise copy C from LDS to global - auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< - ThisThreadBlock, // ThreadGroup - conditional_t, - CGlobalMemoryDataOperation, // DstInMemOp, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - CShuffleDataType, // typename SrcData, - CDataType, // typename DstData, - decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, - true, // bool ThreadTransferSrcResetCoordinateAfterRun, - false> // bool ThreadTransferDstResetCoordinateAfterRun> - {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(0, 0, 0, 0), - c_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_m_id, 0, block_n_id, 0), - lds_to_global_element_op()}; - - // space filling curve for threadwise C in VGPR - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; - - // space filling curve for shuffled blockwise C in global mem - constexpr auto sfc_c_global = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); - - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - c_shuffle_block_copy_lds_to_global.Run( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - c_shuffle_block_buf, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_buf); - - if constexpr(access_id < num_access - 1) - { - constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); - - // move on C - c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); - } - }); - } + Base::template RunEpilogue(blockwise_gemm_pipeline, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_m_id, + block_n_id, + p_shared, + p_c_grid, + problem.c_element_op_); } template ( p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); // divide block work by [M, N] const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; @@ -1896,10 +1569,10 @@ struct GridwiseGemm_xdl_cshuffle_v3 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); // A matrix blockwise copy auto a_blockwise_copy = @@ -2014,201 +1687,15 @@ struct GridwiseGemm_xdl_cshuffle_v3 num_k_block_main_loop); // shuffle C and write out - { - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared_0), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; - - // shuffle: blockwise copy C from LDS to global - auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< - ThisThreadBlock, // ThreadGroup - CElementwiseOperation, // ElementwiseOperation, - CGlobalMemoryDataOperation, // DstInMemOp, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - CShuffleDataType, // typename SrcData, - CDataType, // typename DstData, - decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, - true, // bool ThreadTransferSrcResetCoordinateAfterRun, - false> // bool ThreadTransferDstResetCoordinateAfterRun> - {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(0, 0, 0, 0), - c_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_m_id, 0, block_n_id, 0), - problem.c_element_op_}; - - // space filling curve for threadwise C in VGPR - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; - - // space filling curve for shuffled blockwise C in global mem - constexpr auto sfc_c_global = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); - - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - c_shuffle_block_copy_lds_to_global.Run( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - c_shuffle_block_buf, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_buf); - - if constexpr(access_id < num_access - 1) - { - constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); - - // move on C - c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); - } - }); - } + Base::template RunEpilogue( + blockwise_gemm_pipeline, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_m_id, + block_n_id, + p_shared_0, + p_c_grid, + problem.c_element_op_); } template ()) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char + p_shared[GridwiseGemm::template GetSharedMemoryNumberOfByte(get_device_arch())]; // Full K needed for matrix B const index_t Kt = karg.K; @@ -79,8 +80,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) { // Pass two lds pointer is the key to tell compiler that ds_read/write // operate on different lds chunk at same time without order dependecy - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char + p_shared_0[GridwiseGemm::template GetSharedMemoryNumberOfByte(get_device_arch())]; + __shared__ char + p_shared_1[GridwiseGemm::template GetSharedMemoryNumberOfByte(get_device_arch())]; // Full K needed for matrix B const index_t Kt = karg.K; @@ -154,22 +157,107 @@ template struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle + : public GridwiseGemm_xdl_cshuffle_base< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + Tuple<>, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + ComputeTypeA, + ComputeTypeB, + BlkGemmPipelineVer == BlockGemmPipelineVersion::v4> { - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; - - // K1 should be Number<...> - static constexpr auto AK0Number = Number{}; - static constexpr auto BK0Number = Number{}; - static constexpr auto AK1Number = Number{}; - static constexpr auto BK1Number = Number{}; - + using Base = GridwiseGemm_xdl_cshuffle_base< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + Tuple<>, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + ComputeTypeA, + ComputeTypeB, + BlkGemmPipelineVer == BlockGemmPipelineVersion::v4>; + + using Base::AK0Number; + using Base::AK1Number; + using Base::BK0Number; + using Base::BK1Number; + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::I0; + using Base::I1; + using Base::I2; + using ThisThreadBlock = typename Base::ThisThreadBlock; // Use singal rate mfma instruction for this special case A (f8_t) * B (pk_i4_t) // See example gemm_xdl_fp8_pk_i4_bpreshuffle_v3 // TODO: explore optimization opportunity by using new mfma instructions on gfx950 @@ -189,8 +277,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle static constexpr index_t NLane = NPerXdl; static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; - using ThisThreadBlock = ThisThreadBlock; - static constexpr index_t APackedSize = []() { if constexpr(is_same_v, pk_i4_t>) return 2; @@ -697,124 +783,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle index_t c_reduce_offset; }; - __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t WaveSize = BlockSize / (MWave * NWave); - // A matrix in LDS memory, dst of blockwise copy - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); - } - // xor tensor transformation request more unnecessary vgpr usage, would cause register spill - // in some cases. - else if constexpr(is_same::value) - { - constexpr auto a_lds_block_desc = - make_naive_tensor_descriptor(make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple(make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<1, 0>{}, Sequence<2>{}), - make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - - return a_lds_block_desc_permuted; - } - else // ColumnMajor A - { - // kfold and mpair dimension is not always required. - // more dimension in merge_transform increase the difficulty of generating immarg offset - // for compiler. - constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); - constexpr auto M1 = MPerBlock / M0; - - constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); - constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; - constexpr auto KThreadRead = WaveSize / MPerXdl; - constexpr auto K0PerThreadRead = AK0Number / KThreadRead; - - constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) - ? 1 - : 128 / (AK1Number * M0 * sizeof(ADataType)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=mpair<=n0 - constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128) - ? 1 - : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0 - ? M0 - : 128 / (AK1Number * MPerXdl * sizeof(ADataType))); - - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - AK1Number)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); - - constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<1>{}, - Sequence<2>{}, - Sequence<0, 3>{}, - Sequence<4, 5>{}, - Sequence<6>{}, - Sequence<7>{})); - - constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(Number{}, - Number{}, - Number{}, - Number{})), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return a_lds_block_desc_ak0_m_ak1; - } - } - __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() { // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack @@ -825,67 +793,31 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle Number{})); //??? BK1Value same as KPack? } - __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - I1, - Number{})); - - return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; - } - - using BlockwiseGemmPipe = - remove_cvref_t())>; - - __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - - // lds max alignment - constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); - - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); - - // LDS allocation for C shuffle in LDS - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - constexpr auto c_block_size = - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - - return math::max(a_block_space_size_aligned * sizeof(ADataType) / APackedSize, - c_block_size * sizeof(CShuffleDataType)); - } + using BlockwiseGemmPipe = remove_cvref_t< + decltype(BlockGemmBPreshufflePipeline_Selector< + BlkGemmPipelineVer, + BlkGemmPipeSched, + BlockSize, + ADataType, + BDataType, + ComputeTypeA, + AccDataType, + decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch())), + decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), + decltype(MakeAMmaTileDescriptor_M0_M1_M2_K( + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()))), + decltype(MakeBMmaTileDescriptor_N0_N1_N2_K( + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack>())>; IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) @@ -1142,8 +1074,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( p_b_grid, b_grid_desc_bpreshuffled.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); const AElementwiseOperation a_element_op{}; // const BElementwiseOperation b_element_op{}; @@ -1175,7 +1105,8 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); @@ -1264,200 +1195,15 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle num_k_block_main_loop); // shuffle C and write out - { - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; - - // shuffle: blockwise copy C from LDS to global - auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< - ThisThreadBlock, // ThreadGroup - CElementwiseOperation, // ElementwiseOperation, - CGlobalMemoryDataOperation, // DstInMemOp, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - CShuffleDataType, // typename SrcData, - CDataType, // typename DstData, - decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, - true, // bool ThreadTransferSrcResetCoordinateAfterRun, - false> // bool ThreadTransferDstResetCoordinateAfterRun> - {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(0, 0, 0, 0), - c_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_m_id, 0, block_n_id, 0), - c_element_op}; - - // space filling curve for threadwise C in VGPR - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; - - // space filling curve for shuffled blockwise C in global mem - constexpr auto sfc_c_global = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); - - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - c_shuffle_block_copy_lds_to_global.Run( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - c_shuffle_block_buf, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_buf); - - if constexpr(access_id < num_access - 1) - { - constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); - - // move on C - c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); - } - }); - } + Base::template RunEpilogue( + blockwise_gemm_pipeline, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_m_id, + block_n_id, + p_shared, + p_c_grid, + c_element_op); } template ( p_b_grid, b_grid_desc_bpreshuffled.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); const AElementwiseOperation a_element_op{}; // const BElementwiseOperation b_element_op{}; @@ -1554,7 +1298,8 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); @@ -1651,200 +1396,15 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle num_k_block_main_loop); // shuffle C and write out - { - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared_0), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; - - // shuffle: blockwise copy C from LDS to global - auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< - ThisThreadBlock, // ThreadGroup - CElementwiseOperation, // ElementwiseOperation, - CGlobalMemoryDataOperation, // DstInMemOp, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - CShuffleDataType, // typename SrcData, - CDataType, // typename DstData, - decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, - true, // bool ThreadTransferSrcResetCoordinateAfterRun, - false> // bool ThreadTransferDstResetCoordinateAfterRun> - {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(0, 0, 0, 0), - c_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_m_id, 0, block_n_id, 0), - c_element_op}; - - // space filling curve for threadwise C in VGPR - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; - - // space filling curve for shuffled blockwise C in global mem - constexpr auto sfc_c_global = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); - - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - c_shuffle_block_copy_lds_to_global.Run( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - c_shuffle_block_buf, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_buf); - - if constexpr(access_id < num_access - 1) - { - constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); - - // move on C - c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); - } - }); - } + Base::template RunEpilogue( + blockwise_gemm_pipeline, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_m_id, + block_n_id, + p_shared_0, + p_c_grid, + c_element_op); } template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(GridwiseGemm::MaxBlockSize, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) @@ -38,7 +38,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); @@ -62,7 +62,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(GridwiseGemm::MaxBlockSize, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) @@ -72,8 +72,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) { // Pass two lds pointer is the key to tell compiler that ds_read/write // operate on different lds chunk at same time without order dependecy - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); @@ -142,24 +142,113 @@ template struct GridwiseGemm_xdl_cshuffle_v3 + : public GridwiseGemm_xdl_cshuffle_base< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + Tuple<>, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + ComputeTypeA, + ComputeTypeB, + BlkGemmPipelineVer == BlockGemmPipelineVersion::v4> { - using BScaleType = ck::half_t; - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; - - // K1 should be Number<...> - static constexpr auto AK0Number = Number{}; - static constexpr auto BK0Number = Number{}; - static constexpr auto AK1Number = Number{}; - static constexpr auto BK1Number = Number{}; - + using Base = GridwiseGemm_xdl_cshuffle_base< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + Tuple<>, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + ComputeTypeA, + ComputeTypeB, + BlkGemmPipelineVer == BlockGemmPipelineVersion::v4>; + + using Base::AK0Number; + using Base::AK1Number; + using Base::BK0Number; + using Base::BK1Number; + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; + using Base::I0; + using Base::I1; + using Base::I2; + using ThisThreadBlock = typename Base::ThisThreadBlock; + + static constexpr index_t APackedSize = packed_size_v; + static constexpr index_t BPackedSize = packed_size_v; + + using BScaleType = ck::half_t; static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); static constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && @@ -179,22 +268,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 is_single_rate_mfma, is_scale_mfma>::selected_mfma.k_per_blk); - using ThisThreadBlock = ThisThreadBlock; - - static constexpr index_t APackedSize = []() { - if constexpr(is_same_v, pk_i4_t>) - return 2; - else - return 1; - }(); - - static constexpr index_t BPackedSize = []() { - if constexpr(is_same_v, pk_i4_t>) - return 2; - else - return 1; - }(); - __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) { return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); @@ -711,347 +784,31 @@ struct GridwiseGemm_xdl_cshuffle_v3 index_t c_reduce_offset; }; - __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - constexpr index_t WaveSize = BlockSize / (MWave * NWave); - // A matrix in LDS memory, dst of blockwise copy - if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) - { - return make_naive_tensor_descriptor( - make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); - } - // xor tensor transformation request more unnecessary vgpr usage, would cause register spill - // in some cases. - else if constexpr(is_same::value) - { - constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(ADataType) / APackedSize; - constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize; - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( - make_tuple( - AK0Number * Number{}, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple(make_xor_with_modulo_transform(make_tuple( - Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<1, 0>{}, Sequence<2>{}), - make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - - constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); - - constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_ak0_mldslayer_m_ak1, - make_tuple(make_pass_through_transform(AK0Number), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return a_lds_block_desc_ak0_m_ak1; - } - else // ColumnMajor A - { - // kfold and mpair dimension is not always required. - // more dimension in merge_transform increase the difficulty of generating immarg offset - // for compiler. - constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); - constexpr auto M1 = MPerBlock / M0; - - constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); - constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; - constexpr auto KThreadRead = WaveSize / MPerXdl; - constexpr auto K0PerThreadRead = AK0Number / KThreadRead; - - constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) - ? 1 - : 128 / (AK1Number * M0 * sizeof(ADataType)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=mpair<=n0 - constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128) - ? 1 - : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0 - ? M0 - : 128 / (AK1Number * MPerXdl * sizeof(ADataType))); - - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - AK1Number)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); - - constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<1>{}, - Sequence<2>{}, - Sequence<0, 3>{}, - Sequence<4, 5>{}, - Sequence<6>{}, - Sequence<7>{})); - - constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(Number{}, - Number{}, - Number{}, - Number{})), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return a_lds_block_desc_ak0_m_ak1; - } - } - - __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - constexpr index_t WaveSize = BlockSize / (MWave * NWave); - // B matrix in LDS memory, dst of blockwise copy - if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) - { - return make_naive_tensor_descriptor( - make_tuple(BK0Number, Number{}, BK1Number), - make_tuple(BK1Number, Number{}, I1)); - } - else if constexpr(is_same::value) - { - // NLdsLayer * K0 as logical Bank - constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType) / BPackedSize; - constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize; - constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( - make_tuple( - BK0Number * Number{}, Number{}, BK1Number), - make_tuple(BK1Number, Number{}, I1)); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc, - make_tuple(make_xor_with_modulo_transform(make_tuple( - Number{}, Number{})), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<1, 0>{}, Sequence<2>{}), - make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - - constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); - - constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_bk0_nldslayer_n_bk1, - make_tuple(make_pass_through_transform(BK0Number), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{})), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return b_lds_block_desc_bk0_n_bk1; - } - else // RowMajor B - { - constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); - constexpr auto N1 = NPerBlock / N0; - - constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); - constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; - constexpr auto KThreadRead = WaveSize / NPerXdl; - constexpr auto K0PerThreadRead = BK0Number / KThreadRead; - - constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) - ? 1 - : 128 / (BK1Number * N0 * sizeof(BDataType)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=npair<=n0 - constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128) - ? 1 - : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0 - ? N0 - : 128 / (BK1Number * NPerXdl * sizeof(BDataType))); - - constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - BK1Number)); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(BK1Number)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); - - constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<1>{}, - Sequence<2>{}, - Sequence<0, 3>{}, - Sequence<4, 5>{}, - Sequence<6>{}, - Sequence<7>{})); - - constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(Number{}, - Number{}, - Number{}, - Number{})), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return b_lds_block_desc_bk0_n_bk1; - } - } - - __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - I1, - Number{})); - - return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; - } - - using BlockwiseGemmPipe = - remove_cvref_t())>; - - __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - - // lds max alignment - constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); - - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size_aligned = math::integer_least_multiple( - b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); - - // LDS allocation for C shuffle in LDS - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - constexpr auto c_block_size = - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - - return math::max((a_block_space_size_aligned * sizeof(ADataType) / APackedSize + - b_block_space_size_aligned * sizeof(BDataType) / BPackedSize), - c_block_size * sizeof(CShuffleDataType)); - } + using BlockwiseGemmPipe = remove_cvref_t< + decltype(BlockGemmPipeline_Selector< + BlkGemmPipelineVer, + BlkGemmPipeSched, + BlockSize, + ADataType, + BDataType, + ComputeTypeA, + AccDataType, + decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch())), + decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch())), + decltype(MakeAMmaTileDescriptor_M0_M1_M2_K( + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()))), + decltype(MakeBMmaTileDescriptor_N0_N1_N2_K( + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()))), + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack>())>; IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) @@ -1308,8 +1065,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); // B Scale buffer const auto b_scale_grid_buf = make_dynamic_buffer( @@ -1347,10 +1102,12 @@ struct GridwiseGemm_xdl_cshuffle_v3 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); // A matrix blockwise copy auto a_blockwise_copy = @@ -1513,201 +1270,15 @@ struct GridwiseGemm_xdl_cshuffle_v3 num_k_block_per_scale); // shuffle C and write out - { - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; - - // shuffle: blockwise copy C from LDS to global - auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< - ThisThreadBlock, // ThreadGroup - CElementwiseOperation, // ElementwiseOperation, - CGlobalMemoryDataOperation, // DstInMemOp, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - CShuffleDataType, // typename SrcData, - CDataType, // typename DstData, - decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, - true, // bool ThreadTransferSrcResetCoordinateAfterRun, - false> // bool ThreadTransferDstResetCoordinateAfterRun> - {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(0, 0, 0, 0), - c_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_m_id, 0, block_n_id, 0), - c_element_op}; - - // space filling curve for threadwise C in VGPR - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; - - // space filling curve for shuffled blockwise C in global mem - constexpr auto sfc_c_global = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); - - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - c_shuffle_block_copy_lds_to_global.Run( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - c_shuffle_block_buf, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_buf); - - if constexpr(access_id < num_access - 1) - { - constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); - - // move on C - c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); - } - }); - } + Base::template RunEpilogue( + blockwise_gemm_pipeline, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_m_id, + block_n_id, + p_shared, + p_c_grid, + c_element_op); } template ( p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); // B Scale buffer const auto b_scale_grid_buf = make_dynamic_buffer( @@ -1817,10 +1386,12 @@ struct GridwiseGemm_xdl_cshuffle_v3 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); // A matrix blockwise copy auto a_blockwise_copy = @@ -1994,201 +1565,15 @@ struct GridwiseGemm_xdl_cshuffle_v3 num_k_block_per_scale); // shuffle C and write out - { - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared_0), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; - - // shuffle: blockwise copy C from LDS to global - auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< - ThisThreadBlock, // ThreadGroup - CElementwiseOperation, // ElementwiseOperation, - CGlobalMemoryDataOperation, // DstInMemOp, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - CShuffleDataType, // typename SrcData, - CDataType, // typename DstData, - decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, - true, // bool ThreadTransferSrcResetCoordinateAfterRun, - false> // bool ThreadTransferDstResetCoordinateAfterRun> - {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(0, 0, 0, 0), - c_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_m_id, 0, block_n_id, 0), - c_element_op}; - - // space filling curve for threadwise C in VGPR - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; - - // space filling curve for shuffled blockwise C in global mem - constexpr auto sfc_c_global = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); - - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - c_shuffle_block_copy_lds_to_global.Run( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - c_shuffle_block_buf, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_buf); - - if constexpr(access_id < num_access - 1) - { - constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); - - // move on C - c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); - } - }); - } + Base::template RunEpilogue( + blockwise_gemm_pipeline, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_m_id, + block_n_id, + p_shared_0, + p_c_grid, + c_element_op); } template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(GridwiseGemm::MaxBlockSize, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) @@ -37,7 +38,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; GridwiseGemm::template Run( karg.p_as_grid, @@ -62,7 +63,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(GridwiseGemm::MaxBlockSize, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) @@ -72,8 +73,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) { // Pass two lds pointer is the key to tell compiler that ds_read/write // operate on different lds chunk at same time without order dependecy - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; GridwiseGemm::template Run_2Lds( karg.p_as_grid, @@ -140,28 +141,115 @@ template struct GridwiseGemm_xdl_cshuffle_v3 + : public GridwiseGemm_xdl_cshuffle_base< + ALayout, + BLayout, + CLayout, + ComputeTypeA, + ComputeTypeB, + AccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + ComputeTypeA, + ComputeTypeB, + false> // ForceNaiveLayout { - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; + using Base = GridwiseGemm_xdl_cshuffle_base< + ALayout, + BLayout, + CLayout, + ComputeTypeA, + ComputeTypeB, + AccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + ComputeTypeA, + ComputeTypeB, + false>; // ForceNaiveLayout + + using Base::AK0Number; + using Base::AK1Number; + using Base::BK0Number; + using Base::BK1Number; + using Base::I0; + using Base::I1; + using Base::I2; + using ThisThreadBlock = typename Base::ThisThreadBlock; + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; + using Base::NumDTensor; using LDSTypeA = ComputeTypeA; using LDSTypeB = ComputeTypeB; - // K1 should be Number<...> - static constexpr auto AK0Number = Number{}; - static constexpr auto BK0Number = Number{}; - static constexpr auto AK1Number = Number{}; - static constexpr auto BK1Number = Number{}; - static constexpr index_t NumATensor = AsDataType::Size(); static constexpr index_t NumBTensor = BsDataType::Size(); - static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr auto MakeAsGridPointer() { @@ -219,8 +307,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 is_single_rate_mfma, is_scale_mfma>::selected_mfma.k_per_blk); - using ThisThreadBlock = ThisThreadBlock; - __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) { return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); @@ -773,353 +859,93 @@ struct GridwiseGemm_xdl_cshuffle_v3 }; #endif - __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + using BlockwiseGemmPipe = remove_cvref_t< + decltype(BlockGemmPipeline_Selector< + BlkGemmPipelineVer, + BlkGemmPipeSched, + BlockSize, + LDSTypeA, + LDSTypeB, + ComputeTypeA, + AccDataType, + decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch())), + decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch())), + decltype(MakeAMmaTileDescriptor_M0_M1_M2_K( + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()))), + decltype(MakeBMmaTileDescriptor_N0_N1_N2_K( + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()))), + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack>())>; + + template + static constexpr index_t GetEstimateVgprCount() { constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); constexpr index_t WaveSize = BlockSize / (MWave * NWave); - // A matrix in LDS memory, dst of blockwise copy - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); - } - // xor tensor transformation request more unnecessary vgpr usage, would cause register spill - // in some cases. - else if constexpr(is_same::value) - { - constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeA) < 1 - ? 1 - : 32 * 4 / KPerBlock / sizeof(LDSTypeA); - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( - make_tuple( - AK0Number * Number{}, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple(make_xor_with_modulo_transform(make_tuple( - Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<1, 0>{}, Sequence<2>{}), - make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - - constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); - - constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_ak0_mldslayer_m_ak1, - make_tuple(make_pass_through_transform(AK0Number), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return a_lds_block_desc_ak0_m_ak1; - } - else // ColumnMajor A + + // VGPR used in LDS loading and WMMA + constexpr index_t BaseInputVgprCount = + MPerBlock * KPerBlock / MWave / WaveSize * sizeof(ComputeTypeA) / sizeof(uint32_t) + + NPerBlock * KPerBlock / NWave / WaveSize * sizeof(ComputeTypeB) / sizeof(uint32_t); + // WMMA input is duplicated in GFX11 + constexpr index_t InputVgprCount = IsGfx11 ? BaseInputVgprCount * 2 : BaseInputVgprCount; + // VGPR used in Accumulator + constexpr index_t AccVgprCount = + MPerBlock * NPerBlock / BlockSize * sizeof(AccDataType) / sizeof(uint32_t); + + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) { - // kfold and mpair dimension is not always required. - // more dimension in merge_transform increase the difficulty of generating immarg offset - // for compiler. - constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); - constexpr auto M1 = MPerBlock / M0; - - constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); - constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; - constexpr auto KThreadRead = WaveSize / MPerXdl; - constexpr auto K0PerThreadRead = AK0Number / KThreadRead; - - constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128) - ? 1 - : 128 / (AK1Number * M0 * sizeof(LDSTypeA)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=mpair<=n0 - constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128) - ? 1 - : ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0 - ? M0 - : 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))); - - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - AK1Number)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); - - constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<1>{}, - Sequence<2>{}, - Sequence<0, 3>{}, - Sequence<4, 5>{}, - Sequence<6>{}, - Sequence<7>{})); - - constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(Number{}, - Number{}, - Number{}, - Number{})), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return a_lds_block_desc_ak0_m_ak1; + return InputVgprCount + AccVgprCount; } - } - - __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - constexpr index_t WaveSize = BlockSize / (MWave * NWave); - // B matrix in LDS memory, dst of blockwise copy - if constexpr(BBlockLdsExtraN) + else if constexpr((BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) || + (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) || + (BlkGemmPipelineVer == BlockGemmPipelineVersion::v5)) { - return make_naive_tensor_descriptor( - make_tuple(BK0Number, Number{}, BK1Number), - make_tuple(BK1Number, Number{}, I1)); + return 2 * InputVgprCount + AccVgprCount; } - else if constexpr(is_same::value) + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) { - // NLdsLayer * K0 as logical Bank - constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeB) < 1 - ? 1 - : 32 * 4 / KPerBlock / sizeof(LDSTypeB); - ; - constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( - make_tuple( - BK0Number * Number{}, Number{}, BK1Number), - make_tuple(BK1Number, Number{}, I1)); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc, - make_tuple(make_xor_with_modulo_transform(make_tuple( - Number{}, Number{})), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<1, 0>{}, Sequence<2>{}), - make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - - constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); - - constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_bk0_nldslayer_n_bk1, - make_tuple(make_pass_through_transform(BK0Number), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{})), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return b_lds_block_desc_bk0_n_bk1; + return 3 * InputVgprCount + AccVgprCount; } - else // RowMajor B + else { - constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); - constexpr auto N1 = NPerBlock / N0; - - constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); - constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; - constexpr auto KThreadRead = WaveSize / NPerXdl; - constexpr auto K0PerThreadRead = BK0Number / KThreadRead; - - constexpr auto kfold = (BK1Number * N0 * sizeof(LDSTypeB) > 128) - ? 1 - : 128 / (BK1Number * N0 * sizeof(LDSTypeB)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=npair<=n0 - constexpr auto npair = (BK1Number * NPerXdl * sizeof(LDSTypeB) > 128) - ? 1 - : ((128 / (BK1Number * NPerXdl * sizeof(LDSTypeB))) > N0 - ? N0 - : 128 / (BK1Number * NPerXdl * sizeof(LDSTypeB))); - - constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - BK1Number)); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(BK1Number)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); - - constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<1>{}, - Sequence<2>{}, - Sequence<0, 3>{}, - Sequence<4, 5>{}, - Sequence<6>{}, - Sequence<7>{})); - - constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(Number{}, - Number{}, - Number{}, - Number{})), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return b_lds_block_desc_bk0_n_bk1; + // invalid pipeline version + static_assert(0); } } - - __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - I1, - Number{})); - - return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; - } - - using BlockwiseGemmPipe = - remove_cvref_t())>; - - __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + template < + InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> + __device__ static bool constexpr IsValidCompilationParameter() { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - - // lds max alignment - constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); - - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size_aligned = math::integer_least_multiple( - b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); - - // LDS allocation for C shuffle in LDS - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - constexpr auto c_block_size = - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + constexpr bool IsGfx11 = is_same_v; + constexpr auto EstimateVgprCount = GetEstimateVgprCount(); + constexpr auto AvailableVgprCount = get_max_vgpr_count(get_device_arch()); + if constexpr(EstimateVgprCount > (AvailableVgprCount + AvailableVgprCount / 4)) + { + return false; + } - return math::max((a_block_space_size_aligned * sizeof(LDSTypeA) + - b_block_space_size_aligned * sizeof(LDSTypeB)), - c_block_size * sizeof(CShuffleDataType)); + return ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + CDataType, + CGlobalMemoryDataOperation_>(); } - IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) - // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ static constexpr bool CheckValidity(const Argument& karg) { @@ -1292,6 +1118,34 @@ struct GridwiseGemm_xdl_cshuffle_v3 } } + const auto availableVgprCount = []() { + if(ck::is_gfx12_supported()) + { + return get_max_vgpr_count(gfx12_t{}); + } + else if(ck::is_gfx11_supported()) + { + return get_max_vgpr_count(gfx11_t{}); + } + else + { + return get_max_vgpr_count(gfx9_t{}); + } + }(); + const auto estimateVgprCount = + ck::is_gfx11_supported() ? GetEstimateVgprCount() : GetEstimateVgprCount(); + if(estimateVgprCount > (availableVgprCount + availableVgprCount / 4)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Estimated VGPR count (" << estimateVgprCount + << ") exceeds available VGPR count (" << availableVgprCount << ")! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) return true; } @@ -1411,16 +1265,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 }, Number{}); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - const auto ds_grid_buf = generate_tuple( - [&](auto i) { - return make_dynamic_buffer( - p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); - }, - Number{}); - // divide block work by [M, N] const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; @@ -1449,10 +1293,12 @@ struct GridwiseGemm_xdl_cshuffle_v3 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); #if 0 // A matrix blockwise copy @@ -1620,311 +1466,17 @@ struct GridwiseGemm_xdl_cshuffle_v3 num_k_block_main_loop); // shuffle C and write out - { - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; - -#if 0 - // shuffle: blockwise copy C from LDS to global - auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< - ThisThreadBlock, // ThreadGroup - CElementwiseOperation, // ElementwiseOperation, - CGlobalMemoryDataOperation, // DstInMemOp, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - CShuffleDataType, // typename SrcData, - CDataType, // typename DstData, - decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, - true, // bool ThreadTransferSrcResetCoordinateAfterRun, - false> // bool ThreadTransferDstResetCoordinateAfterRun> - {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(0, 0, 0, 0), - c_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_m_id, 0, block_n_id, 0), - c_element_op}; -#else - using EDataType = CDataType; - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_desc_refs = concat_tuple_of_reference( - tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_buf_refs = concat_tuple_of_reference( - tie(c_shuffle_block_buf), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); - - // tuple of starting index of C/Ds blockwise copy - const auto idx_c_ds_block_begin = container_concat( - make_tuple(make_multi_index(0, 0, 0, 0)), - generate_tuple( - [&](auto) { - return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); - }, - Number{})); - - const auto e_grid_desc_mblock_mperblock_nblock_nperblock = - c_grid_desc_mblock_mperblock_nblock_nperblock; - - using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; - const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; - const auto CDEShuffleBlockTransferScalarPerVector_NPerBlock = - CShuffleBlockTransferScalarPerVector_NPerBlock; - - auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r2< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence - // support arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVector_NPerBlock, - CDEShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags - {c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), - c_element_op}; - -#endif - - // space filling curve for threadwise C in VGPR - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); -#if 0 - // space filling curve for shuffled blockwise C in global mem - constexpr auto sfc_c_global = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - - static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); - -#else - // space filling curve for shuffled blockwise C/D/E - constexpr auto sfc_cde_block = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); -#endif - - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - -#if 0 - // each block copy its data from LDS to global - c_shuffle_block_copy_lds_to_global.Run( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - c_shuffle_block_buf, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_buf); - - if constexpr(access_id < num_access - 1) - { - constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); - - // move on C - c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); - } -#else - // each block copy its data from LDS to global - cde_block_copy_lds_and_global.Run( - c_ds_desc_refs, - c_ds_buf_refs, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - tie(c_grid_buf)); - - if constexpr(access_id < num_access - 1) - { - constexpr auto cde_lds_and_global_step = - sfc_cde_block.GetForwardStep(access_id); - - // move on Ds - static_for<0, NumDTensor, 1>{}([&](auto i) { - cde_block_copy_lds_and_global.MoveSrcSliceWindow( - c_ds_desc_refs, i + I1, cde_lds_and_global_step); - }); - - // move on E - cde_block_copy_lds_and_global.MoveDstSliceWindow( - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - I0, - cde_lds_and_global_step); - } -#endif - }); - } + Base::template RunMultiDEpilogue( + blockwise_gemm_pipeline, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_m_id, + block_n_id, + p_shared, + p_ds_grid, + p_c_grid, + c_element_op); } #if 1 @@ -1982,16 +1534,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 }, Number{}); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - const auto ds_grid_buf = generate_tuple( - [&](auto i) { - return make_dynamic_buffer( - p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); - }, - Number{}); - // divide block work by [M, N] const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; @@ -2020,10 +1562,12 @@ struct GridwiseGemm_xdl_cshuffle_v3 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); #if 0 // A matrix blockwise copy @@ -2201,308 +1745,17 @@ struct GridwiseGemm_xdl_cshuffle_v3 num_k_block_main_loop); // shuffle C and write out - { - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared_0), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; - -#if 0 - // shuffle: blockwise copy C from LDS to global - auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< - ThisThreadBlock, // ThreadGroup - CElementwiseOperation, // ElementwiseOperation, - CGlobalMemoryDataOperation, // DstInMemOp, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - CShuffleDataType, // typename SrcData, - CDataType, // typename DstData, - decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, - true, // bool ThreadTransferSrcResetCoordinateAfterRun, - false> // bool ThreadTransferDstResetCoordinateAfterRun> - {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(0, 0, 0, 0), - c_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_m_id, 0, block_n_id, 0), - c_element_op}; -#else - using EDataType = CDataType; - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_desc_refs = concat_tuple_of_reference( - tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_buf_refs = concat_tuple_of_reference( - tie(c_shuffle_block_buf), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); - - // tuple of starting index of C/Ds blockwise copy - const auto idx_c_ds_block_begin = container_concat( - make_tuple(make_multi_index(0, 0, 0, 0)), - generate_tuple( - [&](auto) { - return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); - }, - Number{})); - - const auto e_grid_desc_mblock_mperblock_nblock_nperblock = - c_grid_desc_mblock_mperblock_nblock_nperblock; - - using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; - const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; - const auto CDEShuffleBlockTransferScalarPerVector_NPerBlock = - CShuffleBlockTransferScalarPerVector_NPerBlock; - - auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r2< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence - // support arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVector_NPerBlock, - CDEShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags - {c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)), - c_element_op}; - -#endif - - // space filling curve for threadwise C in VGPR - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; - - // space filling curve for shuffled blockwise C in global mem - constexpr auto sfc_c_global = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); - -#if 1 - // space filling curve for shuffled blockwise C/D/E - constexpr auto sfc_cde_block = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; -#endif - - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - -#if 0 - // each block copy its data from LDS to global - c_shuffle_block_copy_lds_to_global.Run( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - c_shuffle_block_buf, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_buf); - - if constexpr(access_id < num_access - 1) - { - constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); - - // move on C - c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); - } -#else - // each block copy its data from LDS to global - cde_block_copy_lds_and_global.Run( - c_ds_desc_refs, - c_ds_buf_refs, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - tie(c_grid_buf)); - - if constexpr(access_id < num_access - 1) - { - constexpr auto cde_lds_and_global_step = - sfc_cde_block.GetForwardStep(access_id); - - // move on Ds - static_for<0, NumDTensor, 1>{}([&](auto i) { - cde_block_copy_lds_and_global.MoveSrcSliceWindow( - c_ds_desc_refs, i + I1, cde_lds_and_global_step); - }); - - // move on E - cde_block_copy_lds_and_global.MoveDstSliceWindow( - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - I0, - cde_lds_and_global_step); - } -#endif - }); - } + Base::template RunMultiDEpilogue( + blockwise_gemm_pipeline, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_m_id, + block_n_id, + p_shared_0, + p_ds_grid, + p_c_grid, + c_element_op); } #endif }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp index e7e24b148a2..16595faab80 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp @@ -4,18 +4,19 @@ #pragma once #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/host_utility/device_prop.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" -#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp" -#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp" #define DEBUG_LOG 0 @@ -34,7 +35,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(GridwiseGemm::MaxBlockSize, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_multi_d(typename GridwiseGemm::Argument karg) @@ -42,7 +43,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); @@ -69,7 +70,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(GridwiseGemm::MaxBlockSize, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_multi_d_2lds(typename GridwiseGemm::Argument karg) @@ -79,8 +80,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) { // Pass two lds pointer is the key to tell compiler that ds_read/write // operate on different lds chunk at same time without order dependecy - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); @@ -154,33 +155,119 @@ template struct GridwiseGemmMultiD_xdl_cshuffle_v3 + : public GridwiseGemm_xdl_cshuffle_base< + ALayout, + BLayout, + CLayout, + LDSTypeA, + LDSTypeB, + AccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraMCustom, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraNCustom, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + ComputeTypeA, + ComputeTypeB, + BlkGemmPipelineVer == BlockGemmPipelineVersion::v4, + DirectLoad> { static_assert((is_same_v && is_same_v) || !DirectLoad); - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; - - static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock = - CDEShuffleBlockTransferScalarPerVectors{}[I0]; + using Base = GridwiseGemm_xdl_cshuffle_base< + ALayout, + BLayout, + CLayout, + LDSTypeA, + LDSTypeB, + AccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraMCustom, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraNCustom, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + ComputeTypeA, + ComputeTypeB, + BlkGemmPipelineVer == BlockGemmPipelineVersion::v4, + DirectLoad>; + + using Base::AK0Number; + using Base::AK1Number; + using Base::BK0Number; + using Base::BK1Number; + using Base::I0; + using Base::I1; + using Base::I2; + using ThisThreadBlock = typename Base::ThisThreadBlock; + using Base::CShuffleBlockTransferScalarPerVector_NPerBlock; + using Base::NumDTensor; // K1 should be Number<...> - static constexpr auto AK0Number = Number{}; - static constexpr auto BK0Number = Number{}; - static constexpr auto AK1Number = Number{}; - static constexpr auto BK1Number = Number{}; static constexpr bool DirectLoadEnabled = DirectLoad; - static constexpr index_t NumDTensor = DsDataType::Size(); - static constexpr auto MakeDsGridPointer() { return generate_tuple( @@ -214,8 +301,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 is_single_rate_mfma, is_scale_mfma>::selected_mfma.k_per_blk); - using ThisThreadBlock = ThisThreadBlock; - __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) { return std::make_tuple(Block2CTileMapDefault::CalculateGridSize(M, N), 1, KBatch); @@ -765,360 +850,90 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 index_t b_k_split_offset; }; - __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + template + __device__ __host__ static constexpr auto + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(DeviceArch) { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - constexpr index_t WaveSize = BlockSize / (MWave * NWave); -#if defined(__gfx950__) - // Force use padded layout on gfx950 to reduce bank conflicts - constexpr index_t ABlockLdsExtraM = 1; -#else - constexpr index_t ABlockLdsExtraM = ABlockLdsExtraMCustom; -#endif - - // A matrix in LDS memory, dst of blockwise copy - if constexpr(DirectLoad) + if constexpr(is_same_v) { - return make_naive_tensor_descriptor( - make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); - } - else if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) - { - // bank conflict when writting the data into LDS, but don't worry, we have whole entire - // loop to hide it in v4. it may give you some benefit from less valu in compute address - return make_naive_tensor_descriptor( - make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(Number{} * AK1Number, AK1Number, I1)); - } - // xor tensor transformation request more unnecessary vgpr usage, would cause register spill - // in some cases. - else if constexpr(is_same::value) - { - constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeA) < 1 - ? 1 - : 32 * 4 / KPerBlock / sizeof(LDSTypeA); - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( - make_tuple( - AK0Number * Number{}, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple(make_xor_with_modulo_transform(make_tuple( - Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<1, 0>{}, Sequence<2>{}), - make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - - constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); - - constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_ak0_mldslayer_m_ak1, - make_tuple(make_pass_through_transform(AK0Number), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return a_lds_block_desc_ak0_m_ak1; + if constexpr(DirectLoad == false) + { + // Force use padded layout on gfx950 to reduce bank conflicts + constexpr index_t ABlockLdsExtraM = 1; + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(Number{} * AK1Number, AK1Number, I1)); + } + else + { + return Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(DeviceArch{}); + } } - else // ColumnMajor A + else { - // kfold and mpair dimension is not always required. - // more dimension in merge_transform increase the difficulty of generating immarg offset - // for compiler. - constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); - constexpr auto M1 = MPerBlock / M0; - - constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); - constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; - constexpr auto KThreadRead = WaveSize / MPerXdl; - constexpr auto K0PerThreadRead = AK0Number / KThreadRead; - - constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128) - ? 1 - : 128 / (AK1Number * M0 * sizeof(LDSTypeA)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=mpair<=n0 - constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128) - ? 1 - : ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0 - ? M0 - : 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))); - - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - AK1Number)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); - - constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<1>{}, - Sequence<2>{}, - Sequence<0, 3>{}, - Sequence<4, 5>{}, - Sequence<6>{}, - Sequence<7>{})); - - constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(Number{}, - Number{}, - Number{}, - Number{})), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return a_lds_block_desc_ak0_m_ak1; + return Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(DeviceArch{}); } } - __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + template + __device__ __host__ static constexpr auto + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(DeviceArch) { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - constexpr index_t WaveSize = BlockSize / (MWave * NWave); -#if defined(__gfx950__) - // Force use padded layout on gfx950 to reduce bank conflicts - constexpr index_t BBlockLdsExtraN = 1; -#else - constexpr index_t BBlockLdsExtraN = BBlockLdsExtraNCustom; -#endif - - // B matrix in LDS memory, dst of blockwise copy - if constexpr(DirectLoad) - { - return make_naive_tensor_descriptor( - make_tuple(BK0Number, Number{}, BK1Number), - make_tuple(BK1Number, Number{}, I1)); - } - else if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + if constexpr(is_same_v) { - // bank conflict when writting the data into LDS, but don't worry, we have whole entire - // loop to hide it in v4. it may give you some benefit from less valu in compute address - return make_naive_tensor_descriptor( - make_tuple(BK0Number, Number{}, BK1Number), - make_tuple(Number{} * BK1Number, BK1Number, I1)); - } - else if constexpr(is_same::value) - { - // NLdsLayer * K0 as logical Bank - constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeB) < 1 - ? 1 - : 32 * 4 / KPerBlock / sizeof(LDSTypeB); - ; - constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( - make_tuple( - BK0Number * Number{}, Number{}, BK1Number), - make_tuple(BK1Number, Number{}, I1)); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc, - make_tuple(make_xor_with_modulo_transform(make_tuple( - Number{}, Number{})), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<1, 0>{}, Sequence<2>{}), - make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - - constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); - - constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_bk0_nldslayer_n_bk1, - make_tuple(make_pass_through_transform(BK0Number), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{})), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return b_lds_block_desc_bk0_n_bk1; + if constexpr(DirectLoad == false) + { + constexpr index_t BBlockLdsExtraN = 1; + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(Number{} * BK1Number, BK1Number, I1)); + } + else + { + return Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(DeviceArch{}); + } } - else // RowMajor B + else { - constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); - constexpr auto N1 = NPerBlock / N0; - - constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); - constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; - constexpr auto KThreadRead = WaveSize / NPerXdl; - constexpr auto K0PerThreadRead = BK0Number / KThreadRead; - - constexpr auto kfold = (BK1Number * N0 * sizeof(LDSTypeB) > 128) - ? 1 - : 128 / (BK1Number * N0 * sizeof(LDSTypeB)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=npair<=n0 - constexpr auto npair = (BK1Number * NPerXdl * sizeof(LDSTypeB) > 128) - ? 1 - : ((128 / (BK1Number * NPerXdl * sizeof(LDSTypeB))) > N0 - ? N0 - : 128 / (BK1Number * NPerXdl * sizeof(LDSTypeB))); - - constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - BK1Number)); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(BK1Number)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); - - constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<1>{}, - Sequence<2>{}, - Sequence<0, 3>{}, - Sequence<4, 5>{}, - Sequence<6>{}, - Sequence<7>{})); - - constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(Number{}, - Number{}, - Number{}, - Number{})), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return b_lds_block_desc_bk0_n_bk1; + return Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(DeviceArch{}); } } - __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - I1, - Number{})); - - return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; - } - - using BlockwiseGemmPipe = - remove_cvref_t())>; - - __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + using BlockwiseGemmPipe = remove_cvref_t< + decltype(BlockGemmPipeline_Selector< + BlkGemmPipelineVer, + BlkGemmPipeSched, + BlockSize, + LDSTypeA, + LDSTypeB, + ComputeTypeA, + AccDataType, + decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch())), + decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch())), + decltype(MakeAMmaTileDescriptor_M0_M1_M2_K( + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()))), + decltype(MakeBMmaTileDescriptor_N0_N1_N2_K( + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()))), + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + DirectLoad>())>; + + template + __device__ __host__ static constexpr index_t GetSharedMemoryNumberOfByte(DeviceArch) { // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(DeviceArch{}); + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(DeviceArch{}); // lds max alignment constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); @@ -1131,7 +946,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 // LDS allocation for C shuffle in LDS constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + Base::GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DeviceArch{}); constexpr auto c_block_size = c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); @@ -1141,6 +956,57 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 c_block_size * sizeof(CShuffleDataType)); } + __host__ static index_t GetSharedMemoryNumberOfByteOnHost() + { +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) + if(ck::get_device_name() == "gfx950") + { + return GetSharedMemoryNumberOfByte(gfx950_t{}); + } + else +#endif + { + return GetSharedMemoryNumberOfByte(gfx_invalid_t{}); + } + } + + template + static constexpr index_t GetEstimateVgprCount() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t WaveSize = BlockSize / (MWave * NWave); + + // VGPR used in LDS loading and WMMA + constexpr index_t BaseInputVgprCount = + MPerBlock * KPerBlock / MWave / WaveSize * sizeof(ComputeTypeA) / sizeof(uint32_t) + + NPerBlock * KPerBlock / NWave / WaveSize * sizeof(ComputeTypeB) / sizeof(uint32_t); + // WMMA input is duplicated in GFX11 + constexpr index_t InputVgprCount = IsGfx11 ? BaseInputVgprCount * 2 : BaseInputVgprCount; + // VGPR used in Accumulator + constexpr index_t AccVgprCount = + MPerBlock * NPerBlock / BlockSize * sizeof(AccDataType) / sizeof(uint32_t); + + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + return InputVgprCount + AccVgprCount; + } + else if constexpr((BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) || + (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) || + (BlkGemmPipelineVer == BlockGemmPipelineVersion::v5)) + { + return 2 * InputVgprCount + AccVgprCount; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + return 3 * InputVgprCount + AccVgprCount; + } + else + { + // invalid pipeline version + static_assert(0); + } + } template < InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> __device__ static bool constexpr IsValidCompilationParameter() @@ -1195,6 +1061,22 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 { return false; } + + constexpr index_t LdsBufferCount = + BlkGemmPipelineVer == BlockGemmPipelineVersion::v4 ? 2 : 1; + if constexpr(GetSharedMemoryNumberOfByte(get_device_arch()) * LdsBufferCount > + get_lds_size(get_device_arch())) + { + return false; + } + + constexpr bool IsGfx11 = is_same_v; + constexpr auto EstimateVgprCount = GetEstimateVgprCount(); + constexpr auto AvailableVgprCount = get_max_vgpr_count(get_device_arch()); + if constexpr(EstimateVgprCount > (AvailableVgprCount + AvailableVgprCount / 4)) + { + return false; + } return true; } @@ -1209,7 +1091,16 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 { return false; } - + constexpr index_t ldsBufferCount = + BlkGemmPipelineVer == BlockGemmPipelineVersion::v4 ? 2 : 1; + if(GetSharedMemoryNumberOfByteOnHost() * ldsBufferCount > get_lds_size()) + { + return false; + } + if(!is_xdl_wmma_k_supported()) + { + return false; + } if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || @@ -1383,6 +1274,35 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 return false; } + const auto availableVgprCount = []() { + if(ck::is_gfx12_supported()) + { + return get_max_vgpr_count(gfx12_t{}); + } + else if(ck::is_gfx11_supported()) + { + return get_max_vgpr_count(gfx11_t{}); + } + else + { + return get_max_vgpr_count(gfx9_t{}); + } + }(); + + const auto estimateVgprCount = + ck::is_gfx11_supported() ? GetEstimateVgprCount() : GetEstimateVgprCount(); + if(estimateVgprCount > (availableVgprCount + availableVgprCount / 4)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Estimated VGPR count (" << estimateVgprCount + << ") exceeds available VGPR count (" << availableVgprCount << ")! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) return true; } @@ -1521,9 +1441,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( c_grid_desc_m_n, problem.MBlock, problem.NBlock); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); @@ -1549,10 +1466,12 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); auto get_a_blockwise_copy = [&]() { if constexpr(DirectLoad) @@ -1709,284 +1628,22 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 num_k_block_main_loop); // shuffle C and write out - { - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - tensor_operation::element_wise::PassThrough pass_through{}; - const auto& vpgr_to_lds_element_op = [&] { - if constexpr(DoElementwiseBeforeCShuffle) - { - return c_element_op; - } - else - { - return pass_through; - } - }; - const auto& lds_to_global_element_op = [&] { - if constexpr(!DoElementwiseBeforeCShuffle) - { - return c_element_op; - } - else - { - return pass_through; - } - }; - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< - AccDataType, - CShuffleDataType, - decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), - decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), - conditional_t, - Sequence, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - vpgr_to_lds_element_op()}; - - using EDataType = CDataType; - - const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = - MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n, problem.MBlock, problem.NBlock); - - const auto ds_grid_buf = generate_tuple( - [&](auto i) { - return make_dynamic_buffer( - p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); - }, - Number{}); - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_desc_refs = concat_tuple_of_reference( - tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_buf_refs = concat_tuple_of_reference( - tie(c_shuffle_block_buf), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); - - // tuple of starting index of C/Ds blockwise copy - const auto idx_c_ds_block_begin = container_concat( - make_tuple(make_multi_index(0, 0, 0, 0)), - generate_tuple( - [&](auto) { - return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); - }, - Number{})); - - const auto e_grid_desc_mblock_mperblock_nblock_nperblock = - c_grid_desc_mblock_mperblock_nblock_nperblock; - - using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; - const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; - - auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - conditional_t, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence - // support arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVectors, - CShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags - {c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), - lds_to_global_element_op()}; - - // space filling curve for threadwise C in VGPR - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - // space filling curve for shuffled blockwise C/D/E - constexpr auto sfc_cde_block = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); - - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - cde_block_copy_lds_and_global.Run( - c_ds_desc_refs, - c_ds_buf_refs, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - tie(c_grid_buf)); - - if constexpr(access_id < num_access - 1) - { - constexpr auto cde_lds_and_global_step = - sfc_cde_block.GetForwardStep(access_id); - - // move on Ds - static_for<0, NumDTensor, 1>{}([&](auto i) { - cde_block_copy_lds_and_global.MoveSrcSliceWindow( - c_ds_desc_refs, i + I1, cde_lds_and_global_step); - }); - - // move on E - cde_block_copy_lds_and_global.MoveDstSliceWindow( - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - I0, - cde_lds_and_global_step); - } - }); - } + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + Base::template RunMultiDEpilogue(blockwise_gemm_pipeline, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_m_id, + block_n_id, + p_shared, + p_ds_grid, + p_c_grid, + c_element_op); } template ( p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); @@ -2123,10 +1778,12 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); auto get_a_blockwise_copy = [&]() { if constexpr(DirectLoad) @@ -2293,284 +1950,22 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 num_k_block_main_loop); // shuffle C and write out - { - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared_0), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - tensor_operation::element_wise::PassThrough pass_through{}; - const auto& vpgr_to_lds_element_op = [&] { - if constexpr(DoElementwiseBeforeCShuffle) - { - return c_element_op; - } - else - { - return pass_through; - } - }; - const auto& lds_to_global_element_op = [&] { - if constexpr(!DoElementwiseBeforeCShuffle) - { - return c_element_op; - } - else - { - return pass_through; - } - }; - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< - AccDataType, - CShuffleDataType, - decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), - decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), - conditional_t, - Sequence, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - vpgr_to_lds_element_op()}; - - using EDataType = CDataType; - - const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = - MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n, problem.MBlock, problem.NBlock); - - const auto ds_grid_buf = generate_tuple( - [&](auto i) { - return make_dynamic_buffer( - p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); - }, - Number{}); - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_desc_refs = concat_tuple_of_reference( - tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_buf_refs = concat_tuple_of_reference( - tie(c_shuffle_block_buf), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); - - // tuple of starting index of C/Ds blockwise copy - const auto idx_c_ds_block_begin = container_concat( - make_tuple(make_multi_index(0, 0, 0, 0)), - generate_tuple( - [&](auto) { - return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); - }, - Number{})); - - const auto e_grid_desc_mblock_mperblock_nblock_nperblock = - c_grid_desc_mblock_mperblock_nblock_nperblock; - - using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; - const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; - - auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - conditional_t, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence - // support arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVectors, - CShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags - {c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), - lds_to_global_element_op()}; - - // space filling curve for threadwise C in VGPR - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - // space filling curve for shuffled blockwise C/D/E - constexpr auto sfc_cde_block = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); - - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - cde_block_copy_lds_and_global.Run( - c_ds_desc_refs, - c_ds_buf_refs, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - tie(c_grid_buf)); - - if constexpr(access_id < num_access - 1) - { - constexpr auto cde_lds_and_global_step = - sfc_cde_block.GetForwardStep(access_id); - - // move on Ds - static_for<0, NumDTensor, 1>{}([&](auto i) { - cde_block_copy_lds_and_global.MoveSrcSliceWindow( - c_ds_desc_refs, i + I1, cde_lds_and_global_step); - }); - - // move on E - cde_block_copy_lds_and_global.MoveDstSliceWindow( - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - I0, - cde_lds_and_global_step); - } - }); - } + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + Base::template RunMultiDEpilogue(blockwise_gemm_pipeline, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_m_id, + block_n_id, + p_shared_0, + p_ds_grid, + p_c_grid, + c_element_op); } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp index 13061c7cd19..36895f55eab 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp @@ -10,11 +10,10 @@ #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_ab_scale_selector.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" -#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp" #define DEBUG_LOG 0 @@ -33,7 +32,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(GridwiseGemm::MaxBlockSize, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) @@ -41,7 +40,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); @@ -117,30 +116,114 @@ template struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 + : public GridwiseGemm_xdl_cshuffle_base< + ALayout, + BLayout, + CLayout, + LDSTypeA, + LDSTypeB, + AccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + ComputeTypeA, + ComputeTypeB, + false> // ForceNaiveLayout { + using Base = GridwiseGemm_xdl_cshuffle_base< + ALayout, + BLayout, + CLayout, + LDSTypeA, + LDSTypeB, + AccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + ComputeTypeA, + ComputeTypeB, + false>; // ForceNaiveLayout + + using Base::AK0Number; + using Base::AK1Number; + using Base::BK0Number; + using Base::BK1Number; + using Base::I0; + using Base::I1; + using Base::I2; + using ThisThreadBlock = typename Base::ThisThreadBlock; + using Base::CShuffleBlockTransferScalarPerVector_NPerBlock; + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; + using Base::NumDTensor; + using AScaleType = float; using BScaleType = float; - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; - - static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock = - CDEShuffleBlockTransferScalarPerVectors{}[I0]; - - // K1 should be Number<...> - static constexpr auto AK0Number = Number{}; - static constexpr auto BK0Number = Number{}; - static constexpr auto AK1Number = Number{}; - static constexpr auto BK1Number = Number{}; - - static constexpr index_t NumDTensor = DsDataType::Size(); - static constexpr auto MakeDsGridPointer() { return generate_tuple( @@ -171,8 +254,6 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 is_single_rate_mfma, is_scale_mfma>::selected_mfma.k_per_blk); - using ThisThreadBlock = ThisThreadBlock; - __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) { return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); @@ -731,306 +812,31 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 index_t scale_b_k_split_offset; // B scale matrix offset }; - __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - constexpr index_t WaveSize = BlockSize / (MWave * NWave); - // A matrix in LDS memory, dst of blockwise copy - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); - } - // xor tensor transformation request more unnecessary vgpr usage, would cause register spill - // in some cases. - else if constexpr(is_same::value) - { - constexpr auto a_lds_block_desc = - make_naive_tensor_descriptor(make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple(make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<1, 0>{}, Sequence<2>{}), - make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - - return a_lds_block_desc_permuted; - } - else // ColumnMajor A - { - // kfold and mpair dimension is not always required. - // more dimension in merge_transform increase the difficulty of generating immarg offset - // for compiler. - constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); - constexpr auto M1 = MPerBlock / M0; - - constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); - constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; - constexpr auto KThreadRead = WaveSize / MPerXdl; - constexpr auto K0PerThreadRead = AK0Number / KThreadRead; - - constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128) - ? 1 - : 128 / (AK1Number * M0 * sizeof(LDSTypeA)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=mpair<=n0 - constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128) - ? 1 - : ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0 - ? M0 - : 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))); - - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - AK1Number)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); - - constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<1>{}, - Sequence<2>{}, - Sequence<0, 3>{}, - Sequence<4, 5>{}, - Sequence<6>{}, - Sequence<7>{})); - - constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(Number{}, - Number{}, - Number{}, - Number{})), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return a_lds_block_desc_ak0_m_ak1; - } - } - - __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - constexpr index_t WaveSize = BlockSize / (MWave * NWave); - // B matrix in LDS memory, dst of blockwise copy - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(BK0Number, Number{}, BK1Number), - make_tuple(BK1Number, Number{}, I1)); - } - else if constexpr(is_same::value) - { - constexpr auto b_lds_block_desc = - make_naive_tensor_descriptor(make_tuple(BK0Number, Number{}, BK1Number), - make_tuple(BK1Number, Number{}, I1)); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc, - make_tuple(make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<1, 0>{}, Sequence<2>{}), - make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - - return b_lds_block_desc_permuted; - } - else // RowMajor B - { - constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); - constexpr auto N1 = NPerBlock / N0; - - constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); - constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; - constexpr auto KThreadRead = WaveSize / NPerXdl; - constexpr auto K0PerThreadRead = BK0Number / KThreadRead; - - constexpr auto kfold = (BK1Number * N0 * sizeof(LDSTypeB) > 128) - ? 1 - : 128 / (BK1Number * N0 * sizeof(LDSTypeB)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=npair<=n0 - constexpr auto npair = (BK1Number * NPerXdl * sizeof(LDSTypeB) > 128) - ? 1 - : ((128 / (BK1Number * NPerXdl * sizeof(LDSTypeB))) > N0 - ? N0 - : 128 / (BK1Number * NPerXdl * sizeof(LDSTypeB))); - - constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - BK1Number)); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(BK1Number)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); - - constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<1>{}, - Sequence<2>{}, - Sequence<0, 3>{}, - Sequence<4, 5>{}, - Sequence<6>{}, - Sequence<7>{})); - - constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(Number{}, - Number{}, - Number{}, - Number{})), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return b_lds_block_desc_bk0_n_bk1; - } - } - - __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - I1, - Number{})); - - return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; - } - - using BlockwiseGemmPipe = - remove_cvref_t())>; - - __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - - // lds max alignment - constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); - - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size_aligned = math::integer_least_multiple( - b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); - - // LDS allocation for C shuffle in LDS - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - constexpr auto c_block_size = - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - - return math::max((a_block_space_size_aligned * sizeof(LDSTypeA) + - b_block_space_size_aligned * sizeof(LDSTypeB)), - c_block_size * sizeof(CShuffleDataType)); - } + using BlockwiseGemmPipe = remove_cvref_t< + decltype(BlockGemmABScalePipeline_Selector< + BlkGemmPipelineVer, + BlkGemmPipeSched, + BlockSize, + LDSTypeA, + LDSTypeB, + ComputeTypeA, + AccDataType, + decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch())), + decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch())), + decltype(MakeAMmaTileDescriptor_M0_M1_M2_K( + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()))), + decltype(MakeBMmaTileDescriptor_N0_N1_N2_K( + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()))), + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack>())>; IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) @@ -1278,8 +1084,6 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); const auto a_scale_grid_buf = make_dynamic_buffer( p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); @@ -1315,10 +1119,12 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); // A matrix blockwise copy auto a_blockwise_copy = @@ -1496,262 +1302,23 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 num_k_block_main_loop); // shuffle C and write out - { - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - // transposed XDL - // // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); - - // // TODO: hacky, fix it! - // only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp = - blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5); - constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6); - constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2)), // M2 = MPerXdl - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2, // N2 * N3 * N4 = NPerXdl - N3, - N4))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4>{}, Sequence<>{}, Sequence<1, 3, 5, 6, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - n_thread_data_on_block_idx[I2], - n_thread_data_on_block_idx[I3], - n_thread_data_on_block_idx[I4]), - tensor_operation::element_wise::PassThrough{}}; - - using EDataType = CDataType; - - const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( - problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); - - const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = - MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n, problem.MBlock, problem.NBlock); - - const auto ds_grid_buf = generate_tuple( - [&](auto i) { - return make_dynamic_buffer( - p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); - }, - Number{}); - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_desc_refs = concat_tuple_of_reference( - tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_buf_refs = concat_tuple_of_reference( - tie(c_shuffle_block_buf), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); - - // tuple of starting index of C/Ds blockwise copy - const auto idx_c_ds_block_begin = container_concat( - make_tuple(make_multi_index(0, 0, 0, 0)), - generate_tuple( - [&](auto) { - return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); - }, - Number{})); - - const auto e_grid_desc_mblock_mperblock_nblock_nperblock = - c_grid_desc_mblock_mperblock_nblock_nperblock; - - using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; - const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; - - auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence - // support arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVectors, - CShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags - {c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), - c_element_op}; - - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - // space filling curve for shuffled blockwise C/D/E - constexpr auto sfc_cde_block = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); - - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - cde_block_copy_lds_and_global.Run( - c_ds_desc_refs, - c_ds_buf_refs, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - tie(c_grid_buf)); - - if constexpr(access_id < num_access - 1) - { - constexpr auto cde_lds_and_global_step = - sfc_cde_block.GetForwardStep(access_id); - - // move on Ds - static_for<0, NumDTensor, 1>{}([&](auto i) { - cde_block_copy_lds_and_global.MoveSrcSliceWindow( - c_ds_desc_refs, i + I1, cde_lds_and_global_step); - }); - - // move on E - cde_block_copy_lds_and_global.MoveDstSliceWindow( - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - I0, - cde_lds_and_global_step); - } - }); - } + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + Base::template RunMultiDEpilogue( + blockwise_gemm_pipeline, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_m_id, + block_n_id, + p_shared, + p_ds_grid, + p_c_grid, + c_element_op); } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp index 5fe77f2b711..e810a467e76 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp @@ -10,11 +10,10 @@ #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" -#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp" #define DEBUG_LOG 0 @@ -41,7 +40,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char + p_shared[GridwiseGemm::template GetSharedMemoryNumberOfByte(get_device_arch())]; // Full K needed for matrix B const index_t Kt = karg.K; @@ -84,8 +84,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char + p_shared[GridwiseGemm::template GetSharedMemoryNumberOfByte(get_device_arch())]; + __shared__ char + p_shared1[GridwiseGemm::template GetSharedMemoryNumberOfByte(get_device_arch())]; // Full K needed for matrix B const index_t Kt = karg.K; @@ -165,26 +167,111 @@ template struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle + : public GridwiseGemm_xdl_cshuffle_base< + ALayout, + BLayout, + CLayout, + LDSTypeA, + LDSTypeB, + AccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + ComputeTypeA, + ComputeTypeB, + false> // ForceNaiveLayout { - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; - - static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock = - CDEShuffleBlockTransferScalarPerVectors{}[I0]; - // K1 should be Number<...> - static constexpr auto AK0Number = Number{}; - static constexpr auto BK0Number = Number{}; - static constexpr auto AK1Number = Number{}; - static constexpr auto BK1Number = Number{}; - static constexpr auto BlockSizeNumber = Number{}; + using Base = GridwiseGemm_xdl_cshuffle_base< + ALayout, + BLayout, + CLayout, + LDSTypeA, + LDSTypeB, + AccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + ComputeTypeA, + ComputeTypeB, + false>; // ForceNaiveLayout + + using Base::AK0Number; + using Base::AK1Number; + using Base::BK0Number; + using Base::BK1Number; + using Base::I0; + using Base::I1; + using Base::I2; + using ThisThreadBlock = typename Base::ThisThreadBlock; + using Base::CShuffleBlockTransferScalarPerVector_NPerBlock; + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::NumDTensor; - static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr auto BlockSizeNumber = Number{}; static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); static constexpr bool is_single_rate_mfma = @@ -232,8 +319,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle using DsGridPointer = decltype(MakeDsGridPointer()); - using ThisThreadBlock = ThisThreadBlock; - __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) { return std::make_tuple(Block2CTileMapDefault::CalculateGridSize(M, N), 1, KBatch); @@ -720,124 +805,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle index_t a_k_split_offset; }; - __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t WaveSize = BlockSize / (MWave * NWave); - // A matrix in LDS memory, dst of blockwise copy - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); - } - // xor tensor transformation request more unnecessary vgpr usage, would cause register spill - // in some cases. - else if constexpr(is_same::value) - { - constexpr auto a_lds_block_desc = - make_naive_tensor_descriptor(make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple(make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<1, 0>{}, Sequence<2>{}), - make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - - return a_lds_block_desc_permuted; - } - else // ColumnMajor A - { - // kfold and mpair dimension is not always required. - // more dimension in merge_transform increase the difficulty of generating immarg offset - // for compiler. - constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); - constexpr auto M1 = MPerBlock / M0; - - constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); - constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; - constexpr auto KThreadRead = WaveSize / MPerXdl; - constexpr auto K0PerThreadRead = AK0Number / KThreadRead; - - constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128) - ? 1 - : 128 / (AK1Number * M0 * sizeof(LDSTypeA)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=mpair<=n0 - constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128) - ? 1 - : ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0 - ? M0 - : 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))); - - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - AK1Number)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); - - constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<1>{}, - Sequence<2>{}, - Sequence<0, 3>{}, - Sequence<4, 5>{}, - Sequence<6>{}, - Sequence<7>{})); - - constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(Number{}, - Number{}, - Number{}, - Number{})), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return a_lds_block_desc_ak0_m_ak1; - } - } - __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() { // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack @@ -859,52 +826,31 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; } - using BlockwiseGemmPipe = - remove_cvref_t())>; - - __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - // lds max alignment - constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); - - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); - - // LDS allocation for C shuffle in LDS - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - constexpr auto c_block_size = - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - - return math::max(a_block_space_size_aligned * sizeof(LDSTypeA), - c_block_size * sizeof(CShuffleDataType)); - } + using BlockwiseGemmPipe = remove_cvref_t< + decltype(BlockGemmBPreshufflePipeline_Selector< + BlkGemmPipelineVer, + BlkGemmPipeSched, + BlockSize, + LDSTypeA, + LDSTypeB, + ComputeTypeA, + AccDataType, + decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch())), + decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), + decltype(MakeAMmaTileDescriptor_M0_M1_M2_K( + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()))), + decltype(MakeBMmaTileDescriptor_N0_N1_N2_K( + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack>())>; template < InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> @@ -1225,8 +1171,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( p_b_grid, b_grid_desc_bpreshuffled.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); @@ -1250,7 +1194,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); @@ -1339,261 +1284,23 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle num_k_block_main_loop); // shuffle C and write out - { - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; - - using EDataType = CDataType; - - const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( - problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); - - const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = - MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n, problem.MBlock, problem.NBlock); - - const auto ds_grid_buf = generate_tuple( - [&](auto i) { - return make_dynamic_buffer( - p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); - }, - Number{}); - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_desc_refs = concat_tuple_of_reference( - tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_buf_refs = concat_tuple_of_reference( - tie(c_shuffle_block_buf), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); - - // tuple of starting index of C/Ds blockwise copy - const auto idx_c_ds_block_begin = container_concat( - make_tuple(make_multi_index(0, 0, 0, 0)), - generate_tuple( - [&](auto) { - return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); - }, - Number{})); - - const auto e_grid_desc_mblock_mperblock_nblock_nperblock = - c_grid_desc_mblock_mperblock_nblock_nperblock; - - using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; - const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; - - auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence - // support arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVectors, - CShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags - {c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), - c_element_op}; - - // space filling curve for threadwise C in VGPR - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - // space filling curve for shuffled blockwise C/D/E - constexpr auto sfc_cde_block = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); - - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - cde_block_copy_lds_and_global.Run( - c_ds_desc_refs, - c_ds_buf_refs, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - tie(c_grid_buf)); - - if constexpr(access_id < num_access - 1) - { - constexpr auto cde_lds_and_global_step = - sfc_cde_block.GetForwardStep(access_id); - - // move on Ds - static_for<0, NumDTensor, 1>{}([&](auto i) { - cde_block_copy_lds_and_global.MoveSrcSliceWindow( - c_ds_desc_refs, i + I1, cde_lds_and_global_step); - }); - - // move on E - cde_block_copy_lds_and_global.MoveDstSliceWindow( - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - I0, - cde_lds_and_global_step); - } - }); - } + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + Base::template RunMultiDEpilogue( + blockwise_gemm_pipeline, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_m_id, + block_n_id, + p_shared, + p_ds_grid, + p_c_grid, + c_element_op); } template ( p_b_grid, b_grid_desc_bpreshuffled.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); @@ -1691,7 +1396,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); @@ -1786,261 +1492,23 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle num_k_block_main_loop); // shuffle C and write out - { - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; - - using EDataType = CDataType; - - const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( - problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); - - const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = - MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n, problem.MBlock, problem.NBlock); - - const auto ds_grid_buf = generate_tuple( - [&](auto i) { - return make_dynamic_buffer( - p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); - }, - Number{}); - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_desc_refs = concat_tuple_of_reference( - tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_buf_refs = concat_tuple_of_reference( - tie(c_shuffle_block_buf), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); - - // tuple of starting index of C/Ds blockwise copy - const auto idx_c_ds_block_begin = container_concat( - make_tuple(make_multi_index(0, 0, 0, 0)), - generate_tuple( - [&](auto) { - return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); - }, - Number{})); - - const auto e_grid_desc_mblock_mperblock_nblock_nperblock = - c_grid_desc_mblock_mperblock_nblock_nperblock; - - using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; - const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; - - auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence - // support arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVectors, - CShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags - {c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), - c_element_op}; - - // space filling curve for threadwise C in VGPR - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - // space filling curve for shuffled blockwise C/D/E - constexpr auto sfc_cde_block = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); - - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - cde_block_copy_lds_and_global.Run( - c_ds_desc_refs, - c_ds_buf_refs, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - tie(c_grid_buf)); - - if constexpr(access_id < num_access - 1) - { - constexpr auto cde_lds_and_global_step = - sfc_cde_block.GetForwardStep(access_id); - - // move on Ds - static_for<0, NumDTensor, 1>{}([&](auto i) { - cde_block_copy_lds_and_global.MoveSrcSliceWindow( - c_ds_desc_refs, i + I1, cde_lds_and_global_step); - }); - - // move on E - cde_block_copy_lds_and_global.MoveDstSliceWindow( - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - I0, - cde_lds_and_global_step); - } - }); - } + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + Base::template RunMultiDEpilogue( + blockwise_gemm_pipeline, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_m_id, + block_n_id, + p_shared, + p_ds_grid, + p_c_grid, + c_element_op); } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp index b02d9c7f77a..fa0f401743d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp @@ -10,11 +10,10 @@ #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_selector.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" -#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp" #define DEBUG_LOG 0 @@ -33,7 +32,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(GridwiseGemm::MaxBlockSize, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle( @@ -42,7 +41,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char + p_shared[GridwiseGemm::template GetSharedMemoryNumberOfByte(get_device_arch())]; GridwiseGemm::template Run( karg.p_a_grid, @@ -69,7 +69,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(GridwiseGemm::MaxBlockSize, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle_2lds( @@ -78,8 +78,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char + p_shared[GridwiseGemm::template GetSharedMemoryNumberOfByte(get_device_arch())]; + __shared__ char + p_shared1[GridwiseGemm::template GetSharedMemoryNumberOfByte(get_device_arch())]; GridwiseGemm::template Run_2Lds( karg.p_a_grid, @@ -154,30 +156,115 @@ template struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle + : public GridwiseGemm_xdl_cshuffle_base< + ALayout, + BLayout, + CLayout, + LDSTypeA, + LDSTypeB, + AccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + ComputeTypeA, + ComputeTypeB, + false> { + using Base = GridwiseGemm_xdl_cshuffle_base< + ALayout, + BLayout, + CLayout, + LDSTypeA, + LDSTypeB, + AccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + ComputeTypeA, + ComputeTypeB, + false>; + + using Base::AK0Number; + using Base::AK1Number; + using Base::BK0Number; + using Base::BK1Number; + using Base::I0; + using Base::I1; + using Base::I2; + using ThisThreadBlock = typename Base::ThisThreadBlock; + using Base::CShuffleBlockTransferScalarPerVector_NPerBlock; + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::NumDTensor; + using AScaleType = float; using BScaleType = float; - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; - - static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock = - CDEShuffleBlockTransferScalarPerVectors{}[I0]; - // K1 should be Number<...> - static constexpr auto AK0Number = Number{}; - static constexpr auto BK0Number = Number{}; - static constexpr auto AK1Number = Number{}; - static constexpr auto BK1Number = Number{}; static constexpr auto BlockSizeNumber = Number{}; - static constexpr index_t NumDTensor = DsDataType::Size(); - using mfma_selector = MfmaSelector; static constexpr index_t KPack = math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk); @@ -211,8 +298,6 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle using DsGridPointer = decltype(MakeDsGridPointer()); - using ThisThreadBlock = ThisThreadBlock; - __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) { return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); @@ -677,124 +762,6 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle index_t b_k_split_offset; }; - __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t WaveSize = BlockSize / (MWave * NWave); - // A matrix in LDS memory, dst of blockwise copy - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); - } - // xor tensor transformation request more unnecessary vgpr usage, would cause register spill - // in some cases. - else if constexpr(is_same::value) - { - constexpr auto a_lds_block_desc = - make_naive_tensor_descriptor(make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple(make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<1, 0>{}, Sequence<2>{}), - make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - - return a_lds_block_desc_permuted; - } - else // ColumnMajor A - { - // kfold and mpair dimension is not always required. - // more dimension in merge_transform increase the difficulty of generating immarg offset - // for compiler. - constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); - constexpr auto M1 = MPerBlock / M0; - - constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); - constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; - constexpr auto KThreadRead = WaveSize / MPerXdl; - constexpr auto K0PerThreadRead = AK0Number / KThreadRead; - - constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128) - ? 1 - : 128 / (AK1Number * M0 * sizeof(LDSTypeA)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=mpair<=n0 - constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128) - ? 1 - : ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0 - ? M0 - : 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))); - - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - AK1Number)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); - - constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<1>{}, - Sequence<2>{}, - Sequence<0, 3>{}, - Sequence<4, 5>{}, - Sequence<6>{}, - Sequence<7>{})); - - constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(Number{}, - Number{}, - Number{}, - Number{})), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return a_lds_block_desc_ak0_m_ak1; - } - } - __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() { // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack @@ -816,55 +783,34 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; } - using BlockwiseGemmPipe = - remove_cvref_t())>; - - __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - // lds max alignment - constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); - - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); - - // LDS allocation for C shuffle in LDS - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - constexpr auto c_block_size = - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - - return math::max(a_block_space_size_aligned * sizeof(LDSTypeA), - c_block_size * sizeof(CShuffleDataType)); - } + using BlockwiseGemmPipe = remove_cvref_t< + decltype(BlockGemmBlockScaleBPreshufflePipeline_Selector< + BlkGemmPipelineVer, + BlkGemmPipeSched, + BlockSize, + LDSTypeA, + LDSTypeB, + ComputeTypeA, + AccDataType, + decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch())), + decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), + decltype(MakeAMmaTileDescriptor_M0_M1_M2_K( + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()))), + decltype(MakeBMmaTileDescriptor_N0_N1_N2_K( + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + ScaleBlockM, + ScaleBlockN, + ScaleBlockK, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack>())>; IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) @@ -1119,8 +1065,6 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( p_b_grid, b_grid_desc_bpreshuffled.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); const auto a_scale_grid_buf = make_dynamic_buffer( p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); @@ -1153,7 +1097,8 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); @@ -1317,261 +1262,25 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle num_k_block_main_loop); // shuffle C and write out - { - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - - // transposed XDL - // // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); - - // // TODO: hacky, fix it! - // only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp = - blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5); - constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6); - constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2)), // M2 = MPerXdl - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2, // N2 * N3 * N4 = NPerXdl - N3, - N4))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4>{}, Sequence<>{}, Sequence<1, 3, 5, 6, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - n_thread_data_on_block_idx[I2], - n_thread_data_on_block_idx[I3], - n_thread_data_on_block_idx[I4]), - tensor_operation::element_wise::PassThrough{}}; - - using EDataType = CDataType; - - const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( - problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); - - const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = - MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n, problem.MBlock, problem.NBlock); - - const auto ds_grid_buf = generate_tuple( - [&](auto i) { - return make_dynamic_buffer( - p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); - }, - Number{}); - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_desc_refs = concat_tuple_of_reference( - tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_buf_refs = concat_tuple_of_reference( - tie(c_shuffle_block_buf), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); - - // tuple of starting index of C/Ds blockwise copy - const auto idx_c_ds_block_begin = container_concat( - make_tuple(make_multi_index(0, 0, 0, 0)), - generate_tuple( - [&](auto) { - return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); - }, - Number{})); - - const auto e_grid_desc_mblock_mperblock_nblock_nperblock = - c_grid_desc_mblock_mperblock_nblock_nperblock; - - using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; - const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; - - auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence - // support arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVectors, - CShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags - {c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), - c_element_op}; - - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - // space filling curve for shuffled blockwise C/D/E - constexpr auto sfc_cde_block = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); - - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - cde_block_copy_lds_and_global.Run( - c_ds_desc_refs, - c_ds_buf_refs, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - tie(c_grid_buf)); - - if constexpr(access_id < num_access - 1) - { - constexpr auto cde_lds_and_global_step = - sfc_cde_block.GetForwardStep(access_id); - - // move on Ds - static_for<0, NumDTensor, 1>{}([&](auto i) { - cde_block_copy_lds_and_global.MoveSrcSliceWindow( - c_ds_desc_refs, i + I1, cde_lds_and_global_step); - }); - - // move on E - cde_block_copy_lds_and_global.MoveDstSliceWindow( - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - I0, - cde_lds_and_global_step); - } - }); - } + + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + + Base::template RunMultiDEpilogue( + blockwise_gemm_pipeline, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_m_id, + block_n_id, + p_shared, + p_ds_grid, + p_c_grid, + c_element_op); } template ( p_b_grid, b_grid_desc_bpreshuffled.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); const auto a_scale_grid_buf = make_dynamic_buffer( p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); @@ -1651,7 +1358,8 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); @@ -1820,261 +1528,24 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle num_k_block_main_loop); // shuffle C and write out - { - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - - // transposed XDL - // // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); - - // // TODO: hacky, fix it! - // only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp = - blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5); - constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6); - constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2)), // M2 = MPerXdl - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2, // N2 * N3 * N4 = NPerXdl - N3, - N4))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4>{}, Sequence<>{}, Sequence<1, 3, 5, 6, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - n_thread_data_on_block_idx[I2], - n_thread_data_on_block_idx[I3], - n_thread_data_on_block_idx[I4]), - tensor_operation::element_wise::PassThrough{}}; - - using EDataType = CDataType; - - const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( - problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); - - const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = - MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n, problem.MBlock, problem.NBlock); - - const auto ds_grid_buf = generate_tuple( - [&](auto i) { - return make_dynamic_buffer( - p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); - }, - Number{}); - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_desc_refs = concat_tuple_of_reference( - tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_buf_refs = concat_tuple_of_reference( - tie(c_shuffle_block_buf), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); - - // tuple of starting index of C/Ds blockwise copy - const auto idx_c_ds_block_begin = container_concat( - make_tuple(make_multi_index(0, 0, 0, 0)), - generate_tuple( - [&](auto) { - return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); - }, - Number{})); - - const auto e_grid_desc_mblock_mperblock_nblock_nperblock = - c_grid_desc_mblock_mperblock_nblock_nperblock; - - using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; - const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; - - auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence - // support arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVectors, - CShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags - {c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), - c_element_op}; - - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - // space filling curve for shuffled blockwise C/D/E - constexpr auto sfc_cde_block = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); - - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - cde_block_copy_lds_and_global.Run( - c_ds_desc_refs, - c_ds_buf_refs, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - tie(c_grid_buf)); - - if constexpr(access_id < num_access - 1) - { - constexpr auto cde_lds_and_global_step = - sfc_cde_block.GetForwardStep(access_id); - - // move on Ds - static_for<0, NumDTensor, 1>{}([&](auto i) { - cde_block_copy_lds_and_global.MoveSrcSliceWindow( - c_ds_desc_refs, i + I1, cde_lds_and_global_step); - }); - - // move on E - cde_block_copy_lds_and_global.MoveDstSliceWindow( - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - I0, - cde_lds_and_global_step); - } - }); - } + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + + Base::template RunMultiDEpilogue( + blockwise_gemm_pipeline, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_m_id, + block_n_id, + p_shared, + p_ds_grid, + p_c_grid, + c_element_op); } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp index a41f096abb1..d2dd1d243cb 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp @@ -15,11 +15,10 @@ #include "ck/utility/common_header.hpp" #include "ck/utility/env.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp" namespace ck { -#ifndef KERNEL_GEMM_XDL_CSHUFFLE_V3_MX -#define KERNEL_GEMM_XDL_CSHUFFLE_V3_MX // Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same // kernel function Blockers: // 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on @@ -42,7 +41,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #if defined(__gfx950__) if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); @@ -78,8 +77,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) { // Pass two lds pointer is the key to tell compiler that ds_read/write // operate on different lds chunk at same time without order dependecy - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); @@ -97,7 +96,6 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) ignore = karg; #endif // end of if (defined(__gfx9__)) } -#endif template struct GridwiseGemmMX_xdl_cshuffle_v3 + : public GridwiseGemm_xdl_cshuffle_base< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + Tuple<>, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + ComputeTypeA, + ComputeTypeB, + BlkGemmPipelineVer == BlockGemmPipelineVersion::v4, + true, + true> { + using Base = GridwiseGemm_xdl_cshuffle_base< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + Tuple<>, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + ComputeTypeA, + ComputeTypeB, + BlkGemmPipelineVer == BlockGemmPipelineVersion::v4, + true, + true>; + + using Base::AK0Number; + using Base::AK1Number; + using Base::BK0Number; + using Base::BK1Number; + using ThisThreadBlock = typename Base::ThisThreadBlock; + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -166,12 +266,6 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 static constexpr auto I8 = Number<8>{}; static constexpr auto I9 = Number<9>{}; - // K1 should be Number<...> - static constexpr auto AK0Number = Number{}; - static constexpr auto BK0Number = Number{}; - static constexpr auto AK1Number = Number{}; - static constexpr auto BK1Number = Number{}; - static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); static constexpr bool is_single_rate_mfma = false; static constexpr auto is_scale_mfma = true; @@ -185,7 +279,6 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 // Should be a multiple of k_per_blk. // TODO: Move this to blockwise pipeline base // KPack in packed data types for pk A/B - static constexpr index_t APackedSize = packed_size_v; static constexpr index_t BPackedSize = packed_size_v; @@ -199,8 +292,6 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 is_scale_mfma>::selected_mfma.k_per_blk / APackedSize); - using ThisThreadBlock = ThisThreadBlock; - __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) { return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); @@ -809,313 +900,34 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 index_t c_reduce_offset; }; - __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - constexpr index_t WaveSize = BlockSize / (MWave * NWave); - - // A matrix in LDS memory, dst of blockwise copy - if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) - { - // contiguous in LDS - return make_naive_tensor_descriptor( - make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); - } - // xor tensor transformation request more unnecessary vgpr usage, would cause register spill - // in some cases. - else if constexpr(is_same::value) - { - constexpr auto a_lds_block_desc = - make_naive_tensor_descriptor(make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple(make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<1, 0>{}, Sequence<2>{}), - make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - - return a_lds_block_desc_permuted; - } - else // ColumnMajor A - { - // kfold and mpair dimension is not always required. - // more dimension in merge_transform increase the difficulty of generating immarg offset - // for compiler. - constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); - constexpr auto M1 = MPerBlock / M0; - - constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); - constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; - constexpr auto KThreadRead = WaveSize / MPerXdl; - constexpr auto K0PerThreadRead = AK0Number / KThreadRead; - - constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) - ? 1 - : 128 / (AK1Number * M0 * sizeof(ADataType)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=mpair<=n0 - constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128) - ? 1 - : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0 - ? M0 - : 128 / (AK1Number * MPerXdl * sizeof(ADataType))); - - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - AK1Number)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); - - constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<1>{}, - Sequence<2>{}, - Sequence<0, 3>{}, - Sequence<4, 5>{}, - Sequence<6>{}, - Sequence<7>{})); - - constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(Number{}, - Number{}, - Number{}, - Number{})), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return a_lds_block_desc_ak0_m_ak1; - } - } - - __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - constexpr index_t WaveSize = BlockSize / (MWave * NWave); - // B matrix in LDS memory, dst of blockwise copy - if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) - { - // contiguous in lds - return make_naive_tensor_descriptor( - make_tuple(BK0Number, Number{}, BK1Number), - make_tuple(BK1Number, Number{}, I1)); - } - else if constexpr(is_same::value) - { - // NLdsLayer * K0 as logical Bank - constexpr auto b_lds_block_desc = - make_naive_tensor_descriptor(make_tuple(BK0Number, Number{}, BK1Number), - make_tuple(BK1Number, Number{}, I1)); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc, - make_tuple(make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<1, 0>{}, Sequence<2>{}), - make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - - return b_lds_block_desc_permuted; - } - else // RowMajor B - { - constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); - constexpr auto N1 = NPerBlock / N0; - - constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); - constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; - constexpr auto KThreadRead = WaveSize / NPerXdl; - constexpr auto K0PerThreadRead = BK0Number / KThreadRead; - - constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) - ? 1 - : 128 / (BK1Number * N0 * sizeof(BDataType)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=npair<=n0 - constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128) - ? 1 - : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0 - ? N0 - : 128 / (BK1Number * NPerXdl * sizeof(BDataType))); - - constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - BK1Number)); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(BK1Number)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); - - constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<1>{}, - Sequence<2>{}, - Sequence<0, 3>{}, - Sequence<4, 5>{}, - Sequence<6>{}, - Sequence<7>{})); - - constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(Number{}, - Number{}, - Number{}, - Number{})), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return b_lds_block_desc_bk0_n_bk1; - } - } - - __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - I1, - Number{})); - - return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; - } - - using BlockwiseGemmPipe = - remove_cvref_t())>; - - __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - - // lds max alignment - constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); - - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size_aligned = math::integer_least_multiple( - b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); - - // LDS allocation for C shuffle in LDS - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - constexpr auto c_block_size = - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - - return math::max((a_block_space_size_aligned * sizeof(ADataType) + - b_block_space_size_aligned * sizeof(BDataType)), - c_block_size * sizeof(CShuffleDataType)); - } + using BlockwiseGemmPipe = remove_cvref_t< + decltype(BlockGemmMXPipeline_Selector< + BlkGemmPipelineVer, + BlkGemmPipeSched, + BlockSize, + ScaleBlockSize, + ADataType, + AScaleDataType, + BDataType, + BScaleDataType, + ComputeTypeA, + AccDataType, + decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch())), + decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch())), + decltype(MakeAMmaTileDescriptor_M0_M1_M2_M3_K( + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()))), + decltype(MakeBMmaTileDescriptor_N0_N1_N2_N3_K( + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()))), + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack>())>; IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) @@ -1389,8 +1201,6 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); // A Scale buffer const auto a_scale_grid_buf = make_dynamic_buffer( @@ -1430,10 +1240,12 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_DirectLoad( - static_cast(p_shared), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per - // shuffle - M1, // M1 = MWave - M2, // M2 = MXdlPack - M3, // M3 * M4 * M5 = MPerXdl - M4, - M5)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per - // shuffle - N1, // N1 = NWave - N2, // N2 = NXdlPack - N3))), // N3 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<>{}, - Sequence<0, 2, 4, 6, 7, 8>{}, - Sequence<>{}, - Sequence<1, 3, 5, 9>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))), - make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))), - make_tuple(Sequence<0, 1, 2, 3>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< - AccDataType, - CShuffleDataType, - decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), - decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), - ck::tensor_operation::element_wise::PassThrough, - Sequence, - Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, - 9, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - n_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - m_thread_data_on_block_idx[I5], - n_thread_data_on_block_idx[I3]), - ck::tensor_operation::element_wise::PassThrough{}}; - - // shuffle: blockwise copy C from LDS to global - auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< - ThisThreadBlock, // ThreadGroup - CElementwiseOperation, // ElementwiseOperation, - CGlobalMemoryDataOperation, // DstInMemOp, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - CShuffleDataType, // typename SrcData, - CDataType, // typename DstData, - decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, - true, // bool ThreadTransferSrcResetCoordinateAfterRun, - false> // bool ThreadTransferDstResetCoordinateAfterRun> - {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(0, 0, 0, 0), - c_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_m_id, 0, block_n_id, 0), - c_element_op}; - - // space filling curve for threadwise C in VGPR - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, - Sequence>{}; - - // space filling curve for shuffled blockwise C in global mem - constexpr auto sfc_c_global = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); - - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - c_shuffle_block_copy_lds_to_global.Run( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - c_shuffle_block_buf, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_buf); - - if constexpr(access_id < num_access - 1) - { - constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); - - // move on C - c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); - } - }); - } + Base::template RunEpilogue( + blockwise_gemm_pipeline, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_m_id, + block_n_id, + p_shared, + p_c_grid, + c_element_op); } template ( p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); // A Scale buffer const auto a_scale_grid_buf = make_dynamic_buffer( @@ -1941,10 +1540,12 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_DirectLoad( - static_cast(p_shared_0), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per - // shuffle - M1, // M1 = MWave - M2, // M2 = MXdlPack - M3, // M3 * M4 * M5 = MPerXdl - M4, - M5)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per - // shuffle - N1, // N1 = NWave - N2, // N2 = NXdlPack - N3))), // N3 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<>{}, - Sequence<0, 2, 4, 6, 7, 8>{}, - Sequence<>{}, - Sequence<1, 3, 5, 9>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))), - make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))), - make_tuple(Sequence<0, 1, 2, 3>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< - AccDataType, - CShuffleDataType, - decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), - decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), - ck::tensor_operation::element_wise::PassThrough, - Sequence, - Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, - 9, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - n_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - m_thread_data_on_block_idx[I5], - n_thread_data_on_block_idx[I3]), - ck::tensor_operation::element_wise::PassThrough{}}; - - // shuffle: blockwise copy C from LDS to global - auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< - ThisThreadBlock, // ThreadGroup - CElementwiseOperation, // ElementwiseOperation, - CGlobalMemoryDataOperation, // DstInMemOp, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - CShuffleDataType, // typename SrcData, - CDataType, // typename DstData, - decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, - true, // bool ThreadTransferSrcResetCoordinateAfterRun, - false> // bool ThreadTransferDstResetCoordinateAfterRun> - {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(0, 0, 0, 0), - c_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_m_id, 0, block_n_id, 0), - c_element_op}; - - // space filling curve for threadwise C in VGPR - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, - Sequence>{}; - - // space filling curve for shuffled blockwise C in global mem - constexpr auto sfc_c_global = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); - - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - c_shuffle_block_copy_lds_to_global.Run( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - c_shuffle_block_buf, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_buf); - - if constexpr(access_id < num_access - 1) - { - constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); - - // move on C - c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); - } - }); - } + Base::template RunEpilogue( + blockwise_gemm_pipeline, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_m_id, + block_n_id, + p_shared_0, + p_c_grid, + c_element_op); } template __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) - kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg) + kernel_gemm_xdl_cshuffle_v3_mx_bpreshuffle(typename GridwiseGemm::Argument karg) { #if defined(__gfx950__) if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char + p_shared[GridwiseGemm::template GetSharedMemoryNumberOfByte(get_device_arch())]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); @@ -71,15 +70,17 @@ __global__ enable_if_t __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) - kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg) + kernel_gemm_xdl_cshuffle_v3_mx_bpreshuffle(typename GridwiseGemm::Argument karg) { #if defined(__gfx950__) if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { // Pass two lds pointer is the key to tell compiler that ds_read/write // operate on different lds chunk at same time without order dependecy - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char + p_shared_0[GridwiseGemm::template GetSharedMemoryNumberOfByte(get_device_arch())]; + __shared__ char + p_shared_1[GridwiseGemm::template GetSharedMemoryNumberOfByte(get_device_arch())]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); @@ -97,7 +98,6 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) ignore = karg; #endif // end of if (defined(__gfx9__)) } -#endif template struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle + : public GridwiseGemm_xdl_cshuffle_base< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + Tuple<>, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + ComputeTypeA, + ComputeTypeB, + BlkGemmPipelineVer == BlockGemmPipelineVersion::v4, + true, + true> { + using Base = GridwiseGemm_xdl_cshuffle_base< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + Tuple<>, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + ComputeTypeA, + ComputeTypeB, + BlkGemmPipelineVer == BlockGemmPipelineVersion::v4, + true, + true>; + + using Base::AK0Number; + using Base::AK1Number; + using Base::BK0Number; + using Base::BK1Number; + using ThisThreadBlock = typename Base::ThisThreadBlock; + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -166,12 +267,6 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle static constexpr auto I8 = Number<8>{}; static constexpr auto I9 = Number<9>{}; - // K1 should be Number<...> - static constexpr auto AK0Number = Number{}; - static constexpr auto BK0Number = Number{}; - static constexpr auto AK1Number = Number{}; - static constexpr auto BK1Number = Number{}; - static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); static constexpr bool is_single_rate_mfma = false; static constexpr auto is_scale_mfma = true; @@ -185,7 +280,6 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle // Should be a multiple of k_per_blk. // TODO: Move this to blockwise pipeline base // KPack in packed data types for pk A/B - static constexpr index_t APackedSize = packed_size_v; static constexpr index_t BPackedSize = packed_size_v; @@ -204,8 +298,6 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; static constexpr index_t KRepeat = KPerBlock / KLane / KPack; - using ThisThreadBlock = ThisThreadBlock; - using mx_scale_t = e8m0_bexp_t; static constexpr index_t scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t); static constexpr index_t scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t); @@ -835,126 +927,6 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle index_t c_reduce_offset; }; - __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t WaveSize = BlockSize / (MWave * NWave); - - // A matrix in LDS memory, dst of blockwise copy - if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) - { - // contiguous in LDS - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); - } - // xor tensor transformation request more unnecessary vgpr usage, would cause register spill - // in some cases. - else if constexpr(is_same::value) - { - constexpr auto a_lds_block_desc = - make_naive_tensor_descriptor(make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple(make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<1, 0>{}, Sequence<2>{}), - make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - - return a_lds_block_desc_permuted; - } - else // ColumnMajor A - { - // kfold and mpair dimension is not always required. - // more dimension in merge_transform increase the difficulty of generating immarg offset - // for compiler. - constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); - constexpr auto M1 = MPerBlock / M0; - - constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); - constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; - constexpr auto KThreadRead = WaveSize / MPerXdl; - constexpr auto K0PerThreadRead = AK0Number / KThreadRead; - - constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) - ? 1 - : 128 / (AK1Number * M0 * sizeof(ADataType)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=mpair<=n0 - constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128) - ? 1 - : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0 - ? M0 - : 128 / (AK1Number * MPerXdl * sizeof(ADataType))); - - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - AK1Number)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); - - constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<1>{}, - Sequence<2>{}, - Sequence<0, 3>{}, - Sequence<4, 5>{}, - Sequence<6>{}, - Sequence<7>{})); - - constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(Number{}, - Number{}, - Number{}, - Number{})), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return a_lds_block_desc_ak0_m_ak1; - } - } - __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() { // K0 -> N0/NWave/NXdlPack -> NWave -> NXdlPack -> KLane -> NLane -> KPack @@ -965,74 +937,69 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle Number{})); } - __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + using BlockwiseGemmPipe = remove_cvref_t< + decltype(BlockGemmMXBPreshufflePipeline_Selector< + BlkGemmPipelineVer, + BlkGemmPipeSched, + BlockSize, + ScaleBlockSize, + ADataType, + AScaleDataType, + BDataType, + BScaleDataType, + ComputeTypeA, + AccDataType, + decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch())), + decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), + decltype(MakeAMmaTileDescriptor_M0_M1_M2_M3_K( + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()))), + decltype(MakeBMmaTileDescriptor_N0_N1_N2_N3_K( + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack>())>; + + template < + InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> + __device__ static bool constexpr IsValidCompilationParameter() { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - // constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - I1, - Number{})); - - return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; - } - - using BlockwiseGemmPipe = - remove_cvref_t())>; - - __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - - // lds max alignment - constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); - - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); - - // LDS allocation for C shuffle in LDS - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + constexpr bool valid = ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + CDataType, + CGlobalMemoryDataOperation_>(); + if constexpr(!valid) + { + return false; + } - constexpr auto c_block_size = - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + if constexpr(NXdlPerWave % CShuffleNXdlPerWavePerShuffle != 0) + { + return false; + } - return math::max(a_block_space_size_aligned * sizeof(ADataType), - c_block_size * sizeof(CShuffleDataType)); + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if constexpr(MXdlPerWave < 4) + { + // return false; + } + } + return true; } - IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) - // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ static constexpr bool CheckValidity(const Argument& karg) { @@ -1332,7 +1299,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); @@ -1485,225 +1453,15 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle num_k_block_main_loop); // shuffle C and write out - { - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); - static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 && - CShuffleNXdlPerWavePerShuffle % NXdlPack == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); - constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8); - constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per - // shuffle - M1, // M1 = MWave - M2, // M2 = MXdlPack - M3, // M3 * M4 * M5 = MPerXdl - M4, - M5)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per - // shuffle - N1, // N1 = NWave - N2, // N2 = NXdlPack - N3))), // N3 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<>{}, - Sequence<0, 2, 4, 6, 7, 8>{}, - Sequence<>{}, - Sequence<1, 3, 5, 9>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))), - make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))), - make_tuple(Sequence<0, 1, 2, 3>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< - AccDataType, - CShuffleDataType, - decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), - decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), - ck::tensor_operation::element_wise::PassThrough, - Sequence, - Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, - 9, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - n_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - m_thread_data_on_block_idx[I5], - n_thread_data_on_block_idx[I3]), - ck::tensor_operation::element_wise::PassThrough{}}; - - // shuffle: blockwise copy C from LDS to global - auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< - ThisThreadBlock, // ThreadGroup - CElementwiseOperation, // ElementwiseOperation, - CGlobalMemoryDataOperation, // DstInMemOp, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - CShuffleDataType, // typename SrcData, - CDataType, // typename DstData, - decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, - true, // bool ThreadTransferSrcResetCoordinateAfterRun, - false> // bool ThreadTransferDstResetCoordinateAfterRun> - {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(0, 0, 0, 0), - c_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_m_id, 0, block_n_id, 0), - c_element_op}; - - // space filling curve for threadwise C in VGPR - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, - Sequence>{}; - - // space filling curve for shuffled blockwise C in global mem - constexpr auto sfc_c_global = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); - - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - c_shuffle_block_copy_lds_to_global.Run( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - c_shuffle_block_buf, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_buf); - - if constexpr(access_id < num_access - 1) - { - constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); - - // move on C - c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); - } - }); - } + Base::template RunEpilogue( + blockwise_gemm_pipeline, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_m_id, + block_n_id, + p_shared, + p_c_grid, + c_element_op); } template ( p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); // A Scale buffer const auto a_scale_grid_buf = make_dynamic_buffer( @@ -1832,7 +1588,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle // constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); @@ -1990,278 +1747,15 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle num_k_block_main_loop); // shuffle C and write out - { - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); - static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 && - CShuffleNXdlPerWavePerShuffle % NXdlPack == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - // constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); - constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8); - constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared_0), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per - // shuffle - M1, // M1 = MWave - M2, // M2 = MXdlPack - M3, // M3 * M4 * M5 = MPerXdl - M4, - M5)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per - // shuffle - N1, // N1 = NWave - N2, // N2 = NXdlPack - N3))), // N3 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<>{}, - Sequence<0, 2, 4, 6, 7, 8>{}, - Sequence<>{}, - Sequence<1, 3, 5, 9>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))), - make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))), - make_tuple(Sequence<0, 1, 2, 3>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< - AccDataType, - CShuffleDataType, - decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), - decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), - ck::tensor_operation::element_wise::PassThrough, - Sequence, - Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, - 9, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - n_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - m_thread_data_on_block_idx[I5], - n_thread_data_on_block_idx[I3]), - ck::tensor_operation::element_wise::PassThrough{}}; - - // calculate C grid descriptor - constexpr auto DWORD_BYTES = 4; - constexpr auto atomic_vector_size = DWORD_BYTES / sizeof(CDataType); - - constexpr auto CShuffleBlockTransferClusterLengths = [&]() { - if constexpr(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::Set) - { - return CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock{}; - } - // Atomic operation - else - { - return generate_sequence_v2( - [&](auto i) { - if constexpr(i == 3) - { - return Number< - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock{} - .At(i) * - CShuffleBlockTransferScalarPerVector_NPerBlock / - atomic_vector_size>{}; - } - else if constexpr(i == 1) - { - return Number< - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock{} - .At(i) / - CShuffleBlockTransferScalarPerVector_NPerBlock * - atomic_vector_size>{}; - } - else - { - return Number< - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock{} - .At(i)>{}; - } - }, - Number<4>{}); - } - }(); - - constexpr auto CShuffleBlockTransferScalarPerVector = [&]() { - if constexpr(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::Set) - { - return CShuffleBlockTransferScalarPerVector_NPerBlock; - } - else - { - return atomic_vector_size; - } - }(); - - // shuffle: blockwise copy C from LDS to global - auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< - ThisThreadBlock, // ThreadGroup - CElementwiseOperation, // ElementwiseOperation, - CGlobalMemoryDataOperation, // DstInMemOp, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - decltype(CShuffleBlockTransferClusterLengths), - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - CShuffleDataType, // typename SrcData, - CDataType, // typename DstData, - decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CShuffleBlockTransferScalarPerVector, // index_t ScalarPerVector, - true, // bool ThreadTransferSrcResetCoordinateAfterRun, - false> // bool ThreadTransferDstResetCoordinateAfterRun> - {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(0, 0, 0, 0), - c_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_m_id, 0, block_n_id, 0), - c_element_op}; - - // space filling curve for threadwise C in VGPR - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, - Sequence>{}; - - // space filling curve for shuffled blockwise C in global mem - constexpr auto sfc_c_global = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); - - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - c_shuffle_block_copy_lds_to_global.Run( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - c_shuffle_block_buf, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_buf); - - if constexpr(access_id < num_access - 1) - { - constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); - - // move on C - c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); - } - }); - } + Base::template RunEpilogue( + blockwise_gemm_pipeline, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_m_id, + block_n_id, + p_shared_0, + p_c_grid, + c_element_op); } template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) #endif kernel_gemm_layernorm_xdl_cshuffle_v1( const FloatAB* __restrict__ p_a_grid, @@ -157,64 +157,127 @@ template struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 + : public GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + FloatAB, + FloatAB, + FloatGemmAcc, + FloatCShuffle, + Tuple<>, + FloatC, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + FloatAB, + FloatAB, + true> // ForceNaiveLayout { - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; + using Base = GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + FloatAB, + FloatAB, + FloatGemmAcc, + FloatCShuffle, + Tuple<>, + FloatC, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + FloatAB, + FloatAB, + true>; // ForceNaiveLayout + + using Base::I0; + using Base::I1; + using Base::I2; + using Base::I3; + using Base::I4; + using Base::I5; + using Base::I6; + using Base::I7; + using ThisThreadBlock = typename Base::ThisThreadBlock; + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; + using Base::NumDTensor; // K1 should be Number<...> - static constexpr auto AK0 = Number{}; - static constexpr auto BK0 = Number{}; - static constexpr auto AK1 = Number{}; - static constexpr auto BK1 = Number{}; - - using ThisThreadBlock = ThisThreadBlock; + static constexpr auto AK0 = Base::AK0Number; + static constexpr auto BK0 = Base::BK0Number; + static constexpr auto AK1 = Base::AK1Number; + static constexpr auto BK1 = Base::BK1Number; using GridwiseGemmPipe = remove_cvref_t< decltype(GridwiseGemmPipeline_Selector())>; - __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() - { - // A matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(AK0, Number{}, AK1), - make_tuple(Number{} * AK1, AK1, I1)); - } - - __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() - { - // B matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(BK0, Number{}, BK1), - make_tuple(Number{} * BK1, BK1, I1)); - } - - __host__ __device__ static constexpr auto - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - I1, - Number{})); - - return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; - } - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); // lds max alignment constexpr auto max_lds_align = math::lcm(AK1, BK1); @@ -227,7 +290,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 // LDS allocation for C shuffle in LDS constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + Base::GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(get_device_arch()); // Align 16 bytes (maximum LDS read/write width) constexpr auto c_block_size_aligned = @@ -433,10 +496,12 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 constexpr auto max_lds_align = math::lcm(AK1, BK1); // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); // A matrix blockwise copy auto a_blockwise_copy = @@ -585,109 +650,32 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - // TODO: hacky, fix it! constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + Base::template GetCThreadDescriptor(); constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + Base::GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + get_device_arch()); auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared), + static_cast(p_shared), c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - tensor_operation::element_wise::PassThrough{}}; + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + Base::template GetCBlockThreadDescriptor< + false, + decltype(blockwise_gemm), + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock)>(); + + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = Base::template GetCThreadCopyVgprToLds( + blockwise_gemm, + c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + ck::tensor_operation::element_wise::PassThrough{}); // shuffle: blockwise copy C from LDS to global auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp index 2903b219b49..83ac6f73d09 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp @@ -11,9 +11,9 @@ #include "ck/tensor_operation/gpu/grid/gridwise_gemm_waveletmodel.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" -#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp" namespace ck { @@ -61,22 +61,114 @@ template struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle + : public GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + ABDataType, + ABDataType, + FloatGemmAcc, + EDataTypeShuffle, + Tuple<>, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + math::max(TileLoadThreadGroupSize, TileMathThreadGroupSize), + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + ABDataType, + ABDataType, + true> // ForceNaiveLayout { - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; + using Base = GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + ABDataType, + ABDataType, + FloatGemmAcc, + EDataTypeShuffle, + Tuple<>, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + math::max(TileLoadThreadGroupSize, TileMathThreadGroupSize), + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + ABDataType, + ABDataType, + true>; // ForceNaiveLayout + + using Base::AK0Number; + using Base::AK1Number; + using Base::BK0Number; + using Base::BK1Number; + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; + using Base::I0; + using Base::I1; + using Base::I2; + using Base::NumDTensor; // K1 should be Number<...> - static constexpr auto AK1 = Number{}; - static constexpr auto BK1 = Number{}; - static constexpr auto AK0PerBlock = Number{}; - static constexpr auto BK0PerBlock = Number{}; + static constexpr auto AK1 = Base::AK1Number; + static constexpr auto BK1 = Base::AK1Number; + static constexpr auto AK0PerBlock = Base::AK0Number; + static constexpr auto BK0PerBlock = Base::BK0Number; static constexpr auto BlockSize = math::max(TileLoadThreadGroupSize, TileMathThreadGroupSize); struct TileLoadThreadGroup @@ -113,65 +205,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle using GridwiseGemmLoad = GridwiseGemmLoadWave; using GridwiseGemmMath = GridwiseGemmMathWave; - __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() - { - // A matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(AK0PerBlock, Number{}, AK1), - make_tuple(Number{} * AK1, AK1, I1)); - } - - __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() - { - // B matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(BK0PerBlock, Number{}, BK1), - make_tuple(Number{} * BK1, BK1, I1)); - } - - __host__ __device__ static constexpr auto - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - I1, - Number{})); - - return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; - } - - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - - // lds max alignment - constexpr auto max_lds_align = math::lcm(AK1, BK1); - - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size_aligned = math::integer_least_multiple( - b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); - - // LDS allocation for C shuffle in LDS - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - constexpr auto c_block_size = - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - - return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * - sizeof(ABDataType), - c_block_size * sizeof(EDataTypeShuffle)); - } - template < InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> __device__ static bool constexpr IsValidCompilationParameter() @@ -375,10 +408,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle // loadWave and MathWave synchronized through LDS // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); // lds max alignment constexpr auto max_lds_align = math::lcm(AK1, BK1); @@ -537,8 +572,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle KPack>{}; auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); - auto c_grid_buf = make_dynamic_buffer( - p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); // TODO re-architect LDS+math stages // Writing data to GMEM: only math wave is doing the work in cshuffle @@ -554,223 +587,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle // sanity check // shuffle C and write out - { - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<>{}, - Sequence<0, 2, 4, 5, 6>{}, - Sequence<>{}, - Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< - FloatGemmAcc, - EDataTypeShuffle, - decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), - decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), - ck::tensor_operation::element_wise::PassThrough, - Sequence, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; - - // shuffle: blockwise copy C from LDS to global - auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< - CShuffleBlockTransferThreadGroup, // ThreadGroup - EElementwiseOperation, // ElementwiseOperation, - CGlobalMemoryDataOperation, // DstInMemOp, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - EDataTypeShuffle, // typename SrcData, - EDataType, // typename DstData, - decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - decltype(e_grid_desc_mblock_mperblock_nblock_nperblock), - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, - true, // bool ThreadTransferSrcResetCoordinateAfterRun, - false> // bool ThreadTransferDstResetCoordinateAfterRun> - {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(0, 0, 0, 0), - e_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), - e_element_op}; - - // space filling curve for threadwise C in VGPR - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; - - // space filling curve for shuffled blockwise C in global mem - constexpr auto sfc_c_global = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); - - // Different way of getting coalesced writes: - // We can get rid of doing cshuffle. Instead of reading A rows in contiguous manner - // do it interleaved, then mfma can have nice c-mat layout as below: - // - // TODO - // We do not need to do LDS swizzle to align global writes writing cache lines: - // v_mfma cmat, amat, bmat, cmat - c-mat register layout are 1xN - // elments (N is vertical or strided - // dimension) - // v_mfma cmat, bmat, amat, cmat - c-mat register layout are Mx1 - // elments (M is coalescing - // dimension) by enumerating M index in - // amat, bmat you can align cmat - // register(s) to contiguous M elements - // for example - // 1st mfma instruction output space : 0 4 8 12 16 .... - // 2nd mfma instruction output space : 1 5 9 13 17 .... - // 3rd mfma instruction output space : 2 6 10 14 18 .... - // 4th mfma instruction output space : 3 7 11 15 19 .... - // you can pack 4 registers output space into 2WORD and do global write - // (no LDS swizzling required) - - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - c_shuffle_block_copy_lds_to_global.Run( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - c_shuffle_block_buf, - e_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_buf); - - if constexpr(access_id < num_access - 1) - { - constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); - - // move on C - c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( - e_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); - } - }); - } + Base::template RunEpilogue( + blockwise_gemm, + e_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_work_idx[I0], + block_work_idx[I1], + p_shared, + p_e_grid, + e_element_op); } } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index e6f055d1832..92d4e4a847b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -11,9 +11,10 @@ #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" -#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp" namespace ck { @@ -154,7 +155,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) #endif kernel_gemm_xdlops_bwd_weight(const FloatA* __restrict__ p_a_grid, const FloatB* __restrict__ p_b_grid, @@ -263,21 +264,111 @@ template struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight + : public GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + FloatA, + FloatB, + FloatAcc, + FloatC, + Tuple<>, + FloatC, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + K0PerBlock * K1Value, + K1Value, + K1Value, + MPerXdl, + NPerXdl, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + FloatA, + FloatB, + true> // ForceNaiveLayout { - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; + using Base = GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + FloatA, + FloatB, + FloatAcc, + FloatC, + Tuple<>, + FloatC, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + K0PerBlock * K1Value, + K1Value, + K1Value, + MPerXdl, + NPerXdl, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + FloatA, + FloatB, + true>; // ForceNaiveLayout + + using Base::I0; + using Base::I1; + using Base::I2; + using Base::I3; + using Base::I4; + using Base::I5; + using Base::I6; + using Base::I7; + using ThisThreadBlock = typename Base::ThisThreadBlock; // K1 should be Number<...> static constexpr auto K1 = Number{}; - using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t< decltype(GridwiseGemmPipeline_Selector())>; @@ -308,8 +399,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() { - constexpr auto max_lds_align = K1; - // A matrix in LDS memory, dst of blockwise copy constexpr auto a_block_desc_k0_m_k1 = [&]() { if constexpr(ABlockLdsExtraM) @@ -337,15 +426,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight } else { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); + return Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); } } else { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + return Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); } }(); @@ -354,8 +440,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight __host__ __device__ static constexpr auto GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1() { - constexpr auto max_lds_align = K1; - // A matrix in LDS memory, dst of blockwise copy constexpr auto a_block_desc_b_k0_m_k1 = [&]() { if constexpr(ABlockLdsExtraM) @@ -389,19 +473,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight } else { - return make_naive_tensor_descriptor( - make_tuple(Number<1>{}, Number{}, Number{}, K1), - make_tuple(Number{} * Number{} * K1, - Number{} * K1, - K1, - I1)); + return Base::GetABlockDescriptor_AKB_AK0PerBlock_MPerBlock_AK1( + GetABlockDescriptor_K0PerBlock_MPerBlock_K1()); } } else { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number<1>{}, Number{}, Number{}, K1), - max_lds_align); + return Base::GetABlockDescriptor_AKB_AK0PerBlock_MPerBlock_AK1( + GetABlockDescriptor_K0PerBlock_MPerBlock_K1()); } }(); @@ -410,8 +489,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1() { - constexpr auto max_lds_align = K1; - // B matrix in LDS memory, dst of blockwise copy constexpr auto b_block_desc_k0_n_k1 = [&]() { if constexpr(BBlockLdsExtraN) @@ -439,15 +516,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight } else { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); + return Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); } } else { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + return Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); } }(); @@ -456,8 +530,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight __host__ __device__ static constexpr auto GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1() { - constexpr auto max_lds_align = K1; - // B matrix in LDS memory, dst of blockwise copy constexpr auto b_block_desc_b_k0_n_k1 = [&]() { if constexpr(BBlockLdsExtraN) @@ -491,19 +563,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight } else { - return make_naive_tensor_descriptor( - make_tuple(Number<1>{}, Number{}, Number{}, K1), - make_tuple(Number{} * Number{} * K1, - Number{} * K1, - K1, - I1)); + return Base::GetBBlockDescriptor_BKB_BK0PerBlock_NPerBlock_BK1( + GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()); } } else { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number<1>{}, Number{}, Number{}, K1), - max_lds_align); + return Base::GetBBlockDescriptor_BKB_BK0PerBlock_NPerBlock_BK1( + GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()); } }(); @@ -528,7 +595,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight b_b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); constexpr auto c_block_size = - GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize(); + Base::GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(get_device_arch()) + .GetElementSpaceSize(); return math::max((a_block_space_size * sizeof(FloatAAdjusted) + b_block_space_size * sizeof(FloatBAdjusted)), @@ -539,7 +607,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> __device__ static bool constexpr IsValidCompilationParameter() { - return ck::tensor_operation::device::IsValidGemmCompilationParameter< + constexpr bool valid = ck::tensor_operation::device::IsValidGemmCompilationParameter< BlockSize, MPerBlock, NPerBlock, @@ -549,6 +617,18 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight NRepeat, FloatC, CGlobalMemoryDataOperation_>(); + if constexpr(!valid) + { + return false; + } + + if constexpr(K1Value % MfmaSelector:: + selected_mfma.k_per_blk != + 0) + { + return false; + } + return true; } // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template @@ -641,19 +721,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight c_m_n_grid_desc, M01, N01, KBatch); } - __host__ __device__ static constexpr auto - GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() - { - constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl); - constexpr index_t NWave = NPerBlock / (NRepeat * NPerXdl); - - return make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - I1, - Number{})); - } - using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype(MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CMNGridDesc{})); using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1)); @@ -703,8 +770,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight p_a_grid + split_k_offset_a, a_buffer_size); const auto b_grid_buf = make_dynamic_buffer( p_b_grid + split_k_offset_b, b_buffer_size); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); if(!c_block_cluster_adaptor.ValidCTileIndex( make_tuple(block_work_idx[I1], block_work_idx[I2]), @@ -867,205 +932,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight K0BlockMainLoop); // output: register to global memory - { - constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl); - constexpr index_t NWave = NPerBlock / (NRepeat * NPerXdl); - - constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = - blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = - blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0); - constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1); - constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2); - constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3); - constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); - constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); - constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6); - constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7); - - constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock = - GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - static_assert(M1 == MWave, ""); - static_assert(N1 == NWave, ""); - static_assert(M2 * M3 * M4 == MPerXdl, ""); - static_assert(N2 == NPerXdl, ""); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), // freeze mblock - make_unmerge_transform(make_tuple(CShuffleMRepeatPerShuffle, - M1, - M2, - M3, - M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl - make_freeze_transform(I0), // freeze nblock - make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle, - N1, - N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; - - // LDS to global - auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< - ThisThreadBlock, // index_t BlockSize, - CElementwiseOperation, // ElementwiseOperation, - CGlobalMemoryDataOperation, // DstInMemOp, - Sequence<1, - CShuffleMRepeatPerShuffle * MWave * MPerXdl, - 1, - CShuffleNRepeatPerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - FloatC, // typename SrcData, - FloatC, // typename DstData, - decltype(c_block_desc_mblock_mperblock_nblock_nperblock), - decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector, - true, // bool ThreadTransferSrcResetCoordinateAfterRun, - false> // bool ThreadTransferDstResetCoordinateAfterRun - {c_block_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(0, 0, 0, 0), - c_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0), - c_element_op}; - - constexpr auto mxdlperwave_forward_step = - make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXdl, 0, 0); - constexpr auto nxdlperwave_forward_step = - make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXdl); - constexpr auto nxdlperwave_backward_step = - make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXdl); - - static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) { - constexpr auto mxdlperwave = mxdlperwave_iter; - - static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) { - constexpr bool nxdlperwave_forward_sweep = - (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0); - - constexpr index_t nxdlperwave_value = - nxdlperwave_forward_sweep - ? nxdlperwave_iter - : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle); - - constexpr auto nxdlperwave = Number{}; - - // make sure it's safe to do ds_write - block_sync_lds(); - - // VGPR to LDS - c_thread_copy_vgpr_to_lds.Run( - c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, - make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_block_buf); - - // make sure it's safe to do ds_read - block_sync_lds(); - - // LDS to global - c_block_copy_lds_to_global.Run(c_block_desc_mblock_mperblock_nblock_nperblock, - c_block_buf, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_buf); - - // move on nxdlperwave dimension - if constexpr(nxdlperwave_forward_sweep && - (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle)) - { - c_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, - nxdlperwave_forward_step); - } - else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) - { - c_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, - nxdlperwave_backward_step); - } - }); - - // move on mxdlperwave dimension - if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle) - { - c_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, mxdlperwave_forward_step); - } - }); - } + Base::template RunEpilogue( + blockwise_gemm, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_work_idx[I1], + block_work_idx[I2], + p_shared, + p_c_grid, + c_element_op); } template @@ -1086,8 +961,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); @@ -1258,205 +1131,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight K0BlockMainLoop); // output: register to global memory - { - constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl); - constexpr index_t NWave = NPerBlock / (NRepeat * NPerXdl); - - constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = - blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = - blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0); - constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1); - constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2); - constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3); - constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); - constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); - constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6); - constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7); - - constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock = - GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - static_assert(M1 == MWave, ""); - static_assert(N1 == NWave, ""); - static_assert(M2 * M3 * M4 == MPerXdl, ""); - static_assert(N2 == NPerXdl, ""); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), // freeze mblock - make_unmerge_transform(make_tuple(CShuffleMRepeatPerShuffle, - M1, - M2, - M3, - M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl - make_freeze_transform(I0), // freeze nblock - make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle, - N1, - N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; - - // LDS to global - auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< - ThisThreadBlock, // index_t BlockSize, - CElementwiseOperation, // ElementwiseOperation, - CGlobalMemoryDataOperation, // DstInMemOp, - Sequence<1, - CShuffleMRepeatPerShuffle * MWave * MPerXdl, - 1, - CShuffleNRepeatPerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - FloatC, // typename SrcData, - FloatC, // typename DstData, - decltype(c_block_desc_mblock_mperblock_nblock_nperblock), - decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector, - true, // bool ThreadTransferSrcResetCoordinateAfterRun, - false> // bool ThreadTransferDstResetCoordinateAfterRun - {c_block_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(0, 0, 0, 0), - c_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0), - c_element_op}; - - constexpr auto mxdlperwave_forward_step = - make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXdl, 0, 0); - constexpr auto nxdlperwave_forward_step = - make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXdl); - constexpr auto nxdlperwave_backward_step = - make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXdl); - - static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) { - constexpr auto mxdlperwave = mxdlperwave_iter; - - static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) { - constexpr bool nxdlperwave_forward_sweep = - (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0); - - constexpr index_t nxdlperwave_value = - nxdlperwave_forward_sweep - ? nxdlperwave_iter - : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle); - - constexpr auto nxdlperwave = Number{}; - - // make sure it's safe to do ds_write - block_sync_lds(); - - // VGPR to LDS - c_thread_copy_vgpr_to_lds.Run( - c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, - make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_block_buf); - - // make sure it's safe to do ds_read - block_sync_lds(); - - // LDS to global - c_block_copy_lds_to_global.Run(c_block_desc_mblock_mperblock_nblock_nperblock, - c_block_buf, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_buf); - - // move on nxdlperwave dimension - if constexpr(nxdlperwave_forward_sweep && - (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle)) - { - c_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, - nxdlperwave_forward_step); - } - else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) - { - c_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, - nxdlperwave_backward_step); - } - }); - - // move on mxdlperwave dimension - if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle) - { - c_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, mxdlperwave_forward_step); - } - }); - } + Base::template RunEpilogue( + blockwise_gemm, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_work_idx[I1], + block_work_idx[I2], + p_shared, + p_c_grid, + c_element_op); } }; // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp index b19c5d8fb69..150fbce3f75 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp @@ -13,6 +13,7 @@ #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp" namespace ck { @@ -29,7 +30,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) #endif kernel_gemm_xdlops_skip_b_lds_v1(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, @@ -46,7 +47,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char + p_shared[GridwiseGemm::template GetSharedMemoryNumberOfByte(get_device_arch())]; auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 = GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n); @@ -114,15 +116,102 @@ template struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 + : public GridwiseGemm_xdl_cshuffle_base, + FloatC, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + K0PerBlock * K1Value, + K1Value, + K1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + Sequence<>, + Sequence<>, + Sequence<>, + 1, + BBlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + 0, + 1, + 1, + Sequence<>, + Sequence, + FloatAB, + FloatAB, + true> // ForceNaiveLayout { - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; + + using Base = GridwiseGemm_xdl_cshuffle_base, + FloatC, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + K0PerBlock * K1Value, + K1Value, + K1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + Sequence<>, + Sequence<>, + Sequence<>, + 1, + BBlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + 0, + 1, + 1, + Sequence<>, + Sequence, + FloatAB, + FloatAB, + true>; // ForceNaiveLayout + + using Base::I0; + using Base::I1; + using Base::I2; + using ThisThreadBlock = typename Base::ThisThreadBlock; + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; // K1 should be Number<...> static constexpr auto K1 = Number{}; @@ -134,44 +223,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 static constexpr auto xdlops_gemm = XdlopsGemm{}; static constexpr index_t K0PerThread = K0PerBlock / xdlops_gemm.K0PerXdlops; - using ThisThreadBlock = ThisThreadBlock; - - __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() - { - constexpr auto max_lds_align = K1; - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_k0_m_k1 = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), - max_lds_align); - } - }(); - - return a_block_desc_k0_m_k1; - } - - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); - - constexpr auto max_lds_align = K1; - - constexpr auto a_block_space_size_aligned = - math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); - - return (a_block_space_size_aligned) * sizeof(FloatAB); - } - template < InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> __device__ static bool constexpr IsValidCompilationParameter() @@ -402,8 +453,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( p_b_grid, b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize()); const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); @@ -415,11 +464,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); - const index_t n_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + constexpr auto a_block_desc_k0_m_k1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // A matrix blockwise copy auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1< @@ -618,82 +665,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 } // output: register to global memory - { - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_grid = - m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; - - const index_t n_thread_data_on_grid = - n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_grid_idx = - m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_grid)); - - const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_grid_idx = - n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_grid)); - - auto c_thread_copy = - ThreadwiseTensorSliceTransfer_v1r3, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - CGlobalMemoryDataOperation, - 1, - true>{ - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(m_thread_data_on_grid_idx[I0], - n_thread_data_on_grid_idx[I0], - m_thread_data_on_grid_idx[I1], - n_thread_data_on_grid_idx[I1], - m_thread_data_on_grid_idx[I2], - m_thread_data_on_grid_idx[I3], - m_thread_data_on_grid_idx[I4], - n_thread_data_on_grid_idx[I2]), - c_element_op}; - - c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), - c_thread_buf, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_grid_buf); - } + Base::template RunEpilogueNoShuffle( + blockwise_gemm, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_thread_buf, + block_work_idx[I0], + block_work_idx[I1], + p_c_grid, + c_element_op); } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp index 05de26b2ae2..3134096899f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp @@ -18,6 +18,7 @@ #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp" namespace ck { @@ -30,7 +31,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) #endif kernel_gemm_xdlops_splitk_lds_direct_load(typename GridwiseGemm::Argument karg, const Block2CTileMap& b2c_map, @@ -96,15 +97,110 @@ template struct GridwiseGemm_xdlops_splitk_lds_direct_load + : public GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + FloatA, + FloatB, + FloatAcc, + FloatC, + Tuple<>, + FloatC, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + K0PerBlock * K1Value, + K1Value, + K1Value, + MPerXdl, + NPerXdl, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferSrcScalarPerVector, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + ComputeType, + ComputeType, + true, // ForceNaiveLayout + true> // DirectLoad { - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; + + using Base = GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + FloatA, + FloatB, + FloatAcc, + FloatC, + Tuple<>, + FloatC, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + K0PerBlock * K1Value, + K1Value, + K1Value, + MPerXdl, + NPerXdl, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferSrcScalarPerVector, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + ComputeType, + ComputeType, + true, // ForceNaiveLayout + true>; // DirectLoad + + using Base::I0; + using Base::I1; + using Base::I2; + using Base::I3; + using ThisThreadBlock = typename Base::ThisThreadBlock; + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::GetABlockDescriptor_AKB_AK0PerBlock_MPerBlock_AK1; + using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; + using Base::GetBBlockDescriptor_BKB_BK0PerBlock_NPerBlock_BK1; // K1 should be Number<...> static constexpr auto K1 = Number{}; @@ -116,8 +212,6 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load tensor_operation::device::GemmPadder{ MPerBlock, NPerBlock, K1* K0PerBlock}; - using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t< decltype(GridwiseGemmPipeline_Selector())>; @@ -350,51 +444,8 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { - constexpr auto max_lds_align = K1; - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_k0_m_k1_block_desc = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_k0_n_k1_block_desc = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_space_size = - math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size = - math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); - - constexpr auto c_block_size = - GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize(); - - return math::max(NumGemmKPrefetchStage * (a_block_space_size + b_block_space_size) * - sizeof(ComputeType), - c_block_size * sizeof(FloatC)); + return Base::template GetSharedMemoryNumberOfByte( + get_device_arch()); } static constexpr index_t MXdlPerWave = MRepeat; @@ -577,8 +628,6 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); // divide block work by [KBatch, M, N] const auto block_work_idx = @@ -607,73 +656,17 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load constexpr auto max_lds_align = K1; // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_k0_m_k1_block_desc = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(K1, Number{}, I1)); - } - }(); + constexpr auto a_k0_m_k1_block_desc = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); - constexpr auto a_b_k0_m_k1_block_desc = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(Number<1>{}, Number{}, Number{}, K1), - make_tuple(Number{} * Number{} * K1, - Number{} * K1, - K1, - I1)); - } - else - { - return make_naive_tensor_descriptor( - make_tuple(Number<1>{}, Number{}, Number{}, K1), - make_tuple( - Number{} * Number{}, K1, Number{}, I1)); - } - }(); + constexpr auto a_b_k0_m_k1_block_desc = + GetABlockDescriptor_AKB_AK0PerBlock_MPerBlock_AK1(a_k0_m_k1_block_desc); // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_k0_n_k1_block_desc = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(K1, Number{}, I1)); - } - }(); + constexpr auto b_k0_n_k1_block_desc = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); - constexpr auto b_b_k0_n_k1_block_desc = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(Number<1>{}, Number{}, Number{}, K1), - make_tuple(Number{} * Number{} * K1, - Number{} * K1, - K1, - I1)); - } - else - { - return make_naive_tensor_descriptor( - make_tuple(Number<1>{}, Number{}, Number{}, K1), - make_tuple( - Number{} * Number{}, K1, Number{}, I1)); - } - }(); + constexpr auto b_b_k0_n_k1_block_desc = + GetBBlockDescriptor_BKB_BK0PerBlock_NPerBlock_BK1(b_k0_n_k1_block_desc); auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_DirectLoad( - static_cast(p_shared_block), - c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), // freeze mblock - make_unmerge_transform(make_tuple(CShuffleMRepeatPerShuffle, - M1, - M2, - M3, - M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl - make_freeze_transform(I0), // freeze nblock - make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle, - N1, - N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; - - // LDS to global - auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< - ThisThreadBlock, // index_t BlockSize, - CElementwiseOperation, // ElementwiseOperation, - CGlobalMemoryDataOperation, // DstInMemOp, - Sequence<1, - CShuffleMRepeatPerShuffle * MWave * MPerXdl, - 1, - CShuffleNRepeatPerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - FloatC, // typename SrcData, - FloatC, // typename DstData, - decltype(c_block_desc_mblock_mperblock_nblock_nperblock), - decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector, - true, // bool ThreadTransferSrcResetCoordinateAfterRun, - false> // bool ThreadTransferDstResetCoordinateAfterRun - {c_block_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(0, 0, 0, 0), - c_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_m_id, 0, block_n_id, 0), - c_element_op}; - - constexpr auto mxdlperwave_forward_step = - make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXdl, 0, 0); - constexpr auto nxdlperwave_forward_step = - make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXdl); - constexpr auto nxdlperwave_backward_step = - make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXdl); - - static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) { - constexpr auto mxdlperwave = mxdlperwave_iter; - - static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) { - constexpr bool nxdlperwave_forward_sweep = - (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0); - - constexpr index_t nxdlperwave_value = - nxdlperwave_forward_sweep - ? nxdlperwave_iter - : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle); - - constexpr auto nxdlperwave = Number{}; - - // make sure it's safe to do ds_write - block_sync_lds(); - - // VGPR to LDS - c_thread_copy_vgpr_to_lds.Run( - c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, - make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_block_buf); - - // make sure it's safe to do ds_read - block_sync_lds(); - - // LDS to global - c_block_copy_lds_to_global.Run(c_block_desc_mblock_mperblock_nblock_nperblock, - c_block_buf, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_buf); - - // move on nxdlperwave dimension - if constexpr(nxdlperwave_forward_sweep && - (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle)) - { - c_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, - nxdlperwave_forward_step); - } - else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) - { - c_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, - nxdlperwave_backward_step); - } - }); - - // move on mxdlperwave dimension - if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle) - { - c_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, mxdlperwave_forward_step); - } - }); - } + Base::template RunEpilogue( + blockwise_gemm, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_m_id, + block_n_id, + p_shared_block, + p_c_grid, + c_element_op); } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp index d4ea7a21498..e14b92aa679 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp @@ -17,13 +17,15 @@ #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v3.hpp" #include "ck/utility/workgroup_barrier.hpp" #include "ck/utility/reduction_functions_accumulate.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) #endif kernel_gemm_xdlops_streamk(const typename GridwiseGemm::FloatAB* p_a_grid, const typename GridwiseGemm::FloatAB* p_b_grid, @@ -40,7 +42,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) { - constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); + constexpr index_t shared_size = + GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch()); __shared__ uint8_t p_shared[shared_size]; @@ -112,15 +115,105 @@ template struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk + : public GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + FloatAB_, + FloatAB_, + FloatAcc_, + FloatAcc_, + Tuple<>, + FloatC_, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + K0PerBlock * K1Value, + K1Value, + K1Value, + MPerXdl, + NPerXdl, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + FloatAB_, + FloatAB_, + true> // ForceNaiveLayout { - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; + using Base = GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + FloatAB_, + FloatAB_, + FloatAcc_, + FloatAcc_, + Tuple<>, + FloatC_, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + K0PerBlock * K1Value, + K1Value, + K1Value, + MPerXdl, + NPerXdl, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + FloatAB_, + FloatAB_, + true>; // ForceNaiveLayout + + using Base::I0; + using Base::I1; + using Base::I2; + using ThisThreadBlock = typename Base::ThisThreadBlock; + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; + using Base::GetSharedMemoryNumberOfByte; // K1 should be Number<...> static constexpr auto K1 = Number{}; @@ -128,9 +221,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk static constexpr auto N01 = 1; static constexpr auto KPerBlock = K0PerBlock * K1; - using ThisThreadBlock = ThisThreadBlock; - using FloatAcc = FloatAcc_; - using FloatCShuffle = FloatAcc; + using FloatAcc = FloatAcc_; + using FloatCShuffle = FloatAcc; using Block2CTileMap = Block2CTileMap_; using FloatAB = FloatAB_; @@ -189,6 +281,17 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk karg.k_batch); } + __host__ static constexpr index_t GetSharedMemoryNumberOfByteOnHost() + { + if(ck::get_device_name() == "gfx950") + { + return GetSharedMemoryNumberOfByte(gfx950_t{}); + } + else + { + return GetSharedMemoryNumberOfByte(gfx_invalid_t{}); + } + } __host__ __device__ static auto CalculateK0(index_t KPad) { return KPad / K1; } __host__ __device__ static auto @@ -270,47 +373,34 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk make_tuple(Sequence<0>{}, Sequence<1>{})); } - __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() - { - // A matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - - __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() - { - // B matrix in LDS memory, dst of blockwise copy - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - constexpr auto max_lds_align = K1; - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - - constexpr auto a_block_space_size_aligned = - math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size_aligned = - math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align); - - constexpr auto c_block_size = - GetCBlockDescriptor_MBlock_MPerShuffle_NBlock_NPerShuffle().GetElementSpaceSize(); - - return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * - sizeof(FloatAB), - c_block_size * sizeof(FloatCShuffle)); - } - static constexpr index_t MXdlPerWave = MRepeat; static constexpr index_t NXdlPerWave = NRepeat; - IS_VALID_COMPILATION_PARAMETER_IMPL(FloatC) + template < + InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> + __device__ static bool constexpr IsValidCompilationParameter() + { + constexpr bool valid = ck::tensor_operation::device::IsValidGemmCompilationParameter< + BlockSize, + MPerBlock, + NPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + FloatC, + CGlobalMemoryDataOperation_>(); + if(!valid) + { + return false; + } + if constexpr(K1 % + MfmaSelector::GetK1PerXdlops() != + 0) + { + return false; + } + return true; + } __host__ __device__ static constexpr bool CheckValidity(const Argument& karg) { @@ -347,6 +437,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk return false; } + if(!is_xdl_wmma_k_supported()) + { + return false; + } + return true; } @@ -485,10 +580,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk constexpr auto max_lds_align = K1; // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_k0_m_k1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto b_block_desc_k0_n_k1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1(); constexpr auto c_block_desc_mblock_mpershuffle_nblock_npershuffle = - GetCBlockDescriptor_MBlock_MPerShuffle_NBlock_NPerShuffle(); + Base::GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + get_device_arch()); constexpr auto c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle = - GetCBlockDescriptor_MShuffleRepeat_MPerShuffle_NShuffleRepeat_NPerShuffle(); + Base::GetCBlockDescriptor_MShuffle_MPerShuffle_NShuffle_NPerShuffle(); auto c_block_buf = make_dynamic_buffer( - reinterpret_cast(p_shared_block), + static_cast(p_shared_block), c_block_desc_mblock_mpershuffle_nblock_npershuffle.GetElementSpaceSize()); + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + Base::template GetCBlockThreadDescriptor< + false, + decltype(blockwise_gemm), + decltype(c_block_desc_mblock_mpershuffle_nblock_npershuffle)>(); + auto c_partial_acc_buf = make_dynamic_buffer( reinterpret_cast(p_workspace) + block_acc_offset, c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle .GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_block_desc_mblock_mpershuffle_nblock_npershuffle, - make_tuple(make_freeze_transform(I0), // freeze mblock - make_unmerge_transform( - make_tuple(CShuffleMRepeatPerShuffle, - M1, - M2, - M3, - M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl - make_freeze_transform(I0), // freeze nblock - make_unmerge_transform( - make_tuple(CShuffleNRepeatPerShuffle, - N1, - N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<>{}, - Sequence<0, 2, 4, 5, 6>{}, - Sequence<>{}, - Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - // VGPR to LDS - auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< - FloatAcc, - FloatCShuffle, - decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc), - decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), - ck::tensor_operation::element_wise::PassThrough, - Sequence, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; + auto c_thread_copy_vgpr_to_lds = Base::template GetCThreadCopyVgprToLds( + blockwise_gemm, + c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + ck::tensor_operation::element_wise::PassThrough{}); // LDS to global auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1r2< diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp index bc5e31013ef..8a5a451d5b4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp @@ -14,6 +14,7 @@ #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp" namespace ck { @@ -26,7 +27,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) #endif #if CK_USE_WAVES_PER_EU __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU))) @@ -42,7 +43,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; GridwiseGemm::template Run(p_a_grid, p_b_grid, @@ -65,7 +66,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) #endif #if CK_USE_WAVES_PER_EU __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU))) @@ -76,7 +77,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; const auto a_grid_desc_k0_m_k1 = amd_wave_read_first_lane(GridwiseGemm::MakeAGridDescriptor_K0_M_K1( @@ -139,15 +140,102 @@ template struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 + : public GridwiseGemm_xdl_cshuffle_base, + FloatC, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + K0PerBlock * K1Value, + K1Value, + K1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + 1, + 1, + Sequence<>, + Sequence, + FloatAB, + FloatAB, + true> // ForceNaiveLayout { - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; + using Base = GridwiseGemm_xdl_cshuffle_base, + FloatC, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + K0PerBlock * K1Value, + K1Value, + K1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + 1, + 1, + Sequence<>, + Sequence, + FloatAB, + FloatAB, + true>; // ForceNaiveLayout + + using Base::I0; + using Base::I1; + using Base::I2; + using ThisThreadBlock = typename Base::ThisThreadBlock; + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; // K1 should be Number<...> static constexpr bool is_single_rate_mfma = @@ -162,8 +250,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 MfmaSelector:: selected_mfma.k_per_blk)>{}; - using ThisThreadBlock = ThisThreadBlock; - using ElementDataTypeAB = conditional_t, float, FloatAB>; __host__ static auto CalculateGridSize(index_t M, index_t N) @@ -272,73 +358,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 using FloatABAdjusted = FloatAB; #endif - __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() - { - constexpr auto max_lds_align = K1; - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_k0_m_k1 = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - return a_block_desc_k0_m_k1; - } - - __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1() - { - constexpr auto max_lds_align = K1; - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_k0_n_k1 = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - return b_block_desc_k0_n_k1; - } - - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); - - constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); - - constexpr auto max_lds_align = K1; - - constexpr auto a_block_space_size_aligned = - math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size_aligned = - math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align); - - return (a_block_space_size_aligned + b_block_space_size_aligned) * - sizeof(ElementDataTypeAB); - } - template < InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> __device__ static bool constexpr IsValidCompilationParameter() { + if constexpr(K1Value % MfmaSelector::selected_mfma + .k_per_blk != + 0) + { + return false; + } return ck::tensor_operation::device::IsValidGemmCompilationParameter< BlockSize, MPerBlock, @@ -364,6 +393,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, "Invalid tuning param!"); + if(!is_xdl_wmma_k_supported()) + { + return false; + } + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); const auto N = b_grid_desc_k0_n_k1.GetLength(I1); const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); @@ -415,60 +449,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); } - template - __host__ __device__ static constexpr auto - MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc& c_grid_desc_m_n) - { - constexpr auto max_lds_align = K1; - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_k0_m_k1 = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_k0_n_k1 = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - using BlockwiseGemm = - BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; - - return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n); - } - // return block_id to C matrix tile idx (m0, n0) mapping using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt; @@ -484,15 +464,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, const CGridDesc_M_N& c_grid_desc_m_n) { - const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n); - const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize()); const AElementwiseOperation a_element_op{}; const BElementwiseOperation b_element_op{}; @@ -505,14 +480,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); - if(!block_2_ctile_map.ValidCTileIndex( - block_work_idx, - make_tuple(c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0), - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1)))) - { - return; - } - // HACK: this force m/n_block_data_idx_on_grid into SGPR const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); @@ -524,10 +491,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 constexpr auto max_lds_align = K1; // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + constexpr auto a_block_desc_k0_m_k1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + constexpr auto b_block_desc_k0_n_k1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); // A matrix blockwise copy auto a_blockwise_copy = @@ -614,6 +583,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 FloatABAdjusted, FloatABAdjusted>(); + const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n); + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0), + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1)))) + { + return; + } + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); // LDS allocation for A and B: be careful of alignment @@ -651,82 +630,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 num_k_block_main_loop); // output: register to global memory - { - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_grid = - m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; - - const index_t n_thread_data_on_grid = - n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_grid_idx = - m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_grid)); - - const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_grid_idx = - n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_grid)); - - auto c_thread_copy = - ThreadwiseTensorSliceTransfer_v1r3, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - CGlobalMemoryDataOperation, - 1, - true>{ - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(m_thread_data_on_grid_idx[I0], - n_thread_data_on_grid_idx[I0], - m_thread_data_on_grid_idx[I1], - n_thread_data_on_grid_idx[I1], - m_thread_data_on_grid_idx[I2], - m_thread_data_on_grid_idx[I3], - m_thread_data_on_grid_idx[I4], - n_thread_data_on_grid_idx[I2]), - c_element_op}; - - c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), - c_thread_buf, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_grid_buf); - } + Base::template RunEpilogueNoShuffle( + blockwise_gemm, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_thread_buf, + block_work_idx[I0], + block_work_idx[I1], + p_c_grid, + c_element_op); } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp index 3c1e1ca27bb..74e20646a79 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp @@ -29,7 +29,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) #endif kernel_gemm_xdlops_v2r4(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, @@ -179,6 +179,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> __device__ static bool constexpr IsValidCompilationParameter() { + if constexpr(K1Value % MfmaSelector::selected_mfma + .k_per_blk != + 0) + { + return false; + } return ck::tensor_operation::device::IsValidGemmCompilationParameter< BlockSize, MPerBlock, @@ -206,6 +212,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 (NPerBlock % (NRepeat * NPerXDL)) == 0, "Invalid tuning param!"); + if(!is_xdl_wmma_k_supported()) + { + return false; + } + const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2); const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2); const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp index 5869caf19be..360847ecd46 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp @@ -12,10 +12,10 @@ #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" -#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp" namespace ck { @@ -28,7 +28,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) #endif kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg, const Block2CTileMap& b2c_map, @@ -39,7 +39,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { - constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); + constexpr index_t shared_size = + GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch()); __shared__ uint8_t p_shared[shared_size]; @@ -103,15 +104,106 @@ template struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 + : public GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + LDSTypeA, + LDSTypeB, + FloatAcc, + FloatC, + Tuple<>, + FloatC, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + K0PerBlock * K1Value, + K1Value, + K1Value, + MPerXdl, + NPerXdl, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + ComputeTypeA, + ComputeTypeB, + true> // ForceNaiveLayout { - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; + using Base = GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + LDSTypeA, + LDSTypeB, + FloatAcc, + FloatC, + Tuple<>, + FloatC, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + K0PerBlock * K1Value, + K1Value, + K1Value, + MPerXdl, + NPerXdl, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + ComputeTypeA, + ComputeTypeB, + true>; // ForceNaiveLayout + using Base::I0; + using Base::I1; + using Base::I2; + using Base::I3; + using ThisThreadBlock = typename Base::ThisThreadBlock; + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::GetABlockDescriptor_AKB_AK0PerBlock_MPerBlock_AK1; + using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; + using Base::GetBBlockDescriptor_BKB_BK0PerBlock_NPerBlock_BK1; // K1 should be Number<...> static constexpr auto K1 = Number{}; @@ -122,8 +214,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 tensor_operation::device::GemmPadder{ MPerBlock, NPerBlock, K1* K0PerBlock}; - using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t< decltype(GridwiseGemmPipeline_Selector())>; @@ -384,55 +474,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 return gemm_padder.PadCDescriptor_M_N(c_grid_desc_m_n); } - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - constexpr auto max_lds_align = K1; - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_k0_m_k1_block_desc = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_k0_n_k1_block_desc = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_space_size = - math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size = - math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); - - constexpr auto c_block_size = - GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize(); - - return math::max(a_block_space_size * sizeof(LDSTypeA) + - b_block_space_size * sizeof(LDSTypeB), - c_block_size * sizeof(FloatC)); - } - static constexpr auto MXdlPerWave = MRepeat; static constexpr auto NXdlPerWave = NRepeat; IS_VALID_COMPILATION_PARAMETER_IMPL(FloatC) @@ -687,8 +728,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); // divide block work by [KBatch, M, N] const auto block_work_idx = @@ -717,69 +756,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 constexpr auto max_lds_align = K1; // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_k0_m_k1_block_desc = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - constexpr auto a_b_k0_m_k1_block_desc = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(Number<1>{}, Number{}, Number{}, K1), - make_tuple(Number{} * Number{} * K1, - Number{} * K1, - K1, - I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number<1>{}, Number{}, Number{}, K1), - max_lds_align); - } - }(); + constexpr auto a_k0_m_k1_block_desc = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); + constexpr auto a_b_k0_m_k1_block_desc = + GetABlockDescriptor_AKB_AK0PerBlock_MPerBlock_AK1(a_k0_m_k1_block_desc); // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_k0_n_k1_block_desc = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - constexpr auto b_b_k0_n_k1_block_desc = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(Number<1>{}, Number{}, Number{}, K1), - make_tuple(Number{} * Number{} * K1, - Number{} * K1, - K1, - I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number<1>{}, Number{}, Number{}, K1), - max_lds_align); - } - }(); + constexpr auto b_k0_n_k1_block_desc = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); + constexpr auto b_b_k0_n_k1_block_desc = + GetBBlockDescriptor_BKB_BK0PerBlock_NPerBlock_BK1(b_k0_n_k1_block_desc); // A matrix blockwise copy auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1( - static_cast(p_shared_block), - c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), // freeze mblock - make_unmerge_transform(make_tuple(CShuffleMRepeatPerShuffle, - M1, - M2, - M3, - M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl - make_freeze_transform(I0), // freeze nblock - make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle, - N1, - N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; - - // LDS to global - auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< - ThisThreadBlock, // index_t BlockSize, - CElementwiseOperation, // ElementwiseOperation, - CGlobalMemoryDataOperation, // DstInMemOp, - Sequence<1, - CShuffleMRepeatPerShuffle * MWave * MPerXdl, - 1, - CShuffleNRepeatPerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - FloatC, // typename SrcData, - FloatC, // typename DstData, - decltype(c_block_desc_mblock_mperblock_nblock_nperblock), - decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector, - true, // bool ThreadTransferSrcResetCoordinateAfterRun, - false> // bool ThreadTransferDstResetCoordinateAfterRun - {c_block_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(0, 0, 0, 0), - c_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_m_id, 0, block_n_id, 0), - c_element_op}; - - constexpr auto mxdlperwave_forward_step = - make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXdl, 0, 0); - constexpr auto nxdlperwave_forward_step = - make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXdl); - constexpr auto nxdlperwave_backward_step = - make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXdl); - - static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) { - constexpr auto mxdlperwave = mxdlperwave_iter; - - static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) { - constexpr bool nxdlperwave_forward_sweep = - (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0); - - constexpr index_t nxdlperwave_value = - nxdlperwave_forward_sweep - ? nxdlperwave_iter - : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle); - - constexpr auto nxdlperwave = Number{}; - - // make sure it's safe to do ds_write - block_sync_lds(); - - // VGPR to LDS - c_thread_copy_vgpr_to_lds.Run( - c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, - make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_block_buf); - - // make sure it's safe to do ds_read - block_sync_lds(); - - // LDS to global - c_block_copy_lds_to_global.Run(c_block_desc_mblock_mperblock_nblock_nperblock, - c_block_buf, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_buf); - - // move on nxdlperwave dimension - if constexpr(nxdlperwave_forward_sweep && - (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle)) - { - c_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, - nxdlperwave_forward_step); - } - else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) - { - c_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, - nxdlperwave_backward_step); - } - }); - - // move on mxdlperwave dimension - if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle) - { - c_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, mxdlperwave_forward_step); - } - }); - } + Base::template RunEpilogue( + blockwise_gemm, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_m_id, + block_n_id, + p_shared_block, + p_c_grid, + c_element_op); } static std::string GetTypeString() diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp index 8c9b7492ccc..9451ad2cf67 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp @@ -12,9 +12,11 @@ #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" -#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp" namespace ck { @@ -31,7 +33,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) #endif kernel_gemm_xdlops_v3r1( const FloatAB* __restrict__ p_a_grid, @@ -50,7 +52,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; GridwiseGemm::template Run( p_a_grid, @@ -124,121 +126,120 @@ template < index_t NumGemmKPrefetchStage = 1, PipelineVersion PipelineVer = PipelineVersion::v1> struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 + : public GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + FloatAB, + FloatAB, + FloatAcc, + FloatCShuffle, + Tuple<>, + FloatC, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + Sequence, + FloatAB, + FloatAB, + true> // ForceNaiveLayout { - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; + using Base = GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + FloatAB, + FloatAB, + FloatAcc, + FloatCShuffle, + Tuple<>, + FloatC, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + Sequence, + FloatAB, + FloatAB, + true>; + + using Base::I0; + using Base::I1; + using Base::I2; + using Base::I3; + using ThisThreadBlock = typename Base::ThisThreadBlock; + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; // K1 should be Number<...> - static constexpr auto AK0 = Number{}; - static constexpr auto BK0 = Number{}; - static constexpr auto AK1 = Number{}; - static constexpr auto BK1 = Number{}; - - using ThisThreadBlock = ThisThreadBlock; + static constexpr auto AK0 = Base::AK0Number; + static constexpr auto BK0 = Base::BK0Number; + static constexpr auto AK1 = Base::AK1Number; + static constexpr auto BK1 = Base::BK1Number; using GridwiseGemmPipe = remove_cvref_t< decltype(GridwiseGemmPipeline_Selector())>; - __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() - { - constexpr auto max_lds_align = AK1; - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(AK0, Number{}, AK1), - make_tuple(Number{} * AK1, AK1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(AK0, Number{}, AK1), max_lds_align); - } - }(); - - return a_block_desc_ak0_m_ak1; - } - - __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() - { - constexpr auto max_lds_align = BK1; - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(BK0, Number{}, BK1), - make_tuple(Number{} * BK1, BK1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(BK0, Number{}, BK1), max_lds_align); - } - }(); - - return b_block_desc_bk0_n_bk1; - } - - __host__ __device__ static constexpr auto - GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - constexpr auto - c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - Number{}, - I1, - Number{}, - Number{})); - - return c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; - } - - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - - constexpr auto a_block_space_size_aligned = - math::integer_least_multiple(a_block_desc_ak0_m_ak1.GetElementSpaceSize(), AK1); - - constexpr auto b_block_space_size_aligned = - math::integer_least_multiple(b_block_desc_bk0_n_bk1.GetElementSpaceSize(), BK1); - - // LDS allocation for C shuffle in LDS - constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = - GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(); - - constexpr auto c_block_size = - c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl - .GetElementSpaceSize(); - - return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * - sizeof(FloatAB), - c_block_size * sizeof(FloatCShuffle)); - } - template < InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> __device__ static bool constexpr IsValidCompilationParameter() { - return ck::tensor_operation::device::IsValidGemmCompilationParameter< + constexpr bool valid = ck::tensor_operation::device::IsValidGemmCompilationParameter< BlockSize, MPerBlock, NPerBlock, @@ -248,15 +249,26 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 NXdlPerWave, FloatC, CGlobalMemoryDataOperation>(); + if constexpr(!valid) + { + return false; + } + + if constexpr(KPerBlock % + MfmaSelector::GetKPerXdlops() != + 0) + { + return false; + } + return true; } // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template - __host__ __device__ static constexpr bool - CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, - const CGridDesc_M_N& c_grid_desc_m_n, - const Block2CTileMap& block_2_ctile_map) + __host__ __device__ static bool CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) { // static_assert(is_known_at_compile_time>::value && // is_known_at_compile_time>::value, @@ -276,6 +288,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) return false; + if(!is_xdl_wmma_k_supported()) + { + return false; + } // check gridwise gemm pipeline const auto num_k_loop = K / KPerBlock; @@ -364,7 +380,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 p_c_grid, c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl .GetElementSpaceSize()); - // divide block work by [M, N] const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); @@ -391,10 +406,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 constexpr auto max_lds_align = math::lcm(AK1, BK1); // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); // A matrix blockwise copy auto a_blockwise_copy = @@ -539,117 +556,30 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - // TODO: hacky, fix it! constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + Base::template GetCThreadDescriptor(); - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); - - constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = - GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(); + constexpr auto + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = Base:: + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + get_device_arch()); auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared), + static_cast(p_shared), c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl .GetElementSpaceSize()); - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - make_tuple( - make_freeze_transform(I0), // freeze mblock - make_pass_through_transform( - Number{}), // M0 (MXdlPerWave) per shuffle - make_unmerge_transform( - make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl - make_freeze_transform(I0), // freeze nblock - make_pass_through_transform( - Number{}), // N0 (NXdlPerWave) per shuffle - make_unmerge_transform( - make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<>{}, - Sequence<0>{}, - Sequence<2, 4, 5, 6>{}, - Sequence<>{}, - Sequence<1>{}, - Sequence<3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = Base::template GetCBlockThreadDescriptor< + false, + decltype(blockwise_gemm), + decltype(c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl)>(); // VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; + auto c_thread_copy_vgpr_to_lds = Base::template GetCThreadCopyVgprToLds( + blockwise_gemm, + c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + ck::tensor_operation::element_wise::PassThrough{}); // LDS to global auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp index ec3e7abf21e..654b09aca8f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp @@ -14,6 +14,8 @@ #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp" namespace ck { @@ -31,7 +33,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) #endif kernel_gemm_xdlops_v3r2( const FloatAB* __restrict__ p_a_grid, @@ -53,7 +55,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; GridwiseGemm::template Run( p_a_grid, @@ -130,119 +132,127 @@ template < index_t NumGemmKPrefetchStage = 1, PipelineVersion PipelineVer = PipelineVersion::v1> struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 + : public GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + FloatAB, + FloatAB, + FloatAcc, + FloatC, + Tuple<>, + FloatC, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + K0PerBlock * K1Value, + K1Value, + K1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + Sequence, + FloatAB, + FloatAB, + true> // ForceNaiveLayout { - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; + using Base = GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + FloatAB, + FloatAB, + FloatAcc, + FloatC, + Tuple<>, + FloatC, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + K0PerBlock * K1Value, + K1Value, + K1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + Sequence, + FloatAB, + FloatAB, + true>; // ForceNaiveLayout + + using Base::I0; + using Base::I1; + using Base::I2; + using Base::I3; + using ThisThreadBlock = typename Base::ThisThreadBlock; + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; // K1 should be Number<...> - static constexpr auto K1 = Number{}; - - using ThisThreadBlock = ThisThreadBlock; + static constexpr auto AK0 = Base::AK0Number; + static constexpr auto BK0 = Base::BK0Number; + static constexpr auto AK1 = Base::AK1Number; + static constexpr auto BK1 = Base::BK1Number; + // K1 should be Number<...> + static constexpr auto K1 = Base::AK1Number; using GridwiseGemmPipe = remove_cvref_t< decltype(GridwiseGemmPipeline_Selector())>; - __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() - { - constexpr auto max_lds_align = K1; - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_k0_m_k1 = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - return a_block_desc_k0_m_k1; - } - - __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1() - { - constexpr auto max_lds_align = K1; - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_k0_n_k1 = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - return b_block_desc_k0_n_k1; - } - - __host__ __device__ static constexpr auto - GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - constexpr auto - c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - Number{}, - I1, - Number{}, - Number{})); - - return c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; - } - - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); - - constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); - - constexpr auto max_lds_align = K1; - - constexpr auto a_block_space_size_aligned = - math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size_aligned = - math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align); - - // LDS allocation for C shuffle in LDS - constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = - GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(); - - constexpr auto c_block_size = - c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl - .GetElementSpaceSize(); - - return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * - sizeof(FloatAB), - c_block_size * sizeof(FloatC)); - } - template < InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> __device__ static bool constexpr IsValidCompilationParameter() { + if constexpr(K1Value % MfmaSelector::selected_mfma + .k_per_blk != + 0) + { + return false; + } return ck::tensor_operation::device::IsValidGemmCompilationParameter< BlockSize, MPerBlock, @@ -270,6 +280,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, "Invalid tuning param!"); + if(!is_xdl_wmma_k_supported()) + { + return false; + } + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); const auto N = b_grid_desc_k0_n_k1.GetLength(I1); const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); @@ -414,10 +429,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 constexpr auto max_lds_align = K1; // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + constexpr auto a_block_desc_k0_m_k1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + constexpr auto b_block_desc_k0_n_k1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); // A matrix blockwise copy auto a_blockwise_copy = @@ -546,119 +563,29 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - // TODO: hacky, fix it! constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); - - constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = - GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(); + Base::template GetCThreadDescriptor(); + constexpr auto + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = Base:: + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + get_device_arch()); auto c_block_buf = make_dynamic_buffer( static_cast(p_shared), c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl .GetElementSpaceSize()); - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - make_tuple( - make_freeze_transform(I0), // freeze mblock - make_pass_through_transform( - Number{}), // M0 (MXdlPerWave) per shuffle - make_unmerge_transform( - make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl - make_freeze_transform(I0), // freeze nblock - make_pass_through_transform( - Number{}), // N0 (NXdlPerWave) per shuffle - make_unmerge_transform( - make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<>{}, - Sequence<0>{}, - Sequence<2, 4, 5, 6>{}, - Sequence<>{}, - Sequence<1>{}, - Sequence<3, 7>{}) - - ); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = Base::template GetCBlockThreadDescriptor< + false, + decltype(blockwise_gemm), + decltype(c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl)>(); // VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; + auto c_thread_copy_vgpr_to_lds = Base::template GetCThreadCopyVgprToLds( + blockwise_gemm, + c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + ck::tensor_operation::element_wise::PassThrough{}); auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r2< ThisThreadBlock, // index_t BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp index f7570adfc14..86b95421ee6 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp @@ -14,6 +14,8 @@ #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp" namespace ck { @@ -32,7 +34,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) #endif kernel_gemm_xdlops_v3r3( const FloatAB* __restrict__ p_a_grid, @@ -57,7 +59,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; GridwiseGemm::template Run( p_a_grid, @@ -139,119 +141,127 @@ template < index_t NumGemmKPrefetchStage = 1, PipelineVersion PipelineVer = PipelineVersion::v1> struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 + : public GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + FloatAB, + FloatAB, + FloatAcc, + FloatC, + Tuple<>, + FloatC, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + K0PerBlock * K1Value, + K1Value, + K1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + Sequence, + FloatAB, + FloatAB, + true> // ForceNaiveLayout { - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; + using Base = GridwiseGemm_xdl_cshuffle_base< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + FloatAB, + FloatAB, + FloatAcc, + FloatC, + Tuple<>, + FloatC, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + K0PerBlock * K1Value, + K1Value, + K1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + Sequence, + FloatAB, + FloatAB, + true>; // ForceNaiveLayout + + using Base::I0; + using Base::I1; + using Base::I2; + using Base::I3; + using ThisThreadBlock = typename Base::ThisThreadBlock; + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; // K1 should be Number<...> - static constexpr auto K1 = Number{}; - - using ThisThreadBlock = ThisThreadBlock; + static constexpr auto AK0 = Base::AK0Number; + static constexpr auto BK0 = Base::BK0Number; + static constexpr auto AK1 = Base::AK1Number; + static constexpr auto BK1 = Base::BK1Number; + // K1 should be Number<...> + static constexpr auto K1 = Base::AK1Number; using GridwiseGemmPipe = remove_cvref_t< decltype(GridwiseGemmPipeline_Selector())>; - __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() - { - constexpr auto max_lds_align = K1; - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_k0_m_k1 = [&]() { - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - return a_block_desc_k0_m_k1; - } - - __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1() - { - constexpr auto max_lds_align = K1; - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_k0_n_k1 = [&]() { - if constexpr(BBlockLdsExtraN) - { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); - } - else - { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); - } - }(); - - return b_block_desc_k0_n_k1; - } - - __host__ __device__ static constexpr auto - GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - constexpr auto - c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - Number{}, - I1, - Number{}, - Number{})); - - return c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; - } - - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); - - constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); - - constexpr auto max_lds_align = K1; - - constexpr auto a_block_space_size_aligned = - math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size_aligned = - math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align); - - // LDS allocation for C shuffle in LDS - constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = - GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(); - - constexpr auto c_block_size = - c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl - .GetElementSpaceSize(); - - return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * - sizeof(FloatAB), - c_block_size * sizeof(FloatC)); - } - template < InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set> __device__ static bool constexpr IsValidCompilationParameter() { + if constexpr(K1Value % MfmaSelector::selected_mfma + .k_per_blk != + 0) + { + return false; + } return ck::tensor_operation::device::IsValidGemmCompilationParameter< BlockSize, MPerBlock, @@ -279,6 +289,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, "Invalid tuning param!"); + if(!is_xdl_wmma_k_supported()) + { + return false; + } + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); const auto N = b_grid_desc_k0_n_k1.GetLength(I1); const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); @@ -434,10 +449,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 constexpr auto max_lds_align = K1; // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + constexpr auto a_block_desc_k0_m_k1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + constexpr auto b_block_desc_k0_n_k1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); // A matrix blockwise copy auto a_blockwise_copy = @@ -564,120 +581,30 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - // TODO: hacky, fix it! constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + Base::template GetCThreadDescriptor(); - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); - - constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = - GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(); + constexpr auto + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = Base:: + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + get_device_arch()); auto c_block_buf = make_dynamic_buffer( static_cast(p_shared), c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl .GetElementSpaceSize()); - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - make_tuple(make_freeze_transform(I0), // freeze mblock - make_pass_through_transform( - Number{}), // M0 (MXdlPerWave) per - // shuffle - make_unmerge_transform( - make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl - make_freeze_transform(I0), // freeze nblock - make_pass_through_transform( - Number{}), // N0 (NXdlPerWave) per - // shuffle - make_unmerge_transform( - make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<>{}, - Sequence<0>{}, - Sequence<2, 4, 5, 6>{}, - Sequence<>{}, - Sequence<1>{}, - Sequence<3, 7>{}) - - ); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = Base::template GetCBlockThreadDescriptor< + false, + decltype(blockwise_gemm), + decltype(c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl)>(); // VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; + auto c_thread_copy_vgpr_to_lds = Base::template GetCThreadCopyVgprToLds( + blockwise_gemm, + c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + ck::tensor_operation::element_wise::PassThrough{}); auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r3< ThisThreadBlock, // ThreadGroup diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index 97f6dc30840..5e95d3c55be 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -14,7 +14,7 @@ #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp" #define DEBUG_LOG 0 @@ -27,12 +27,6 @@ namespace ck { // 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds // buffer when we declare __shared__ inside blkgemmpipe -enum Activation -{ - gelu_and_mul = 0, - silu_and_mul = 1 -}; - template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(GridwiseGemm::MaxBlockSize, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_gemm(typename GridwiseGemm::Argument karg) @@ -48,7 +42,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char + p_shared[GridwiseGemm::template GetSharedMemoryNumberOfByte(get_device_arch())]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); @@ -78,7 +73,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(GridwiseGemm::MaxBlockSize, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg) @@ -86,8 +81,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char + p_shared[GridwiseGemm::template GetSharedMemoryNumberOfByte(get_device_arch())]; + __shared__ char + p_shared1[GridwiseGemm::template GetSharedMemoryNumberOfByte(get_device_arch())]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); @@ -167,28 +164,117 @@ template -struct GridwiseMoeGemm +struct GridwiseMoeGemm : public GridwiseGemm_xdl_cshuffle_base< + ALayout, + BLayout, + CLayout, + LDSTypeA, + LDSTypeB, + AccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + ComputeTypeA, + ComputeTypeB, + false> { - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; - - static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock = - CDEShuffleBlockTransferScalarPerVectors{}[I0]; - // K1 should be Number<...> - static constexpr auto AK0Number = Number{}; - static constexpr auto BK0Number = Number{}; - static constexpr auto AK1Number = Number{}; - static constexpr auto BK1Number = Number{}; + using Base = GridwiseGemm_xdl_cshuffle_base< + ALayout, + BLayout, + CLayout, + LDSTypeA, + LDSTypeB, + AccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + ComputeTypeA, + ComputeTypeB, + false>; + + using Base::AK0Number; + using Base::AK1Number; + using Base::BK0Number; + using Base::BK1Number; + using Base::CShuffleBlockTransferScalarPerVector_NPerBlock; + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::GetSharedMemoryNumberOfByte; + using Base::I0; + using Base::I1; + using Base::I2; + using Base::I3; + using Base::I4; + using Base::I5; + using Base::I6; + using Base::I7; + using ThisThreadBlock = typename Base::ThisThreadBlock; + using Base::NumDTensor; static constexpr auto BlockSizeNumber = Number{}; - static constexpr index_t NumDTensor = DsDataType::Size(); - using mfma_selector = MfmaSelector; static constexpr index_t KPack = math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk); @@ -227,8 +313,6 @@ struct GridwiseMoeGemm using DsGridPointer = decltype(MakeDsGridPointer()); - using ThisThreadBlock = ThisThreadBlock; - static constexpr index_t APackedSize = []() { if constexpr(is_same_v, pk_i4_t>) return 2; @@ -758,125 +842,6 @@ struct GridwiseMoeGemm index_t b_k_split_offset; }; - __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t WaveSize = BlockSize / (MWave * NWave); - - // A matrix in LDS memory, dst of blockwise copy - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); - } - // xor tensor transformation request more unnecessary vgpr usage, would cause register spill - // in some cases. - else if constexpr(is_same::value) - { - constexpr auto a_lds_block_desc = - make_naive_tensor_descriptor(make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple(make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<1, 0>{}, Sequence<2>{}), - make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - - return a_lds_block_desc_permuted; - } - else // ColumnMajor A - { - // kfold and mpair dimension is not always required. - // more dimension in merge_transform increase the difficulty of generating immarg offset - // for compiler. - constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); - constexpr auto M1 = MPerBlock / M0; - - constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); - constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; - constexpr auto KThreadRead = WaveSize / MPerXdl; - constexpr auto K0PerThreadRead = AK0Number / KThreadRead; - - constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128) - ? 1 - : 128 / (AK1Number * M0 * sizeof(LDSTypeA)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=mpair<=n0 - constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128) - ? 1 - : ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0 - ? M0 - : 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))); - - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - AK1Number)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); - - constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<1>{}, - Sequence<2>{}, - Sequence<0, 3>{}, - Sequence<4, 5>{}, - Sequence<6>{}, - Sequence<7>{})); - - constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(Number{}, - Number{}, - Number{}, - Number{})), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return a_lds_block_desc_ak0_m_ak1; - } - } - __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() { // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack @@ -884,67 +849,32 @@ struct GridwiseMoeGemm make_tuple(Number{}, I1, Number{}, Number{})); } - __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - I1, - Number{})); - - return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; - } - - using BlockwiseGemmPipe = - remove_cvref_t())>; - - __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - // lds max alignment - constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); - - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); - - // LDS allocation for C shuffle in LDS - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - constexpr auto c_block_size = - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - - return math::max(a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize, - c_block_size * sizeof(CShuffleDataType)); - } + using BlockwiseGemmPipe = remove_cvref_t< + decltype(BlockGemmBPreshufflePipeline_Selector< + BlkGemmPipelineVer, + BlkGemmPipeSched, + BlockSize, + ADataType, + BDataType, + ComputeTypeA, + AccDataType, + decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch())), + decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), + decltype(MakeAMmaTileDescriptor_M0_M1_M2_K( + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()))), + decltype(MakeBMmaTileDescriptor_N0_N1_N2_K( + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + IsInputGemm>())>; IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) @@ -1257,7 +1187,8 @@ struct GridwiseMoeGemm const auto b_grid_buf = make_dynamic_buffer( p_b_grid + expert_offset, b_grid_desc_bpreshuffled.GetElementSpaceSize()); // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy // dummy @@ -1398,475 +1329,224 @@ struct GridwiseMoeGemm } // shuffle C and write out + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + // constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + // constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + // constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + // mul scales + const float* p_sorted_weights_0 = p_ds_grid[I0]; + const float* p_scale_b = p_ds_grid[I1]; + + static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock); + static_assert(M4 == 4 || M4 == 8); + const index_t m1 = get_warp_local_1d_id() / NWave; + const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl; + + if(p_sorted_weights_0 != nullptr && p_scale_b != nullptr) { - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); - - // mul scales - const float* p_sorted_weights_0 = p_ds_grid[I0]; - const float* p_scale_b = p_ds_grid[I1]; - - static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock); - static_assert(M4 == 4 || M4 == 8); - const index_t m1 = get_warp_local_1d_id() / NWave; - const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl; - - if(p_sorted_weights_0 != nullptr && p_scale_b != nullptr) + if constexpr(PerTokenQuant) { - if constexpr(PerTokenQuant) - { - constexpr index_t scale_stride = (IsInputGemm ? 2 : 1); - p_scale_b += expert_id * problem.N * scale_stride + block_n_id * NPerBlock + - get_warp_local_1d_id() % NWave * NPerXdl + threadIdx.x % NPerXdl; - } - else - { - p_scale_b += expert_id; - } - - vector_type scale_token_ids; - vector_type topk_weights; - static_for<0, NXdlPerWave, 1>{}([&](auto n0) { - const float scale_b = p_scale_b[n0 * NWave * NPerXdl * PerTokenQuant]; - static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave - static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk - const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + - m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; - if constexpr(PerTokenQuant) - { - scale_token_ids = - *c_style_pointer_cast*>( - p_sorted_token_ids + m_pos); - } - if constexpr(MulRoutedWeight) + constexpr index_t scale_stride = (IsInputGemm ? 2 : 1); + p_scale_b += expert_id * problem.N * scale_stride + block_n_id * NPerBlock + + get_warp_local_1d_id() % NWave * NPerXdl + threadIdx.x % NPerXdl; + } + else + { + p_scale_b += expert_id; + } + + vector_type scale_token_ids; + vector_type topk_weights; + static_for<0, NXdlPerWave, 1>{}([&](auto n0) { + const float scale_b = p_scale_b[n0 * NWave * NPerXdl * PerTokenQuant]; + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk + const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + + m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; + if constexpr(PerTokenQuant) + { + scale_token_ids = + *c_style_pointer_cast*>( + p_sorted_token_ids + m_pos); + } + if constexpr(MulRoutedWeight) + { + topk_weights = *c_style_pointer_cast*>( + p_ds_grid[I2] + m_pos); + } + static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size + float scale_a = [&]() { + if constexpr(PerTokenQuant) + { + index_t fused_token = + scale_token_ids.template AsType()[m4]; + const index_t token_offset = fused_token & 0xffffff; + return token_offset < problem.NumTokens + ? p_sorted_weights_0[IsInputGemm + ? token_offset + : token_offset * + problem.TopK + + (fused_token >> 24)] + : 0.0; + } + else + { + return p_sorted_weights_0[0]; + } + }(); + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, m2 * M4 + m4)); + constexpr auto cidx = Number{}; + if constexpr(IsInputGemm) // gu fusion { - topk_weights = *c_style_pointer_cast*>( - p_ds_grid[I2] + m_pos); - } - static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size - float scale_a = [&]() { - if constexpr(PerTokenQuant) + if constexpr(ActivationOperation == Activation::silu_and_mul) + { + const float scale_up = + p_scale_b[(n0 * NWave * NPerXdl + problem.N) * + PerTokenQuant]; + float gate = scale_a * scale_b * c_thread_buf[cidx]; + float up = scale_a * scale_up * c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) { - index_t fused_token = - scale_token_ids.template AsType()[m4]; - const index_t token_offset = fused_token & 0xffffff; - return token_offset < problem.NumTokens - ? p_sorted_weights_0[IsInputGemm - ? token_offset - : token_offset * - problem.TopK + - (fused_token >> - 24)] - : 0.0; + gate = gate * topk_weights.template AsType()[m4]; + up = up * topk_weights.template AsType()[m4]; } - else + if constexpr(is_same_v, pk_i4_t>) { - return p_sorted_weights_0[0]; + gate *= 16; + up *= 16; } - }(); - constexpr index_t c_offset = - blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( - make_tuple(m0, n0, m2 * M4 + m4)); - constexpr auto cidx = Number{}; - if constexpr(IsInputGemm) // gu fusion + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + else if(ActivationOperation == Activation::gelu_and_mul) { - if constexpr(ActivationOperation == Activation::silu_and_mul) + const float scale_up = + p_scale_b[(n0 * NWave * NPerXdl + problem.N) * + PerTokenQuant]; + float gate = scale_a * scale_b * c_thread_buf[cidx]; + float up = scale_a * scale_up * c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) { - const float scale_up = - p_scale_b[(n0 * NWave * NPerXdl + problem.N) * - PerTokenQuant]; - float gate = scale_a * scale_b * c_thread_buf[cidx]; - float up = scale_a * scale_up * c_thread_buf_up[cidx]; - if constexpr(MulRoutedWeight) - { - gate = gate * topk_weights.template AsType()[m4]; - up = up * topk_weights.template AsType()[m4]; - } - if constexpr(is_same_v, pk_i4_t>) - { - gate *= 16; - up *= 16; - } - tensor_operation::element_wise::Silu{}(gate, gate); - c_thread_buf_fp32(cidx) = gate * up; + gate = gate * topk_weights.template AsType()[m4]; + up = up * topk_weights.template AsType()[m4]; } - else if(ActivationOperation == Activation::gelu_and_mul) + if constexpr(is_same_v, pk_i4_t>) { - const float scale_up = - p_scale_b[(n0 * NWave * NPerXdl + problem.N) * - PerTokenQuant]; - float gate = scale_a * scale_b * c_thread_buf[cidx]; - float up = scale_a * scale_up * c_thread_buf_up[cidx]; - if constexpr(MulRoutedWeight) - { - gate = gate * topk_weights.template AsType()[m4]; - up = up * topk_weights.template AsType()[m4]; - } - if constexpr(is_same_v, pk_i4_t>) - { - gate *= 16; - up *= 16; - } - tensor_operation::element_wise::Gelu{}(gate, gate); - c_thread_buf_fp32(cidx) = gate * up; + gate *= 16; + up *= 16; } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; } - else + } + else + { + c_thread_buf_fp32(cidx) = scale_a * scale_b * c_thread_buf[cidx]; + if constexpr(MulRoutedWeight) { c_thread_buf_fp32(cidx) = - scale_a * scale_b * c_thread_buf[cidx]; - if constexpr(MulRoutedWeight) - { - c_thread_buf_fp32(cidx) = - c_thread_buf_fp32(cidx) * - topk_weights.template AsType()[m4]; - } + c_thread_buf_fp32(cidx) * + topk_weights.template AsType()[m4]; } - }); + } }); }); }); - } - else - { - vector_type topk_weights; // for gemm2 only - static_for<0, NXdlPerWave, 1>{}([&](auto n0) { - static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave - static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk - const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + - m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; - if constexpr(MulRoutedWeight) + }); + } + else + { + vector_type topk_weights; // for gemm2 only + static_for<0, NXdlPerWave, 1>{}([&](auto n0) { + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk + const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + + m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; + if constexpr(MulRoutedWeight) + { + topk_weights = *c_style_pointer_cast*>( + p_ds_grid[I2] + m_pos); + } + static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, m2 * M4 + m4)); + constexpr auto cidx = Number{}; + + if constexpr(IsInputGemm) // gu fusion { - topk_weights = *c_style_pointer_cast*>( - p_ds_grid[I2] + m_pos); - } - static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size - constexpr index_t c_offset = - blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( - make_tuple(m0, n0, m2 * M4 + m4)); - constexpr auto cidx = Number{}; - - if constexpr(IsInputGemm) // gu fusion + if constexpr(ActivationOperation == Activation::silu_and_mul) { - if constexpr(ActivationOperation == Activation::silu_and_mul) - { - float gate = c_thread_buf[cidx]; - float up = c_thread_buf_up[cidx]; - if constexpr(MulRoutedWeight) - { - gate = gate * topk_weights.template AsType()[m4]; - up = up * topk_weights.template AsType()[m4]; - } - tensor_operation::element_wise::Silu{}(gate, gate); - c_thread_buf_fp32(cidx) = gate * up; - } - else if(ActivationOperation == Activation::gelu_and_mul) + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) { - float gate = c_thread_buf[cidx]; - float up = c_thread_buf_up[cidx]; - if constexpr(MulRoutedWeight) - { - gate = gate * topk_weights.template AsType()[m4]; - up = up * topk_weights.template AsType()[m4]; - } - tensor_operation::element_wise::Gelu{}(gate, gate); - c_thread_buf_fp32(cidx) = gate * up; + gate = gate * topk_weights.template AsType()[m4]; + up = up * topk_weights.template AsType()[m4]; } + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; } - else + else if(ActivationOperation == Activation::gelu_and_mul) { - c_thread_buf_fp32(cidx) = c_thread_buf[cidx]; + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; if constexpr(MulRoutedWeight) { - c_thread_buf_fp32(cidx) = - topk_weights.template AsType()[m4] * - c_thread_buf_fp32[cidx]; + gate = gate * topk_weights.template AsType()[m4]; + up = up * topk_weights.template AsType()[m4]; } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; } - }); + } + else + { + c_thread_buf_fp32(cidx) = c_thread_buf[cidx]; + if constexpr(MulRoutedWeight) + { + c_thread_buf_fp32(cidx) = + topk_weights.template AsType()[m4] * + c_thread_buf_fp32[cidx]; + } + } }); }); }); - } - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; - - using EDataType = CDataType; - - const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( - problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); - - const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = - MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n, problem.MBlock, problem.NBlock); - - const auto ds_grid_buf = generate_tuple( - [&](auto i) { - return make_dynamic_buffer( - p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); - }, - Number{}); - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_desc_refs = concat_tuple_of_reference( - tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_buf_refs = concat_tuple_of_reference( - tie(c_shuffle_block_buf), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); - - // tuple of starting index of C/Ds blockwise copy - const auto idx_c_ds_block_begin = - container_concat(make_tuple(make_multi_index(0, 0, 0, 0)), - generate_tuple( - [&](auto) { - return make_multi_index(block_m_id, 0, block_n_id, 0); - // return make_multi_index(block_work_idx[I0], 0, - // block_work_idx[I1], 0); - }, - Number{})); - - const auto e_grid_desc_mblock_mperblock_nblock_nperblock = - c_grid_desc_mblock_mperblock_nblock_nperblock; - - using CDEBlockTransferCluster = - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; - const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; - constexpr index_t scatter_weight_idx = 3; // hack fix felix - auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence - // support arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferCluster, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVectors, - CShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags - IndexType, - 1, // ScatterDim - true, // OutputScatter: false, only use scatter weights - scatter_weight_idx // ScatterWeightIdx: ascale - >{c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(0, 0, block_n_id, 0)), - c_element_op}; - - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - // space filling curve for shuffled blockwise C/D/E - constexpr auto sfc_cde_block = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); - constexpr auto EMThreads = - CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1); - constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads; - constexpr auto ENThreads = - CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - StaticallyIndexedArray scatter_offsets; - - auto dstidx = sfc_cde_block.GetIndex(access_id); - const index_t c_token_pos = - block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1); - static_for<0, EMRepeats, 1>{}([&](auto m0) { - const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; - IndexType token_offset = fused_token & 0xffffff; - if constexpr(IsInputGemm) - { - token_offset = token_offset * problem.TopK + (fused_token >> 24); - } - scatter_offsets(m0) = static_cast(token_offset) * problem.N; - }); - - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf_fp32, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - cde_block_copy_lds_and_global.Run( - c_ds_desc_refs, - c_ds_buf_refs, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - tie(c_grid_buf), - scatter_offsets); - - if constexpr(access_id < num_access - 1) - { - constexpr auto cde_lds_and_global_step = - sfc_cde_block.GetForwardStep(access_id); - - // move on Ds - static_for<0, NumDTensor, 1>{}([&](auto i) { - cde_block_copy_lds_and_global.MoveSrcSliceWindow( - c_ds_desc_refs, i + I1, cde_lds_and_global_step); - }); - - // move on E - cde_block_copy_lds_and_global.MoveDstSliceWindow( - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - I0, - cde_lds_and_global_step); - } }); } + + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + Base::template RunMoeEpilogue( + blockwise_gemm_pipeline, + c_grid_desc_mblock_mperblock_nblock_nperblock, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf_fp32, + block_m_id, + block_n_id, + p_shared, + p_sorted_token_ids, + p_c_grid, + p_ds_grid, + c_element_op, + problem.TopK, + problem.N); } template scale_token_ids; - vector_type topk_weights; - static_for<0, NXdlPerWave, 1>{}([&](auto n0) { - const float scale_b = p_scale_b[n0 * NWave * NPerXdl * PerTokenQuant]; - static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave - static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk - const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + - m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; - if constexpr(PerTokenQuant) - { - scale_token_ids = - *c_style_pointer_cast*>( - p_sorted_token_ids + m_pos); - } - if constexpr(MulRoutedWeight) + constexpr index_t scale_stride = (IsInputGemm ? 2 : 1); + p_scale_b += expert_id * problem.N * scale_stride + block_n_id * NPerBlock + + get_warp_local_1d_id() % NWave * NPerXdl + threadIdx.x % NPerXdl; + } + else + { + p_scale_b += expert_id; + } + + vector_type scale_token_ids; + vector_type topk_weights; + static_for<0, NXdlPerWave, 1>{}([&](auto n0) { + const float scale_b = p_scale_b[n0 * NWave * NPerXdl * PerTokenQuant]; + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk + const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + + m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; + if constexpr(PerTokenQuant) + { + scale_token_ids = + *c_style_pointer_cast*>( + p_sorted_token_ids + m_pos); + } + if constexpr(MulRoutedWeight) + { + topk_weights = *c_style_pointer_cast*>( + p_ds_grid[I2] + m_pos); + } + static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size + float scale_a = [&]() { + if constexpr(PerTokenQuant) + { + index_t fused_token = + scale_token_ids.template AsType()[m4]; + const index_t token_offset = fused_token & 0xffffff; + return token_offset < problem.NumTokens + ? p_sorted_weights_0[IsInputGemm + ? token_offset + : token_offset * + problem.TopK + + (fused_token >> 24)] + : 0.0; + } + else + { + return p_sorted_weights_0[0]; + } + }(); + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, m2 * M4 + m4)); + constexpr auto cidx = Number{}; + if constexpr(IsInputGemm) // gu fusion { - topk_weights = *c_style_pointer_cast*>( - p_ds_grid[I2] + m_pos); - } - static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size - float scale_a = [&]() { - if constexpr(PerTokenQuant) + if constexpr(ActivationOperation == Activation::silu_and_mul) + { + const float scale_up = + p_scale_b[(n0 * NWave * NPerXdl + problem.N) * + PerTokenQuant]; + float gate = scale_a * scale_b * c_thread_buf[cidx]; + float up = scale_a * scale_up * c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) { - index_t fused_token = - scale_token_ids.template AsType()[m4]; - const index_t token_offset = fused_token & 0xffffff; - return token_offset < problem.NumTokens - ? p_sorted_weights_0[IsInputGemm - ? token_offset - : token_offset * - problem.TopK + - (fused_token >> - 24)] - : 0.0; + gate = gate * topk_weights.template AsType()[m4]; + up = up * topk_weights.template AsType()[m4]; } - else + if constexpr(is_same_v, pk_i4_t>) { - return p_sorted_weights_0[0]; + gate *= 16; + up *= 16; } - }(); - constexpr index_t c_offset = - blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( - make_tuple(m0, n0, m2 * M4 + m4)); - constexpr auto cidx = Number{}; - if constexpr(IsInputGemm) // gu fusion + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + else if(ActivationOperation == Activation::gelu_and_mul) { - if constexpr(ActivationOperation == Activation::silu_and_mul) + const float scale_up = + p_scale_b[(n0 * NWave * NPerXdl + problem.N) * + PerTokenQuant]; + float gate = scale_a * scale_b * c_thread_buf[cidx]; + float up = scale_a * scale_up * c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) { - const float scale_up = - p_scale_b[(n0 * NWave * NPerXdl + problem.N) * - PerTokenQuant]; - float gate = scale_a * scale_b * c_thread_buf[cidx]; - float up = scale_a * scale_up * c_thread_buf_up[cidx]; - if constexpr(MulRoutedWeight) - { - gate = gate * topk_weights.template AsType()[m4]; - up = up * topk_weights.template AsType()[m4]; - } - if constexpr(is_same_v, pk_i4_t>) - { - gate *= 16; - up *= 16; - } - tensor_operation::element_wise::Silu{}(gate, gate); - c_thread_buf_fp32(cidx) = gate * up; + gate = gate * topk_weights.template AsType()[m4]; + up = up * topk_weights.template AsType()[m4]; } - else if(ActivationOperation == Activation::gelu_and_mul) + if constexpr(is_same_v, pk_i4_t>) { - const float scale_up = - p_scale_b[(n0 * NWave * NPerXdl + problem.N) * - PerTokenQuant]; - float gate = scale_a * scale_b * c_thread_buf[cidx]; - float up = scale_a * scale_up * c_thread_buf_up[cidx]; - if constexpr(MulRoutedWeight) - { - gate = gate * topk_weights.template AsType()[m4]; - up = up * topk_weights.template AsType()[m4]; - } - if constexpr(is_same_v, pk_i4_t>) - { - gate *= 16; - up *= 16; - } - tensor_operation::element_wise::Gelu{}(gate, gate); - c_thread_buf_fp32(cidx) = gate * up; + gate *= 16; + up *= 16; } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; } - else + } + else + { + c_thread_buf_fp32(cidx) = scale_a * scale_b * c_thread_buf[cidx]; + if constexpr(MulRoutedWeight) { c_thread_buf_fp32(cidx) = - scale_a * scale_b * c_thread_buf[cidx]; - if constexpr(MulRoutedWeight) - { - c_thread_buf_fp32(cidx) = - c_thread_buf_fp32(cidx) * - topk_weights.template AsType()[m4]; - } + c_thread_buf_fp32(cidx) * + topk_weights.template AsType()[m4]; } - }); + } }); }); }); - } - else - { - vector_type topk_weights; // for gemm2 only - static_for<0, NXdlPerWave, 1>{}([&](auto n0) { - static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave - static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk - const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + - m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; - if constexpr(MulRoutedWeight) + }); + } + else + { + vector_type topk_weights; // for gemm2 only + static_for<0, NXdlPerWave, 1>{}([&](auto n0) { + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk + const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + + m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; + if constexpr(MulRoutedWeight) + { + topk_weights = *c_style_pointer_cast*>( + p_ds_grid[I2] + m_pos); + } + static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, m2 * M4 + m4)); + constexpr auto cidx = Number{}; + + if constexpr(IsInputGemm) // gu fusion { - topk_weights = *c_style_pointer_cast*>( - p_ds_grid[I2] + m_pos); - } - static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size - constexpr index_t c_offset = - blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( - make_tuple(m0, n0, m2 * M4 + m4)); - constexpr auto cidx = Number{}; - - if constexpr(IsInputGemm) // gu fusion + if constexpr(ActivationOperation == Activation::silu_and_mul) { - if constexpr(ActivationOperation == Activation::silu_and_mul) - { - float gate = c_thread_buf[cidx]; - float up = c_thread_buf_up[cidx]; - if constexpr(MulRoutedWeight) - { - gate = gate * topk_weights.template AsType()[m4]; - up = up * topk_weights.template AsType()[m4]; - } - tensor_operation::element_wise::Silu{}(gate, gate); - c_thread_buf_fp32(cidx) = gate * up; - } - else if(ActivationOperation == Activation::gelu_and_mul) + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) { - float gate = c_thread_buf[cidx]; - float up = c_thread_buf_up[cidx]; - if constexpr(MulRoutedWeight) - { - gate = gate * topk_weights.template AsType()[m4]; - up = up * topk_weights.template AsType()[m4]; - } - tensor_operation::element_wise::Gelu{}(gate, gate); - c_thread_buf_fp32(cidx) = gate * up; + gate = gate * topk_weights.template AsType()[m4]; + up = up * topk_weights.template AsType()[m4]; } + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; } - else + else if(ActivationOperation == Activation::gelu_and_mul) { - c_thread_buf_fp32(cidx) = c_thread_buf[cidx]; + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; if constexpr(MulRoutedWeight) { - c_thread_buf_fp32(cidx) = - topk_weights.template AsType()[m4] * - c_thread_buf_fp32[cidx]; + gate = gate * topk_weights.template AsType()[m4]; + up = up * topk_weights.template AsType()[m4]; } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + } + else + { + c_thread_buf_fp32(cidx) = c_thread_buf[cidx]; + if constexpr(MulRoutedWeight) + { + c_thread_buf_fp32(cidx) = + topk_weights.template AsType()[m4] * + c_thread_buf_fp32[cidx]; } - }); + } }); }); }); - } - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; - - using EDataType = CDataType; - - const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( - problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); - - const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = - MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n, problem.MBlock, problem.NBlock); - - const auto ds_grid_buf = generate_tuple( - [&](auto i) { - return make_dynamic_buffer( - p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); - }, - Number{}); - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_desc_refs = concat_tuple_of_reference( - tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_buf_refs = concat_tuple_of_reference( - tie(c_shuffle_block_buf), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); - - // tuple of starting index of C/Ds blockwise copy - const auto idx_c_ds_block_begin = - container_concat(make_tuple(make_multi_index(0, 0, 0, 0)), - generate_tuple( - [&](auto) { - return make_multi_index(block_m_id, 0, block_n_id, 0); - // return make_multi_index(block_work_idx[I0], 0, - // block_work_idx[I1], 0); - }, - Number{})); - - const auto e_grid_desc_mblock_mperblock_nblock_nperblock = - c_grid_desc_mblock_mperblock_nblock_nperblock; - - using CDEBlockTransferCluster = - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; - const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; - constexpr index_t scatter_weight_idx = 3; // hack fix felix - auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence - // support arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferCluster, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVectors, - CShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags - IndexType, - 1, // ScatterDim - true, // OutputScatter: false, only use scatter weights - scatter_weight_idx // ScatterWeightIdx: ascale - >{c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(0, 0, block_n_id, 0)), - c_element_op}; - - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - // space filling curve for shuffled blockwise C/D/E - constexpr auto sfc_cde_block = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); - constexpr auto EMThreads = - CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1); - constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads; - constexpr auto ENThreads = - CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - StaticallyIndexedArray scatter_offsets; - - auto dstidx = sfc_cde_block.GetIndex(access_id); - const index_t c_token_pos = - block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1); - static_for<0, EMRepeats, 1>{}([&](auto m0) { - const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; - IndexType token_offset = fused_token & 0xffffff; - if constexpr(IsInputGemm) - { - token_offset = token_offset * problem.TopK + (fused_token >> 24); - } - scatter_offsets(m0) = static_cast(token_offset) * problem.N; - }); - - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf_fp32, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - cde_block_copy_lds_and_global.Run( - c_ds_desc_refs, - c_ds_buf_refs, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - tie(c_grid_buf), - scatter_offsets); - - if constexpr(access_id < num_access - 1) - { - constexpr auto cde_lds_and_global_step = - sfc_cde_block.GetForwardStep(access_id); - - // move on Ds - static_for<0, NumDTensor, 1>{}([&](auto i) { - cde_block_copy_lds_and_global.MoveSrcSliceWindow( - c_ds_desc_refs, i + I1, cde_lds_and_global_step); - }); - - // move on E - cde_block_copy_lds_and_global.MoveDstSliceWindow( - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - I0, - cde_lds_and_global_step); - } }); } + + // shuffle C and write out + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + Base::template RunMoeEpilogue( + blockwise_gemm_pipeline, + c_grid_desc_mblock_mperblock_nblock_nperblock, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf_fp32, + block_m_id, + block_n_id, + p_shared, + p_sorted_token_ids, + p_c_grid, + p_ds_grid, + c_element_op, + problem.TopK, + problem.N); } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp index c556dbec106..d70139eea2a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp @@ -10,12 +10,9 @@ #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_selector.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_gather.hpp" -#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp" - +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp" #define DEBUG_LOG 0 namespace ck { @@ -27,12 +24,6 @@ namespace ck { // 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds // buffer when we declare __shared__ inside blkgemmpipe -enum Activation -{ - gelu_and_mul = 0, - silu_and_mul = 1 -}; - template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(GridwiseGemm::MaxBlockSize, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_gemm(typename GridwiseGemm::Argument karg) @@ -48,7 +39,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char + p_shared[GridwiseGemm::template GetSharedMemoryNumberOfByte(get_device_arch())]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); @@ -80,7 +72,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(GridwiseGemm::MaxBlockSize, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg) @@ -88,8 +80,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char + p_shared[GridwiseGemm::template GetSharedMemoryNumberOfByte(get_device_arch())]; + __shared__ char + p_shared1[GridwiseGemm::template GetSharedMemoryNumberOfByte(get_device_arch())]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); @@ -175,30 +169,120 @@ template struct GridwiseMoeGemmBlockScale + : public GridwiseGemm_xdl_cshuffle_base< + ALayout, + BLayout, + CLayout, + LDSTypeA, + LDSTypeB, + AccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + ComputeTypeA, + ComputeTypeB, + false> { + using Base = GridwiseGemm_xdl_cshuffle_base< + ALayout, + BLayout, + CLayout, + LDSTypeA, + LDSTypeB, + AccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + ComputeTypeA, + ComputeTypeB, + false>; + + using Base::AK0Number; + using Base::AK1Number; + using Base::BK0Number; + using Base::BK1Number; + using Base::CShuffleBlockTransferScalarPerVector_NPerBlock; + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::GetSharedMemoryNumberOfByte; + using Base::I0; + using Base::I1; + using Base::I2; + using Base::I3; + using Base::I4; + using Base::I5; + using Base::I6; + using Base::I7; + using ThisThreadBlock = typename Base::ThisThreadBlock; + using Base::NumDTensor; using AScaleType = float; using BScaleType = float; - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; - - static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock = - CDEShuffleBlockTransferScalarPerVectors{}[I0]; - // K1 should be Number<...> - static constexpr auto AK0Number = Number{}; - static constexpr auto BK0Number = Number{}; - static constexpr auto AK1Number = Number{}; - static constexpr auto BK1Number = Number{}; static constexpr auto BlockSizeNumber = Number{}; - static constexpr index_t NumDTensor = DsDataType::Size(); - using mfma_selector = MfmaSelector; static constexpr index_t KPack = math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk); @@ -234,8 +318,6 @@ struct GridwiseMoeGemmBlockScale using DsGridPointer = decltype(MakeDsGridPointer()); - using ThisThreadBlock = ThisThreadBlock; - static constexpr index_t APackedSize = []() { if constexpr(is_same_v, pk_i4_t>) return 2; @@ -785,124 +867,6 @@ struct GridwiseMoeGemmBlockScale index_t bscale_k_split_offset; }; - __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t WaveSize = BlockSize / (MWave * NWave); - // A matrix in LDS memory, dst of blockwise copy - if constexpr(ABlockLdsExtraM) - { - return make_naive_tensor_descriptor( - make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); - } - // xor tensor transformation request more unnecessary vgpr usage, would cause register spill - // in some cases. - else if constexpr(is_same::value) - { - constexpr auto a_lds_block_desc = - make_naive_tensor_descriptor(make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple(make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<1, 0>{}, Sequence<2>{}), - make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - - return a_lds_block_desc_permuted; - } - else // ColumnMajor A - { - // kfold and mpair dimension is not always required. - // more dimension in merge_transform increase the difficulty of generating immarg offset - // for compiler. - constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); - constexpr auto M1 = MPerBlock / M0; - - constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); - constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; - constexpr auto KThreadRead = WaveSize / MPerXdl; - constexpr auto K0PerThreadRead = AK0Number / KThreadRead; - - constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128) - ? 1 - : 128 / (AK1Number * M0 * sizeof(LDSTypeA)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=mpair<=n0 - constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128) - ? 1 - : ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0 - ? M0 - : 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))); - - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - AK1Number)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); - - constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<1>{}, - Sequence<2>{}, - Sequence<0, 3>{}, - Sequence<4, 5>{}, - Sequence<6>{}, - Sequence<7>{})); - - constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(Number{}, - Number{}, - Number{}, - Number{})), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return a_lds_block_desc_ak0_m_ak1; - } - } - __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() { // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack @@ -910,70 +874,34 @@ struct GridwiseMoeGemmBlockScale make_tuple(Number{}, I1, Number{}, Number{})); } - __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - I1, - Number{})); - - return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; - } - - using BlockwiseGemmPipe = - remove_cvref_t ())>; - - __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - // lds max alignment - constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); - - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); - - // LDS allocation for C shuffle in LDS - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - constexpr auto c_block_size = - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - - return math::max(a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize, - c_block_size * sizeof(CShuffleDataType)); - } + using BlockwiseGemmPipe = remove_cvref_t< + decltype(BlockGemmBlockMoeScaleBPreshufflePipeline_Selector < BlkGemmPipelineVer, + BlkGemmPipeSched, + BlockSize, + ADataType, + BDataType, + ComputeTypeA, + AccDataType, + decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch())), + decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), + decltype(MakeAMmaTileDescriptor_M0_M1_M2_K( + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()))), + decltype(MakeBMmaTileDescriptor_N0_N1_N2_K( + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + ScaleBlockM, + ScaleBlockN, + ScaleBlockK, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + IsInputGemm && !IsSplitK > ())>; IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) @@ -1311,7 +1239,8 @@ struct GridwiseMoeGemmBlockScale b_scale_grid_desc_bn_ak.GetElementSpaceSize()); // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy // dummy @@ -1575,368 +1504,110 @@ struct GridwiseMoeGemmBlockScale } // shuffle C and write out - { - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - - // transposed XDL - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); - - // TODO: hacky, fix it! - // c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp = - blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); - - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5); - constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6); - constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7); - - static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock); - static_assert(M0 * M1 * M2 == MPerBlock); - static_assert(N4 == 4 || N4 == 8); - const index_t m1 = get_warp_local_1d_id() / NWave; - const index_t m2 = threadIdx.x % get_warp_size() % M2; - - float topk_weight; - static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave - static_for<0, NXdlPerWave, 1>{}([&](auto n0) { - if constexpr(MulRoutedWeight) - { - const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 + m1 * M2 + m2; - topk_weight = p_ds_grid[I0][m_pos]; - } - static_for<0, N2, 1>{}([&](auto n2) { // num_groups_per_blk - static_for<0, N4, 1>{}([&](auto n4) { // inst_group_size - constexpr index_t c_offset = - blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( - make_tuple(m0, n0, n2 * N4 + n4)); - constexpr auto cidx = Number{}; - if constexpr(IsInputGemm && !IsSplitK) // gu fusion, elementwise + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp = + BlockwiseGemmPipe::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6); + constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7); + + static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock); + static_assert(M0 * M1 * M2 == MPerBlock); + static_assert(N4 == 4 || N4 == 8); + const index_t m1 = get_warp_local_1d_id() / NWave; + const index_t m2 = threadIdx.x % get_warp_size() % M2; + + float topk_weight; + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, NXdlPerWave, 1>{}([&](auto n0) { + if constexpr(MulRoutedWeight) + { + const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 + m1 * M2 + m2; + topk_weight = p_ds_grid[I0][m_pos]; + } + static_for<0, N2, 1>{}([&](auto n2) { // num_groups_per_blk + static_for<0, N4, 1>{}([&](auto n4) { // inst_group_size + constexpr index_t c_offset = + BlockwiseGemmPipe::GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, n2 * N4 + n4)); + constexpr auto cidx = Number{}; + if constexpr(IsInputGemm && !IsSplitK) // gu fusion, elementwise + { + if constexpr(ActivationOperation == Activation::silu_and_mul) { - if constexpr(ActivationOperation == Activation::silu_and_mul) + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) { - float gate = c_thread_buf[cidx]; - float up = c_thread_buf_up[cidx]; - if constexpr(MulRoutedWeight) - { - gate = gate * topk_weight; - up = up * topk_weight; - } - if constexpr(is_same_v, pk_i4_t>) - { - gate *= 16; - up *= 16; - } - tensor_operation::element_wise::Silu{}(gate, gate); - c_thread_buf(cidx) = gate * up; + gate = gate * topk_weight; + up = up * topk_weight; } - else if(ActivationOperation == Activation::gelu_and_mul) + if constexpr(is_same_v, pk_i4_t>) { - float gate = c_thread_buf[cidx]; - float up = c_thread_buf_up[cidx]; - if constexpr(MulRoutedWeight) - { - gate = gate * topk_weight; - up = up * topk_weight; - } - if constexpr(is_same_v, pk_i4_t>) - { - gate *= 16; - up *= 16; - } - tensor_operation::element_wise::Gelu{}(gate, gate); - c_thread_buf(cidx) = gate * up; + gate *= 16; + up *= 16; } + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf(cidx) = gate * up; } - else + else if(ActivationOperation == Activation::gelu_and_mul) { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; if constexpr(MulRoutedWeight) { - c_thread_buf(cidx) = c_thread_buf[cidx] * topk_weight; + gate = gate * topk_weight; + up = up * topk_weight; + } + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf(cidx) = gate * up; } - }); + } + else + { + if constexpr(MulRoutedWeight) + { + c_thread_buf(cidx) = c_thread_buf[cidx] * topk_weight; + } + } }); }); }); + }); - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2)), // M2 = MPerXdl - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2, // N2 * N3 * N4 = NPerXdl - N3, - N4))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4>{}, Sequence<>{}, Sequence<1, 3, 5, 6, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - n_thread_data_on_block_idx[I2], - n_thread_data_on_block_idx[I3], - n_thread_data_on_block_idx[I4]), - ck::tensor_operation::element_wise::PassThrough{}}; - - using EDataType = CDataType; - - const auto ds_grid_desc_m_n = - MakeDsGridDescriptor_M_N(problem.M, - problem.MPadded, - problem.N * (IsInputGemm && IsSplitK ? 2 : 1), - problem.NPadded * (IsInputGemm && IsSplitK ? 2 : 1), - problem.StrideDs); - - const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = - MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n, problem.MBlock, problem.NBlock); - - const auto ds_grid_buf = generate_tuple( - [&](auto i) { - using DDataType = remove_cvref_t>; - const DDataType* ptr_ = p_ds_grid[i]; - // hack logic here to support different kind of strides. todo fix it. - // ascale t, 1; bscale E, N, 1, move ptr to E - return make_dynamic_buffer( - ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize()); - }, - Number{}); - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_desc_refs = concat_tuple_of_reference( - tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_buf_refs = concat_tuple_of_reference( - tie(c_shuffle_block_buf), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); - - // tuple of starting index of C/Ds blockwise copy - const auto idx_c_ds_block_begin = - container_concat(make_tuple(make_multi_index(0, 0, 0, 0)), - generate_tuple( - [&](auto) { - return make_multi_index(block_m_id, 0, block_n_id, 0); - // return make_multi_index(block_work_idx[I0], 0, - // block_work_idx[I1], 0); - }, - Number{})); - - const auto e_grid_desc_mblock_mperblock_nblock_nperblock = - c_grid_desc_mblock_mperblock_nblock_nperblock; - - using CDEBlockTransferCluster = - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; - const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; - constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 1; // hack fix felix - auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence - // support arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferCluster, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVectors, - CShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags - IndexType, - 1, // ScatterDim - true, // OutputScatter: false, only use scatter weights - scatter_weight_idx // ScatterWeightIdx: ascale - >{c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(0, 0, block_n_id, 0)), - c_element_op}; - - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - // space filling curve for threadwise C in VGPR - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - // space filling curve for shuffled blockwise C/D/E - constexpr auto sfc_cde_block = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); - constexpr auto EMThreads = - CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1); - constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads; - constexpr auto ENThreads = - CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - StaticallyIndexedArray scatter_offsets; - - auto dstidx = sfc_cde_block.GetIndex(access_id); - const index_t c_token_pos = - block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1); - static_for<0, EMRepeats, 1>{}([&](auto m0) { - const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; - index_t token_offset = fused_token & 0xffffff; - if constexpr(IsInputGemm) - { - token_offset = token_offset * problem.TopK + (fused_token >> 24); - } - scatter_offsets(m0) = - token_offset * problem.N * (IsInputGemm && IsSplitK ? 2 : 1); - }); - - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - cde_block_copy_lds_and_global.Run( - c_ds_desc_refs, - c_ds_buf_refs, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - tie(c_grid_buf), - scatter_offsets); - - if constexpr(access_id < num_access - 1) - { - constexpr auto cde_lds_and_global_step = - sfc_cde_block.GetForwardStep(access_id); - - // move on Ds - static_for<0, NumDTensor, 1>{}([&](auto i) { - cde_block_copy_lds_and_global.MoveSrcSliceWindow( - c_ds_desc_refs, i + I1, cde_lds_and_global_step); - }); - - // move on E - cde_block_copy_lds_and_global.MoveDstSliceWindow( - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - I0, - cde_lds_and_global_step); - } - }); - } + auto problemN = problem.N * (IsInputGemm && IsSplitK ? 2 : 1); + auto problemNPadded = problem.NPadded * (IsInputGemm && IsSplitK ? 2 : 1); + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problemN, problemNPadded, problem.StrideDs); + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + Base::template RunMoeEpilogue( + blockwise_gemm_pipeline, + c_grid_desc_mblock_mperblock_nblock_nperblock, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_m_id, + block_n_id, + p_shared, + p_sorted_token_ids, + p_c_grid, + p_ds_grid, + c_element_op, + problem.TopK, + problemN); } template {}([&](auto m0) { // MXDLPerWave - static_for<0, NXdlPerWave, 1>{}([&](auto n0) { - if constexpr(MulRoutedWeight) - { - const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 + m1 * M2 + m2; - topk_weight = p_ds_grid[I0][m_pos]; - } - static_for<0, N2, 1>{}([&](auto n2) { // num_groups_per_blk - static_for<0, N4, 1>{}([&](auto n4) { // inst_group_size - constexpr index_t c_offset = - blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( - make_tuple(m0, n0, n2 * N4 + n4)); - constexpr auto cidx = Number{}; - if constexpr(IsInputGemm && !IsSplitK) // gu fusion, elementwise + // c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp = + BlockwiseGemmPipe::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6); + constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7); + + static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock); + static_assert(M0 * M1 * M2 == MPerBlock); + static_assert(N4 == 4 || N4 == 8); + const index_t m1 = get_warp_local_1d_id() / NWave; + const index_t m2 = threadIdx.x % get_warp_size() % M2; + + float topk_weight; + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, NXdlPerWave, 1>{}([&](auto n0) { + if constexpr(MulRoutedWeight) + { + const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 + m1 * M2 + m2; + topk_weight = p_ds_grid[I0][m_pos]; + } + static_for<0, N2, 1>{}([&](auto n2) { // num_groups_per_blk + static_for<0, N4, 1>{}([&](auto n4) { // inst_group_size + constexpr index_t c_offset = + BlockwiseGemmPipe::GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, n2 * N4 + n4)); + constexpr auto cidx = Number{}; + if constexpr(IsInputGemm && !IsSplitK) // gu fusion, elementwise + { + if constexpr(ActivationOperation == Activation::silu_and_mul) { - if constexpr(ActivationOperation == Activation::silu_and_mul) + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) { - float gate = c_thread_buf[cidx]; - float up = c_thread_buf_up[cidx]; - if constexpr(MulRoutedWeight) - { - gate = gate * topk_weight; - up = up * topk_weight; - } - if constexpr(is_same_v, pk_i4_t>) - { - gate *= 16; - up *= 16; - } - tensor_operation::element_wise::Silu{}(gate, gate); - c_thread_buf(cidx) = gate * up; + gate = gate * topk_weight; + up = up * topk_weight; } - else if(ActivationOperation == Activation::gelu_and_mul) + if constexpr(is_same_v, pk_i4_t>) { - float gate = c_thread_buf[cidx]; - float up = c_thread_buf_up[cidx]; - if constexpr(MulRoutedWeight) - { - gate = gate * topk_weight; - up = up * topk_weight; - } - if constexpr(is_same_v, pk_i4_t>) - { - gate *= 16; - up *= 16; - } - tensor_operation::element_wise::Gelu{}(gate, gate); - c_thread_buf(cidx) = gate * up; + gate *= 16; + up *= 16; } + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf(cidx) = gate * up; } - else + else if(ActivationOperation == Activation::gelu_and_mul) { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; if constexpr(MulRoutedWeight) { - c_thread_buf(cidx) = c_thread_buf[cidx] * topk_weight; + gate = gate * topk_weight; + up = up * topk_weight; + } + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf(cidx) = gate * up; } - - }); + } + else + { + if constexpr(MulRoutedWeight) + { + c_thread_buf(cidx) = c_thread_buf[cidx] * topk_weight; + } + } }); }); }); + }); - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2)), // M2 = MPerXdl - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2, // N2 * N3 * N4 = NPerXdl - N3, - N4))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4>{}, Sequence<>{}, Sequence<1, 3, 5, 6, 7>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - n_thread_data_on_block_idx[I2], - n_thread_data_on_block_idx[I3], - n_thread_data_on_block_idx[I4]), - ck::tensor_operation::element_wise::PassThrough{}}; - - using EDataType = CDataType; - - const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( - problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); - - const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = - MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n, problem.MBlock, problem.NBlock); - - const auto ds_grid_buf = generate_tuple( - [&](auto i) { - return make_dynamic_buffer( - p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); - }, - Number{}); - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_desc_refs = concat_tuple_of_reference( - tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_buf_refs = concat_tuple_of_reference( - tie(c_shuffle_block_buf), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); - - // tuple of starting index of C/Ds blockwise copy - const auto idx_c_ds_block_begin = - container_concat(make_tuple(make_multi_index(0, 0, 0, 0)), - generate_tuple( - [&](auto) { - return make_multi_index(block_m_id, 0, block_n_id, 0); - // return make_multi_index(block_work_idx[I0], 0, - // block_work_idx[I1], 0); - }, - Number{})); - - const auto e_grid_desc_mblock_mperblock_nblock_nperblock = - c_grid_desc_mblock_mperblock_nblock_nperblock; - - using CDEBlockTransferCluster = - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; - const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; - constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 1; // hack fix felix - auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence - // support arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferCluster, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVectors, - CShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags - IndexType, - 1, // ScatterDim - true, // OutputScatter: false, only use scatter weights - scatter_weight_idx // ScatterWeightIdx: ascale - >{c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(0, 0, block_n_id, 0)), - c_element_op}; - - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - // space filling curve for threadwise C in VGPR - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - // space filling curve for shuffled blockwise C/D/E - constexpr auto sfc_cde_block = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); - constexpr auto EMThreads = - CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1); - constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads; - constexpr auto ENThreads = - CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - StaticallyIndexedArray - scatter_offsets; //= p_sorted_token_ids[c_token_pos]; - - auto dstidx = sfc_cde_block.GetIndex(access_id); - const index_t c_token_pos = - block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1); - static_for<0, EMRepeats, 1>{}([&](auto m0) { - const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; - index_t token_offset = fused_token & 0xffffff; - if constexpr(IsInputGemm) - { - token_offset = token_offset * problem.TopK + (fused_token >> 24); - } - scatter_offsets(m0) = - token_offset * problem.N * (IsInputGemm && IsSplitK ? 2 : 1); - }); - - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - cde_block_copy_lds_and_global.Run( - c_ds_desc_refs, - c_ds_buf_refs, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - tie(c_grid_buf), - scatter_offsets); - - if constexpr(access_id < num_access - 1) - { - constexpr auto cde_lds_and_global_step = - sfc_cde_block.GetForwardStep(access_id); - - // move on Ds - static_for<0, NumDTensor, 1>{}([&](auto i) { - cde_block_copy_lds_and_global.MoveSrcSliceWindow( - c_ds_desc_refs, i + I1, cde_lds_and_global_step); - }); - - // move on E - cde_block_copy_lds_and_global.MoveDstSliceWindow( - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - I0, - cde_lds_and_global_step); - } - }); - } + // shuffle C and write out + auto problemN = problem.N * (IsInputGemm && IsSplitK ? 2 : 1); + auto problemNPadded = problem.NPadded * (IsInputGemm && IsSplitK ? 2 : 1); + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problemN, problemNPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + Base::template RunMoeEpilogue( + blockwise_gemm_pipeline, + c_grid_desc_mblock_mperblock_nblock_nperblock, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_m_id, + block_n_id, + p_shared, + p_sorted_token_ids, + p_c_grid, + p_ds_grid, + c_element_op, + problem.TopK, + problemN); } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp index bfe13987ac8..1d9742099d0 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp @@ -17,6 +17,7 @@ #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_gather_direct_load.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp" #define DEBUG_LOG 0 @@ -29,12 +30,6 @@ namespace ck { // 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds // buffer when we declare __shared__ inside blkgemmpipe -enum Activation -{ - gelu_and_mul = 0, - silu_and_mul = 1 -}; - #if 0 template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(GridwiseGemm::MaxBlockSize, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_mxgemm(typename GridwiseGemm::Argument karg) @@ -84,7 +79,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(GridwiseGemm::MaxBlockSize, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg) @@ -176,35 +171,129 @@ template struct GridwiseMoeGemmMX + : public GridwiseGemm_xdl_cshuffle_base< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + ComputeTypeA, + ComputeTypeB, + BlkGemmPipelineVer == BlockGemmPipelineVersion::v4, + true, + true> { + using Base = GridwiseGemm_xdl_cshuffle_base< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + ComputeTypeA, + ComputeTypeB, + BlkGemmPipelineVer == BlockGemmPipelineVersion::v4, + true, + true>; + + using Base::AK0Number; + using Base::AK1Number; + using Base::BK0Number; + using Base::BK1Number; + using Base::CShuffleBlockTransferScalarPerVector_NPerBlock; + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; + using Base::I0; + using Base::I1; + using Base::I2; + using Base::I3; + using Base::I4; + using Base::I5; + using Base::I6; + using Base::I7; + using Base::I8; + using Base::I9; + using ThisThreadBlock = typename Base::ThisThreadBlock; + using Base::NumDTensor; + using LDSTypeA = ADataType; using LDSTypeB = BDataType; - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; - static constexpr auto I8 = Number<8>{}; - static constexpr auto I9 = Number<9>{}; - - static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock = - CDEShuffleBlockTransferScalarPerVectors{}[I0]; - // K1 should be Number<...> - static constexpr auto AK0Number = Number{}; - static constexpr auto BK0Number = Number{}; - static constexpr auto AK1Number = Number{}; - static constexpr auto BK1Number = Number{}; - static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); static constexpr bool is_single_rate_mfma = false; static constexpr auto is_scale_mfma = true; - static constexpr index_t NumDTensor = DsDataType::Size(); - static constexpr auto MXdlPack = 2; static constexpr auto NXdlPack = 2; static constexpr auto KXdlPack = 2; @@ -243,8 +332,6 @@ struct GridwiseMoeGemmMX using DsGridPointer = decltype(MakeDsGridPointer()); - using ThisThreadBlock = ThisThreadBlock; - __host__ static auto CalculateGridSize(index_t M, index_t N) { const index_t nblock = math::integer_divide_ceil(N, NPerBlock); @@ -854,294 +941,43 @@ struct GridwiseMoeGemmMX index_t b_scale_k_split_offset; }; - __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - constexpr index_t WaveSize = BlockSize / (MWave * NWave); - - // A matrix in LDS memory, dst of blockwise copy - if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) - { - // contiguous in LDS - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); - } - // xor tensor transformation request more unnecessary vgpr usage, would cause register spill - // in some cases. - else if constexpr(is_same::value) - { - constexpr auto a_lds_block_desc = - make_naive_tensor_descriptor(make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple(make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<1, 0>{}, Sequence<2>{}), - make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - - return a_lds_block_desc_permuted; - } - else // ColumnMajor A - { - // kfold and mpair dimension is not always required. - // more dimension in merge_transform increase the difficulty of generating immarg offset - // for compiler. - constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); - constexpr auto M1 = MPerBlock / M0; - - constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); - constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; - constexpr auto KThreadRead = WaveSize / MPerXdl; - constexpr auto K0PerThreadRead = AK0Number / KThreadRead; - - constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) - ? 1 - : 128 / (AK1Number * M0 * sizeof(ADataType)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=mpair<=n0 - constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128) - ? 1 - : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0 - ? M0 - : 128 / (AK1Number * MPerXdl * sizeof(ADataType))); - - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - AK1Number)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); - - constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<1>{}, - Sequence<2>{}, - Sequence<0, 3>{}, - Sequence<4, 5>{}, - Sequence<6>{}, - Sequence<7>{})); - - constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(Number{}, - Number{}, - Number{}, - Number{})), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return a_lds_block_desc_ak0_m_ak1; - } - } - - __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - constexpr index_t WaveSize = BlockSize / (MWave * NWave); - - // B matrix in LDS memory, dst of blockwise copy - if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) - { - // contiguous in lds - return make_naive_tensor_descriptor( - make_tuple(BK0Number, Number{}, BK1Number), - make_tuple(BK1Number, Number{}, I1)); - } - else if constexpr(is_same::value) - { - // NLdsLayer * K0 as logical Bank - constexpr auto b_lds_block_desc = - make_naive_tensor_descriptor(make_tuple(BK0Number, Number{}, BK1Number), - make_tuple(BK1Number, Number{}, I1)); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc, - make_tuple(make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<1, 0>{}, Sequence<2>{}), - make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - - return b_lds_block_desc_permuted; - } - else // RowMajor B - { - constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); - constexpr auto N1 = NPerBlock / N0; - - constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); - constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; - constexpr auto KThreadRead = WaveSize / NPerXdl; - constexpr auto K0PerThreadRead = BK0Number / KThreadRead; - - constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) - ? 1 - : 128 / (BK1Number * N0 * sizeof(BDataType)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=npair<=n0 - constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128) - ? 1 - : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0 - ? N0 - : 128 / (BK1Number * NPerXdl * sizeof(BDataType))); - - constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - BK1Number)); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(BK1Number)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); - - constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<1>{}, - Sequence<2>{}, - Sequence<0, 3>{}, - Sequence<4, 5>{}, - Sequence<6>{}, - Sequence<7>{})); - - constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(Number{}, - Number{}, - Number{}, - Number{})), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return b_lds_block_desc_bk0_n_bk1; - } - } - - __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - I1, - Number{})); - - return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; - } - - using BlockwiseGemmPipe = - remove_cvref_t())>; + using BlockwiseGemmPipe = remove_cvref_t< + decltype(BlockGemmMXPipeline_Selector< + BlkGemmPipelineVer, + BlkGemmPipeSched, + BlockSize, + ScaleBlockSize, + ADataType, + AScaleDataType, + BDataType, + BScaleDataType, + ComputeTypeA, + AccDataType, + decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch())), + decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch())), + decltype(MakeAMmaTileDescriptor_M0_M1_M2_M3_K( + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()))), + decltype(MakeBMmaTileDescriptor_N0_N1_N2_N3_K( + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()))), + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + IsInputGemm>())>; __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); // lds max alignment constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); @@ -1154,7 +990,7 @@ struct GridwiseMoeGemmMX // LDS allocation for C shuffle in LDS constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + Base::GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(get_device_arch()); constexpr auto c_block_size = c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); @@ -2296,10 +2132,12 @@ struct GridwiseMoeGemmMX constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); // A matrix blockwise direct to LDS copy auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_Gather_DirectLoad< @@ -2554,38 +2392,20 @@ struct GridwiseMoeGemmMX // shuffle C and write out { - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); - static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 && - CShuffleNXdlPerWavePerShuffle % NXdlPack == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); - // TODO: hacky, fix it! // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8); - constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9); // mul scales - + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock); static_assert(M5 == 4); const index_t m1 = get_warp_local_1d_id() / NWave; @@ -2657,285 +2477,27 @@ struct GridwiseMoeGemmMX }); }); }); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared_0), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per - // shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4, - M5)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) - // per shuffle - N1, // N1 = NWave - N2, // N2 = NXdlPack - N3))), // N3 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<>{}, - Sequence<0, 2, 4, 6, 7, 8>{}, - Sequence<>{}, - Sequence<1, 3, 5, 9>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))), - make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))), - make_tuple(Sequence<0, 1, 2, 3>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< - AccDataType, - CShuffleDataType, - decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), - decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), - ck::tensor_operation::element_wise::PassThrough, - Sequence, - Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, - 9, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - n_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - m_thread_data_on_block_idx[I5], - n_thread_data_on_block_idx[I3]), - ck::tensor_operation::element_wise::PassThrough{}}; - - using EDataType = CDataType; - const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( ds_grid_desc_m_n, problem.MBlock, problem.NBlock); - - const auto ds_grid_buf = generate_tuple( - [&](auto i) { - return make_dynamic_buffer( - p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); - }, - Number{}); - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_desc_refs = concat_tuple_of_reference( - tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_buf_refs = concat_tuple_of_reference( - tie(c_shuffle_block_buf), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); - - // tuple of starting index of C/Ds blockwise copy - const auto idx_c_ds_block_begin = - container_concat(make_tuple(make_multi_index(0, 0, 0, 0)), - generate_tuple( - [&](auto) { - return make_multi_index(block_m_id, 0, block_n_id, 0); - // return make_multi_index(block_work_idx[I0], 0, - // block_work_idx[I1], 0); - }, - Number{})); - - const auto e_grid_desc_mblock_mperblock_nblock_nperblock = - c_grid_desc_mblock_mperblock_nblock_nperblock; - - using CDEBlockTransferCluster = - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; - const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; - constexpr index_t scatter_weight_idx = 3; // hack fix felix - auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make - // Sequence support - // arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferCluster, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVectors, - CShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags - IndexType, - 1, // ScatterDim - true, // OutputScatter: false, only use scatter weights - scatter_weight_idx // ScatterWeightIdx: ascale - >{c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(0, 0, block_n_id, 0)), - c_element_op}; - - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, - Sequence>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - // space filling curve for shuffled blockwise C/D/E - constexpr auto sfc_cde_block = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); - constexpr auto EMThreads = - CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1); - constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads; - constexpr auto ENThreads = - CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - StaticallyIndexedArray scatter_offsets; - - auto dstidx = sfc_cde_block.GetIndex(access_id); - const index_t c_token_pos = - block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1); - static_for<0, EMRepeats, 1>{}([&](auto m0) { - const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; - IndexType token_offset = fused_token & 0xffffff; - if constexpr(IsInputGemm) - { - token_offset = token_offset * problem.TopK + (fused_token >> 24); - } - scatter_offsets(m0) = static_cast(token_offset) * problem.N; - }); - - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf_fp32, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - cde_block_copy_lds_and_global.Run( - c_ds_desc_refs, - c_ds_buf_refs, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - tie(c_grid_buf), - scatter_offsets); - - if constexpr(access_id < num_access - 1) - { - constexpr auto cde_lds_and_global_step = - sfc_cde_block.GetForwardStep(access_id); - - // move on Ds - static_for<0, NumDTensor, 1>{}([&](auto i) { - cde_block_copy_lds_and_global.MoveSrcSliceWindow( - c_ds_desc_refs, i + I1, cde_lds_and_global_step); - }); - - // move on E - cde_block_copy_lds_and_global.MoveDstSliceWindow( - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - I0, - cde_lds_and_global_step); - } - }); + Base:: + template RunMoeEpilogue( + blockwise_gemm_pipeline, + c_grid_desc_mblock_mperblock_nblock_nperblock, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf_fp32, + block_m_id, + block_n_id, + p_shared_0, + p_sorted_token_ids, + p_c_grid, + p_ds_grid, + c_element_op, + problem.TopK, + problem.N); } } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp index 34c65b4626c..8cf03e3f5c0 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp @@ -17,6 +17,7 @@ #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp" #define DEBUG_LOG 0 @@ -29,12 +30,6 @@ namespace ck { // 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds // buffer when we declare __shared__ inside blkgemmpipe -enum Activation -{ - gelu_and_mul = 0, - silu_and_mul = 1 -}; - template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(GridwiseGemm::MaxBlockSize, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_mxgemm(typename GridwiseGemm::Argument karg) @@ -83,7 +78,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(GridwiseGemm::MaxBlockSize, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg) @@ -176,31 +171,123 @@ template struct GridwiseMoeGemmMXBNS + : public GridwiseGemm_xdl_cshuffle_base< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + ComputeTypeA, + ComputeTypeB, + false, + true, + true> { - using LDSTypeA = ADataType; - using LDSTypeB = BDataType; - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; - static constexpr auto I8 = Number<8>{}; - static constexpr auto I9 = Number<9>{}; - - static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock = - CDEShuffleBlockTransferScalarPerVectors{}[I0]; - // K1 should be Number<...> - static constexpr auto AK0Number = Number{}; - static constexpr auto BK0Number = Number{}; - static constexpr auto AK1Number = Number{}; - static constexpr auto BK1Number = Number{}; - - static constexpr index_t NumDTensor = DsDataType::Size(); - + using Base = GridwiseGemm_xdl_cshuffle_base< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + ComputeTypeA, + ComputeTypeB, + false, + true, + true>; + + using Base::AK0Number; + using Base::AK1Number; + using Base::BK0Number; + using Base::BK1Number; + using Base::CShuffleBlockTransferScalarPerVector_NPerBlock; + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; + using Base::I0; + using Base::I1; + using Base::I2; + using Base::I3; + using Base::I4; + using Base::I5; + using Base::I6; + using Base::I7; + using Base::I8; + using Base::I9; + using ThisThreadBlock = typename Base::ThisThreadBlock; + using Base::NumDTensor; + using LDSTypeA = ADataType; + using LDSTypeB = BDataType; static constexpr auto MXdlPack = 2; static constexpr auto NXdlPack = 2; static constexpr auto KXdlPack = 2; @@ -235,8 +322,6 @@ struct GridwiseMoeGemmMXBNS using DsGridPointer = decltype(MakeDsGridPointer()); - using ThisThreadBlock = ThisThreadBlock; - __host__ static auto CalculateGridSize(index_t M, index_t N) { const index_t nblock = math::integer_divide_ceil(N, NPerBlock); @@ -787,322 +872,40 @@ struct GridwiseMoeGemmMXBNS index_t b_scale_k_split_offset; }; - __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - constexpr index_t WaveSize = BlockSize / (MWave * NWave); - - // A matrix in LDS memory, dst of blockwise copy - if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) - { - return make_naive_tensor_descriptor( - make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); - } - // xor tensor transformation request more unnecessary vgpr usage, would cause register spill - // in some cases. - else if constexpr(is_same::value) - { - constexpr auto a_lds_block_desc = - make_naive_tensor_descriptor(make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple(make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<1, 0>{}, Sequence<2>{}), - make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - - return a_lds_block_desc_permuted; - } - else // ColumnMajor A - { - // kfold and mpair dimension is not always required. - // more dimension in merge_transform increase the difficulty of generating immarg offset - // for compiler. - constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); - constexpr auto M1 = MPerBlock / M0; - - constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); - constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; - constexpr auto KThreadRead = WaveSize / MPerXdl; - constexpr auto K0PerThreadRead = AK0Number / KThreadRead; - - constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) - ? 1 - : 128 / (AK1Number * M0 * sizeof(ADataType)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=mpair<=n0 - constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128) - ? 1 - : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0 - ? M0 - : 128 / (AK1Number * MPerXdl * sizeof(ADataType))); - - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - AK1Number)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); - - constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<1>{}, - Sequence<2>{}, - Sequence<0, 3>{}, - Sequence<4, 5>{}, - Sequence<6>{}, - Sequence<7>{})); - - constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(Number{}, - Number{}, - Number{}, - Number{})), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return a_lds_block_desc_ak0_m_ak1; - } - } - - __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - constexpr index_t WaveSize = BlockSize / (MWave * NWave); - - // B matrix in LDS memory, dst of blockwise copy - if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) - { - return make_naive_tensor_descriptor( - make_tuple(BK0Number, Number{}, BK1Number), - make_tuple(BK1Number, Number{}, I1)); - } - else if constexpr(is_same::value) - { - // NLdsLayer * K0 as logical Bank - constexpr auto b_lds_block_desc = - make_naive_tensor_descriptor(make_tuple(BK0Number, Number{}, BK1Number), - make_tuple(BK1Number, Number{}, I1)); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc, - make_tuple(make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<1, 0>{}, Sequence<2>{}), - make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - - return b_lds_block_desc_permuted; - } - else // RowMajor B - { - constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); - constexpr auto N1 = NPerBlock / N0; - - constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); - constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; - constexpr auto KThreadRead = WaveSize / NPerXdl; - constexpr auto K0PerThreadRead = BK0Number / KThreadRead; - - constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) - ? 1 - : 128 / (BK1Number * N0 * sizeof(BDataType)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=npair<=n0 - constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128) - ? 1 - : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0 - ? N0 - : 128 / (BK1Number * NPerXdl * sizeof(BDataType))); - - constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - BK1Number)); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(BK1Number)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); - - constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<1>{}, - Sequence<2>{}, - Sequence<0, 3>{}, - Sequence<4, 5>{}, - Sequence<6>{}, - Sequence<7>{})); - - constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(Number{}, - Number{}, - Number{}, - Number{})), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return b_lds_block_desc_bk0_n_bk1; - } - } - - __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - I1, - Number{})); - - return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; - } - - using BlockwiseGemmPipe = - remove_cvref_t())>; + using BlockwiseGemmPipe = remove_cvref_t< + decltype(BlockGemmMXNBSPipeline_Selector< + BlkGemmPipelineVer, + BlkGemmPipeSched, + BlockSize, + ScaleBlockSize, + ADataType, + AScaleDataType, + BDataType, + BScaleDataType, + ComputeTypeA, + AccDataType, + decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch())), + decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch())), + decltype(MakeAMmaTileDescriptor_M0_M1_M2_M3_K( + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()))), + decltype(MakeBMmaTileDescriptor_N0_N1_N2_N3_K( + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()))), + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + IsInputGemm>())>; __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - - // lds max alignment - constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); - - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size_aligned = math::integer_least_multiple( - b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); - - // LDS allocation for C shuffle in LDS - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - constexpr auto c_block_size = - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - - if constexpr(IsInputGemm) - { - return math::max((a_block_space_size_aligned * sizeof(ADataType) + - b_block_space_size_aligned * sizeof(BDataType)) * - 2, - c_block_size * sizeof(CShuffleDataType)); - } - else - { - return math::max((a_block_space_size_aligned * sizeof(ADataType) + - b_block_space_size_aligned * sizeof(BDataType)), - c_block_size * sizeof(CShuffleDataType)); - } + return Base::template GetSharedMemoryNumberOfByte < false, + IsInputGemm ? 2 : 1 > (get_device_arch()); } IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) @@ -1459,10 +1262,12 @@ struct GridwiseMoeGemmMXBNS constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); // A matrix blockwise copy auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather< @@ -1731,37 +1536,21 @@ struct GridwiseMoeGemmMXBNS // shuffle C and write out { - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); - static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 && - CShuffleNXdlPerWavePerShuffle % NXdlPack == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); - - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); - // TODO: hacky, fix it! // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8); - constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9); // mul scales + constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; + static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock); static_assert(M5 == 4); const index_t m1 = get_warp_local_1d_id() / NWave; // Mwave id @@ -1844,285 +1633,27 @@ struct GridwiseMoeGemmMXBNS }); }); }); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) - // per shuffle - M1, // M1 = MWave - M2, // M2 = MXdlPack - M3, // M3 * M4 * M5 = MPerXdl - M4, - M5)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) - // per shuffle - N1, // N1 = NWave - N2, // N2 = NXdlPack - N3))), // N3 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<>{}, - Sequence<0, 2, 4, 6, 7, 8>{}, - Sequence<>{}, - Sequence<1, 3, 5, 9>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))), - make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))), - make_tuple(Sequence<0, 1, 2, 3>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< - AccDataType, - CShuffleDataType, - decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), - decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), - ck::tensor_operation::element_wise::PassThrough, - Sequence, - Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, - 9, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - n_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - m_thread_data_on_block_idx[I5], - n_thread_data_on_block_idx[I3]), - ck::tensor_operation::element_wise::PassThrough{}}; - - using EDataType = CDataType; - const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( ds_grid_desc_m_n, problem.MBlock, problem.NBlock); - - const auto ds_grid_buf = generate_tuple( - [&](auto i) { - return make_dynamic_buffer( - p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); - }, - Number{}); - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_desc_refs = concat_tuple_of_reference( - tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_buf_refs = concat_tuple_of_reference( - tie(c_shuffle_block_buf), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); - - // tuple of starting index of C/Ds blockwise copy - const auto idx_c_ds_block_begin = - container_concat(make_tuple(make_multi_index(0, 0, 0, 0)), - generate_tuple( - [&](auto) { - return make_multi_index(block_m_id, 0, block_n_id, 0); - // return make_multi_index(block_work_idx[I0], 0, - // block_work_idx[I1], 0); - }, - Number{})); - - const auto e_grid_desc_mblock_mperblock_nblock_nperblock = - c_grid_desc_mblock_mperblock_nblock_nperblock; - - using CDEBlockTransferCluster = - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; - const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; - constexpr index_t scatter_weight_idx = 3; // hack fix felix - auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make - // Sequence support - // arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferCluster, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVectors, - CShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags - IndexType, - 1, // ScatterDim - true, // OutputScatter: false, only use scatter weights - scatter_weight_idx // ScatterWeightIdx: ascale - >{c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(0, 0, block_n_id, 0)), - c_element_op}; - - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, - Sequence>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - // space filling curve for shuffled blockwise C/D/E - constexpr auto sfc_cde_block = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); - constexpr auto EMThreads = - CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1); - constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads; - constexpr auto ENThreads = - CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - StaticallyIndexedArray scatter_offsets; - - auto dstidx = sfc_cde_block.GetIndex(access_id); - const index_t c_token_pos = - block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1); - static_for<0, EMRepeats, 1>{}([&](auto m0) { - const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; - IndexType token_offset = fused_token & 0xffffff; - if constexpr(IsInputGemm) - { - token_offset = token_offset * problem.TopK + (fused_token >> 24); - } - scatter_offsets(m0) = static_cast(token_offset) * problem.N; - }); - - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf_fp32, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - cde_block_copy_lds_and_global.Run( - c_ds_desc_refs, - c_ds_buf_refs, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - tie(c_grid_buf), - scatter_offsets); - - if constexpr(access_id < num_access - 1) - { - constexpr auto cde_lds_and_global_step = - sfc_cde_block.GetForwardStep(access_id); - - // move on Ds - static_for<0, NumDTensor, 1>{}([&](auto i) { - cde_block_copy_lds_and_global.MoveSrcSliceWindow( - c_ds_desc_refs, i + I1, cde_lds_and_global_step); - }); - - // move on E - cde_block_copy_lds_and_global.MoveDstSliceWindow( - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - I0, - cde_lds_and_global_step); - } - }); + Base:: + template RunMoeEpilogue( + blockwise_gemm_pipeline, + c_grid_desc_mblock_mperblock_nblock_nperblock, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf_fp32, + block_m_id, + block_n_id, + p_shared, + p_sorted_token_ids, + p_c_grid, + p_ds_grid, + c_element_op, + problem.TopK, + problem.N); } } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp index 2075f7e8d54..c8a917a1156 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp @@ -17,6 +17,7 @@ #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_gather_direct_load.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp" #define DEBUG_LOG 0 @@ -29,12 +30,6 @@ namespace ck { // 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds // buffer when we declare __shared__ inside blkgemmpipe -enum Activation -{ - gelu_and_mul = 0, - silu_and_mul = 1 -}; - template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(GridwiseGemm::MaxBlockSize, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_mxgemm(typename GridwiseGemm::Argument karg) { #if defined(__gfx9__) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char + p_shared[GridwiseGemm::template GetSharedMemoryNumberOfByte(get_device_arch())]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); @@ -79,7 +75,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(GridwiseGemm::MaxBlockSize, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg) @@ -87,8 +83,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #if defined(__gfx9__) if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char + p_shared_0[GridwiseGemm::template GetSharedMemoryNumberOfByte(get_device_arch())]; + __shared__ char + p_shared_1[GridwiseGemm::template GetSharedMemoryNumberOfByte(get_device_arch())]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); @@ -171,35 +169,128 @@ template struct GridwiseMoeGemmMX_BPreshuffle + : public GridwiseGemm_xdl_cshuffle_base< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + ComputeTypeA, + ComputeTypeB, + false, + true, + true> { + using Base = GridwiseGemm_xdl_cshuffle_base< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + ComputeTypeA, + ComputeTypeB, + false, + true, + true>; + + using Base::AK0Number; + using Base::AK1Number; + using Base::BK0Number; + using Base::BK1Number; + using Base::CShuffleBlockTransferScalarPerVector_NPerBlock; + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::I0; + using Base::I1; + using Base::I2; + using Base::I3; + using Base::I4; + using Base::I5; + using Base::I6; + using Base::I7; + using Base::I8; + using Base::I9; + using ThisThreadBlock = typename Base::ThisThreadBlock; + using Base::NumDTensor; + using LDSTypeA = ADataType; using LDSTypeB = BDataType; - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; - static constexpr auto I8 = Number<8>{}; - static constexpr auto I9 = Number<9>{}; - - static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock = - CDEShuffleBlockTransferScalarPerVectors{}[I0]; - // K1 should be Number<...> - static constexpr auto AK0Number = Number{}; - static constexpr auto BK0Number = Number{}; - static constexpr auto AK1Number = Number{}; - static constexpr auto BK1Number = Number{}; - static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); static constexpr bool is_single_rate_mfma = false; static constexpr auto is_scale_mfma = true; - static constexpr index_t NumDTensor = DsDataType::Size(); - static constexpr auto MXdlPack = 2; static constexpr auto NXdlPack = 2; static constexpr auto KXdlPack = 2; @@ -251,8 +342,6 @@ struct GridwiseMoeGemmMX_BPreshuffle using DsGridPointer = decltype(MakeDsGridPointer()); - using ThisThreadBlock = ThisThreadBlock; - __host__ static auto CalculateGridSize(index_t M, index_t N) { const index_t nblock = math::integer_divide_ceil(N, NPerBlock); @@ -867,126 +956,6 @@ struct GridwiseMoeGemmMX_BPreshuffle index_t b_scale_k_split_offset; }; - __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - constexpr index_t WaveSize = BlockSize / (MWave * NWave); - - // A matrix in LDS memory, dst of blockwise copy - if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) - { - // contiguous in LDS - return make_naive_tensor_descriptor( - make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); - } - // xor tensor transformation request more unnecessary vgpr usage, would cause register spill - // in some cases. - else if constexpr(is_same::value) - { - constexpr auto a_lds_block_desc = - make_naive_tensor_descriptor(make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple(make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<1, 0>{}, Sequence<2>{}), - make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - - return a_lds_block_desc_permuted; - } - else // ColumnMajor A - { - // kfold and mpair dimension is not always required. - // more dimension in merge_transform increase the difficulty of generating immarg offset - // for compiler. - constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); - constexpr auto M1 = MPerBlock / M0; - - constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); - constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; - constexpr auto KThreadRead = WaveSize / MPerXdl; - constexpr auto K0PerThreadRead = AK0Number / KThreadRead; - - constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) - ? 1 - : 128 / (AK1Number * M0 * sizeof(ADataType)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=mpair<=n0 - constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128) - ? 1 - : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0 - ? M0 - : 128 / (AK1Number * MPerXdl * sizeof(ADataType))); - - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - AK1Number)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); - - constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<1>{}, - Sequence<2>{}, - Sequence<0, 3>{}, - Sequence<4, 5>{}, - Sequence<6>{}, - Sequence<7>{})); - - constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(Number{}, - Number{}, - Number{}, - Number{})), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return a_lds_block_desc_ak0_m_ak1; - } - } - __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() { // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack @@ -997,70 +966,35 @@ struct GridwiseMoeGemmMX_BPreshuffle Number{})); } - __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() - { - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - I1, - Number{})); - - return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; - } - - using BlockwiseGemmPipe = - remove_cvref_t())>; - - __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - // lds max alignment - constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); - - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); - - // LDS allocation for C shuffle in LDS - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - constexpr auto c_block_size = - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - - return math::max(a_block_space_size_aligned * sizeof(ADataType), - c_block_size * sizeof(CShuffleDataType)); - } + using BlockwiseGemmPipe = remove_cvref_t< + decltype(BlockGemmMXBPreshufflePipeline_Selector< + BlkGemmPipelineVer, + BlkGemmPipeSched, + BlockSize, + ScaleBlockSize, + ADataType, + AScaleDataType, + BDataType, + BScaleDataType, + ComputeTypeA, + AccDataType, + decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch())), + decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), + decltype(MakeAMmaTileDescriptor_M0_M1_M2_M3_K( + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()))), + decltype(MakeBMmaTileDescriptor_N0_N1_N2_N3_K( + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + IsInputGemm>())>; IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) @@ -1419,7 +1353,8 @@ struct GridwiseMoeGemmMX_BPreshuffle b_scale_grid_desc_bn_ak.GetElementSpaceSize()); // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); @@ -1656,35 +1591,16 @@ struct GridwiseMoeGemmMX_BPreshuffle // shuffle C and write out { - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); - static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 && - CShuffleNXdlPerWavePerShuffle % NXdlPack == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); - - // TODO: hacky, fix it! // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8); - constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9); - // mul scales static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock); @@ -1758,285 +1674,27 @@ struct GridwiseMoeGemmMX_BPreshuffle }); }); }); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per - // shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4, - M5)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) - // per shuffle - N1, // N1 = NWave - N2, // N2 = NXdlPack - N3))), // N3 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<>{}, - Sequence<0, 2, 4, 6, 7, 8>{}, - Sequence<>{}, - Sequence<1, 3, 5, 9>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))), - make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))), - make_tuple(Sequence<0, 1, 2, 3>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< - AccDataType, - CShuffleDataType, - decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), - decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), - ck::tensor_operation::element_wise::PassThrough, - Sequence, - Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, - 9, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - n_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - m_thread_data_on_block_idx[I5], - n_thread_data_on_block_idx[I3]), - ck::tensor_operation::element_wise::PassThrough{}}; - - using EDataType = CDataType; - const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( ds_grid_desc_m_n, problem.MBlock, problem.NBlock); - - const auto ds_grid_buf = generate_tuple( - [&](auto i) { - return make_dynamic_buffer( - p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); - }, - Number{}); - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_desc_refs = concat_tuple_of_reference( - tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_buf_refs = concat_tuple_of_reference( - tie(c_shuffle_block_buf), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); - - // tuple of starting index of C/Ds blockwise copy - const auto idx_c_ds_block_begin = - container_concat(make_tuple(make_multi_index(0, 0, 0, 0)), - generate_tuple( - [&](auto) { - return make_multi_index(block_m_id, 0, block_n_id, 0); - // return make_multi_index(block_work_idx[I0], 0, - // block_work_idx[I1], 0); - }, - Number{})); - - const auto e_grid_desc_mblock_mperblock_nblock_nperblock = - c_grid_desc_mblock_mperblock_nblock_nperblock; - - using CDEBlockTransferCluster = - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; - const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; - constexpr index_t scatter_weight_idx = 3; // hack fix felix - auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make - // Sequence support - // arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferCluster, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVectors, - CShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags - IndexType, - 1, // ScatterDim - true, // OutputScatter: false, only use scatter weights - scatter_weight_idx // ScatterWeightIdx: ascale - >{c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(0, 0, block_n_id, 0)), - c_element_op}; - - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, - Sequence>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - // space filling curve for shuffled blockwise C/D/E - constexpr auto sfc_cde_block = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); - constexpr auto EMThreads = - CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1); - constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads; - constexpr auto ENThreads = - CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - StaticallyIndexedArray scatter_offsets; - - auto dstidx = sfc_cde_block.GetIndex(access_id); - const index_t c_token_pos = - block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1); - static_for<0, EMRepeats, 1>{}([&](auto m0) { - const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; - IndexType token_offset = fused_token & 0xffffff; - if constexpr(IsInputGemm) - { - token_offset = token_offset * problem.TopK + (fused_token >> 24); - } - scatter_offsets(m0) = static_cast(token_offset) * problem.N; - }); - - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf_fp32, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - cde_block_copy_lds_and_global.Run( - c_ds_desc_refs, - c_ds_buf_refs, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - tie(c_grid_buf), - scatter_offsets); - - if constexpr(access_id < num_access - 1) - { - constexpr auto cde_lds_and_global_step = - sfc_cde_block.GetForwardStep(access_id); - - // move on Ds - static_for<0, NumDTensor, 1>{}([&](auto i) { - cde_block_copy_lds_and_global.MoveSrcSliceWindow( - c_ds_desc_refs, i + I1, cde_lds_and_global_step); - }); - - // move on E - cde_block_copy_lds_and_global.MoveDstSliceWindow( - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - I0, - cde_lds_and_global_step); - } - }); + Base:: + template RunMoeEpilogue( + blockwise_gemm_pipeline, + c_grid_desc_mblock_mperblock_nblock_nperblock, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf_fp32, + block_m_id, + block_n_id, + p_shared, + p_sorted_token_ids, + p_c_grid, + p_ds_grid, + c_element_op, + problem.TopK, + problem.N); } } @@ -2185,7 +1843,8 @@ struct GridwiseMoeGemmMX_BPreshuffle b_scale_grid_desc_bn_ak.GetElementSpaceSize()); // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()); // B matrix in LDS memory, dst of blockwise copy constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); @@ -2425,37 +2084,19 @@ struct GridwiseMoeGemmMX_BPreshuffle // shuffle C and write out { - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && - NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, - "wrong!"); - static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 && - CShuffleNXdlPerWavePerShuffle % NXdlPack == 0, - "wrong!"); - - constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); - - // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); - // TODO: hacky, fix it! // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8); - constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9); // mul scales - static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock); static_assert(M5 == 4); const index_t m1 = get_warp_local_1d_id() / NWave; @@ -2527,285 +2168,27 @@ struct GridwiseMoeGemmMX_BPreshuffle }); }); }); - - constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared_0), - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per - // shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4, - M5)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) - // per shuffle - N1, // N1 = NWave - N2, // N2 = NXdlPack - N3))), // N3 = NPerXdl - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<>{}, - Sequence<0, 2, 4, 6, 7, 8>{}, - Sequence<>{}, - Sequence<1, 3, 5, 9>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))), - make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))), - make_tuple(Sequence<0, 1, 2, 3>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( - make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< - AccDataType, - CShuffleDataType, - decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), - decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), - ck::tensor_operation::element_wise::PassThrough, - Sequence, - Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, - 9, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - n_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - m_thread_data_on_block_idx[I5], - n_thread_data_on_block_idx[I3]), - ck::tensor_operation::element_wise::PassThrough{}}; - - using EDataType = CDataType; - const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( ds_grid_desc_m_n, problem.MBlock, problem.NBlock); - - const auto ds_grid_buf = generate_tuple( - [&](auto i) { - return make_dynamic_buffer( - p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); - }, - Number{}); - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_desc_refs = concat_tuple_of_reference( - tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); - - // tuple of reference to C/Ds tensor descriptors - const auto c_ds_buf_refs = concat_tuple_of_reference( - tie(c_shuffle_block_buf), - generate_tie([&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); - - // tuple of starting index of C/Ds blockwise copy - const auto idx_c_ds_block_begin = - container_concat(make_tuple(make_multi_index(0, 0, 0, 0)), - generate_tuple( - [&](auto) { - return make_multi_index(block_m_id, 0, block_n_id, 0); - // return make_multi_index(block_work_idx[I0], 0, - // block_work_idx[I1], 0); - }, - Number{})); - - const auto e_grid_desc_mblock_mperblock_nblock_nperblock = - c_grid_desc_mblock_mperblock_nblock_nperblock; - - using CDEBlockTransferCluster = - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; - const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; - constexpr index_t scatter_weight_idx = 3; // hack fix felix - auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make - // Sequence support - // arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferCluster, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVectors, - CShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags - IndexType, - 1, // ScatterDim - true, // OutputScatter: false, only use scatter weights - scatter_weight_idx // ScatterWeightIdx: ascale - >{c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(0, 0, block_n_id, 0)), - c_element_op}; - - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, - Sequence>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - // space filling curve for shuffled blockwise C/D/E - constexpr auto sfc_cde_block = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; - - static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); - constexpr auto EMThreads = - CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1); - constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads; - constexpr auto ENThreads = - CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - StaticallyIndexedArray scatter_offsets; - - auto dstidx = sfc_cde_block.GetIndex(access_id); - const index_t c_token_pos = - block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1); - static_for<0, EMRepeats, 1>{}([&](auto m0) { - const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; - IndexType token_offset = fused_token & 0xffffff; - if constexpr(IsInputGemm) - { - token_offset = token_offset * problem.TopK + (fused_token >> 24); - } - scatter_offsets(m0) = static_cast(token_offset) * problem.N; - }); - - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf_fp32, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - cde_block_copy_lds_and_global.Run( - c_ds_desc_refs, - c_ds_buf_refs, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - tie(c_grid_buf), - scatter_offsets); - - if constexpr(access_id < num_access - 1) - { - constexpr auto cde_lds_and_global_step = - sfc_cde_block.GetForwardStep(access_id); - - // move on Ds - static_for<0, NumDTensor, 1>{}([&](auto i) { - cde_block_copy_lds_and_global.MoveSrcSliceWindow( - c_ds_desc_refs, i + I1, cde_lds_and_global_step); - }); - - // move on E - cde_block_copy_lds_and_global.MoveDstSliceWindow( - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - I0, - cde_lds_and_global_step); - } - }); + Base:: + template RunMoeEpilogue( + blockwise_gemm_pipeline, + c_grid_desc_mblock_mperblock_nblock_nperblock, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf_fp32, + block_m_id, + block_n_id, + p_shared_0, + p_sorted_token_ids, + p_c_grid, + p_ds_grid, + c_element_op, + problem.TopK, + problem.N); } } }; diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp index 262702c693e..cbca8629c38 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp @@ -293,7 +293,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3 // convolution forward. For some reason for that specific type there is an ambiguity // in the type resolution for the ternary expression. I added an explicit cast to // disambiguate and only use it for f8 just in case it affects performance. - if constexpr(std::is_same_v) + if constexpr(is_same_v) { elm_vectors(i).template AsType()(I0) = oob_val ? elm_vector_t{elm_vectors(i).template AsType()[I0]} diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index 67712be4830..f2f49f78730 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -1322,7 +1322,7 @@ struct MfmaSelector } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() { #if defined(__gfx12__) return MfmaInstr::wmma_unsupport_16x16_gfx12; @@ -1338,7 +1338,7 @@ struct MfmaSelector } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() { #if defined(__gfx12__) return MfmaInstr::wmma_unsupport_16x16_gfx12; diff --git a/include/ck/utility/amd_arch.hpp b/include/ck/utility/amd_arch.hpp new file mode 100644 index 00000000000..c42ddce012a --- /dev/null +++ b/include/ck/utility/amd_arch.hpp @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. + +#pragma once + +#include "ck/ck.hpp" + +namespace ck { + +// Architecture tags +struct gfx9_t +{ +}; +struct gfx950_t +{ +}; +struct gfx103_t +{ +}; +struct gfx11_t +{ +}; +struct gfx12_t +{ +}; +struct gfx_invalid_t +{ +}; + +static constexpr auto get_device_arch() +{ +#if defined(__gfx950__) + return gfx950_t{}; +#elif defined(__gfx9__) + return gfx9_t{}; +#elif defined(__gfx10__) + return gfx103_t{}; +#elif defined(__gfx11__) + return gfx11_t{}; +#elif defined(__gfx12__) + return gfx12_t{}; +#else + return gfx_invalid_t{}; +#endif +} + +template +static constexpr index_t get_lds_size(DeviceArch) +{ + return 64 * 1024; +} +template <> +constexpr index_t get_lds_size(gfx950_t) +{ + return 160 * 1024; +} + +template +static constexpr index_t get_n_lds_banks(DeviceArch) +{ + return 32; +} +template <> +constexpr index_t get_n_lds_banks(gfx950_t) +{ + return 64; +} + +template +static constexpr index_t get_max_vgpr_count(DeviceArch) +{ + return 256; +} +template <> +constexpr index_t get_max_vgpr_count(gfx950_t) +{ + return 512; +} +template <> +constexpr index_t get_max_vgpr_count(gfx9_t) +{ + return 512; +} + +} // namespace ck diff --git a/include/ck/utility/common_header.hpp b/include/ck/utility/common_header.hpp index 78c3b78de15..cf8e664e35a 100644 --- a/include/ck/utility/common_header.hpp +++ b/include/ck/utility/common_header.hpp @@ -33,6 +33,7 @@ #include "ck/utility/thread_group.hpp" #include "ck/utility/debug.hpp" +#include "ck/utility/amd_arch.hpp" #if __clang_major__ >= 20 #include "amd_buffer_addressing_builtins.hpp" #else diff --git a/profiler/include/profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp b/profiler/include/profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp index c9eddd1553c..6e869532bd2 100644 --- a/profiler/include/profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp +++ b/profiler/include/profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp @@ -152,6 +152,12 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification, b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1{1}); break; + case 4: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3{0.0, 0.5}); + d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3{0.0, 0.5}); + break; default: a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{1}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential{}); @@ -362,8 +368,16 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification, std::is_same_v && std::is_same_v) { - rtol = 1e-2; - atol = 1e-2; + if(ck::is_gfx11_supported()) + { + rtol = 5e-2; + atol = 5e-2; + } + else + { + rtol = 1e-2; + atol = 1e-2; + } } pass = pass & ck::utils::check_err(c_gs_ms_os_device_result, diff --git a/profiler/include/profiler/profile_batched_gemm_reduce_impl.hpp b/profiler/include/profiler/profile_batched_gemm_reduce_impl.hpp index 249e80309b8..978717e78b5 100644 --- a/profiler/include/profiler/profile_batched_gemm_reduce_impl.hpp +++ b/profiler/include/profiler/profile_batched_gemm_reduce_impl.hpp @@ -318,9 +318,21 @@ bool profile_batched_gemm_reduce_impl(int do_verification, reduce0_device_buf.FromDevice(d0_g_m_device_result.mData.data()); reduce1_device_buf.FromDevice(d1_g_m_device_result.mData.data()); - bool c_error = ck::utils::check_err(c_g_m_n_device_result, c_g_m_n_host_result); - bool d0_error = ck::utils::check_err(d0_g_m_device_result, d0_g_m_host_result); - bool d1_error = ck::utils::check_err(d1_g_m_device_result, d1_g_m_host_result); + bool c_error = ck::utils::check_err(c_g_m_n_device_result, + c_g_m_n_host_result, + "Error: Incorrect results!", + 1e-3, + 1e-3); + bool d0_error = ck::utils::check_err(d0_g_m_device_result, + d0_g_m_host_result, + "Error: Incorrect results!", + 1e-3, + 1e-3); + bool d1_error = ck::utils::check_err(d1_g_m_device_result, + d1_g_m_host_result, + "Error: Incorrect results!", + 1e-3, + 1e-3); pass = pass && (c_error == true); pass = pass && (d0_error == true); diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_util.hpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_util.hpp index 13654b1db8c..cfe5e24791d 100644 --- a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_util.hpp +++ b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_util.hpp @@ -54,6 +54,7 @@ struct TestBatchedGemmMaskingScaleSoftmaxGemmPermute : public ::testing::Test void RunSingle(int M, int N, int K, int O, int G0, int G1) { + int init_method = std::is_same_v ? 4 : 2; bool pass = ck::profiler::profile_batched_gemm_bias_softmax_gemm_permute_impl( - verify_, 2, false, bench_, M, N, K, O, G0, G1, -1, instance_index); + verify_, init_method, false, bench_, M, N, K, O, G0, G1, -1, instance_index); EXPECT_TRUE(pass); } diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp index bce6da4b688..1f7ec247baf 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp @@ -48,7 +48,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test //##########| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| //##########| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | - < NDimSpatial, InLayout, WeiLayout,OutLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 8, 16, 16, 2, 4, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 8>, 4>; + < NDimSpatial, InLayout, WeiLayout,OutLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 4, 16, 16, 16, 2, 4, S<1, 4, 4, 8>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 1, true, S<1, 4, 16, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 8, 4, true, 1, 1, S<1, 16, 1, 8>, 4>; // clang-format on ck::utils::conv::ConvParam conv_param;