Skip to content

Commit ad8995e

Browse files
committed
Fix grouped gemm tile loop
1 parent e05be35 commit ad8995e

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
6666
const CDEElementwiseOperation cde_element_op)
6767
{
6868
#if(defined(__gfx11__) || defined(__gfx12__))
69-
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
70-
typename GridwiseGemm::EpilogueCShuffle>();
69+
using EpilogueType = typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
70+
GridwiseGemm::UseDirectStore,
71+
typename GridwiseGemm::EpilogueDirectStore,
72+
typename GridwiseGemm::EpilogueCShuffle>::type;
73+
74+
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
7175
__shared__ uint8_t p_shared[LDS_size];
7276

7377
const auto gemm_desc_ptr =
@@ -150,7 +154,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
150154
gemm_desc_ptr[group_id].StrideE,
151155
1);
152156

153-
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
157+
auto epilogue_args = EpilogueType{};
154158
constexpr TailNumber TailNum = TailNumber::Full;
155159

156160
if(has_main_k_block_loop)

0 commit comments

Comments
 (0)