Skip to content
Closed
45 changes: 43 additions & 2 deletions fbgemm_gpu/include/fbgemm_gpu/utils/inclusive_sum_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,52 @@ __inline__ __device__ void inclusive_sum_scan_kernel(
const int block_id,
const bool is_multi_block,
const int signal) {
// ROCm path
#ifdef USE_ROCM
// Perform scan within a block
cub::BlockScan<scalar_t, NUM_THREADS_PER_BLOCK>(temp_storage)
.InclusiveSum(arr, arr);

// Perform stream scan across blocks
// Perform scan across blocks
if (is_multi_block) {
const bool is_last_thread =
threadIdx.x == (num_entries_per_block - 1) / ITEMS_PER_THREAD;
// The thread that holds the last entry in the block does synchronization
if (is_last_thread) {
scalar_t block_prev_local = 0;
if (block_id != 0) {
// Spin wait for the previous block to write the sum value
while (atomicAdd(&block_flags[block_id - 1], 0) < signal)
;

// Get sum from the previous block
*block_prev = block_prev_local = block_sums[block_id - 1];
}

// Write sum to global memory for the next block to consume
const int scope = (num_entries_per_block - 1) % ITEMS_PER_THREAD;
block_sums[block_id] = block_prev_local + arr[scope];
__threadfence();
// Set a flag to notify the next block
atomicExch(&block_flags[block_id], signal);
}

__syncthreads();

if (block_id != 0) {
scalar_t block_prev_local = *block_prev;
for (int i = 0; i < ITEMS_PER_THREAD; i++) {
arr[i] += block_prev_local;
}
}
}
#else
// CUDA path
// Perform scan across blocks
cub::BlockScan<scalar_t, NUM_THREADS_PER_BLOCK>(temp_storage)
.InclusiveSum(arr, arr);

// Perform scan across blocks
if (is_multi_block) {
// The thread that holds the last entry in the block does synchronization
if (threadIdx.x == (num_entries_per_block - 1) / ITEMS_PER_THREAD) {
Expand Down Expand Up @@ -104,6 +145,6 @@ __inline__ __device__ void inclusive_sum_scan_kernel(
}
}
}
#endif
}

} // namespace fbgemm_gpu
259 changes: 213 additions & 46 deletions fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
* LICENSE file in the root directory of this source tree.
*/

#include <type_traits>
#include "common.cuh"

using Tensor = at::Tensor;
Expand All @@ -17,7 +18,8 @@ template <
typename index_t,
typename acc_t,
int NUM_THREADS_PER_BLOCK,
int MAX_ENTRIES_PER_BLOCK>
int MAX_ENTRIES_PER_BLOCK,
int ENTRIES_PER_THREAD>
__global__ void index_select_scalar_cumsum_kernel(
pta::PackedTensorAccessor32<scalar_t, 1, at::RestrictPtrTraits> output,
pta::PackedTensorAccessor32<acc_t, 1, at::RestrictPtrTraits> output_cumsum,
Expand All @@ -31,6 +33,81 @@ __global__ void index_select_scalar_cumsum_kernel(
acc_t* block_sums) {
typedef cub::BlockScan<acc_t, NUM_THREADS_PER_BLOCK> BlockScan;
__shared__ typename BlockScan::TempStorage bs_temp_storage;
__shared__ acc_t block_prefix;

// ROCm path
#ifdef USE_ROCM
const int output_batch_size = indices.size(0);
const int num_entries = num_batches * output_batch_size;
const bool multi_block = gridDim.x > 1;
const int block_entries = blockIdx.x == gridDim.x - 1
? last_block_num_entries
: MAX_ENTRIES_PER_BLOCK;
const int block_entry_start = blockIdx.x * MAX_ENTRIES_PER_BLOCK;
const int remaining_entries = num_entries - block_entry_start;
const int num_entries_per_block = remaining_entries > 0
? (remaining_entries < block_entries ? remaining_entries : block_entries)
: 0;

const int base_entry = block_entry_start + threadIdx.x * ENTRIES_PER_THREAD;
acc_t local_data[ENTRIES_PER_THREAD];

#pragma unroll
for (int i = 0; i < ENTRIES_PER_THREAD; ++i) {
const int entry = base_entry + i;
if (entry < num_entries) {
const int bid = entry / output_batch_size;
const int idx_in_batch = entry - bid * output_batch_size;
const int bid_base = bid * input_batch_size;
const index_t sel_idx = indices[idx_in_batch];
local_data[i] = __builtin_nontemporal_load(&input[bid_base + sel_idx]);
output[entry] = local_data[i];
} else {
local_data[i] = 0;
}
}

// Faster path for single block
if (!multi_block) {
if (num_entries_per_block > 0) {
BlockScan(bs_temp_storage).InclusiveSum(local_data, local_data);
}
if (base_entry < num_entries) {
#pragma unroll
for (int i = 0; i < ENTRIES_PER_THREAD; ++i) {
const int entry = base_entry + i;
if (entry < num_entries) {
output_cumsum[entry] = local_data[i];
}
}
}
return;
}

if (num_entries_per_block > 0) {
inclusive_sum_scan_kernel<acc_t, ENTRIES_PER_THREAD, NUM_THREADS_PER_BLOCK>(
local_data,
bs_temp_storage,
block_flags,
block_sums,
&block_prefix,
num_entries_per_block,
blockIdx.x,
multi_block,
1);
}

if (base_entry < num_entries) {
#pragma unroll
for (int i = 0; i < ENTRIES_PER_THREAD; ++i) {
const int entry = base_entry + i;
if (entry < num_entries) {
output_cumsum[entry] = local_data[i];
}
}
}
#else
// CUDA path
__shared__ acc_t smem[MAX_ENTRIES_PER_BLOCK];
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
const int output_batch_size = indices.size(0);
Expand Down Expand Up @@ -65,6 +142,7 @@ __global__ void index_select_scalar_cumsum_kernel(
if (tid < num_batches * output_batch_size) {
output_cumsum[tid] = *local_data;
}
#endif
}

template <
Expand Down Expand Up @@ -183,62 +261,151 @@ class KeyedJaggedIndexSelectDim1GPUOp
const int num_batches = lengths.numel() / batch_size;
const int num_output_lengths = num_batches * indices.numel();
const int MAX_CUMSUM_ENTRIES_PER_BLOCK = 256;
#ifdef USE_ROCM
const int num_entries_per_thread[] = {4, 2, 1};
int entries_per_thread = 1;
for (int i : num_entries_per_thread) {
if (indices.numel() % i == 0) {
entries_per_thread = i;
break;
}
}
#else
constexpr int ENTRIES_PER_THREAD = 1;
auto grid_size = cuda_calc_xblock_count(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

based on the changes below, where is this used?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This commit 200de28 fixed this issue. Now variables entries_per_thread and ENTRIES_PER_BLOCK are used only in ROCm path. This commit also provides a cleaner code separation between ROCm and CUDA paths.

num_output_lengths, MAX_CUMSUM_ENTRIES_PER_BLOCK);
#endif

Tensor output_offsets =
at::empty({num_batches * indices.numel()}, offsets.options());
Tensor output_lengths =
at::empty({num_batches * indices.numel()}, lengths.options());

Tensor block_flags, block_sums;
if (grid_size > 1) {
block_flags = at::zeros({grid_size}, lengths.options().dtype(at::kInt));
block_sums = at::empty({grid_size}, output_offsets.options());
}

// Do index select and cumsum
AT_DISPATCH_INDEX_TYPES(
lengths.scalar_type(), "index_select_scalar_cumsum_wrapper_1", [&] {
using length_t = index_t;
AT_DISPATCH_INDEX_TYPES(
offsets.scalar_type(),
"index_select_scalar_cumsum_wrapper_2",
[&] {
using offset_t = index_t;
AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(),
"index_select_scalar_cumsum_wrapper_3",
[&] {
FBGEMM_LAUNCH_KERNEL(
(index_select_scalar_cumsum_kernel<
length_t,
index_t,
offset_t,
MAX_CUMSUM_ENTRIES_PER_BLOCK,
MAX_CUMSUM_ENTRIES_PER_BLOCK>),
grid_size,
MAX_CUMSUM_ENTRIES_PER_BLOCK,
0,
at::cuda::getCurrentCUDAStream(),
PTA_B(output_lengths, length_t, 1, 32),
PTA_B(output_offsets, offset_t, 1, 32),
PTA_B(lengths, length_t, 1, 32),
PTA_B(indices, index_t, 1, 32),
num_batches,
batch_size,
num_output_lengths -
MAX_CUMSUM_ENTRIES_PER_BLOCK * (grid_size - 1),
grid_size > 1 ? block_flags.data_ptr<int>() : nullptr,
grid_size > 1 ? block_sums.data_ptr<offset_t>()
: nullptr);
});
});
});
#ifdef USE_ROCM
// ROCm path
auto dispatch_cumsum = [&](auto vec_tag) {
constexpr int ENTRIES_PER_THREAD = decltype(vec_tag)::value;
constexpr int ENTRIES_PER_BLOCK =
MAX_CUMSUM_ENTRIES_PER_BLOCK * ENTRIES_PER_THREAD;
const auto rocm_grid_size =
(num_output_lengths + ENTRIES_PER_BLOCK - 1) / ENTRIES_PER_BLOCK;

if (rocm_grid_size == 0)
return;

if (rocm_grid_size > 1) {
block_flags = at::zeros({rocm_grid_size}, lengths.options().dtype(at::kInt));
block_sums = at::empty({rocm_grid_size}, output_offsets.options());
}

AT_DISPATCH_INDEX_TYPES(
lengths.scalar_type(), "index_select_scalar_cumsum_wrapper_1", [&] {
using length_t = index_t;
AT_DISPATCH_INDEX_TYPES(
offsets.scalar_type(),
"index_select_scalar_cumsum_wrapper_2",
[&] {
using offset_t = index_t;
AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(),
"index_select_scalar_cumsum_wrapper_3",
[&] {
FBGEMM_LAUNCH_KERNEL(
(index_select_scalar_cumsum_kernel<
length_t,
index_t,
offset_t,
MAX_CUMSUM_ENTRIES_PER_BLOCK,
ENTRIES_PER_BLOCK,
ENTRIES_PER_THREAD>),
rocm_grid_size,
MAX_CUMSUM_ENTRIES_PER_BLOCK,
0,
at::cuda::getCurrentCUDAStream(),
PTA_B(output_lengths, length_t, 1, 32),
PTA_B(output_offsets, offset_t, 1, 32),
PTA_B(lengths, length_t, 1, 32),
PTA_B(indices, index_t, 1, 32),
num_batches,
batch_size,
num_output_lengths -
ENTRIES_PER_BLOCK * (rocm_grid_size - 1),
rocm_grid_size > 1
? block_flags.data_ptr<int>()
: nullptr,
rocm_grid_size > 1
? block_sums.data_ptr<offset_t>()
: nullptr);
});
});
});
};

