@@ -17,12 +17,12 @@ template <typename OffsetVecType,
1717 typename CoordVecType,
1818 index_t kCoordAxis ,
1919 index_t kPageBlockSize ,
20- index_t kLog2PageSize ,
2120 index_t kLoopStart ,
2221 index_t kLoopCount ,
2322 index_t kLoopStride ,
2423 BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout ,
2524 bool kIsKcache ,
25+ index_t kN0 ,
2626 index_t kVectorSize >
2727CK_TILE_HOST_DEVICE void kv_offset_array_transform (const index_t * page_idx,
2828 const index_t & stride_token,
@@ -31,6 +31,17 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx,
3131 OffsetVecType& kv_offset_vec,
3232 index_t global_seq_offset = 0 )
3333{
34+ static constexpr index_t kLog2PageSize = [] {
35+ index_t shift = 0 ;
36+ index_t val = kPageBlockSize ;
37+ while (val > 1 )
38+ {
39+ val >>= 1 ;
40+ shift++;
41+ }
42+ return shift;
43+ }();
44+
3445 const index_t & thread_coord_start = coord_vec[kCoordAxis ];
3546 constexpr index_t kInPageOffsetMask = (1 << kLog2PageSize ) - 1 ;
3647 if constexpr (kIsKcache )
@@ -48,7 +59,10 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx,
4859 else
4960 {
5061 // for v offsets
51- if constexpr (kLog2PageSize == 0 &&
62+ // for page_size > 1, the V tile crosses pages when page_size is not a multiple of kN0.
63+ static constexpr bool kVTileCrossesPages =
64+ (kPageBlockSize > 1 ) && (kPageBlockSize % kN0 != 0 );
65+ if constexpr (kPageBlockSize == 1 &&
5266 kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT)
5367 {
5468 // page size = 1, per-token page lookup.
@@ -64,11 +78,42 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx,
6478 kv_offset_vec[k0] = page_base_offset;
6579 });
6680 }
67- else
81+ else if constexpr (kVTileCrossesPages )
82+ {
83+ // V tile crosses multiple pages (e.g., page_size < kN0), so page_id must be computed
84+ // per token.
85+ static_for<0 , kLoopCount , 1 >{}([&](auto k0) {
86+ const index_t global_token_idx =
87+ global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value ;
88+ const index_t page_id = global_token_idx >> kLog2PageSize ;
89+ const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask ;
90+
91+ const long_index_t page_base_offset =
92+ static_cast <long_index_t >(page_idx[page_id]) * stride_page_block;
93+
94+ if constexpr (kKVMemoryLayout ==
95+ BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT)
96+ {
97+ // Vectorized layout uses a packed [token/kVectorSize, head_dim, kVectorSize]
98+ // address pattern.
99+ const long_index_t token_offset =
100+ static_cast <long_index_t >((token_idx_in_page / kVectorSize ) *
101+ (stride_token * kVectorSize )) +
102+ (token_idx_in_page % kVectorSize );
103+
104+ kv_offset_vec[k0] = page_base_offset + token_offset;
105+ }
106+ else // BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT
107+ {
108+ kv_offset_vec[k0] = page_base_offset +
109+ static_cast <long_index_t >(token_idx_in_page) * stride_token;
110+ }
111+ });
112+ }
113+ else // !kVTileCrossesPages
68114 {
69- // This path handles page_size > 1 and/or non-linear KV layout, where page_idx is
70- // indexed by page_id (token_idx >> log2_page_size) with an in-page offset.
71- // Assumes the V tile stays within a single page so lane0 can broadcast the page id.
115+ // V tile is fully contained in one page, so page_id is shared.
116+ // Use lane0 to compute page_id once and broadcast page_base_offset.
72117 const index_t lane0_start = __builtin_amdgcn_readfirstlane (thread_coord_start);
73118 const index_t lane0_page_id =
74119 (global_seq_offset + lane0_start + kLoopStart ) >> kLog2PageSize ;
@@ -77,8 +122,9 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_idx,
77122 static_cast <long_index_t >(page_idx[lane0_page_id]) * stride_page_block;
78123
79124 static_for<0 , kLoopCount , 1 >{}([&](auto k0) {
125+ // kLoopStride allows non-unit token spacing in the tile distribution.
80126 const index_t token_idx_in_page =
81- (global_seq_offset + thread_coord_start + kLoopStart + k0.value ) &
127+ (global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value ) &
82128 kInPageOffsetMask ;
83129
84130 if constexpr (kKVMemoryLayout ==
@@ -142,17 +188,13 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
142188 static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim ;
143189 static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim ;
144190 static constexpr index_t kPageBlockSize = Problem::kPageBlockSize ;
145- static constexpr index_t kLog2PageSize = Problem::kLog2PageSize ;
146191 static constexpr index_t kVectorSize = Problem::kVectorSize ;
147192 static constexpr auto I0 = number<0 >{};
148193 static constexpr auto I1 = number<1 >{};
149194 static constexpr auto I2 = number<2 >{};
150195 static constexpr auto I3 = number<3 >{};
151196
152197 static_assert (kSubQKHeaddim <= 256 , " hdim bigger than 256 is not suitable for this pipeline!" );
153- static_assert (kPageBlockSize % kN0 == 0 || kLog2PageSize == 0 ,
154- " Page size must be 1, or a multiple of the tile size (kN0)." );
155-
156198 static constexpr bool kIsGroupMode = Problem::kIsGroupMode ;
157199 // TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
158200 // only need special care about seq_k padding (oob need set -INF of p instead of zero)
@@ -456,12 +498,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
456498 decltype (k_coord),
457499 0 ,
458500 kPageBlockSize ,
459- kLog2PageSize ,
460501 0 ,
461502 NRepeat,
462503 kN0 / NRepeat,
463504 kKVMemoryLayout ,
464505 true ,
506+ kN0 ,
465507 kVectorSize >(
466508 page_idx, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);
467509
@@ -501,12 +543,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
501543 decltype (v_coord),
502544 VPageIndexDim,
503545 kPageBlockSize ,
504- kLog2PageSize ,
505546 0 ,
506547 V_KRepeat,
507548 1 ,
508549 kKVMemoryLayout ,
509550 false ,
551+ kN0 ,
510552 kVectorSize >(
511553 page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
512554
@@ -587,12 +629,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
587629 decltype (v_coord),
588630 VPageIndexDim,
589631 kPageBlockSize ,
590- kLog2PageSize ,
591632 kK1 ,
592633 V_KRepeat,
593634 1 ,
594635 kKVMemoryLayout ,
595636 false ,
637+ kN0 ,
596638 kVectorSize >(
597639 page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
598640 v_dram_window.update_page_idx (v_offsets);
@@ -761,12 +803,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
761803 decltype (v_coord),
762804 VPageIndexDim,
763805 kPageBlockSize ,
764- kLog2PageSize ,
765806 2 * kK1 ,
766807 V_KRepeat,
767808 1 ,
768809 kKVMemoryLayout ,
769810 false ,
811+ kN0 ,
770812 kVectorSize >(
771813 page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
772814 v_dram_window.update_page_idx (v_offsets);
@@ -900,12 +942,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
900942 decltype (v_coord),
901943 VPageIndexDim,
902944 kPageBlockSize ,
903- kLog2PageSize ,
904945 (2 + i_k1.value ) * kK1 ,
905946 V_KRepeat,
906947 1 ,
907948 kKVMemoryLayout ,
908949 false ,
950+ kN0 ,
909951 kVectorSize >(
910952 page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k);
911953 v_dram_window.update_page_idx (v_offsets);
@@ -957,12 +999,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
957999 decltype (k_coord),
9581000 0 ,
9591001 kPageBlockSize ,
960- kLog2PageSize ,
9611002 0 ,
9621003 NRepeat,
9631004 kN0 / NRepeat,
9641005 kKVMemoryLayout ,
9651006 true ,
1007+ kN0 ,
9661008 kVectorSize >(
9671009 page_idx, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);
9681010 k_dram_window.update_page_idx (k_offsets);
0 commit comments