Skip to content

Commit 54075df

Browse files
committed
Add persistent async input scheduler for GEMM kernels
Add signal-based synchronization for persistent GEMM kernels where input data becomes available incrementally. Uses modulo wraparound (like PyTorch's AsyncMM) for chunk index calculation: chunk_idx = ((tile_idx + tile_idx_pivot) / tiles_per_chunk) % num_chunks Key components: - PersistentAsyncInputScheduler struct with tiles_per_chunk_m, chunk_signals, tile_idx_pivot_m, and num_chunks fields - wait_eq_wave method using __builtin_amdgcn_s_sleep for power efficiency - IsSupportedArgument validation for scheduler parameters - Example demonstrating async input scheduling with simulated producer - GTest unit tests covering all layout combinations
1 parent 086a1f8 commit 54075df

File tree

11 files changed

+844
-61
lines changed

11 files changed

+844
-61
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
1515
* Added support for gfx1153 target.
1616
* Added FMHA batch prefill kernel support for several KV cache layouts, flexible page sizes, and different lookup table configurations.
1717
* Added gpt-oss sink support for FMHA FWD, include qr_ks_vs, qr_async, qr_async_trload and splitkv pipelines.
18+
* Added persistent async input scheduler for CK Tile universal GEMM kernels to support asynchronous input streaming.
1819

1920
### Changed
2021

example/ck_tile/03_gemm/gemm_utils.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,8 @@ inline auto create_args()
456456
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
457457
.insert("jsonfile", "gemm.json", "json file name to dump results")
458458
.insert("flush_cache", "true", "flush cache before running the kernel, defaults to true")
459-
.insert("rotating_count", "1000", "rotating count, defaults to 1000");
459+
.insert("rotating_count", "1000", "rotating count, defaults to 1000")
460+
.insert("test_async", "0", "0: normal gemm, 1: test async input scheduler");
460461
return arg_parser;
461462
}
462463

example/ck_tile/03_gemm/universal_gemm.cpp

Lines changed: 195 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -12,59 +12,220 @@
1212
#include "run_gemm_example_common.hpp"
1313
#include "universal_gemm_invoker.hpp"
1414

