Skip to content

Commit 993d3e2

Browse files
authored
[FMHA] Enable page size 16 for batch prefill kernel (#3568)
* [FMHA] Enable page size 16 for batch prefill kernel * Refactor batch prefill KV offset logic to simplify template arguments - Remove redundant `kLog2PageSize` and `kIsVTileFitsInPage` from template args. - Add static assert to forbid `page_size=1` with vectorized layout.
1 parent 5122637 commit 993d3e2

File tree

3 files changed

+62
-28
lines changed

3 files changed

+62
-28
lines changed

example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636

3737
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256}
3838

39-
SUPPORTED_PAGE_SIZE = [1, 128, 256, 1024]
39+
SUPPORTED_PAGE_SIZE = [1, 16, 1024]
4040
SUPPORTED_KV_MEMORY_LAYOUT = ["vectorized", "linear"]
4141
SUPPORTED_KV_LOOKUP_TABLE = ["vllm", "sglang"]
4242
KV_MEMORY_LAYOUT_ENUM_MAP = {

include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp

Lines changed: 59 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@ template <typename OffsetVecType,
1717
typename CoordVecType,
1818
index_t kCoordAxis,
1919
index_t kPageBlockSize,
20-
index_t kLog2PageSize,
2120
index_t kLoopStart,
2221
index_t kLoopCount,
2322
index_t kLoopStride,
2423
BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout,
2524
bool kIsKcache,
25+
index_t kN0,
2626
index_t kVectorSize>
2727
CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx,
2828
const index_t& stride_token,
@@ -31,6 +31,17 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx,
3131
OffsetVecType& kv_offset_vec,
3232
index_t global_seq_offset = 0)
3333
{
34+
static constexpr index_t kLog2PageSize = [] {
35+
index_t shift = 0;
36+
index_t val = kPageBlockSize;
37+
while(val > 1)
38+
{
39+
val >>= 1;
40+
shift++;
41+
}
42+
return shift;
43+
}();
44+
3445
const index_t& thread_coord_start = coord_vec[kCoordAxis];
3546
constexpr index_t kInPageOffsetMask = (1 << kLog2PageSize) - 1;
3647
if constexpr(kIsKcache)
@@ -48,7 +59,10 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx,
4859
else
4960
{
5061
// for v offsets
51-
if constexpr(kLog2PageSize == 0 &&
62+
// for page_size > 1, the V tile crosses pages when page_size is not a multiple of kN0.
63+
static constexpr bool kVTileCrossesPages =
64+
(kPageBlockSize > 1) && (kPageBlockSize % kN0 != 0);
65+
if constexpr(kPageBlockSize == 1 &&
5266
kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT)
5367
{
5468
// page size = 1, per-token page lookup.
@@ -64,11 +78,42 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx,
6478
kv_offset_vec[k0] = page_base_offset;
6579
});
6680
}
67-
else
81+
else if constexpr(kVTileCrossesPages)
82+
{
83+
// V tile crosses multiple pages (e.g., page_size < kN0), so page_id must be computed
84+
// per token.
85+
static_for<0, kLoopCount, 1>{}([&](auto k0) {
86+
const index_t global_token_idx =
87+
global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value;
88+
const index_t page_id = global_token_idx >> kLog2PageSize;
89+
const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask;
90+
91+
const long_index_t page_base_offset =
92+
static_cast<long_index_t>(page_idx[page_id]) * stride_page_block;
93+
94+
if constexpr(kKVMemoryLayout ==
95+
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
96+
{
97+
// Vectorized layout uses a packed [token/kVectorSize, head_dim, kVectorSize]
98+
// address pattern.
99+
const long_index_t token_offset =
100+
static_cast<long_index_t>((token_idx_in_page / kVectorSize) *
101+
(stride_token * kVectorSize)) +
102+
(token_idx_in_page % kVectorSize);
103+
104+
kv_offset_vec[k0] = page_base_offset + token_offset;
105+
}
106+
else // BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT
107+
{
108+
kv_offset_vec[k0] = page_base_offset +
109+
static_cast<long_index_t>(token_idx_in_page) * stride_token;
110+
}
111+
});
112+
}
113+
else // !kVTileCrossesPages
68114
{
69-
// This path handles page_size > 1 and/or non-linear KV layout, where page_idx is
70-
// indexed by page_id (token_idx >> log2_page_size) with an in-page offset.
71-
// Assumes the V tile stays within a single page so lane0 can broadcast the page id.
115+
// V tile is fully contained in one page, so page_id is shared.
116+
// Use lane0 to compute page_id once and broadcast page_base_offset.
72117
const index_t lane0_start = __builtin_amdgcn_readfirstlane(thread_coord_start);
73118
const index_t lane0_page_id =
74119
(global_seq_offset + lane0_start + kLoopStart) >> kLog2PageSize;
@@ -77,8 +122,9 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx,
77122
static_cast<long_index_t>(page_idx[lane0_page_id]) * stride_page_block;
78123

79124
static_for<0, kLoopCount, 1>{}([&](auto k0) {
125+
// kLoopStride allows non-unit token spacing in the tile distribution.
80126
const index_t token_idx_in_page =
81-
(global_seq_offset + thread_coord_start + kLoopStart + k0.value) &
127+
(global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value) &
82128
kInPageOffsetMask;
83129

84130
if constexpr(kKVMemoryLayout ==
@@ -142,17 +188,13 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
142188
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
143189
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
144190
static constexpr index_t kPageBlockSize = Problem::kPageBlockSize;
145-
static constexpr index_t kLog2PageSize = Problem::kLog2PageSize;
146191
static constexpr index_t kVectorSize = Problem::kVectorSize;
147192
static constexpr auto I0 = number<0>{};
148193
static constexpr auto I1 = number<1>{};
149194
static constexpr auto I2 = number<2>{};
150195
static constexpr auto I3 = number<3>{};
151196

152197
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
153-
static_assert(kPageBlockSize % kN0 == 0 || kLog2PageSize == 0,
154-
"Page size must be 1, or a multiple of the tile size (kN0).");
155-
156198
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
157199
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
158200
// only need special care about seq_k padding (oob need set -INF of p instead of zero)
@@ -456,12 +498,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
456498
decltype(k_coord),
457499
0,
458500
kPageBlockSize,
459-
kLog2PageSize,
460501
0,
461502
NRepeat,
462503
kN0 / NRepeat,
463504
kKVMemoryLayout,
464505
true,
506+
kN0,
465507
kVectorSize>(
466508
page_idx, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);
467509

@@ -501,12 +543,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
501543
decltype(v_coord),
502544
VPageIndexDim,
503545
kPageBlockSize,
504-
kLog2PageSize,
505546
0,
506547
V_KRepeat,
507548
1,
508549
kKVMemoryLayout,
509550
false,
551+
kN0,
510552
kVectorSize>(
511553
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
512554

@@ -587,12 +629,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
587629
decltype(v_coord),
588630
VPageIndexDim,
589631
kPageBlockSize,
590-
kLog2PageSize,
591632
kK1,
592633
V_KRepeat,
593634
1,
594635
kKVMemoryLayout,
595636
false,
637+
kN0,
596638
kVectorSize>(
597639
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
598640
v_dram_window.update_page_idx(v_offsets);
@@ -761,12 +803,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
761803
decltype(v_coord),
762804
VPageIndexDim,
763805
kPageBlockSize,
764-
kLog2PageSize,
765806
2 * kK1,
766807
V_KRepeat,
767808
1,
768809
kKVMemoryLayout,
769810
false,
811+
kN0,
770812
kVectorSize>(
771813
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
772814
v_dram_window.update_page_idx(v_offsets);
@@ -900,12 +942,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
900942
decltype(v_coord),
901943
VPageIndexDim,
902944
kPageBlockSize,
903-
kLog2PageSize,
904945
(2 + i_k1.value) * kK1,
905946
V_KRepeat,
906947
1,
907948
kKVMemoryLayout,
908949
false,
950+
kN0,
909951
kVectorSize>(
910952
page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
911953
v_dram_window.update_page_idx(v_offsets);
@@ -957,12 +999,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
957999
decltype(k_coord),
9581000
0,
9591001
kPageBlockSize,
960-
kLog2PageSize,
9611002
0,
9621003
NRepeat,
9631004
kN0 / NRepeat,
9641005
kKVMemoryLayout,
9651006
true,
1007+
kN0,
9661008
kVectorSize>(
9671009
page_idx, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);
9681010
k_dram_window.update_page_idx(k_offsets);

include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -107,16 +107,6 @@ struct BlockFmhaBatchPrefillPipelineProblem
107107
static_assert(kPageBlockSize > 0, "kPageBlockSize must be positive");
108108
static_assert((kPageBlockSize & (kPageBlockSize - 1)) == 0,
109109
"kPageBlockSize must be power of two");
110-
static constexpr index_t kLog2PageSize = []() constexpr {
111-
index_t shift = 0;
112-
index_t val = kPageBlockSize_;
113-
while(val > 1)
114-
{
115-
val >>= 1;
116-
shift++;
117-
}
118-
return shift;
119-
}();
120110

121111
static constexpr index_t kVectorSize = 16 / sizeof(KDataType_); // Dwordx4
122112
static constexpr auto kKVMemoryLayout = Traits_::kKVMemoryLayout;
@@ -126,6 +116,8 @@ struct BlockFmhaBatchPrefillPipelineProblem
126116

127117
static_assert(BlockFmhaShape_::kQKHeaddim % kVectorSize == 0,
128118
"kQKHeaddim must be divisible by kVectorSize");
119+
static_assert(!(kPageBlockSize == 1 && kIsVectorizedLayout),
120+
"page_size=1 only supports linear KV cache layout");
129121
static_assert(!kIsVectorizedLayout || kPageBlockSize % kVectorSize == 0,
130122
"kPageBlockSize must be divisible by kVectorSize for vectorized layout");
131123
static_assert(kIsGroupMode_, "Batch prefill requires group mode");

0 commit comments

Comments
 (0)