switch (entries_per_thread) {
case 4:
dispatch_cumsum(std::integral_constant<int, 4>{});
break;
case 2:
dispatch_cumsum(std::integral_constant<int, 2>{});
break;
default:
dispatch_cumsum(std::integral_constant<int, 1>{});
break;
}
#else
// CUDA path
if (grid_size > 1) {
block_flags = at::zeros({grid_size}, lengths.options().dtype(at::kInt));
block_sums = at::empty({grid_size}, output_offsets.options());
}

AT_DISPATCH_INDEX_TYPES(
lengths.scalar_type(), "index_select_scalar_cumsum_wrapper_1", [&] {
using length_t = index_t;
AT_DISPATCH_INDEX_TYPES(
offsets.scalar_type(),
"index_select_scalar_cumsum_wrapper_2",
[&] {
using offset_t = index_t;
AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(),
"index_select_scalar_cumsum_wrapper_3",
[&] {
FBGEMM_LAUNCH_KERNEL(
(index_select_scalar_cumsum_kernel<
length_t,
index_t,
offset_t,
MAX_CUMSUM_ENTRIES_PER_BLOCK,
MAX_CUMSUM_ENTRIES_PER_BLOCK,
ENTRIES_PER_THREAD>),
grid_size,
MAX_CUMSUM_ENTRIES_PER_BLOCK,
0,
at::cuda::getCurrentCUDAStream(),
PTA_B(output_lengths, length_t, 1, 32),
PTA_B(output_offsets, offset_t, 1, 32),
PTA_B(lengths, length_t, 1, 32),
PTA_B(indices, index_t, 1, 32),
num_batches,
batch_size,
num_output_lengths -
MAX_CUMSUM_ENTRIES_PER_BLOCK * (grid_size - 1),
grid_size > 1
? block_flags.data_ptr<int>()
: nullptr,
grid_size > 1
? block_sums.data_ptr<offset_t>()
: nullptr);
});
});
});
#endif

const int64_t num_outputs = (selected_lengths_sum.has_value())
? selected_lengths_sum.value().guard_int(__FILE__, __LINE__)
: output_offsets[output_offsets.numel() - 1].item<int64_t>();
? selected_lengths_sum.value().guard_int(__FILE__, __LINE__)
: output_offsets[output_offsets.numel() - 1].item<int64_t>();
Tensor output = at::empty({num_outputs}, values.options());
Tensor output_weights;
if (weights.has_value()) {
Expand Down
Loading