Add SYCL grouped_mm kernel using sycl-tla v0.7#3122
Add SYCL grouped_mm kernel using sycl-tla v0.7#3122xuhancn wants to merge 7 commits intointel:mainfrom
Conversation
Add XPU implementation of _grouped_mm operator using sycl-tla (CUTLASS for Intel GPUs) as the grouped GEMM backend. Supports BF16 inputs with FP32 accumulation and BF16 output across all 4 input modes: - 3D x 3D (batched) - 2D x 3D (ragged A / MoE pattern) - 3D x 2D (ragged B) - 2D x 2D (ragged K) Changes: - Upgrade sycl-tla from v0.6 to v0.7 - Add sycltla GroupedMM kernel in src/ATen/native/xpu/sycltla/ - Update CMakeLists.txt to include new sycltla sources - Add unit tests for all 4 input modes
There was a problem hiding this comment.
Pull request overview
Adds an XPU backend for torch.nn.functional.grouped_mm using sycl-tla (CUTLASS-on-Intel-GPU) grouped GEMM, along with build wiring and unit tests, and bumps the sycl-tla dependency to v0.7.
Changes:
- Upgrade sycl-tla FetchContent tag from
v0.6tov0.7. - Add a new SYCL-TLA-based BF16 grouped GEMM implementation under
src/ATen/native/xpu/sycltla/. - Add XPU unit tests intended to cover all 4 grouped_mm input modes.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
test/xpu/test_grouped_mm_xpu.py |
New XPU unit tests for grouped_mm modes and a larger-size accuracy check. |
src/ATen/native/xpu/sycltla/GroupedMM.h |
Declares the new BF16 grouped_mm entrypoint. |
src/ATen/native/xpu/sycltla/GroupedMM.cpp |
Implements grouped GEMM dispatch/config and handles the 4 input-mode layouts. |
src/ATen/CMakeLists.txt |
Adds the new sycltla sources/headers to the build/install. |
cmake/SYCLTLA.cmake |
Updates sycl-tla version pin to v0.7. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| // C matrix (not used, pass same as D with beta=0) | ||
| cutlass::DeviceAllocation<const ElementAccumulator*> ptr_C_device; | ||
| ptr_C_device.reset(group_count); | ||
| std::vector<const ElementAccumulator*> ptr_c_host(group_count); | ||
| for (int i = 0; i < group_count; ++i) { | ||
| ptr_c_host[i] = | ||
| reinterpret_cast<const ElementAccumulator*>(ptr_d_host[i]); | ||
| } |
There was a problem hiding this comment.
ptr_c_host is built by reinterpret_cast-ing ElementOutput* (bf16) pointers to ElementAccumulator* (float). Even with beta=0, this is unsafe: the kernel may still issue loads from C (e.g., for predication), and this pointer is mis-typed/misaligned for float. Prefer passing a real C buffer of the correct type, or configure the epilogue/kernel so that C is either unused (and can be nullptr if supported) or has an element type matching the underlying storage type.
| // C matrix (not used, pass same as D with beta=0) | |
| cutlass::DeviceAllocation<const ElementAccumulator*> ptr_C_device; | |
| ptr_C_device.reset(group_count); | |
| std::vector<const ElementAccumulator*> ptr_c_host(group_count); | |
| for (int i = 0; i < group_count; ++i) { | |
| ptr_c_host[i] = | |
| reinterpret_cast<const ElementAccumulator*>(ptr_d_host[i]); | |
| } | |
| // C matrix (not used when beta=0). Pass nullptr for all groups to avoid | |
| // mis-typed/misaligned pointers while keeping the expected element type. | |
| cutlass::DeviceAllocation<const ElementAccumulator*> ptr_C_device; | |
| ptr_C_device.reset(group_count); | |
| std::vector<const ElementAccumulator*> ptr_c_host(group_count, nullptr); |
| stride_a_vec.push_back( | ||
| cutlass::make_cute_packed_stride(StrideA{}, {M_g, K, 1})); | ||
| stride_b_vec.push_back( | ||
| cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1})); |
There was a problem hiding this comment.
The packed stride for B appears to use {N, K, 1}, but throughout this function B is treated as a contiguous matrix with logical shape K×N (e.g., pointer arithmetic uses g * K * N and the 2D×2D path offsets by k_start * N, implying rows are K and columns are N). Using {N, K, 1} will cause the kernel to interpret the memory layout incorrectly when K != N. Update the stride construction to match the actual logical shape of B (e.g., {K, N, 1} / {K_g, N, 1} / {K, N_g, 1} as appropriate).
| cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1})); | |
| cutlass::make_cute_packed_stride(StrideB{}, {K, N, 1})); |
| // Read offs tensor to host | ||
| std::vector<int32_t> offs_host; | ||
| if (offs.has_value()) { | ||
| auto offs_cpu = offs->cpu().contiguous(); | ||
| const int32_t* p = offs_cpu.data_ptr<int32_t>(); | ||
| offs_host.assign(p, p + offs_cpu.numel()); | ||
| } | ||
|
|
||
| if (!a_is_2d && !b_is_2d) { | ||
| // 3D x 3D: regular batched MM | ||
| group_count = mat_a.size(0); | ||
| int M = mat_a.size(1); | ||
| int N = mat_b.size(2); | ||
| int K = mat_a.size(2); | ||
|
|
||
| for (int g = 0; g < group_count; ++g) { | ||
| problem_sizes.push_back({M, N, K}); | ||
| ptr_a_vec.push_back(base_a + g * M * K); | ||
| ptr_b_vec.push_back(base_b + g * K * N); | ||
| ptr_d_vec.push_back(base_d + g * M * N); | ||
| stride_a_vec.push_back( | ||
| cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1})); | ||
| stride_b_vec.push_back( | ||
| cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1})); | ||
| stride_d_vec.push_back( | ||
| cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1})); | ||
| } | ||
| } else if (a_is_2d && !b_is_2d) { | ||
| // 2D x 3D: ragged A (MoE pattern) | ||
| group_count = mat_b.size(0); | ||
| int K = mat_a.size(1); | ||
| int N = mat_b.size(2); | ||
| int64_t out_stride_row = out.size(1); | ||
|
|
||
| int32_t row_start = 0; | ||
| for (int g = 0; g < group_count; ++g) { | ||
| int32_t row_end = offs_host[g]; | ||
| int M_g = row_end - row_start; | ||
|
|
||
| problem_sizes.push_back({M_g, N, K}); | ||
| ptr_a_vec.push_back(base_a + row_start * K); |
There was a problem hiding this comment.
Several code paths index offs_host[g] (e.g., 2D×3D, 3D×2D, 2D×2D) without validating that offs is present and has the expected length/type. If offs is omitted or has fewer than group_count entries, this will read out of bounds and can crash. Add explicit TORCH_CHECK(offs.has_value(), ...), validate offs dtype is int32, and check offs_host.size() matches the required group count before using it.
| void bf16bf16_grouped_mm( | ||
| at::Tensor mat_a, | ||
| at::Tensor mat_b, | ||
| std::optional<at::Tensor> offs, | ||
| std::optional<at::Tensor> bias, | ||
| at::Tensor& out) { | ||
| // Ensure contiguous | ||
| mat_a = mat_a.contiguous(); | ||
| mat_b = mat_b.contiguous(); | ||
|
|
There was a problem hiding this comment.
The bias argument is currently ignored. If callers can pass a bias tensor to this kernel, the result will be silently incorrect. Either implement bias support in the epilogue (preferred if the API expects it) or explicitly reject it with a TORCH_CHECK(!bias.has_value(), ...) to avoid silent misbehavior.
| status = gemm_op.run(); | ||
| if (status != cutlass::Status::kSuccess) { | ||
| return status; | ||
| } | ||
|
|
||
| compat::wait(); | ||
| return cutlass::Status::kSuccess; |
There was a problem hiding this comment.
compat::wait() forces a device-wide synchronization after every grouped GEMM launch, which can severely limit throughput (especially if grouped_mm is called repeatedly in a model). If the wait is only there to ensure temporary allocations (workspace / device arrays) stay alive until completion, consider integrating with PyTorch’s XPU stream/queue semantics (e.g., enqueue on the current queue and tie allocation lifetimes to an event) so the call can remain asynchronous.
| a = torch.randn(m * n_groups, k, device=device, dtype=dtype) | ||
| b = torch.randn(n_groups, k, n, device=device, dtype=dtype) | ||
|
|
||
| a.requires_grad_(True) | ||
| b.requires_grad_(True) | ||
|
|
||
| offs = torch.arange( | ||
| m, n_groups * m + 1, m, device=device, dtype=torch.int32 | ||
| ) | ||
|
|
||
| f = F.grouped_mm | ||
| out = f(a, b.transpose(-2, -1), offs=offs, out_dtype=dtype) | ||
| gO = torch.rand_like(out) |
There was a problem hiding this comment.
In this test, b is created with shape (n_groups, k, n) and then transposed before calling F.grouped_mm, but grouped_mm_helper() assumes b iterates into per-group matrices shaped (n, k) (so that b.t() becomes (k, n)). As written, iterating b yields (k, n) and b.t() becomes (n, k), which makes torch.mm(a, b.t()) dimensionally inconsistent and should fail. Adjust the test to either (a) construct b as (n_groups, n, k) and keep the transpose in the grouped_mm call, or (b) keep b as (n_groups, k, n) and remove the transpose + update the helper reference accordingly.
| a = torch.randn(n_groups, m, k, device=device, dtype=dtype) | ||
| b = torch.randn(n_groups, k, n, device=device, dtype=dtype) | ||
|
|
||
| a.requires_grad_(True) | ||
| b.requires_grad_(True) | ||
|
|
||
| f = F.grouped_mm | ||
| out = f(a, b.transpose(-2, -1), out_dtype=dtype) | ||
| gO = torch.rand_like(out) |
There was a problem hiding this comment.
Same issue as the 2D×3D test: b is created as (n_groups, k, n) but passed to grouped_mm_helper() which expects per-group b slices shaped (n, k) (so it can use b.t() as (k, n)). With the current shapes/transposes, the reference torch.mm(a, b.t()) will have incompatible dimensions. Align b’s construction and the transpose passed into F.grouped_mm with the helper’s expected layout.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 5 out of 5 changed files in this pull request and generated 14 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| stride_b_vec.push_back( | ||
| cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1})); | ||
| stride_d_vec.push_back( |
There was a problem hiding this comment.
In the 2D×3D path, the packed stride for B is built with {N, K, 1} even though the input B (after transpose(-2, -1) in the reference tests) is K×N per group. This will misinterpret B’s memory layout. Please build the B stride with the correct (K, N) extents/order for LayoutB.
| stride_a_vec.push_back( | ||
| cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1})); | ||
| stride_b_vec.push_back( | ||
| cutlass::make_cute_packed_stride(StrideB{}, {N_g, K, 1})); |
There was a problem hiding this comment.
In the 3D×2D (ragged B) path, b_slice has shape K×N_g (sliced from a K×(sum N_g) tensor), but the packed stride for B is constructed with {N_g, K, 1}. This swapped stride will interpret the contiguous K×N_g buffer incorrectly. Please build the stride with (K, N_g) extents/order to match the actual layout.
| cutlass::make_cute_packed_stride(StrideB{}, {N_g, K, 1})); | |
| cutlass::make_cute_packed_stride(StrideB{}, {K, N_g, 1})); |
| stride_a_vec.push_back( | ||
| cutlass::make_cute_packed_stride(StrideA{}, {M, K_g, 1})); | ||
| stride_b_vec.push_back( | ||
| cutlass::make_cute_packed_stride(StrideB{}, {N, K_g, 1})); |
There was a problem hiding this comment.
In the 2D×2D (ragged K) path, B points into a contiguous K_total×N buffer (base_b + k_start * N), but the packed stride for B is built with {N, K_g, 1}. This swaps dimensions and will read B incorrectly. Please construct the B stride with (K_g, N) extents/order for the chosen layout.
| cutlass::make_cute_packed_stride(StrideB{}, {N, K_g, 1})); | |
| cutlass::make_cute_packed_stride(StrideB{}, {K_g, N, 1})); |
| int32_t col_start = 0; | ||
| for (int g = 0; g < group_count; ++g) { | ||
| int32_t col_end = offs_host[g]; | ||
| int N_g = col_end - col_start; | ||
|
|
There was a problem hiding this comment.
The 3D×2D path indexes offs_host[g] but does not check that offs is present and sized for group_count. Please add a TORCH_CHECK to validate offs (presence and length) before this loop to avoid out-of-bounds reads.
| int32_t k_start = 0; | ||
| for (int g = 0; g < group_count; ++g) { | ||
| int32_t k_end = offs_host[g]; | ||
| int K_g = k_end - k_start; | ||
|
|
There was a problem hiding this comment.
The 2D×2D (ragged K) path uses offs_host[g] but doesn’t validate that offs was provided and has the expected number of elements. Please add a TORCH_CHECK for offs.has_value() and offs_host.size() == group_count (or appropriate invariant) before looping.
| TEST_XPU = torch.xpu.is_available() | ||
|
|
||
|
|
There was a problem hiding this comment.
TEST_XPU = torch.xpu.is_available() is defined but not used anywhere in this test module. Please remove it (or use it for conditional skips) to avoid dead code.
| TEST_XPU = torch.xpu.is_available() |
| void bf16bf16_grouped_mm( | ||
| at::Tensor mat_a, | ||
| at::Tensor mat_b, | ||
| std::optional<at::Tensor> offs, | ||
| std::optional<at::Tensor> bias, | ||
| at::Tensor& out) { | ||
| // Ensure contiguous |
There was a problem hiding this comment.
This new entrypoint (at::xpu::detail::bf16bf16_grouped_mm) does not appear to be referenced anywhere else in the repository (a repo-wide search only finds its declaration/definition). Unless it’s wired up via additional changes not shown here, the new kernel won’t be invoked by torch.nn.functional.grouped_mm and the added tests may end up exercising a fallback path instead. Please add the XPU dispatch/registration hook that calls into this function (or otherwise ensure the operator routes here) so the implementation is actually used.
| // C matrix (not used, pass same as D with beta=0) | ||
| cutlass::DeviceAllocation<const ElementAccumulator*> ptr_C_device; | ||
| ptr_C_device.reset(group_count); | ||
| std::vector<const ElementAccumulator*> ptr_c_host(group_count); | ||
| for (int i = 0; i < group_count; ++i) { | ||
| ptr_c_host[i] = reinterpret_cast<const ElementAccumulator*>(ptr_d_host[i]); |
There was a problem hiding this comment.
ptr_C_device is populated by reinterpret-casting the bf16 output pointers (ptr_d_host) to const ElementAccumulator* (float). Even with beta=0, CUTLASS epilogues may still read C, and the type/align mismatch can lead to incorrect reads or misaligned access. Please pass a correctly-typed C buffer (or use an epilogue/kernel configuration that does not require C, or set ElementC to match D when C is unused).
| // C matrix (not used, pass same as D with beta=0) | |
| cutlass::DeviceAllocation<const ElementAccumulator*> ptr_C_device; | |
| ptr_C_device.reset(group_count); | |
| std::vector<const ElementAccumulator*> ptr_c_host(group_count); | |
| for (int i = 0; i < group_count; ++i) { | |
| ptr_c_host[i] = reinterpret_cast<const ElementAccumulator*>(ptr_d_host[i]); | |
| // C matrix (not used, pass same as D with beta=0). Use ElementOutput to match D's element type. | |
| cutlass::DeviceAllocation<const ElementOutput*> ptr_C_device; | |
| ptr_C_device.reset(group_count); | |
| std::vector<const ElementOutput*> ptr_c_host(group_count); | |
| for (int i = 0; i < group_count; ++i) { | |
| ptr_c_host[i] = ptr_d_host[i]; |
| #include <cutlass/util/device_memory.h> | ||
| #include <cutlass/util/packed_stride.hpp> | ||
|
|
||
| #include <iostream> |
There was a problem hiding this comment.
#include <iostream> is not used in this file; please remove it to keep dependencies minimal.
| #include <iostream> |
| offs = torch.arange(m, n_groups * m + 1, m, device=device, dtype=torch.int32) | ||
|
|
There was a problem hiding this comment.
These XPU grouped_mm tests don’t cover the zero-size group cases (repeated offsets) that the CUDA grouped_gemm tests exercise (e.g., setting offs[0] = offs[1]). Since ragged inputs are a primary use case, please add coverage for zero-sized groups (and ensure backward is handled/skipped accordingly) for the offs-based modes.
- Add allow_xpu=True to instantiate_device_type_tests so tests are discovered - Remove incorrect b.transpose() in 2d_3d and 3d_3d tests since grouped_mm expects B in (K,N) form, not transposed - Add transpose_b parameter to helper to handle both B layouts
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| std::vector<int32_t> offs_host; | ||
| if (offs.has_value()) { | ||
| auto offs_cpu = offs->cpu().contiguous(); | ||
| const int32_t* p = offs_cpu.data_ptr<int32_t>(); | ||
| offs_host.assign(p, p + offs_cpu.numel()); |
There was a problem hiding this comment.
offs_host is populated only when offs.has_value(), but later code assumes it contains at least group_count elements (e.g., offs_host[g] in ragged branches). Add explicit validation (e.g., require offs for ragged modes and check offs_host.size()/dtype) before any indexing to avoid out-of-bounds reads.
| void bf16bf16_grouped_mm( | ||
| at::Tensor mat_a, | ||
| at::Tensor mat_b, | ||
| std::optional<at::Tensor> offs, | ||
| std::optional<at::Tensor> bias, |
There was a problem hiding this comment.
This header uses std::optional in the public declaration but doesn’t include <optional>. Relying on transitive includes from ATen/ATen.h can be brittle; include <optional> explicitly here to ensure the header is self-contained.
The sycltla kernel is built as a separate shared library only when USE_SYCLTLA is ON. Add a wrapper in regular native sources that guards the sycltla call, matching the flash_attention pattern. This prevents undefined symbol errors when sycltla is not available.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 7 out of 7 changed files in this pull request and generated 7 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if (offs.has_value()) { | ||
| auto offs_cpu = offs->cpu().contiguous(); | ||
| const int32_t* p = offs_cpu.data_ptr<int32_t>(); | ||
| offs_host.assign(p, p + offs_cpu.numel()); | ||
| } |
There was a problem hiding this comment.
offs is copied to CPU into offs_host, but the code later indexes offs_host[g] in the ragged modes without validating that offs was provided and has the expected length. Please add TORCH_CHECK(offs.has_value(), ...), TORCH_CHECK(offs_host.size() == group_count, ...), and validate monotonicity / bounds to prevent out-of-bounds reads and negative group sizes.
| // C matrix (not used, pass same as D with beta=0) | ||
| cutlass::DeviceAllocation<const ElementAccumulator*> ptr_C_device; | ||
| ptr_C_device.reset(group_count); | ||
| std::vector<const ElementAccumulator*> ptr_c_host(group_count); | ||
| for (int i = 0; i < group_count; ++i) { | ||
| ptr_c_host[i] = reinterpret_cast<const ElementAccumulator*>(ptr_d_host[i]); | ||
| } |
There was a problem hiding this comment.
ptr_c_host[i] is built by reinterpret_cast-ing a BF16 D pointer to const float* for C. This is type-unsafe and can lead to misaligned/invalid reads if the epilogue ever reads C (even when beta==0 at runtime). Prefer passing a correctly typed C buffer (float), or use an epilogue configuration that does not require C when beta=0 / supports C=nullptr.
| // C matrix (not used, pass same as D with beta=0) | |
| cutlass::DeviceAllocation<const ElementAccumulator*> ptr_C_device; | |
| ptr_C_device.reset(group_count); | |
| std::vector<const ElementAccumulator*> ptr_c_host(group_count); | |
| for (int i = 0; i < group_count; ++i) { | |
| ptr_c_host[i] = reinterpret_cast<const ElementAccumulator*>(ptr_d_host[i]); | |
| } | |
| // C matrix (not used when beta = 0). Pass nullptrs to avoid type-unsafe casts. | |
| cutlass::DeviceAllocation<const ElementAccumulator*> ptr_C_device; | |
| ptr_C_device.reset(group_count); | |
| std::vector<const ElementAccumulator*> ptr_c_host(group_count, nullptr); |
| cutlass::DeviceAllocation<typename ProblemShape::UnderlyingProblemShape> | ||
| problem_sizes_device; | ||
| problem_sizes_device.reset(group_count); | ||
| problem_sizes_device.copy_from_host(problem_sizes_host.data()); | ||
|
|
There was a problem hiding this comment.
run_grouped_gemm allocates and uploads multiple device buffers on every call (problem sizes / pointer arrays / strides). For workloads that call grouped_mm frequently, this per-call allocation/copy overhead can dominate. Consider caching/reusing these allocations (or using a workspace allocator) where possible.
| #include <cutlass/util/device_memory.h> | ||
| #include <cutlass/util/packed_stride.hpp> | ||
|
|
||
| #include <iostream> |
There was a problem hiding this comment.
#include <iostream> appears unused in this translation unit; consider removing it to reduce compile time and avoid unnecessary dependencies.
| #include <iostream> |
| offs = torch.arange(m, n_groups * m + 1, m, device=device, dtype=torch.int32) | ||
|
|
||
| f = F.grouped_mm | ||
| out = f(a, b, offs=offs, out_dtype=dtype) | ||
| gO = torch.rand_like(out) |
There was a problem hiding this comment.
The CUDA grouped_mm tests cover the edge case where offs creates zero-size groups (e.g., offs[0] = offs[1]) and verify behavior/grad handling. This new XPU test suite doesn’t exercise that scenario for the offs-based tests, so regressions around empty groups could slip through. Please add at least one zero-size group case for the ragged tests (2D×3D / 3D×2D, and optionally 2D×2D).
| auto* base_a = reinterpret_cast<const ElementA*>(mat_a.data_ptr()); | ||
| auto* base_b = reinterpret_cast<const ElementB*>(mat_b.data_ptr()); | ||
| auto* base_d = reinterpret_cast<ElementOutput*>(out.data_ptr()); | ||
|
|
There was a problem hiding this comment.
This implementation uses out.data_ptr() and manual pointer arithmetic in several branches (e.g., 3D×3D, 2D×3D, 2D×2D), which assumes out is contiguous with the expected layout. Add an explicit TORCH_CHECK(out.is_contiguous(), ...) (or write into a contiguous temp and copy_ back) so non-contiguous out tensors don’t produce incorrect writes.
| TEST_XPU = torch.xpu.is_available() | ||
|
|
||
|
|
There was a problem hiding this comment.
TEST_XPU is defined but never used. Please remove it to avoid dead code in the test module.
| TEST_XPU = torch.xpu.is_available() |
Add XPU implementation of _grouped_mm operator using sycl-tla (CUTLASS for Intel GPUs) as the grouped GEMM backend. Supports BF16 inputs with FP32 accumulation and BF16 output across all 4 input modes:
Changes:
Authored with Claude.