From dbab5d31a38e8faef9dba658dc57b2a5b36f58f1 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 16 Dec 2025 20:13:00 +0000 Subject: [PATCH 01/16] Optimizations for index_select kernel --- .../fbgemm_gpu/utils/inclusive_sum_scan.cuh | 229 ++++++++++-------- .../keyed_jagged_index_select_dim1.cu | 183 +++++++++++--- 2 files changed, 275 insertions(+), 137 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/inclusive_sum_scan.cuh b/fbgemm_gpu/include/fbgemm_gpu/utils/inclusive_sum_scan.cuh index 7e1f8c1e6d..210d55fc64 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/inclusive_sum_scan.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/inclusive_sum_scan.cuh @@ -6,104 +6,133 @@ * LICENSE file in the root directory of this source tree. */ -#pragma once + #pragma once -// clang-format off -#ifdef USE_ROCM -#define HIPCUB_ARCH 1 -#include -#else -#include "fbgemm_gpu/utils/cub_namespace_prefix.cuh" -#include -#include "fbgemm_gpu/utils/cub_namespace_postfix.cuh" -#endif -// clang-format on - -namespace fbgemm_gpu { - -#ifdef USE_ROCM -namespace cub = hipcub; -#endif - -/** - * inclusive_sum_scan_kernel performs intra- and inter-thread block sum scan - * (i.e., prefix sum scan). We use cub::BlockScan to do inclusive sum within - * thread block and use a waterfall sync method to perform prefix sum across - * thread block. - * - * @param arr an array of input values. Its length must be fixed to - * ITEMS_PER_THREAD - * @param temp_storage a shared memory struct for cub::BlockScan - * @param block_flags a global flag buffer for inter-block sync (must be - * initialized with zeros) - * @param block_sums a global sum buffer for inter-block sync - * @param block_prev a shared memory pointer for sharing sum from the previous - * block within a block - * @param num_entries_per_block a number of input entries for this block - * @param block_id a relative thread block ID (the first block that contains - * the first set of input entries has block_id = 0) - * @param is_multi_block a boolean to indicate if inter-block sum scan has to - * be performed - * @param signal If the value of block_flags of the previous block is equal to - * signal, it means that the previous block has written its sum - * to block_sums. We have thread blocks increment the value of - * block_flags by one after they write their sums to block_sums. - * We increment the flag instead of setting the flag to a single - * value to support multiple sequential inclusive_sum_scan_kernel - * calls (e.g., in the AUC kernel). signal is the order that - * inclusive_sum_scan_kernel is called. Since we intialize - * block_flags with zeros, the signal of the first call should be - * one. - */ -template -__inline__ __device__ void inclusive_sum_scan_kernel( - scalar_t (&arr)[ITEMS_PER_THREAD], - typename cub::BlockScan::TempStorage& - temp_storage, - int* block_flags, - // Declared as volatile to prevent the compiler from register-allocating - // the accesses to block_sums - volatile scalar_t* block_sums, - scalar_t* block_prev, - const int num_entries_per_block, - const int block_id, - const bool is_multi_block, - const int signal) { - // Perform scan within a block - cub::BlockScan(temp_storage) - .InclusiveSum(arr, arr); - - // Perform stream scan across blocks - if (is_multi_block) { - // The thread that holds the last entry in the block does synchronization - if (threadIdx.x == (num_entries_per_block - 1) / ITEMS_PER_THREAD) { - scalar_t block_prev_local = 0; - if (block_id != 0) { - // Spin wait for the previous block to write the sum value - while (atomicAdd(&block_flags[block_id - 1], 0) < signal) - ; - - // Get sum from the previous block - *block_prev = block_prev_local = block_sums[block_id - 1]; - } - - // Write sum to global memory for the next block to consume - const int scope = (num_entries_per_block - 1) % ITEMS_PER_THREAD; - block_sums[block_id] = block_prev_local + arr[scope]; - __threadfence(); - // Set a flag to notify the next block - atomicAdd(&block_flags[block_id], 1); - } - - __syncthreads(); - - if (block_id != 0) { - scalar_t block_prev_local = *block_prev; - for (int i = 0; i < ITEMS_PER_THREAD; i++) { - arr[i] += block_prev_local; - } - } - } -} - -} // namespace fbgemm_gpu + // clang-format off + #ifdef USE_ROCM + #define HIPCUB_ARCH 1 + #include + #else + #include "fbgemm_gpu/utils/cub_namespace_prefix.cuh" + #include + #include "fbgemm_gpu/utils/cub_namespace_postfix.cuh" + #endif + // clang-format on + + namespace fbgemm_gpu { + + #ifdef USE_ROCM + namespace cub = hipcub; + #endif + + /** + * inclusive_sum_scan_kernel performs intra- and inter-thread block sum scan + * (i.e., prefix sum scan). We use cub::BlockScan to do inclusive sum within + * thread block and use a waterfall sync method to perform prefix sum across + * thread block. + * + * @param arr an array of input values. Its length must be fixed to + * ITEMS_PER_THREAD + * @param temp_storage a shared memory struct for cub::BlockScan + * @param block_flags a global flag buffer for inter-block sync (must be + * initialized with zeros) + * @param block_sums a global sum buffer for inter-block sync + * @param block_prev a shared memory pointer for sharing sum from the previous + * block within a block + * @param num_entries_per_block a number of input entries for this block + * @param block_id a relative thread block ID (the first block that contains + * the first set of input entries has block_id = 0) + * @param is_multi_block a boolean to indicate if inter-block sum scan has to + * be performed + * @param signal If the value of block_flags of the previous block is equal to + * signal, it means that the previous block has written its sum + * to block_sums. We have thread blocks increment the value of + * block_flags by one after they write their sums to block_sums. + * We increment the flag instead of setting the flag to a single + * value to support multiple sequential inclusive_sum_scan_kernel + * calls (e.g., in the AUC kernel). signal is the order that + * inclusive_sum_scan_kernel is called. Since we intialize + * block_flags with zeros, the signal of the first call should be + * one. + */ + template + __inline__ __device__ void inclusive_sum_scan_kernel( + scalar_t (&arr)[ITEMS_PER_THREAD], + typename cub::BlockScan::TempStorage& + temp_storage, + int* block_flags, + // Declared as volatile to prevent the compiler from register-allocating + // the accesses to block_sums + volatile scalar_t* block_sums, + scalar_t* block_prev, + const int num_entries_per_block, + const int block_id, + const bool is_multi_block, + const int signal) { + // ROCm path + #ifdef USE_ROCM + cub::BlockScan(temp_storage) + .InclusiveSum(arr, arr); + + if (is_multi_block) { + const bool is_last_thread = + threadIdx.x == (num_entries_per_block - 1) / ITEMS_PER_THREAD; + + if (is_last_thread) { + scalar_t block_prev_local = 0; + if (block_id != 0) { + volatile int* flags = block_flags; + *block_prev = block_prev_local = block_sums[block_id - 1]; + } + + // Write sum to global memory for the next block to consume + const int scope = (num_entries_per_block - 1) % ITEMS_PER_THREAD; + block_sums[block_id] = block_prev_local + arr[scope]; + __threadfence(); + // Set a flag to notify the next block + atomicExch(&block_flags[block_id], signal); + } + + __syncthreads(); + + if (block_id != 0) { + scalar_t block_prev_local = *block_prev; + for (int i = 0; i < ITEMS_PER_THREAD; i++) { + arr[i] += block_prev_local; + } + } + } + #else + // CUDA path + cub::BlockScan(temp_storage) + .InclusiveSum(arr, arr); + + if (is_multi_block) { + if (threadIdx.x == (num_entries_per_block - 1) / ITEMS_PER_THREAD) { + scalar_t block_prev_local = 0; + if (block_id != 0) { + while (atomicAdd(&block_flags[block_id - 1], 0) < signal) + ; + + *block_prev = block_prev_local = block_sums[block_id - 1]; + } + + const int scope = (num_entries_per_block - 1) % ITEMS_PER_THREAD; + block_sums[block_id] = block_prev_local + arr[scope]; + __threadfence(); + atomicAdd(&block_flags[block_id], 1); + } + + __syncthreads(); + + if (block_id != 0) { + scalar_t block_prev_local = *block_prev; + for (int i = 0; i < ITEMS_PER_THREAD; i++) { + arr[i] += block_prev_local; + } + } + } + #endif + } + } // namespace fbgemm_gpu + \ No newline at end of file diff --git a/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu b/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu index c795d19ecd..104c2dcace 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu @@ -6,31 +6,114 @@ * LICENSE file in the root directory of this source tree. */ -#include "common.cuh" - -using Tensor = at::Tensor; - -namespace fbgemm_gpu { - -template < - typename scalar_t, - typename index_t, - typename acc_t, - int NUM_THREADS_PER_BLOCK, - int MAX_ENTRIES_PER_BLOCK> -__global__ void index_select_scalar_cumsum_kernel( - pta::PackedTensorAccessor32 output, - pta::PackedTensorAccessor32 output_cumsum, - const pta::PackedTensorAccessor32 input, - const pta::PackedTensorAccessor32 - indices, - const int num_batches, - const int input_batch_size, - const int last_block_num_entries, - int* block_flags, - acc_t* block_sums) { - typedef cub::BlockScan BlockScan; - __shared__ typename BlockScan::TempStorage bs_temp_storage; + #include "common.cuh" + + using Tensor = at::Tensor; + + namespace fbgemm_gpu { + + template < + typename scalar_t, + typename index_t, + typename acc_t, + int NUM_THREADS_PER_BLOCK, + int MAX_ENTRIES_PER_BLOCK, + int VEC> + + __global__ void index_select_scalar_cumsum_kernel( + pta::PackedTensorAccessor32 output, + pta::PackedTensorAccessor32 output_cumsum, + const pta::PackedTensorAccessor32 input, + const pta::PackedTensorAccessor32 + indices, + const int num_batches, + const int input_batch_size, + const int last_block_num_entries, + int* block_flags, + acc_t* block_sums) { + typedef cub::BlockScan BlockScan; + __shared__ typename BlockScan::TempStorage bs_temp_storage; + __shared__ acc_t block_prefix; + +// ROCm path +#ifdef USE_ROCM + const int output_batch_size = indices.size(0); + const int num_entries = num_batches * output_batch_size; + const bool multi_block = gridDim.x > 1; + const int block_entries = blockIdx.x == gridDim.x - 1 + ? last_block_num_entries + : MAX_ENTRIES_PER_BLOCK; + const int block_entry_start = blockIdx.x * MAX_ENTRIES_PER_BLOCK; + const int remaining_entries = num_entries - block_entry_start; + const int num_entries_per_block = remaining_entries > 0 + ? (remaining_entries < block_entries ? remaining_entries : block_entries) + : 0; + + const int base_entry = block_entry_start + threadIdx.x * VEC; + acc_t local_data[VEC]; + + #pragma unroll + for (int i = 0; i < VEC; ++i) { + const int entry = base_entry + i; + if (entry < num_entries) { + const int bid = entry / output_batch_size; + const int idx_in_batch = entry - bid * output_batch_size; + const int bid_base = bid * input_batch_size; + const index_t sel_idx = indices[idx_in_batch]; + local_data[i] = + #ifdef __HIP_PLATFORM_AMD__ + __builtin_nontemporal_load( + &input[bid_base + sel_idx]); + #else + input[bid_base + sel_idx]; + #endif + output[entry] = local_data[i]; + } else { + local_data[i] = 0; + } + } + + // Faster path for single block + if (!multi_block) { + if (num_entries_per_block > 0) { + BlockScan(bs_temp_storage).InclusiveSum(local_data, local_data); + } + if (base_entry < num_entries) { + #pragma unroll + for (int i = 0; i < VEC; ++i) { + const int entry = base_entry + i; + if (entry < num_entries) { + output_cumsum[entry] = local_data[i]; + } + } + } + return; + } + + if (num_entries_per_block > 0) { + inclusive_sum_scan_kernel( + local_data, + bs_temp_storage, + block_flags, + block_sums, + &block_prefix, + num_entries_per_block, + blockIdx.x, + multi_block, + 1); + } + + if (base_entry < num_entries) { + #pragma unroll + for (int i = 0; i < VEC; ++i) { + const int entry = base_entry + i; + if (entry < num_entries) { + output_cumsum[entry] = local_data[i]; + } + } + } +#else + // CUDA path __shared__ acc_t smem[MAX_ENTRIES_PER_BLOCK]; const auto tid = blockIdx.x * blockDim.x + threadIdx.x; const int output_batch_size = indices.size(0); @@ -39,7 +122,6 @@ __global__ void index_select_scalar_cumsum_kernel( ? last_block_num_entries : MAX_ENTRIES_PER_BLOCK; - // Load data acc_t local_data[1]; if (tid < num_batches * output_batch_size) { *local_data = @@ -49,7 +131,6 @@ __global__ void index_select_scalar_cumsum_kernel( *local_data = 0; } - // Cumsum inclusive_sum_scan_kernel( local_data, bs_temp_storage, @@ -61,18 +142,26 @@ __global__ void index_select_scalar_cumsum_kernel( gridDim.x > 1, 1); - // Store data if (tid < num_batches * output_batch_size) { output_cumsum[tid] = *local_data; } +#endif } + + template < + typename scalar_t, + typename index_t, + typename offset_t, + typename weight_t, + bool has_weights> + +// Total amount of user embeddings may not fit into GPU memory. +// This kernel gathers a subset of users from a total amount of users. +// Gathers raw user's embeddings from scattered memory locations and +// writes them into contiguous memory locations. +// The kernel takes one big jagged tensor containing all keys stacked +// together, and selects the same indices across all keys in a single operation. -template < - typename scalar_t, - typename index_t, - typename offset_t, - typename weight_t, - bool has_weights> __global__ void keyed_jagged_index_select_dim1_kernel( pta::PackedTensorAccessor64 output, pta::PackedTensorAccessor64 @@ -121,6 +210,7 @@ __global__ void keyed_jagged_index_select_dim1_kernel( } } +// Computes gradients for backpropagation during training. template __global__ void keyed_jagged_index_add_dim1_kernel( pta::PackedTensorAccessor64 output, @@ -183,8 +273,24 @@ class KeyedJaggedIndexSelectDim1GPUOp const int num_batches = lengths.numel() / batch_size; const int num_output_lengths = num_batches * indices.numel(); const int MAX_CUMSUM_ENTRIES_PER_BLOCK = 256; +#ifdef USE_ROCM + const int vec_candidates[] = {4, 2, 1}; + int VEC = 1; + for (int v : vec_candidates) { + if (indices.numel() % v == 0) { + VEC = v; + break; + } + } + const int ENTRIES_PER_BLOCK = MAX_CUMSUM_ENTRIES_PER_BLOCK * VEC; + auto grid_size = (num_output_lengths + ENTRIES_PER_BLOCK - 1) / + ENTRIES_PER_BLOCK; +#else + const int VEC = 1; + const int ENTRIES_PER_BLOCK = MAX_CUMSUM_ENTRIES_PER_BLOCK; auto grid_size = cuda_calc_xblock_count( num_output_lengths, MAX_CUMSUM_ENTRIES_PER_BLOCK); +#endif Tensor output_offsets = at::empty({num_batches * indices.numel()}, offsets.options()); @@ -216,7 +322,8 @@ class KeyedJaggedIndexSelectDim1GPUOp index_t, offset_t, MAX_CUMSUM_ENTRIES_PER_BLOCK, - MAX_CUMSUM_ENTRIES_PER_BLOCK>), + ENTRIES_PER_BLOCK, + VEC>), grid_size, MAX_CUMSUM_ENTRIES_PER_BLOCK, 0, @@ -227,8 +334,10 @@ class KeyedJaggedIndexSelectDim1GPUOp PTA_B(indices, index_t, 1, 32), num_batches, batch_size, - num_output_lengths - - MAX_CUMSUM_ENTRIES_PER_BLOCK * (grid_size - 1), + grid_size == 0 + ? 0 + : num_output_lengths - + ENTRIES_PER_BLOCK * (grid_size - 1), grid_size > 1 ? block_flags.data_ptr() : nullptr, grid_size > 1 ? block_sums.data_ptr() : nullptr); From fe9bfbcb3ae21938eaa181fa74371990bfa48cba Mon Sep 17 00:00:00 2001 From: root Date: Tue, 16 Dec 2025 20:20:25 +0000 Subject: [PATCH 02/16] Optimizations for index_select kernel --- .../fbgemm_gpu/utils/inclusive_sum_scan.cuh | 257 +++++++++--------- .../keyed_jagged_index_select_dim1.cu | 67 ++--- 2 files changed, 163 insertions(+), 161 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/inclusive_sum_scan.cuh b/fbgemm_gpu/include/fbgemm_gpu/utils/inclusive_sum_scan.cuh index 210d55fc64..624711b8e4 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/inclusive_sum_scan.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/inclusive_sum_scan.cuh @@ -6,133 +6,132 @@ * LICENSE file in the root directory of this source tree. */ - #pragma once +#pragma once - // clang-format off - #ifdef USE_ROCM - #define HIPCUB_ARCH 1 - #include - #else - #include "fbgemm_gpu/utils/cub_namespace_prefix.cuh" - #include - #include "fbgemm_gpu/utils/cub_namespace_postfix.cuh" - #endif - // clang-format on - - namespace fbgemm_gpu { - - #ifdef USE_ROCM - namespace cub = hipcub; - #endif - - /** - * inclusive_sum_scan_kernel performs intra- and inter-thread block sum scan - * (i.e., prefix sum scan). We use cub::BlockScan to do inclusive sum within - * thread block and use a waterfall sync method to perform prefix sum across - * thread block. - * - * @param arr an array of input values. Its length must be fixed to - * ITEMS_PER_THREAD - * @param temp_storage a shared memory struct for cub::BlockScan - * @param block_flags a global flag buffer for inter-block sync (must be - * initialized with zeros) - * @param block_sums a global sum buffer for inter-block sync - * @param block_prev a shared memory pointer for sharing sum from the previous - * block within a block - * @param num_entries_per_block a number of input entries for this block - * @param block_id a relative thread block ID (the first block that contains - * the first set of input entries has block_id = 0) - * @param is_multi_block a boolean to indicate if inter-block sum scan has to - * be performed - * @param signal If the value of block_flags of the previous block is equal to - * signal, it means that the previous block has written its sum - * to block_sums. We have thread blocks increment the value of - * block_flags by one after they write their sums to block_sums. - * We increment the flag instead of setting the flag to a single - * value to support multiple sequential inclusive_sum_scan_kernel - * calls (e.g., in the AUC kernel). signal is the order that - * inclusive_sum_scan_kernel is called. Since we intialize - * block_flags with zeros, the signal of the first call should be - * one. - */ - template - __inline__ __device__ void inclusive_sum_scan_kernel( - scalar_t (&arr)[ITEMS_PER_THREAD], - typename cub::BlockScan::TempStorage& - temp_storage, - int* block_flags, - // Declared as volatile to prevent the compiler from register-allocating - // the accesses to block_sums - volatile scalar_t* block_sums, - scalar_t* block_prev, - const int num_entries_per_block, - const int block_id, - const bool is_multi_block, - const int signal) { - // ROCm path - #ifdef USE_ROCM - cub::BlockScan(temp_storage) - .InclusiveSum(arr, arr); - - if (is_multi_block) { - const bool is_last_thread = - threadIdx.x == (num_entries_per_block - 1) / ITEMS_PER_THREAD; - - if (is_last_thread) { - scalar_t block_prev_local = 0; - if (block_id != 0) { - volatile int* flags = block_flags; - *block_prev = block_prev_local = block_sums[block_id - 1]; - } - - // Write sum to global memory for the next block to consume - const int scope = (num_entries_per_block - 1) % ITEMS_PER_THREAD; - block_sums[block_id] = block_prev_local + arr[scope]; - __threadfence(); - // Set a flag to notify the next block - atomicExch(&block_flags[block_id], signal); - } - - __syncthreads(); - - if (block_id != 0) { - scalar_t block_prev_local = *block_prev; - for (int i = 0; i < ITEMS_PER_THREAD; i++) { - arr[i] += block_prev_local; - } - } - } - #else - // CUDA path - cub::BlockScan(temp_storage) - .InclusiveSum(arr, arr); - - if (is_multi_block) { - if (threadIdx.x == (num_entries_per_block - 1) / ITEMS_PER_THREAD) { - scalar_t block_prev_local = 0; - if (block_id != 0) { - while (atomicAdd(&block_flags[block_id - 1], 0) < signal) - ; - - *block_prev = block_prev_local = block_sums[block_id - 1]; - } - - const int scope = (num_entries_per_block - 1) % ITEMS_PER_THREAD; - block_sums[block_id] = block_prev_local + arr[scope]; - __threadfence(); - atomicAdd(&block_flags[block_id], 1); - } - - __syncthreads(); - - if (block_id != 0) { - scalar_t block_prev_local = *block_prev; - for (int i = 0; i < ITEMS_PER_THREAD; i++) { - arr[i] += block_prev_local; - } - } - } - #endif - } - } // namespace fbgemm_gpu - \ No newline at end of file +// clang-format off +#ifdef USE_ROCM +#define HIPCUB_ARCH 1 +#include +#else +#include "fbgemm_gpu/utils/cub_namespace_prefix.cuh" +#include +#include "fbgemm_gpu/utils/cub_namespace_postfix.cuh" +#endif +// clang-format on + +namespace fbgemm_gpu { + +#ifdef USE_ROCM +namespace cub = hipcub; +#endif + +/** + * inclusive_sum_scan_kernel performs intra- and inter-thread block sum scan + * (i.e., prefix sum scan). We use cub::BlockScan to do inclusive sum within + * thread block and use a waterfall sync method to perform prefix sum across + * thread block. + * + * @param arr an array of input values. Its length must be fixed to + * ITEMS_PER_THREAD + * @param temp_storage a shared memory struct for cub::BlockScan + * @param block_flags a global flag buffer for inter-block sync (must be + * initialized with zeros) + * @param block_sums a global sum buffer for inter-block sync + * @param block_prev a shared memory pointer for sharing sum from the previous + * block within a block + * @param num_entries_per_block a number of input entries for this block + * @param block_id a relative thread block ID (the first block that contains + * the first set of input entries has block_id = 0) + * @param is_multi_block a boolean to indicate if inter-block sum scan has to + * be performed + * @param signal If the value of block_flags of the previous block is equal to + * signal, it means that the previous block has written its sum + * to block_sums. We have thread blocks increment the value of + * block_flags by one after they write their sums to block_sums. + * We increment the flag instead of setting the flag to a single + * value to support multiple sequential inclusive_sum_scan_kernel + * calls (e.g., in the AUC kernel). signal is the order that + * inclusive_sum_scan_kernel is called. Since we intialize + * block_flags with zeros, the signal of the first call should be + * one. + */ +template +__inline__ __device__ void inclusive_sum_scan_kernel( + scalar_t (&arr)[ITEMS_PER_THREAD], + typename cub::BlockScan::TempStorage& + temp_storage, + int* block_flags, + // Declared as volatile to prevent the compiler from register-allocating + // the accesses to block_sums + volatile scalar_t* block_sums, + scalar_t* block_prev, + const int num_entries_per_block, + const int block_id, + const bool is_multi_block, + const int signal) { +// ROCm path +#ifdef USE_ROCM + cub::BlockScan(temp_storage) + .InclusiveSum(arr, arr); + + if (is_multi_block) { + const bool is_last_thread = + threadIdx.x == (num_entries_per_block - 1) / ITEMS_PER_THREAD; + + if (is_last_thread) { + scalar_t block_prev_local = 0; + if (block_id != 0) { + volatile int* flags = block_flags; + *block_prev = block_prev_local = block_sums[block_id - 1]; + } + + // Write sum to global memory for the next block to consume + const int scope = (num_entries_per_block - 1) % ITEMS_PER_THREAD; + block_sums[block_id] = block_prev_local + arr[scope]; + __threadfence(); + // Set a flag to notify the next block + atomicExch(&block_flags[block_id], signal); + } + + __syncthreads(); + + if (block_id != 0) { + scalar_t block_prev_local = *block_prev; + for (int i = 0; i < ITEMS_PER_THREAD; i++) { + arr[i] += block_prev_local; + } + } + } +#else + // CUDA path + cub::BlockScan(temp_storage) + .InclusiveSum(arr, arr); + + if (is_multi_block) { + if (threadIdx.x == (num_entries_per_block - 1) / ITEMS_PER_THREAD) { + scalar_t block_prev_local = 0; + if (block_id != 0) { + while (atomicAdd(&block_flags[block_id - 1], 0) < signal) + ; + + *block_prev = block_prev_local = block_sums[block_id - 1]; + } + + const int scope = (num_entries_per_block - 1) % ITEMS_PER_THREAD; + block_sums[block_id] = block_prev_local + arr[scope]; + __threadfence(); + atomicAdd(&block_flags[block_id], 1); + } + + __syncthreads(); + + if (block_id != 0) { + scalar_t block_prev_local = *block_prev; + for (int i = 0; i < ITEMS_PER_THREAD; i++) { + arr[i] += block_prev_local; + } + } + } +#endif +} +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu b/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu index 104c2dcace..69c126a6bf 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu @@ -6,34 +6,34 @@ * LICENSE file in the root directory of this source tree. */ - #include "common.cuh" +#include "common.cuh" - using Tensor = at::Tensor; - - namespace fbgemm_gpu { - - template < - typename scalar_t, - typename index_t, - typename acc_t, - int NUM_THREADS_PER_BLOCK, - int MAX_ENTRIES_PER_BLOCK, - int VEC> +using Tensor = at::Tensor; + +namespace fbgemm_gpu { + +template < + typename scalar_t, + typename index_t, + typename acc_t, + int NUM_THREADS_PER_BLOCK, + int MAX_ENTRIES_PER_BLOCK, + int VEC> - __global__ void index_select_scalar_cumsum_kernel( - pta::PackedTensorAccessor32 output, - pta::PackedTensorAccessor32 output_cumsum, - const pta::PackedTensorAccessor32 input, - const pta::PackedTensorAccessor32 - indices, - const int num_batches, - const int input_batch_size, - const int last_block_num_entries, - int* block_flags, - acc_t* block_sums) { - typedef cub::BlockScan BlockScan; - __shared__ typename BlockScan::TempStorage bs_temp_storage; - __shared__ acc_t block_prefix; +__global__ void index_select_scalar_cumsum_kernel( + pta::PackedTensorAccessor32 output, + pta::PackedTensorAccessor32 output_cumsum, + const pta::PackedTensorAccessor32 input, + const pta::PackedTensorAccessor32 + indices, + const int num_batches, + const int input_batch_size, + const int last_block_num_entries, + int* block_flags, + acc_t* block_sums) { + typedef cub::BlockScan BlockScan; + __shared__ typename BlockScan::TempStorage bs_temp_storage; + __shared__ acc_t block_prefix; // ROCm path #ifdef USE_ROCM @@ -122,6 +122,7 @@ ? last_block_num_entries : MAX_ENTRIES_PER_BLOCK; + // Load data acc_t local_data[1]; if (tid < num_batches * output_batch_size) { *local_data = @@ -131,6 +132,7 @@ *local_data = 0; } + // Cumsum inclusive_sum_scan_kernel( local_data, bs_temp_storage, @@ -142,18 +144,19 @@ gridDim.x > 1, 1); + // Store data if (tid < num_batches * output_batch_size) { output_cumsum[tid] = *local_data; } #endif } - template < - typename scalar_t, - typename index_t, - typename offset_t, - typename weight_t, - bool has_weights> +template < + typename scalar_t, + typename index_t, + typename offset_t, + typename weight_t, + bool has_weights> // Total amount of user embeddings may not fit into GPU memory. // This kernel gathers a subset of users from a total amount of users. From 782bf9160c9407c53535fb5e93ba9677c5f7db6b Mon Sep 17 00:00:00 2001 From: root Date: Tue, 16 Dec 2025 20:28:17 +0000 Subject: [PATCH 03/16] Optimizations for index_select kernel --- .../include/fbgemm_gpu/utils/inclusive_sum_scan.cuh | 10 +++++++++- .../keyed_jagged_index_select_dim1.cu | 9 --------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/inclusive_sum_scan.cuh b/fbgemm_gpu/include/fbgemm_gpu/utils/inclusive_sum_scan.cuh index 624711b8e4..d60e6e8aa7 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/inclusive_sum_scan.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/inclusive_sum_scan.cuh @@ -74,14 +74,16 @@ __inline__ __device__ void inclusive_sum_scan_kernel( cub::BlockScan(temp_storage) .InclusiveSum(arr, arr); + // Perform stream scan across blocks if (is_multi_block) { const bool is_last_thread = threadIdx.x == (num_entries_per_block - 1) / ITEMS_PER_THREAD; - + // The thread that holds the last entry in the block does synchronization if (is_last_thread) { scalar_t block_prev_local = 0; if (block_id != 0) { volatile int* flags = block_flags; + // Get sum from the previous block *block_prev = block_prev_local = block_sums[block_id - 1]; } @@ -107,19 +109,25 @@ __inline__ __device__ void inclusive_sum_scan_kernel( cub::BlockScan(temp_storage) .InclusiveSum(arr, arr); + // Perform stream scan across blocks if (is_multi_block) { + // The thread that holds the last entry in the block does synchronization if (threadIdx.x == (num_entries_per_block - 1) / ITEMS_PER_THREAD) { scalar_t block_prev_local = 0; if (block_id != 0) { + // Spin wait for the previous block to write the sum value while (atomicAdd(&block_flags[block_id - 1], 0) < signal) ; + // Get sum from the previous block *block_prev = block_prev_local = block_sums[block_id - 1]; } + // Write sum to global memory for the next block to consume const int scope = (num_entries_per_block - 1) % ITEMS_PER_THREAD; block_sums[block_id] = block_prev_local + arr[scope]; __threadfence(); + // Set a flag to notify the next block atomicAdd(&block_flags[block_id], 1); } diff --git a/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu b/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu index 69c126a6bf..84e6d2b179 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu @@ -157,14 +157,6 @@ template < typename offset_t, typename weight_t, bool has_weights> - -// Total amount of user embeddings may not fit into GPU memory. -// This kernel gathers a subset of users from a total amount of users. -// Gathers raw user's embeddings from scattered memory locations and -// writes them into contiguous memory locations. -// The kernel takes one big jagged tensor containing all keys stacked -// together, and selects the same indices across all keys in a single operation. - __global__ void keyed_jagged_index_select_dim1_kernel( pta::PackedTensorAccessor64 output, pta::PackedTensorAccessor64 @@ -213,7 +205,6 @@ __global__ void keyed_jagged_index_select_dim1_kernel( } } -// Computes gradients for backpropagation during training. template __global__ void keyed_jagged_index_add_dim1_kernel( pta::PackedTensorAccessor64 output, From b9f862583a5803f46d09f801cd50c54de3f0172a Mon Sep 17 00:00:00 2001 From: root Date: Tue, 16 Dec 2025 20:31:44 +0000 Subject: [PATCH 04/16] Updated cumsum --- fbgemm_gpu/include/fbgemm_gpu/utils/inclusive_sum_scan.cuh | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/inclusive_sum_scan.cuh b/fbgemm_gpu/include/fbgemm_gpu/utils/inclusive_sum_scan.cuh index d60e6e8aa7..032760e21c 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/inclusive_sum_scan.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/inclusive_sum_scan.cuh @@ -71,10 +71,11 @@ __inline__ __device__ void inclusive_sum_scan_kernel( const int signal) { // ROCm path #ifdef USE_ROCM + // Perform scan within a block cub::BlockScan(temp_storage) .InclusiveSum(arr, arr); - // Perform stream scan across blocks + // Perform scan across blocks if (is_multi_block) { const bool is_last_thread = threadIdx.x == (num_entries_per_block - 1) / ITEMS_PER_THREAD; @@ -106,10 +107,11 @@ __inline__ __device__ void inclusive_sum_scan_kernel( } #else // CUDA path + // Perform scan across blocks cub::BlockScan(temp_storage) .InclusiveSum(arr, arr); - // Perform stream scan across blocks + // Perform scan across blocks if (is_multi_block) { // The thread that holds the last entry in the block does synchronization if (threadIdx.x == (num_entries_per_block - 1) / ITEMS_PER_THREAD) { From 3ef975bf19b50bc5735f9922590107c26f9c2082 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 16 Dec 2025 21:11:51 +0000 Subject: [PATCH 05/16] Updated template for kernel API --- .../keyed_jagged_index_select_dim1.cu | 138 +++++++++++------- 1 file changed, 85 insertions(+), 53 deletions(-) diff --git a/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu b/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu index 84e6d2b179..44da15231b 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu @@ -6,6 +6,8 @@ * LICENSE file in the root directory of this source tree. */ +#include + #include "common.cuh" using Tensor = at::Tensor; @@ -17,8 +19,7 @@ template < typename index_t, typename acc_t, int NUM_THREADS_PER_BLOCK, - int MAX_ENTRIES_PER_BLOCK, - int VEC> + int MAX_ENTRIES_PER_BLOCK> __global__ void index_select_scalar_cumsum_kernel( pta::PackedTensorAccessor32 output, @@ -269,18 +270,15 @@ class KeyedJaggedIndexSelectDim1GPUOp const int MAX_CUMSUM_ENTRIES_PER_BLOCK = 256; #ifdef USE_ROCM const int vec_candidates[] = {4, 2, 1}; - int VEC = 1; + int vec = 1; for (int v : vec_candidates) { if (indices.numel() % v == 0) { - VEC = v; + vec = v; break; } } - const int ENTRIES_PER_BLOCK = MAX_CUMSUM_ENTRIES_PER_BLOCK * VEC; - auto grid_size = (num_output_lengths + ENTRIES_PER_BLOCK - 1) / - ENTRIES_PER_BLOCK; #else - const int VEC = 1; + const int vec = 1; const int ENTRIES_PER_BLOCK = MAX_CUMSUM_ENTRIES_PER_BLOCK; auto grid_size = cuda_calc_xblock_count( num_output_lengths, MAX_CUMSUM_ENTRIES_PER_BLOCK); @@ -291,53 +289,87 @@ class KeyedJaggedIndexSelectDim1GPUOp Tensor output_lengths = at::empty({num_batches * indices.numel()}, lengths.options()); - Tensor block_flags, block_sums; - if (grid_size > 1) { - block_flags = at::zeros({grid_size}, lengths.options().dtype(at::kInt)); - block_sums = at::empty({grid_size}, output_offsets.options()); - } - // Do index select and cumsum - AT_DISPATCH_INDEX_TYPES( - lengths.scalar_type(), "index_select_scalar_cumsum_wrapper_1", [&] { - using length_t = index_t; - AT_DISPATCH_INDEX_TYPES( - offsets.scalar_type(), - "index_select_scalar_cumsum_wrapper_2", - [&] { - using offset_t = index_t; - AT_DISPATCH_INDEX_TYPES( - indices.scalar_type(), - "index_select_scalar_cumsum_wrapper_3", - [&] { - FBGEMM_LAUNCH_KERNEL( - (index_select_scalar_cumsum_kernel< - length_t, - index_t, - offset_t, - MAX_CUMSUM_ENTRIES_PER_BLOCK, - ENTRIES_PER_BLOCK, - VEC>), - grid_size, - MAX_CUMSUM_ENTRIES_PER_BLOCK, - 0, - at::cuda::getCurrentCUDAStream(), - PTA_B(output_lengths, length_t, 1, 32), - PTA_B(output_offsets, offset_t, 1, 32), - PTA_B(lengths, length_t, 1, 32), - PTA_B(indices, index_t, 1, 32), - num_batches, - batch_size, - grid_size == 0 - ? 0 - : num_output_lengths - - ENTRIES_PER_BLOCK * (grid_size - 1), - grid_size > 1 ? block_flags.data_ptr() : nullptr, - grid_size > 1 ? block_sums.data_ptr() - : nullptr); - }); - }); + auto dispatch_cumsum = [&](auto vec_tag, auto grid_calc) { + constexpr int VEC = decltype(vec_tag)::value; + constexpr int ENTRIES_PER_BLOCK = + MAX_CUMSUM_ENTRIES_PER_BLOCK * VEC; + const auto grid_size = grid_calc(ENTRIES_PER_BLOCK); + + Tensor block_flags, block_sums; + if (grid_size > 1) { + block_flags = + at::zeros({grid_size}, lengths.options().dtype(at::kInt)); + block_sums = at::empty({grid_size}, output_offsets.options()); + } + + AT_DISPATCH_INDEX_TYPES( + lengths.scalar_type(), "index_select_scalar_cumsum_wrapper_1", [&] { + using length_t = index_t; + AT_DISPATCH_INDEX_TYPES( + offsets.scalar_type(), + "index_select_scalar_cumsum_wrapper_2", + [&] { + using offset_t = index_t; + AT_DISPATCH_INDEX_TYPES( + indices.scalar_type(), + "index_select_scalar_cumsum_wrapper_3", + [&] { + FBGEMM_LAUNCH_KERNEL( + (index_select_scalar_cumsum_kernel< + length_t, + index_t, + offset_t, + MAX_CUMSUM_ENTRIES_PER_BLOCK, + ENTRIES_PER_BLOCK, + VEC>), + grid_size, + MAX_CUMSUM_ENTRIES_PER_BLOCK, + 0, + at::cuda::getCurrentCUDAStream(), + PTA_B(output_lengths, length_t, 1, 32), + PTA_B(output_offsets, offset_t, 1, 32), + PTA_B(lengths, length_t, 1, 32), + PTA_B(indices, index_t, 1, 32), + num_batches, + batch_size, + grid_size == 0 + ? 0 + : num_output_lengths - + ENTRIES_PER_BLOCK * (grid_size - 1), + grid_size > 1 + ? block_flags.data_ptr() + : nullptr, + grid_size > 1 + ? block_sums.data_ptr() + : nullptr); + }); + }); + }); + }; + +#ifdef USE_ROCM + auto rocm_grid = [&](int entries_per_block) { + return (num_output_lengths + entries_per_block - 1) / entries_per_block; + }; + switch (vec) { + case 4: + dispatch_cumsum(std::integral_constant{}, rocm_grid); + break; + case 2: + dispatch_cumsum(std::integral_constant{}, rocm_grid); + break; + default: + dispatch_cumsum(std::integral_constant{}, rocm_grid); + break; + } +#else + dispatch_cumsum( + std::integral_constant{}, + [&](int entries_per_block) { + return cuda_calc_xblock_count(num_output_lengths, entries_per_block); }); +#endif const int64_t num_outputs = (selected_lengths_sum.has_value()) ? selected_lengths_sum.value().guard_int(__FILE__, __LINE__) From 20d14456757812a59b9a5a7f48a111acb67fde1f Mon Sep 17 00:00:00 2001 From: root Date: Tue, 16 Dec 2025 21:14:11 +0000 Subject: [PATCH 06/16] Updated formatting --- .../src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu | 2 -- 1 file changed, 2 deletions(-) diff --git a/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu b/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu index 44da15231b..30999ea741 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu @@ -7,7 +7,6 @@ */ #include - #include "common.cuh" using Tensor = at::Tensor; @@ -20,7 +19,6 @@ template < typename acc_t, int NUM_THREADS_PER_BLOCK, int MAX_ENTRIES_PER_BLOCK> - __global__ void index_select_scalar_cumsum_kernel( pta::PackedTensorAccessor32 output, pta::PackedTensorAccessor32 output_cumsum, From 27e55377a7ff2e1ec8a0617f97d4921afe6be89a Mon Sep 17 00:00:00 2001 From: root Date: Tue, 16 Dec 2025 22:30:13 +0000 Subject: [PATCH 07/16] Updated VEC and formatting --- .../keyed_jagged_index_select_dim1.cu | 281 ++++++++---------- 1 file changed, 122 insertions(+), 159 deletions(-) diff --git a/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu b/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu index 30999ea741..87eb445289 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu @@ -36,80 +36,81 @@ __global__ void index_select_scalar_cumsum_kernel( // ROCm path #ifdef USE_ROCM - const int output_batch_size = indices.size(0); - const int num_entries = num_batches * output_batch_size; - const bool multi_block = gridDim.x > 1; - const int block_entries = blockIdx.x == gridDim.x - 1 - ? last_block_num_entries - : MAX_ENTRIES_PER_BLOCK; - const int block_entry_start = blockIdx.x * MAX_ENTRIES_PER_BLOCK; - const int remaining_entries = num_entries - block_entry_start; - const int num_entries_per_block = remaining_entries > 0 - ? (remaining_entries < block_entries ? remaining_entries : block_entries) - : 0; - - const int base_entry = block_entry_start + threadIdx.x * VEC; - acc_t local_data[VEC]; - - #pragma unroll - for (int i = 0; i < VEC; ++i) { - const int entry = base_entry + i; - if (entry < num_entries) { - const int bid = entry / output_batch_size; - const int idx_in_batch = entry - bid * output_batch_size; + // 4 indices per thread + constexpr int VEC = 4; + const int output_batch_size = indices.size(0); + const int num_entries = num_batches * output_batch_size; + const bool multi_block = gridDim.x > 1; + const int block_entries = blockIdx.x == gridDim.x - 1 + ? last_block_num_entries + : MAX_ENTRIES_PER_BLOCK; + const int block_entry_start = blockIdx.x * MAX_ENTRIES_PER_BLOCK; + const int remaining_entries = num_entries - block_entry_start; + const int num_entries_per_block = remaining_entries > 0 + ? (remaining_entries < block_entries ? remaining_entries : block_entries) + : 0; + + const int base_entry = block_entry_start + threadIdx.x * VEC; + acc_t local_data[VEC]; + +#pragma unroll + for (int i = 0; i < VEC; ++i) { + const int entry = base_entry + i; + if (entry < num_entries) { + const int bid = entry / output_batch_size; + const int idx_in_batch = entry - bid * output_batch_size; const int bid_base = bid * input_batch_size; const index_t sel_idx = indices[idx_in_batch]; - local_data[i] = - #ifdef __HIP_PLATFORM_AMD__ - __builtin_nontemporal_load( - &input[bid_base + sel_idx]); - #else + local_data[i] = +#ifdef __HIP_PLATFORM_AMD__ + __builtin_nontemporal_load(&input[bid_base + sel_idx]); +#else input[bid_base + sel_idx]; - #endif - output[entry] = local_data[i]; - } else { - local_data[i] = 0; - } - } - - // Faster path for single block - if (!multi_block) { - if (num_entries_per_block > 0) { - BlockScan(bs_temp_storage).InclusiveSum(local_data, local_data); - } - if (base_entry < num_entries) { - #pragma unroll - for (int i = 0; i < VEC; ++i) { - const int entry = base_entry + i; - if (entry < num_entries) { - output_cumsum[entry] = local_data[i]; - } - } - } - return; - } - - if (num_entries_per_block > 0) { - inclusive_sum_scan_kernel( - local_data, - bs_temp_storage, - block_flags, - block_sums, - &block_prefix, - num_entries_per_block, - blockIdx.x, - multi_block, - 1); - } - - if (base_entry < num_entries) { - #pragma unroll - for (int i = 0; i < VEC; ++i) { - const int entry = base_entry + i; - if (entry < num_entries) { - output_cumsum[entry] = local_data[i]; - } - } +#endif + output[entry] = local_data[i]; + } else { + local_data[i] = 0; + } + } + + // Faster path for single block + if (!multi_block) { + if (num_entries_per_block > 0) { + BlockScan(bs_temp_storage).InclusiveSum(local_data, local_data); + } + if (base_entry < num_entries) { +#pragma unroll + for (int i = 0; i < VEC; ++i) { + const int entry = base_entry + i; + if (entry < num_entries) { + output_cumsum[entry] = local_data[i]; + } + } + } + return; + } + + if (num_entries_per_block > 0) { + inclusive_sum_scan_kernel( + local_data, + bs_temp_storage, + block_flags, + block_sums, + &block_prefix, + num_entries_per_block, + blockIdx.x, + multi_block, + 1); + } + + if (base_entry < num_entries) { +#pragma unroll + for (int i = 0; i < VEC; ++i) { + const int entry = base_entry + i; + if (entry < num_entries) { + output_cumsum[entry] = local_data[i]; + } + } } #else // CUDA path @@ -266,19 +267,15 @@ class KeyedJaggedIndexSelectDim1GPUOp const int num_batches = lengths.numel() / batch_size; const int num_output_lengths = num_batches * indices.numel(); const int MAX_CUMSUM_ENTRIES_PER_BLOCK = 256; + int grid_size = 0; #ifdef USE_ROCM - const int vec_candidates[] = {4, 2, 1}; - int vec = 1; - for (int v : vec_candidates) { - if (indices.numel() % v == 0) { - vec = v; - break; - } - } + constexpr int VEC = 4; + const int ENTRIES_PER_BLOCK = MAX_CUMSUM_ENTRIES_PER_BLOCK * VEC; + grid_size = (num_output_lengths + ENTRIES_PER_BLOCK - 1) / + ENTRIES_PER_BLOCK; #else - const int vec = 1; const int ENTRIES_PER_BLOCK = MAX_CUMSUM_ENTRIES_PER_BLOCK; - auto grid_size = cuda_calc_xblock_count( + grid_size = cuda_calc_xblock_count( num_output_lengths, MAX_CUMSUM_ENTRIES_PER_BLOCK); #endif @@ -288,86 +285,52 @@ class KeyedJaggedIndexSelectDim1GPUOp at::empty({num_batches * indices.numel()}, lengths.options()); // Do index select and cumsum - auto dispatch_cumsum = [&](auto vec_tag, auto grid_calc) { - constexpr int VEC = decltype(vec_tag)::value; - constexpr int ENTRIES_PER_BLOCK = - MAX_CUMSUM_ENTRIES_PER_BLOCK * VEC; - const auto grid_size = grid_calc(ENTRIES_PER_BLOCK); - - Tensor block_flags, block_sums; - if (grid_size > 1) { - block_flags = - at::zeros({grid_size}, lengths.options().dtype(at::kInt)); - block_sums = at::empty({grid_size}, output_offsets.options()); - } - - AT_DISPATCH_INDEX_TYPES( - lengths.scalar_type(), "index_select_scalar_cumsum_wrapper_1", [&] { - using length_t = index_t; - AT_DISPATCH_INDEX_TYPES( - offsets.scalar_type(), - "index_select_scalar_cumsum_wrapper_2", - [&] { - using offset_t = index_t; - AT_DISPATCH_INDEX_TYPES( - indices.scalar_type(), - "index_select_scalar_cumsum_wrapper_3", - [&] { - FBGEMM_LAUNCH_KERNEL( - (index_select_scalar_cumsum_kernel< - length_t, - index_t, - offset_t, - MAX_CUMSUM_ENTRIES_PER_BLOCK, - ENTRIES_PER_BLOCK, - VEC>), - grid_size, - MAX_CUMSUM_ENTRIES_PER_BLOCK, - 0, - at::cuda::getCurrentCUDAStream(), - PTA_B(output_lengths, length_t, 1, 32), - PTA_B(output_offsets, offset_t, 1, 32), - PTA_B(lengths, length_t, 1, 32), - PTA_B(indices, index_t, 1, 32), - num_batches, - batch_size, - grid_size == 0 - ? 0 - : num_output_lengths - - ENTRIES_PER_BLOCK * (grid_size - 1), - grid_size > 1 - ? block_flags.data_ptr() - : nullptr, - grid_size > 1 - ? block_sums.data_ptr() - : nullptr); - }); - }); - }); - }; - -#ifdef USE_ROCM - auto rocm_grid = [&](int entries_per_block) { - return (num_output_lengths + entries_per_block - 1) / entries_per_block; - }; - switch (vec) { - case 4: - dispatch_cumsum(std::integral_constant{}, rocm_grid); - break; - case 2: - dispatch_cumsum(std::integral_constant{}, rocm_grid); - break; - default: - dispatch_cumsum(std::integral_constant{}, rocm_grid); - break; + // Do index select and cumsum + Tensor block_flags, block_sums; + if (grid_size > 1) { + block_flags = at::zeros({grid_size}, lengths.options().dtype(at::kInt)); + block_sums = at::empty({grid_size}, output_offsets.options()); } -#else - dispatch_cumsum( - std::integral_constant{}, - [&](int entries_per_block) { - return cuda_calc_xblock_count(num_output_lengths, entries_per_block); + + AT_DISPATCH_INDEX_TYPES( + lengths.scalar_type(), "index_select_scalar_cumsum_wrapper_1", [&] { + using length_t = index_t; + AT_DISPATCH_INDEX_TYPES( + offsets.scalar_type(), + "index_select_scalar_cumsum_wrapper_2", + [&] { + using offset_t = index_t; + AT_DISPATCH_INDEX_TYPES( + indices.scalar_type(), + "index_select_scalar_cumsum_wrapper_3", + [&] { + FBGEMM_LAUNCH_KERNEL( + (index_select_scalar_cumsum_kernel< + length_t, + index_t, + offset_t, + MAX_CUMSUM_ENTRIES_PER_BLOCK, + ENTRIES_PER_BLOCK>), + grid_size, + MAX_CUMSUM_ENTRIES_PER_BLOCK, + 0, + at::cuda::getCurrentCUDAStream(), + PTA_B(output_lengths, length_t, 1, 32), + PTA_B(output_offsets, offset_t, 1, 32), + PTA_B(lengths, length_t, 1, 32), + PTA_B(indices, index_t, 1, 32), + num_batches, + batch_size, + grid_size == 0 + ? 0 + : num_output_lengths - + ENTRIES_PER_BLOCK * (grid_size - 1), + grid_size > 1 ? block_flags.data_ptr() : nullptr, + grid_size > 1 ? block_sums.data_ptr() + : nullptr); + }); + }); }); -#endif const int64_t num_outputs = (selected_lengths_sum.has_value()) ? selected_lengths_sum.value().guard_int(__FILE__, __LINE__) From 2182e0bf83dec2cfb5b2d52aa01d403e7d9bd081 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 16 Dec 2025 22:33:02 +0000 Subject: [PATCH 08/16] Updated formatting --- .../src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu b/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu index 87eb445289..693bcb0201 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu @@ -150,7 +150,7 @@ __global__ void index_select_scalar_cumsum_kernel( } #endif } - + template < typename scalar_t, typename index_t, @@ -284,7 +284,6 @@ class KeyedJaggedIndexSelectDim1GPUOp Tensor output_lengths = at::empty({num_batches * indices.numel()}, lengths.options()); - // Do index select and cumsum // Do index select and cumsum Tensor block_flags, block_sums; if (grid_size > 1) { From fe929e5d023f4c656b7c5efb6c01101269836b82 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 16 Dec 2025 22:34:08 +0000 Subject: [PATCH 09/16] Updated formatting --- .../src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu b/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu index 693bcb0201..e1ea25f99a 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu @@ -284,13 +284,13 @@ class KeyedJaggedIndexSelectDim1GPUOp Tensor output_lengths = at::empty({num_batches * indices.numel()}, lengths.options()); - // Do index select and cumsum Tensor block_flags, block_sums; if (grid_size > 1) { block_flags = at::zeros({grid_size}, lengths.options().dtype(at::kInt)); block_sums = at::empty({grid_size}, output_offsets.options()); } + // Do index select and cumsum AT_DISPATCH_INDEX_TYPES( lengths.scalar_type(), "index_select_scalar_cumsum_wrapper_1", [&] { using length_t = index_t; From 489e4b6abea2be7c3f3e0e99d6f28f640ee87576 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 17 Dec 2025 17:57:28 +0000 Subject: [PATCH 10/16] Addressing PR comments --- .../keyed_jagged_index_select_dim1.cu | 27 ++++++++----------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu b/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu index e1ea25f99a..f7ebe5d4b6 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu @@ -36,8 +36,8 @@ __global__ void index_select_scalar_cumsum_kernel( // ROCm path #ifdef USE_ROCM - // 4 indices per thread - constexpr int VEC = 4; + // 4 indices/entries per thread + constexpr int ENTRIES_PER_THREAD = 4; const int output_batch_size = indices.size(0); const int num_entries = num_batches * output_batch_size; const bool multi_block = gridDim.x > 1; @@ -50,23 +50,18 @@ __global__ void index_select_scalar_cumsum_kernel( ? (remaining_entries < block_entries ? remaining_entries : block_entries) : 0; - const int base_entry = block_entry_start + threadIdx.x * VEC; - acc_t local_data[VEC]; + const int base_entry = block_entry_start + threadIdx.x * ENTRIES_PER_THREAD; + acc_t local_data[ENTRIES_PER_THREAD]; #pragma unroll - for (int i = 0; i < VEC; ++i) { + for (int i = 0; i < ENTRIES_PER_THREAD; ++i) { const int entry = base_entry + i; if (entry < num_entries) { const int bid = entry / output_batch_size; const int idx_in_batch = entry - bid * output_batch_size; const int bid_base = bid * input_batch_size; const index_t sel_idx = indices[idx_in_batch]; - local_data[i] = -#ifdef __HIP_PLATFORM_AMD__ - __builtin_nontemporal_load(&input[bid_base + sel_idx]); -#else - input[bid_base + sel_idx]; -#endif + local_data[i] = __builtin_nontemporal_load(&input[bid_base + sel_idx]); output[entry] = local_data[i]; } else { local_data[i] = 0; @@ -80,7 +75,7 @@ __global__ void index_select_scalar_cumsum_kernel( } if (base_entry < num_entries) { #pragma unroll - for (int i = 0; i < VEC; ++i) { + for (int i = 0; i < ENTRIES_PER_THREAD; ++i) { const int entry = base_entry + i; if (entry < num_entries) { output_cumsum[entry] = local_data[i]; @@ -91,7 +86,7 @@ __global__ void index_select_scalar_cumsum_kernel( } if (num_entries_per_block > 0) { - inclusive_sum_scan_kernel( + inclusive_sum_scan_kernel( local_data, bs_temp_storage, block_flags, @@ -105,7 +100,7 @@ __global__ void index_select_scalar_cumsum_kernel( if (base_entry < num_entries) { #pragma unroll - for (int i = 0; i < VEC; ++i) { + for (int i = 0; i < ENTRIES_PER_THREAD; ++i) { const int entry = base_entry + i; if (entry < num_entries) { output_cumsum[entry] = local_data[i]; @@ -269,8 +264,8 @@ class KeyedJaggedIndexSelectDim1GPUOp const int MAX_CUMSUM_ENTRIES_PER_BLOCK = 256; int grid_size = 0; #ifdef USE_ROCM - constexpr int VEC = 4; - const int ENTRIES_PER_BLOCK = MAX_CUMSUM_ENTRIES_PER_BLOCK * VEC; + constexpr int ENTRIES_PER_THREAD = 4; + const int ENTRIES_PER_BLOCK = MAX_CUMSUM_ENTRIES_PER_BLOCK * ENTRIES_PER_THREAD; grid_size = (num_output_lengths + ENTRIES_PER_BLOCK - 1) / ENTRIES_PER_BLOCK; #else From 6083e78ba7c4f950d24943806860340e31012b29 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 17 Dec 2025 18:59:57 +0000 Subject: [PATCH 11/16] Added back VEC templateparameter --- .../keyed_jagged_index_select_dim1.cu | 280 ++++++++++-------- 1 file changed, 159 insertions(+), 121 deletions(-) diff --git a/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu b/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu index f7ebe5d4b6..6645100f2e 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu @@ -18,7 +18,8 @@ template < typename index_t, typename acc_t, int NUM_THREADS_PER_BLOCK, - int MAX_ENTRIES_PER_BLOCK> + int MAX_ENTRIES_PER_BLOCK, + int ENTRIES_PER_THREAD> __global__ void index_select_scalar_cumsum_kernel( pta::PackedTensorAccessor32 output, pta::PackedTensorAccessor32 output_cumsum, @@ -36,76 +37,74 @@ __global__ void index_select_scalar_cumsum_kernel( // ROCm path #ifdef USE_ROCM - // 4 indices/entries per thread - constexpr int ENTRIES_PER_THREAD = 4; - const int output_batch_size = indices.size(0); - const int num_entries = num_batches * output_batch_size; - const bool multi_block = gridDim.x > 1; - const int block_entries = blockIdx.x == gridDim.x - 1 - ? last_block_num_entries - : MAX_ENTRIES_PER_BLOCK; - const int block_entry_start = blockIdx.x * MAX_ENTRIES_PER_BLOCK; - const int remaining_entries = num_entries - block_entry_start; - const int num_entries_per_block = remaining_entries > 0 - ? (remaining_entries < block_entries ? remaining_entries : block_entries) - : 0; - - const int base_entry = block_entry_start + threadIdx.x * ENTRIES_PER_THREAD; - acc_t local_data[ENTRIES_PER_THREAD]; - -#pragma unroll - for (int i = 0; i < ENTRIES_PER_THREAD; ++i) { - const int entry = base_entry + i; - if (entry < num_entries) { - const int bid = entry / output_batch_size; - const int idx_in_batch = entry - bid * output_batch_size; + const int output_batch_size = indices.size(0); + const int num_entries = num_batches * output_batch_size; + const bool multi_block = gridDim.x > 1; + const int block_entries = blockIdx.x == gridDim.x - 1 + ? last_block_num_entries + : MAX_ENTRIES_PER_BLOCK; + const int block_entry_start = blockIdx.x * MAX_ENTRIES_PER_BLOCK; + const int remaining_entries = num_entries - block_entry_start; + const int num_entries_per_block = remaining_entries > 0 + ? (remaining_entries < block_entries ? remaining_entries : block_entries) + : 0; + + const int base_entry = block_entry_start + threadIdx.x * ENTRIES_PER_THREAD; + acc_t local_data[ENTRIES_PER_THREAD]; + + #pragma unroll + for (int i = 0; i < ENTRIES_PER_THREAD; ++i) { + const int entry = base_entry + i; + if (entry < num_entries) { + const int bid = entry / output_batch_size; + const int idx_in_batch = entry - bid * output_batch_size; const int bid_base = bid * input_batch_size; const index_t sel_idx = indices[idx_in_batch]; - local_data[i] = __builtin_nontemporal_load(&input[bid_base + sel_idx]); - output[entry] = local_data[i]; - } else { - local_data[i] = 0; - } - } - - // Faster path for single block - if (!multi_block) { - if (num_entries_per_block > 0) { - BlockScan(bs_temp_storage).InclusiveSum(local_data, local_data); - } - if (base_entry < num_entries) { -#pragma unroll - for (int i = 0; i < ENTRIES_PER_THREAD; ++i) { - const int entry = base_entry + i; - if (entry < num_entries) { - output_cumsum[entry] = local_data[i]; - } - } - } - return; - } - - if (num_entries_per_block > 0) { - inclusive_sum_scan_kernel( - local_data, - bs_temp_storage, - block_flags, - block_sums, - &block_prefix, - num_entries_per_block, - blockIdx.x, - multi_block, - 1); - } - - if (base_entry < num_entries) { -#pragma unroll - for (int i = 0; i < ENTRIES_PER_THREAD; ++i) { - const int entry = base_entry + i; - if (entry < num_entries) { - output_cumsum[entry] = local_data[i]; - } - } + local_data[i] = __builtin_nontemporal_load(&input[bid_base + sel_idx]); + output[entry] = local_data[i]; + } else { + local_data[i] = 0; + } + } + + // Faster path for single block + if (!multi_block) { + if (num_entries_per_block > 0) { + BlockScan(bs_temp_storage).InclusiveSum(local_data, local_data); + } + if (base_entry < num_entries) { + #pragma unroll + for (int i = 0; i < ENTRIES_PER_THREAD; ++i) { + const int entry = base_entry + i; + if (entry < num_entries) { + output_cumsum[entry] = local_data[i]; + } + } + } + return; + } + + if (num_entries_per_block > 0) { + inclusive_sum_scan_kernel( + local_data, + bs_temp_storage, + block_flags, + block_sums, + &block_prefix, + num_entries_per_block, + blockIdx.x, + multi_block, + 1); + } + + if (base_entry < num_entries) { + #pragma unroll + for (int i = 0; i < ENTRIES_PER_THREAD; ++i) { + const int entry = base_entry + i; + if (entry < num_entries) { + output_cumsum[entry] = local_data[i]; + } + } } #else // CUDA path @@ -145,7 +144,7 @@ __global__ void index_select_scalar_cumsum_kernel( } #endif } - + template < typename scalar_t, typename index_t, @@ -262,15 +261,19 @@ class KeyedJaggedIndexSelectDim1GPUOp const int num_batches = lengths.numel() / batch_size; const int num_output_lengths = num_batches * indices.numel(); const int MAX_CUMSUM_ENTRIES_PER_BLOCK = 256; - int grid_size = 0; #ifdef USE_ROCM - constexpr int ENTRIES_PER_THREAD = 4; - const int ENTRIES_PER_BLOCK = MAX_CUMSUM_ENTRIES_PER_BLOCK * ENTRIES_PER_THREAD; - grid_size = (num_output_lengths + ENTRIES_PER_BLOCK - 1) / - ENTRIES_PER_BLOCK; + const int num_entries_per_thread[] = {4, 2, 1}; + int entries_per_thread = 1; + for (int i : num_entries_per_thread) { + if (indices.numel() % i == 0) { + entries_per_thread = i; + break; + } + } #else + const int entries_per_thread = 1; const int ENTRIES_PER_BLOCK = MAX_CUMSUM_ENTRIES_PER_BLOCK; - grid_size = cuda_calc_xblock_count( + auto grid_size = cuda_calc_xblock_count( num_output_lengths, MAX_CUMSUM_ENTRIES_PER_BLOCK); #endif @@ -279,52 +282,87 @@ class KeyedJaggedIndexSelectDim1GPUOp Tensor output_lengths = at::empty({num_batches * indices.numel()}, lengths.options()); - Tensor block_flags, block_sums; - if (grid_size > 1) { - block_flags = at::zeros({grid_size}, lengths.options().dtype(at::kInt)); - block_sums = at::empty({grid_size}, output_offsets.options()); - } - // Do index select and cumsum - AT_DISPATCH_INDEX_TYPES( - lengths.scalar_type(), "index_select_scalar_cumsum_wrapper_1", [&] { - using length_t = index_t; - AT_DISPATCH_INDEX_TYPES( - offsets.scalar_type(), - "index_select_scalar_cumsum_wrapper_2", - [&] { - using offset_t = index_t; - AT_DISPATCH_INDEX_TYPES( - indices.scalar_type(), - "index_select_scalar_cumsum_wrapper_3", - [&] { - FBGEMM_LAUNCH_KERNEL( - (index_select_scalar_cumsum_kernel< - length_t, - index_t, - offset_t, - MAX_CUMSUM_ENTRIES_PER_BLOCK, - ENTRIES_PER_BLOCK>), - grid_size, - MAX_CUMSUM_ENTRIES_PER_BLOCK, - 0, - at::cuda::getCurrentCUDAStream(), - PTA_B(output_lengths, length_t, 1, 32), - PTA_B(output_offsets, offset_t, 1, 32), - PTA_B(lengths, length_t, 1, 32), - PTA_B(indices, index_t, 1, 32), - num_batches, - batch_size, - grid_size == 0 - ? 0 - : num_output_lengths - - ENTRIES_PER_BLOCK * (grid_size - 1), - grid_size > 1 ? block_flags.data_ptr() : nullptr, - grid_size > 1 ? block_sums.data_ptr() - : nullptr); - }); - }); + auto dispatch_cumsum = [&](auto vec_tag, auto grid_calc) { + constexpr int ENTRIES_PER_THREAD = decltype(vec_tag)::value; + constexpr int ENTRIES_PER_BLOCK = + MAX_CUMSUM_ENTRIES_PER_BLOCK * ENTRIES_PER_THREAD; + const auto grid_size = grid_calc(ENTRIES_PER_BLOCK); + + Tensor block_flags, block_sums; + if (grid_size > 1) { + block_flags = + at::zeros({grid_size}, lengths.options().dtype(at::kInt)); + block_sums = at::empty({grid_size}, output_offsets.options()); + } + + AT_DISPATCH_INDEX_TYPES( + lengths.scalar_type(), "index_select_scalar_cumsum_wrapper_1", [&] { + using length_t = index_t; + AT_DISPATCH_INDEX_TYPES( + offsets.scalar_type(), + "index_select_scalar_cumsum_wrapper_2", + [&] { + using offset_t = index_t; + AT_DISPATCH_INDEX_TYPES( + indices.scalar_type(), + "index_select_scalar_cumsum_wrapper_3", + [&] { + FBGEMM_LAUNCH_KERNEL( + (index_select_scalar_cumsum_kernel< + length_t, + index_t, + offset_t, + MAX_CUMSUM_ENTRIES_PER_BLOCK, + ENTRIES_PER_BLOCK, + ENTRIES_PER_THREAD>), + grid_size, + MAX_CUMSUM_ENTRIES_PER_BLOCK, + 0, + at::cuda::getCurrentCUDAStream(), + PTA_B(output_lengths, length_t, 1, 32), + PTA_B(output_offsets, offset_t, 1, 32), + PTA_B(lengths, length_t, 1, 32), + PTA_B(indices, index_t, 1, 32), + num_batches, + batch_size, + grid_size == 0 + ? 0 + : num_output_lengths - + ENTRIES_PER_BLOCK * (grid_size - 1), + grid_size > 1 + ? block_flags.data_ptr() + : nullptr, + grid_size > 1 + ? block_sums.data_ptr() + : nullptr); + }); + }); + }); + }; + +#ifdef USE_ROCM + auto rocm_grid = [&](int entries_per_block) { + return (num_output_lengths + entries_per_block - 1) / entries_per_block; + }; + switch (entries_per_thread) { + case 4: + dispatch_cumsum(std::integral_constant{}, rocm_grid); + break; + case 2: + dispatch_cumsum(std::integral_constant{}, rocm_grid); + break; + default: + dispatch_cumsum(std::integral_constant{}, rocm_grid); + break; + } +#else + dispatch_cumsum( + std::integral_constant{}, + [&](int entries_per_block) { + return cuda_calc_xblock_count(num_output_lengths, entries_per_block); }); +#endif const int64_t num_outputs = (selected_lengths_sum.has_value()) ? selected_lengths_sum.value().guard_int(__FILE__, __LINE__) @@ -623,4 +661,4 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { DISPATCH_TO_CUDA( "keyed_jagged_index_select_dim1", fbgemm_gpu::keyed_jagged_index_select_dim_1_gpu); -} +} \ No newline at end of file From 1a754605c04b26c3126a74324f9157b6d0360846 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 17 Dec 2025 19:02:52 +0000 Subject: [PATCH 12/16] Fixed formatting --- .../src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu b/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu index 6645100f2e..c46b44b84d 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu @@ -144,7 +144,7 @@ __global__ void index_select_scalar_cumsum_kernel( } #endif } - + template < typename scalar_t, typename index_t, @@ -661,4 +661,4 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { DISPATCH_TO_CUDA( "keyed_jagged_index_select_dim1", fbgemm_gpu::keyed_jagged_index_select_dim_1_gpu); -} \ No newline at end of file +} From 2d160b2921b8705785cc8398c36428fb08684fe3 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 18 Dec 2025 16:25:18 +0000 Subject: [PATCH 13/16] Removed volatile and added back spin-lock --- fbgemm_gpu/include/fbgemm_gpu/utils/inclusive_sum_scan.cuh | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/inclusive_sum_scan.cuh b/fbgemm_gpu/include/fbgemm_gpu/utils/inclusive_sum_scan.cuh index 032760e21c..837807136a 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/inclusive_sum_scan.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/inclusive_sum_scan.cuh @@ -83,7 +83,10 @@ __inline__ __device__ void inclusive_sum_scan_kernel( if (is_last_thread) { scalar_t block_prev_local = 0; if (block_id != 0) { - volatile int* flags = block_flags; + // Spin wait for the previous block to write the sum value + while (atomicAdd(&block_flags[block_id - 1], 0) < signal) + ; + // Get sum from the previous block *block_prev = block_prev_local = block_sums[block_id - 1]; } From 8fe70ea1535e6e70b670aa2d291e952a173874de Mon Sep 17 00:00:00 2001 From: root Date: Thu, 22 Jan 2026 20:47:27 +0000 Subject: [PATCH 14/16] Fixed grid_size is 0 and added early return --- .../jagged_tensor_ops/keyed_jagged_index_select_dim1.cu | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu b/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu index c46b44b84d..74917093e4 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu @@ -289,6 +289,9 @@ class KeyedJaggedIndexSelectDim1GPUOp MAX_CUMSUM_ENTRIES_PER_BLOCK * ENTRIES_PER_THREAD; const auto grid_size = grid_calc(ENTRIES_PER_BLOCK); + if (grid_size == 0) + return; + Tensor block_flags, block_sums; if (grid_size > 1) { block_flags = @@ -326,10 +329,8 @@ class KeyedJaggedIndexSelectDim1GPUOp PTA_B(indices, index_t, 1, 32), num_batches, batch_size, - grid_size == 0 - ? 0 - : num_output_lengths - - ENTRIES_PER_BLOCK * (grid_size - 1), + num_output_lengths - + ENTRIES_PER_BLOCK * (grid_size - 1), grid_size > 1 ? block_flags.data_ptr() : nullptr, From 200de28902ff8c4edf9feee83d5c0c82751e2bb8 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 22 Jan 2026 23:09:58 +0000 Subject: [PATCH 15/16] Cleaner separation between ROCm and CUDA paths --- .../keyed_jagged_index_select_dim1.cu | 82 ++++++++++++++----- 1 file changed, 60 insertions(+), 22 deletions(-) diff --git a/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu b/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu index 74917093e4..a2e545f8e6 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu @@ -271,8 +271,7 @@ class KeyedJaggedIndexSelectDim1GPUOp } } #else - const int entries_per_thread = 1; - const int ENTRIES_PER_BLOCK = MAX_CUMSUM_ENTRIES_PER_BLOCK; + constexpr int ENTRIES_PER_THREAD = 1; auto grid_size = cuda_calc_xblock_count( num_output_lengths, MAX_CUMSUM_ENTRIES_PER_BLOCK); #endif @@ -281,21 +280,22 @@ class KeyedJaggedIndexSelectDim1GPUOp at::empty({num_batches * indices.numel()}, offsets.options()); Tensor output_lengths = at::empty({num_batches * indices.numel()}, lengths.options()); + Tensor block_flags, block_sums; - // Do index select and cumsum - auto dispatch_cumsum = [&](auto vec_tag, auto grid_calc) { +#ifdef USE_ROCM + // ROCm path + auto dispatch_cumsum = [&](auto vec_tag) { constexpr int ENTRIES_PER_THREAD = decltype(vec_tag)::value; constexpr int ENTRIES_PER_BLOCK = MAX_CUMSUM_ENTRIES_PER_BLOCK * ENTRIES_PER_THREAD; - const auto grid_size = grid_calc(ENTRIES_PER_BLOCK); + const auto grid_size = + (num_output_lengths + ENTRIES_PER_BLOCK - 1) / ENTRIES_PER_BLOCK; if (grid_size == 0) return; - Tensor block_flags, block_sums; if (grid_size > 1) { - block_flags = - at::zeros({grid_size}, lengths.options().dtype(at::kInt)); + block_flags = at::zeros({grid_size}, lengths.options().dtype(at::kInt)); block_sums = at::empty({grid_size}, output_offsets.options()); } @@ -342,32 +342,70 @@ class KeyedJaggedIndexSelectDim1GPUOp }); }; -#ifdef USE_ROCM - auto rocm_grid = [&](int entries_per_block) { - return (num_output_lengths + entries_per_block - 1) / entries_per_block; - }; switch (entries_per_thread) { case 4: - dispatch_cumsum(std::integral_constant{}, rocm_grid); + dispatch_cumsum(std::integral_constant{}); break; case 2: - dispatch_cumsum(std::integral_constant{}, rocm_grid); + dispatch_cumsum(std::integral_constant{}); break; default: - dispatch_cumsum(std::integral_constant{}, rocm_grid); + dispatch_cumsum(std::integral_constant{}); break; } #else - dispatch_cumsum( - std::integral_constant{}, - [&](int entries_per_block) { - return cuda_calc_xblock_count(num_output_lengths, entries_per_block); - }); + // CUDA path + if (grid_size > 1) { + block_flags = at::zeros({grid_size}, lengths.options().dtype(at::kInt)); + block_sums = at::empty({grid_size}, output_offsets.options()); + } + + AT_DISPATCH_INDEX_TYPES( + lengths.scalar_type(), "index_select_scalar_cumsum_wrapper_1", [&] { + using length_t = index_t; + AT_DISPATCH_INDEX_TYPES( + offsets.scalar_type(), + "index_select_scalar_cumsum_wrapper_2", + [&] { + using offset_t = index_t; + AT_DISPATCH_INDEX_TYPES( + indices.scalar_type(), + "index_select_scalar_cumsum_wrapper_3", + [&] { + FBGEMM_LAUNCH_KERNEL( + (index_select_scalar_cumsum_kernel< + length_t, + index_t, + offset_t, + MAX_CUMSUM_ENTRIES_PER_BLOCK, + MAX_CUMSUM_ENTRIES_PER_BLOCK, + ENTRIES_PER_THREAD>), + grid_size, + MAX_CUMSUM_ENTRIES_PER_BLOCK, + 0, + at::cuda::getCurrentCUDAStream(), + PTA_B(output_lengths, length_t, 1, 32), + PTA_B(output_offsets, offset_t, 1, 32), + PTA_B(lengths, length_t, 1, 32), + PTA_B(indices, index_t, 1, 32), + num_batches, + batch_size, + num_output_lengths - + MAX_CUMSUM_ENTRIES_PER_BLOCK * (grid_size - 1), + grid_size > 1 + ? block_flags.data_ptr() + : nullptr, + grid_size > 1 + ? block_sums.data_ptr() + : nullptr); + }); + }); + }); #endif const int64_t num_outputs = (selected_lengths_sum.has_value()) - ? selected_lengths_sum.value().guard_int(__FILE__, __LINE__) - : output_offsets[output_offsets.numel() - 1].item(); + ? selected_lengths_sum.value().guard_int(__FILE__, __LINE__) + : output_offsets[output_offsets.numel() - 1].item(); Tensor output = at::empty({num_outputs}, values.options()); Tensor output_weights; if (weights.has_value()) { From af5ccbc7fcdeb6594f4ddb2a0d72bacf915209ef Mon Sep 17 00:00:00 2001 From: root Date: Thu, 22 Jan 2026 23:12:40 +0000 Subject: [PATCH 16/16] Changed variable name grid_size to rocm_grid_size for better clarity --- .../keyed_jagged_index_select_dim1.cu | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu b/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu index a2e545f8e6..bb05c61206 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu @@ -288,15 +288,15 @@ class KeyedJaggedIndexSelectDim1GPUOp constexpr int ENTRIES_PER_THREAD = decltype(vec_tag)::value; constexpr int ENTRIES_PER_BLOCK = MAX_CUMSUM_ENTRIES_PER_BLOCK * ENTRIES_PER_THREAD; - const auto grid_size = + const auto rocm_grid_size = (num_output_lengths + ENTRIES_PER_BLOCK - 1) / ENTRIES_PER_BLOCK; - if (grid_size == 0) + if (rocm_grid_size == 0) return; - if (grid_size > 1) { - block_flags = at::zeros({grid_size}, lengths.options().dtype(at::kInt)); - block_sums = at::empty({grid_size}, output_offsets.options()); + if (rocm_grid_size > 1) { + block_flags = at::zeros({rocm_grid_size}, lengths.options().dtype(at::kInt)); + block_sums = at::empty({rocm_grid_size}, output_offsets.options()); } AT_DISPATCH_INDEX_TYPES( @@ -319,7 +319,7 @@ class KeyedJaggedIndexSelectDim1GPUOp MAX_CUMSUM_ENTRIES_PER_BLOCK, ENTRIES_PER_BLOCK, ENTRIES_PER_THREAD>), - grid_size, + rocm_grid_size, MAX_CUMSUM_ENTRIES_PER_BLOCK, 0, at::cuda::getCurrentCUDAStream(), @@ -330,11 +330,11 @@ class KeyedJaggedIndexSelectDim1GPUOp num_batches, batch_size, num_output_lengths - - ENTRIES_PER_BLOCK * (grid_size - 1), - grid_size > 1 + ENTRIES_PER_BLOCK * (rocm_grid_size - 1), + rocm_grid_size > 1 ? block_flags.data_ptr() : nullptr, - grid_size > 1 + rocm_grid_size > 1 ? block_sums.data_ptr() : nullptr); });