1212#include " run_gemm_example_common.hpp"
1313#include " universal_gemm_invoker.hpp"
1414
15+ // Universal GEMM-specific wrapper that handles test_async flag
16+ template <typename GemmConfig,
17+ typename ADataType,
18+ typename BDataType = ADataType,
19+ typename CDataType = ADataType,
20+ typename ALayout,
21+ typename BLayout,
22+ typename CLayout>
23+ int run_gemm_example_with_layouts_universal (ck_tile::ArgParser& arg_parser,
24+ const ALayout a_layout = ALayout{},
25+ const BLayout b_layout = BLayout{},
26+ const CLayout c_layout = CLayout{})
27+ {
28+ using Invoker = UniversalInvoker;
29+ using AccDataType = typename GemmTypeConfig<ADataType, BDataType, CDataType>::AccDataType;
30+
31+ // Check for async input scheduler test mode
32+ bool test_async = arg_parser.get_int (" test_async" );
33+ if (test_async)
34+ {
35+ // Extract parameters for async test (same as shared implementation)
36+ const ck_tile::index_t M = arg_parser.get_int (" m" );
37+ const ck_tile::index_t N = arg_parser.get_int (" n" );
38+ const ck_tile::index_t K = arg_parser.get_int (" k" );
39+ const ck_tile::index_t kbatch = arg_parser.get_int (" split_k" );
40+
41+ using Row = ck_tile::tensor_layout::gemm::RowMajor;
42+ constexpr bool is_a_row_major = std::is_same_v<ALayout, Row>;
43+ constexpr bool is_b_row_major = std::is_same_v<BLayout, Row>;
44+ constexpr bool is_c_row_major = std::is_same_v<CLayout, Row>;
45+
46+ const ck_tile::index_t stride_A = is_a_row_major ? K : M;
47+ const ck_tile::index_t stride_B = is_b_row_major ? N : K;
48+ const ck_tile::index_t stride_C = is_c_row_major ? N : M;
49+
50+ // Allocate and initialize tensors
51+ ck_tile::HostTensor<ADataType> a_m_k (ck_tile::host_tensor_descriptor (
52+ M, K, stride_A, ck_tile::bool_constant<is_a_row_major>{}));
53+ ck_tile::HostTensor<BDataType> b_k_n (ck_tile::host_tensor_descriptor (
54+ K, N, stride_B, ck_tile::bool_constant<is_b_row_major>{}));
55+ ck_tile::HostTensor<CDataType> c_m_n_dev_result (ck_tile::host_tensor_descriptor (
56+ M, N, stride_C, ck_tile::bool_constant<is_c_row_major>{}));
57+
58+ ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5 , 5 }(a_m_k);
59+ ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5 , 5 }(b_k_n);
60+
61+ ck_tile::DeviceMem a_m_k_dev_buf (a_m_k.get_element_space_size_in_bytes ());
62+ ck_tile::DeviceMem b_k_n_dev_buf (b_k_n.get_element_space_size_in_bytes ());
63+ ck_tile::DeviceMem c_m_n_dev_buf (c_m_n_dev_result.get_element_space_size_in_bytes ());
64+
65+ a_m_k_dev_buf.ToDevice (a_m_k.data ());
66+ b_k_n_dev_buf.ToDevice (b_k_n.data ());
67+ c_m_n_dev_buf.SetZero ();
68+ c_m_n_dev_result.SetZero ();
69+
70+ ck_tile::GemmHostArgs args = {a_m_k_dev_buf.GetDeviceBuffer (),
71+ b_k_n_dev_buf.GetDeviceBuffer (),
72+ c_m_n_dev_buf.GetDeviceBuffer (),
73+ kbatch,
74+ M,
75+ N,
76+ K,
77+ stride_A,
78+ stride_B,
79+ stride_C};
80+
81+ Invoker::template test_async_input_scheduler<GemmConfig,
82+ ADataType,
83+ BDataType,
84+ ck_tile::tuple<>,
85+ AccDataType,
86+ CDataType,
87+ ALayout,
88+ BLayout,
89+ ck_tile::tuple<>,
90+ CLayout,
91+ ck_tile::element_wise::PassThrough>(
92+ args, ck_tile::stream_config{nullptr , false , 1 });
93+
94+ // Copy result from device for verification
95+ c_m_n_dev_buf.FromDevice (c_m_n_dev_result.data ());
96+
97+ // Compute CPU reference
98+ ck_tile::HostTensor<CDataType> c_m_n_ref (ck_tile::host_tensor_descriptor (
99+ M, N, stride_C, ck_tile::bool_constant<is_c_row_major>{}));
100+ c_m_n_ref.SetZero ();
101+ ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
102+ a_m_k, b_k_n, c_m_n_ref);
103+
104+ // Verify results
105+ const float max_accumulated_value =
106+ *std::max_element (c_m_n_ref.mData .begin (), c_m_n_ref.mData .end ());
107+ const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
108+ K, kbatch, max_accumulated_value);
109+ bool pass = do_verify (c_m_n_dev_result, c_m_n_ref, rtol_atol, " CPU" );
110+
111+ std::cout << " Async input scheduler test: " << (pass ? " PASS" : " FAIL" ) << std::endl;
112+ return pass;
113+ }
114+
115+ // Normal path - delegate to shared implementation
116+ return run_gemm_example_with_layouts<GemmConfig, Invoker, ADataType, BDataType, CDataType>(
117+ arg_parser, a_layout, b_layout, c_layout);
118+ }
119+
120+ // Universal GEMM-specific prec_type dispatcher that uses the wrapper
121+ template <typename GemmConfig,
122+ typename APrecType,
123+ typename BPrecType = APrecType,
124+ typename CPrecType = APrecType>
125+ int run_gemm_example_prec_type_universal (std::string a_layout,
126+ std::string b_layout,
127+ ck_tile::ArgParser& arg_parser)
128+ {
129+ using Row = ck_tile::tensor_layout::gemm::RowMajor;
130+ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
131+ bool preshuffle = GemmConfig::Preshuffle;
132+
133+ if (preshuffle && std::is_same_v<BPrecType, ck_tile::pk_int4_t >)
134+ {
135+ throw std::runtime_error (" Preshuffle is not supported for this int4 datatype!" );
136+ }
137+
138+ if (preshuffle && a_layout != " R" && b_layout != " C" )
139+ {
140+ throw std::runtime_error (
141+ " Preshuffle is supported only for A(Row major), B(column major) input matrices!" );
142+ }
143+
144+ using LayoutVariant = std::variant<Row, Col>;
145+
146+ auto string_to_layout = [](const std::string& layout) -> LayoutVariant {
147+ if (layout == " R" )
148+ return Row{};
149+ if (layout == " C" )
150+ return Col{};
151+ throw std::runtime_error (" Unsupported layout: " + layout);
152+ };
153+
154+ auto a_layout_variant = string_to_layout (a_layout);
155+ auto b_layout_variant = string_to_layout (b_layout);
156+
157+ return std::visit (
158+ [&](auto a_layout_type, auto b_layout_type) -> int {
159+ if constexpr (std::is_same_v<BPrecType, ck_tile::pk_int4_t > &&
160+ std::is_same_v<decltype (b_layout_type), Row>)
161+ {
162+ throw std::runtime_error (" Unsupported memory layout for the input matrices when "
163+ " BPrecType is ck_tile::pk_int4_t!" );
164+ }
165+ else
166+ {
167+ return run_gemm_example_with_layouts_universal<GemmConfig,
168+ APrecType,
169+ BPrecType,
170+ CPrecType>(
171+ arg_parser, a_layout_type, b_layout_type, Row{});
172+ }
173+ },
174+ a_layout_variant,
175+ b_layout_variant);
176+ }
177+
15178template <template <typename PrecType> typename GemmConfig>
16179int run_gemm_example (ck_tile::ArgParser& arg_parser)
17180{
18181 std::string data_type = arg_parser.get_str (" prec" );
19182 std::string a_layout = arg_parser.get_str (" a_layout" );
20183 std::string b_layout = arg_parser.get_str (" b_layout" );
21184
22- using Invoker = UniversalInvoker;
23-
24185 if (data_type == " fp16" )
25186 {
26- return run_gemm_example_prec_type <GemmConfig<ck_tile::half_t >, Invoker , ck_tile::half_t >(
187+ return run_gemm_example_prec_type_universal <GemmConfig<ck_tile::half_t >, ck_tile::half_t >(
27188 a_layout, b_layout, arg_parser);
28189 }
29190 else if (data_type == " bf16" )
30191 {
31- return run_gemm_example_prec_type <GemmConfig<ck_tile::bf16_t >, Invoker , ck_tile::bf16_t >(
192+ return run_gemm_example_prec_type_universal <GemmConfig<ck_tile::bf16_t >, ck_tile::bf16_t >(
32193 a_layout, b_layout, arg_parser);
33194 }
34195 else if (data_type == " fp8" )
35196 {
36- return run_gemm_example_prec_type <GemmConfig<ck_tile::fp8_t >,
37- Invoker ,
38- ck_tile::fp8_t ,
39- ck_tile::fp8_t ,
40- ck_tile:: half_t >( a_layout, b_layout, arg_parser);
197+ return run_gemm_example_prec_type_universal <GemmConfig<ck_tile::fp8_t >,
198+ ck_tile:: fp8_t ,
199+ ck_tile::fp8_t ,
200+ ck_tile::half_t >(
201+ a_layout, b_layout, arg_parser);
41202 }
42203 else if (data_type == " bf8" )
43204 {
44- return run_gemm_example_prec_type <GemmConfig<ck_tile::bf8_t >,
45- Invoker ,
46- ck_tile::bf8_t ,
47- ck_tile::bf8_t ,
48- ck_tile:: half_t >( a_layout, b_layout, arg_parser);
205+ return run_gemm_example_prec_type_universal <GemmConfig<ck_tile::bf8_t >,
206+ ck_tile:: bf8_t ,
207+ ck_tile::bf8_t ,
208+ ck_tile::half_t >(
209+ a_layout, b_layout, arg_parser);
49210 }
50211 else if (data_type == " int8" )
51212 {
52- return run_gemm_example_prec_type <GemmConfig<ck_tile::int8_t >,
53- Invoker ,
54- ck_tile::int8_t ,
55- ck_tile::int8_t ,
56- ck_tile:: int32_t >( a_layout, b_layout, arg_parser);
213+ return run_gemm_example_prec_type_universal <GemmConfig<ck_tile::int8_t >,
214+ ck_tile:: int8_t ,
215+ ck_tile::int8_t ,
216+ ck_tile::int32_t >(
217+ a_layout, b_layout, arg_parser);
57218 }
58219 else if (data_type == " fp16i4" )
59220 {
60221 // TODO: Add support for bhalf_t ADataType
61222 if constexpr (GemmConfig<ck_tile::half_t >::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
62223 {
63- return run_gemm_example_prec_type <GemmConfig<ck_tile::half_t >,
64- Invoker ,
65- ck_tile::half_t ,
66- ck_tile::pk_int4_t ,
67- ck_tile:: half_t >( a_layout, b_layout, arg_parser);
224+ return run_gemm_example_prec_type_universal <GemmConfig<ck_tile::half_t >,
225+ ck_tile:: half_t ,
226+ ck_tile::pk_int4_t ,
227+ ck_tile::half_t >(
228+ a_layout, b_layout, arg_parser);
68229 }
69230 else
70231 {
@@ -75,11 +236,11 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
75236 {
76237 if constexpr (GemmConfig<ck_tile::fp8_t >::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
77238 {
78- return run_gemm_example_prec_type <GemmConfig<ck_tile::fp8_t >,
79- Invoker ,
80- ck_tile::fp8_t ,
81- ck_tile::pk_int4_t ,
82- ck_tile:: half_t >( a_layout, b_layout, arg_parser);
239+ return run_gemm_example_prec_type_universal <GemmConfig<ck_tile::fp8_t >,
240+ ck_tile:: fp8_t ,
241+ ck_tile::pk_int4_t ,
242+ ck_tile::half_t >(
243+ a_layout, b_layout, arg_parser);
83244 }
84245 else
85246 {
@@ -90,11 +251,11 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
90251 {
91252 if constexpr (GemmConfig<ck_tile::bf8_t >::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
92253 {
93- return run_gemm_example_prec_type <GemmConfig<ck_tile::bf8_t >,
94- Invoker ,
95- ck_tile::bf8_t ,
96- ck_tile::pk_int4_t ,
97- ck_tile:: half_t >( a_layout, b_layout, arg_parser);
254+ return run_gemm_example_prec_type_universal <GemmConfig<ck_tile::bf8_t >,
255+ ck_tile:: bf8_t ,
256+ ck_tile::pk_int4_t ,
257+ ck_tile::half_t >(
258+ a_layout, b_layout, arg_parser);
98259 }
99260 else
100261 {
0 commit comments