15+
// Universal GEMM-specific wrapper that handles test_async flag
16+
template <typename GemmConfig,
17+
typename ADataType,
18+
typename BDataType = ADataType,
19+
typename CDataType = ADataType,
20+
typename ALayout,
21+
typename BLayout,
22+
typename CLayout>
23+
int run_gemm_example_with_layouts_universal(ck_tile::ArgParser& arg_parser,
24+
const ALayout a_layout = ALayout{},
25+
const BLayout b_layout = BLayout{},
26+
const CLayout c_layout = CLayout{})
27+
{
28+
using Invoker = UniversalInvoker;
29+
using AccDataType = typename GemmTypeConfig<ADataType, BDataType, CDataType>::AccDataType;
30+
31+
// Check for async input scheduler test mode
32+
bool test_async = arg_parser.get_int("test_async");
33+
if(test_async)
34+
{
35+
// Extract parameters for async test (same as shared implementation)
36+
const ck_tile::index_t M = arg_parser.get_int("m");
37+
const ck_tile::index_t N = arg_parser.get_int("n");
38+
const ck_tile::index_t K = arg_parser.get_int("k");
39+
const ck_tile::index_t kbatch = arg_parser.get_int("split_k");
40+
41+
using Row = ck_tile::tensor_layout::gemm::RowMajor;
42+
constexpr bool is_a_row_major = std::is_same_v<ALayout, Row>;
43+
constexpr bool is_b_row_major = std::is_same_v<BLayout, Row>;
44+
constexpr bool is_c_row_major = std::is_same_v<CLayout, Row>;
45+
46+
const ck_tile::index_t stride_A = is_a_row_major ? K : M;
47+
const ck_tile::index_t stride_B = is_b_row_major ? N : K;
48+
const ck_tile::index_t stride_C = is_c_row_major ? N : M;
49+
50+
// Allocate and initialize tensors
51+
ck_tile::HostTensor<ADataType> a_m_k(ck_tile::host_tensor_descriptor(
52+
M, K, stride_A, ck_tile::bool_constant<is_a_row_major>{}));
53+
ck_tile::HostTensor<BDataType> b_k_n(ck_tile::host_tensor_descriptor(
54+
K, N, stride_B, ck_tile::bool_constant<is_b_row_major>{}));
55+
ck_tile::HostTensor<CDataType> c_m_n_dev_result(ck_tile::host_tensor_descriptor(
56+
M, N, stride_C, ck_tile::bool_constant<is_c_row_major>{}));
57+
58+
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5, 5}(a_m_k);
59+
ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5, 5}(b_k_n);
60+
61+
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
62+
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
63+
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
64+
65+
a_m_k_dev_buf.ToDevice(a_m_k.data());
66+
b_k_n_dev_buf.ToDevice(b_k_n.data());
67+
c_m_n_dev_buf.SetZero();
68+
c_m_n_dev_result.SetZero();
69+
70+
ck_tile::GemmHostArgs args = {a_m_k_dev_buf.GetDeviceBuffer(),
71+
b_k_n_dev_buf.GetDeviceBuffer(),
72+
c_m_n_dev_buf.GetDeviceBuffer(),
73+
kbatch,
74+
M,
75+
N,
76+
K,
77+
stride_A,
78+
stride_B,
79+
stride_C};
80+
81+
Invoker::template test_async_input_scheduler<GemmConfig,
82+
ADataType,
83+
BDataType,
84+
ck_tile::tuple<>,
85+
AccDataType,
86+
CDataType,
87+
ALayout,
88+
BLayout,
89+
ck_tile::tuple<>,
90+
CLayout,
91+
ck_tile::element_wise::PassThrough>(
92+
args, ck_tile::stream_config{nullptr, false, 1});
93+
94+
// Copy result from device for verification
95+
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
96+
97+
// Compute CPU reference
98+
ck_tile::HostTensor<CDataType> c_m_n_ref(ck_tile::host_tensor_descriptor(
99+
M, N, stride_C, ck_tile::bool_constant<is_c_row_major>{}));
100+
c_m_n_ref.SetZero();
101+
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
102+
a_m_k, b_k_n, c_m_n_ref);
103+
104+
// Verify results
105+
const float max_accumulated_value =
106+
*std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end());
107+
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
108+
K, kbatch, max_accumulated_value);
109+
bool pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "CPU");
110+
111+
std::cout << "Async input scheduler test: " << (pass ? "PASS" : "FAIL") << std::endl;
112+
return pass;
113+
}
114+
115+
// Normal path - delegate to shared implementation
116+
return run_gemm_example_with_layouts<GemmConfig, Invoker, ADataType, BDataType, CDataType>(
117+
arg_parser, a_layout, b_layout, c_layout);
118+
}
119+
120+
// Universal GEMM-specific prec_type dispatcher that uses the wrapper
121+
template <typename GemmConfig,
122+
typename APrecType,
123+
typename BPrecType = APrecType,
124+
typename CPrecType = APrecType>
125+
int run_gemm_example_prec_type_universal(std::string a_layout,
126+
std::string b_layout,
127+
ck_tile::ArgParser& arg_parser)
128+
{
129+
using Row = ck_tile::tensor_layout::gemm::RowMajor;
130+
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
131+
bool preshuffle = GemmConfig::Preshuffle;
132+
133+
if(preshuffle && std::is_same_v<BPrecType, ck_tile::pk_int4_t>)
134+
{
135+
throw std::runtime_error("Preshuffle is not supported for this int4 datatype!");
136+
}
137+
138+
if(preshuffle && a_layout != "R" && b_layout != "C")
139+
{
140+
throw std::runtime_error(
141+
"Preshuffle is supported only for A(Row major), B(column major) input matrices!");
142+
}
143+
144+
using LayoutVariant = std::variant<Row, Col>;
145+
146+
auto string_to_layout = [](const std::string& layout) -> LayoutVariant {
147+
if(layout == "R")
148+
return Row{};
149+
if(layout == "C")
150+
return Col{};
151+
throw std::runtime_error("Unsupported layout: " + layout);
152+
};
153+
154+
auto a_layout_variant = string_to_layout(a_layout);
155+
auto b_layout_variant = string_to_layout(b_layout);
156+
157+
return std::visit(
158+
[&](auto a_layout_type, auto b_layout_type) -> int {
159+
if constexpr(std::is_same_v<BPrecType, ck_tile::pk_int4_t> &&
160+
std::is_same_v<decltype(b_layout_type), Row>)
161+
{
162+
throw std::runtime_error("Unsupported memory layout for the input matrices when "
163+
"BPrecType is ck_tile::pk_int4_t!");
164+
}
165+
else
166+
{
167+
return run_gemm_example_with_layouts_universal<GemmConfig,
168+
APrecType,
169+
BPrecType,
170+
CPrecType>(
171+
arg_parser, a_layout_type, b_layout_type, Row{});
172+
}
173+
},
174+
a_layout_variant,
175+
b_layout_variant);
176+
}
177+
15178
template <template <typename PrecType> typename GemmConfig>
16179
int run_gemm_example(ck_tile::ArgParser& arg_parser)
17180
{
18181
std::string data_type = arg_parser.get_str("prec");
19182
std::string a_layout = arg_parser.get_str("a_layout");
20183
std::string b_layout = arg_parser.get_str("b_layout");
21184

22-
using Invoker = UniversalInvoker;
23-
24185
if(data_type == "fp16")
25186
{
26-
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, Invoker, ck_tile::half_t>(
187+
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
27188
a_layout, b_layout, arg_parser);
28189
}
29190
else if(data_type == "bf16")
30191
{
31-
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t>, Invoker, ck_tile::bf16_t>(
192+
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::bf16_t>, ck_tile::bf16_t>(
32193
a_layout, b_layout, arg_parser);
33194
}
34195
else if(data_type == "fp8")
35196
{
36-
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
37-
Invoker,
38-
ck_tile::fp8_t,
39-
ck_tile::fp8_t,
40-
ck_tile::half_t>(a_layout, b_layout, arg_parser);
197+
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::fp8_t>,
198+
ck_tile::fp8_t,
199+
ck_tile::fp8_t,
200+
ck_tile::half_t>(
201+
a_layout, b_layout, arg_parser);
41202
}
42203
else if(data_type == "bf8")
43204
{
44-
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
45-
Invoker,
46-
ck_tile::bf8_t,
47-
ck_tile::bf8_t,
48-
ck_tile::half_t>(a_layout, b_layout, arg_parser);
205+
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::bf8_t>,
206+
ck_tile::bf8_t,
207+
ck_tile::bf8_t,
208+
ck_tile::half_t>(
209+
a_layout, b_layout, arg_parser);
49210
}
50211
else if(data_type == "int8")
51212
{
52-
return run_gemm_example_prec_type<GemmConfig<ck_tile::int8_t>,
53-
Invoker,
54-
ck_tile::int8_t,
55-
ck_tile::int8_t,
56-
ck_tile::int32_t>(a_layout, b_layout, arg_parser);
213+
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::int8_t>,
214+
ck_tile::int8_t,
215+
ck_tile::int8_t,
216+
ck_tile::int32_t>(
217+
a_layout, b_layout, arg_parser);
57218
}
58219
else if(data_type == "fp16i4")
59220
{
60221
// TODO: Add support for bhalf_t ADataType
61222
if constexpr(GemmConfig<ck_tile::half_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
62223
{
63-
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>,
64-
Invoker,
65-
ck_tile::half_t,
66-
ck_tile::pk_int4_t,
67-
ck_tile::half_t>(a_layout, b_layout, arg_parser);
224+
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::half_t>,
225+
ck_tile::half_t,
226+
ck_tile::pk_int4_t,
227+
ck_tile::half_t>(
228+
a_layout, b_layout, arg_parser);
68229
}
69230
else
70231
{
@@ -75,11 +236,11 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
75236
{
76237
if constexpr(GemmConfig<ck_tile::fp8_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
77238
{
78-
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
79-
Invoker,
80-
ck_tile::fp8_t,
81-
ck_tile::pk_int4_t,
82-
ck_tile::half_t>(a_layout, b_layout, arg_parser);
239+
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::fp8_t>,
240+
ck_tile::fp8_t,
241+
ck_tile::pk_int4_t,
242+
ck_tile::half_t>(
243+
a_layout, b_layout, arg_parser);
83244
}
84245
else
85246
{
@@ -90,11 +251,11 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
90251
{
91252
if constexpr(GemmConfig<ck_tile::bf8_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
92253
{
93-
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
94-
Invoker,
95-
ck_tile::bf8_t,
96-
ck_tile::pk_int4_t,
97-
ck_tile::half_t>(a_layout, b_layout, arg_parser);
254+
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::bf8_t>,
255+
ck_tile::bf8_t,
256+
ck_tile::pk_int4_t,
257+
ck_tile::half_t>(
258+
a_layout, b_layout, arg_parser);
98259
}
99260
else
100261
{

0 commit comments

Comments
 (0)