Skip to content

Conversation

@aryaman-gupta
Copy link
Contributor

This PR optimizes the performance on ROCm of the group_index_select_or_add_2d_kernel kernel on tables with small embedding dimensions (i.e., num_cols).

For tables with small embedding dimensions, the code is refactored to process multiple rows within the same warp. Two files are changed:

  1. fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp - The calculation of the warp_offsets is changed in the host-side code.
  2. fbgemm_gpu/src/sparse_ops/sparse_group_index.cu - The group_index_select_or_add_2d_kernel kernel is modified to process multiple rows within a warp for small embedding dimensions.

@meta-codesync
Copy link
Contributor

meta-codesync bot commented Dec 16, 2025

@q10 has imported this pull request. If you are a Meta employee, you can view this in D89316371.

return GROUP_INDEX_SELECT_COLS_PER_WARP;
}

int get_group_index_select_unroll_factor() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this constexpr?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not trivially, unfortunately. We would need to move the function definition to sparse_ops.h where it is declared, and move the GROUP_INDEX_SELECT_UNROLL_FACTOR variable there as well.

Happy to do this refactor if you would suggest.

if (num_cols < COLS_PER_WARP && num_cols >= UNROLL_FACTOR) {
// Need to ensure that [member_id] and [member_warp_id] are calculated correctly
// for the small embedding dimension path below
int rows_per_warp = COLS_PER_WARP / num_cols;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using const and auto

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in bbdc17d

I left input, output, indices and idx as they were previously defined

q10 pushed a commit to q10/FBGEMM that referenced this pull request Jan 28, 2026
…rate codepath for small embedding dimensions (pytorch#5233)

Summary:
This PR optimizes the performance on ROCm of the `group_index_select_or_add_2d_kernel` kernel on tables with small embedding dimensions (i.e., `num_cols`).

For tables with small embedding dimensions, the code is refactored to process multiple rows within the same warp. Two files are changed:

1. `fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp` - The calculation of the `warp_offsets` is changed in the host-side code.
2. `fbgemm_gpu/src/sparse_ops/sparse_group_index.cu` - The `group_index_select_or_add_2d_kernel` kernel is modified to process multiple rows within a warp for small embedding dimensions.


Test Plan:
```
cd ~/fbsource/fbcode/ai_codesign/nonprod/bensonma415/scripts/D89316371

bash run_benchmark.sh amd/mi300 2>&1 | pastry
```

https://docs.google.com/document/d/12ywiAQhA3eZqcIwUyc8_CQinwZJy_WDD8B4_43G8Gr8/edit?tab=t.0#heading=h.5row1qfol66k

Reviewed By: echen4096

Differential Revision: D89316371

Pulled By: q10
cyyever pushed a commit to cyyever/FBGEMM that referenced this pull request Jan 29, 2026
…rate codepath for small embedding dimensions (pytorch#5233)

Summary:
X-link: https://github.com/facebookresearch/FBGEMM/pull/2340

This PR optimizes the performance on ROCm of the `group_index_select_or_add_2d_kernel` kernel on tables with small embedding dimensions (i.e., `num_cols`).

For tables with small embedding dimensions, the code is refactored to process multiple rows within the same warp. Two files are changed:

1. `fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp` - The calculation of the `warp_offsets` is changed in the host-side code.
2. `fbgemm_gpu/src/sparse_ops/sparse_group_index.cu` - The `group_index_select_or_add_2d_kernel` kernel is modified to process multiple rows within a warp for small embedding dimensions.

Pull Request resolved: pytorch#5233

Test Plan:
```
cd ~/fbsource/fbcode/ai_codesign/nonprod/bensonma415/scripts/D89316371

bash run_benchmark.sh amd/mi300 2>&1 | pastry
```

https://docs.google.com/document/d/12ywiAQhA3eZqcIwUyc8_CQinwZJy_WDD8B4_43G8Gr8/edit?tab=t.0#heading=h.5row1qfol66k

Reviewed By: echen4096

Differential Revision: D89316371

Pulled By: q10

fbshipit-source-id: 2742965773c92ff96419fa5978c93ca6d23dbed4
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants