Skip to content

Add SYCL grouped_mm kernel using sycl-tla v0.7#3122

Open
xuhancn wants to merge 7 commits intointel:mainfrom
xuhancn:xpu-grouped-mm
Open

Add SYCL grouped_mm kernel using sycl-tla v0.7#3122
xuhancn wants to merge 7 commits intointel:mainfrom
xuhancn:xpu-grouped-mm

Conversation

@xuhancn
Copy link
Copy Markdown
Contributor

@xuhancn xuhancn commented Mar 24, 2026

Add XPU implementation of _grouped_mm operator using sycl-tla (CUTLASS for Intel GPUs) as the grouped GEMM backend. Supports BF16 inputs with FP32 accumulation and BF16 output across all 4 input modes:

  • 3D x 3D (batched)
  • 2D x 3D (ragged A / MoE pattern)
  • 3D x 2D (ragged B)
  • 2D x 2D (ragged K)

Changes:

  • Upgrade sycl-tla from v0.6 to v0.7
  • Add sycltla GroupedMM kernel in src/ATen/native/xpu/sycltla/
  • Update CMakeLists.txt to include new sycltla sources
  • Add unit tests for all 4 input modes

Authored with Claude.

Add XPU implementation of _grouped_mm operator using sycl-tla (CUTLASS
for Intel GPUs) as the grouped GEMM backend. Supports BF16 inputs with
FP32 accumulation and BF16 output across all 4 input modes:
- 3D x 3D (batched)
- 2D x 3D (ragged A / MoE pattern)
- 3D x 2D (ragged B)
- 2D x 2D (ragged K)

Changes:
- Upgrade sycl-tla from v0.6 to v0.7
- Add sycltla GroupedMM kernel in src/ATen/native/xpu/sycltla/
- Update CMakeLists.txt to include new sycltla sources
- Add unit tests for all 4 input modes
Copilot AI review requested due to automatic review settings March 24, 2026 02:41
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds an XPU backend for torch.nn.functional.grouped_mm using sycl-tla (CUTLASS-on-Intel-GPU) grouped GEMM, along with build wiring and unit tests, and bumps the sycl-tla dependency to v0.7.

Changes:

  • Upgrade sycl-tla FetchContent tag from v0.6 to v0.7.
  • Add a new SYCL-TLA-based BF16 grouped GEMM implementation under src/ATen/native/xpu/sycltla/.
  • Add XPU unit tests intended to cover all 4 grouped_mm input modes.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
test/xpu/test_grouped_mm_xpu.py New XPU unit tests for grouped_mm modes and a larger-size accuracy check.
src/ATen/native/xpu/sycltla/GroupedMM.h Declares the new BF16 grouped_mm entrypoint.
src/ATen/native/xpu/sycltla/GroupedMM.cpp Implements grouped GEMM dispatch/config and handles the 4 input-mode layouts.
src/ATen/CMakeLists.txt Adds the new sycltla sources/headers to the build/install.
cmake/SYCLTLA.cmake Updates sycl-tla version pin to v0.7.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +146 to +153
// C matrix (not used, pass same as D with beta=0)
cutlass::DeviceAllocation<const ElementAccumulator*> ptr_C_device;
ptr_C_device.reset(group_count);
std::vector<const ElementAccumulator*> ptr_c_host(group_count);
for (int i = 0; i < group_count; ++i) {
ptr_c_host[i] =
reinterpret_cast<const ElementAccumulator*>(ptr_d_host[i]);
}
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ptr_c_host is built by reinterpret_cast-ing ElementOutput* (bf16) pointers to ElementAccumulator* (float). Even with beta=0, this is unsafe: the kernel may still issue loads from C (e.g., for predication), and this pointer is mis-typed/misaligned for float. Prefer passing a real C buffer of the correct type, or configure the epilogue/kernel so that C is either unused (and can be nullptr if supported) or has an element type matching the underlying storage type.

Suggested change
// C matrix (not used, pass same as D with beta=0)
cutlass::DeviceAllocation<const ElementAccumulator*> ptr_C_device;
ptr_C_device.reset(group_count);
std::vector<const ElementAccumulator*> ptr_c_host(group_count);
for (int i = 0; i < group_count; ++i) {
ptr_c_host[i] =
reinterpret_cast<const ElementAccumulator*>(ptr_d_host[i]);
}
// C matrix (not used when beta=0). Pass nullptr for all groups to avoid
// mis-typed/misaligned pointers while keeping the expected element type.
cutlass::DeviceAllocation<const ElementAccumulator*> ptr_C_device;
ptr_C_device.reset(group_count);
std::vector<const ElementAccumulator*> ptr_c_host(group_count, nullptr);

Copilot uses AI. Check for mistakes.
stride_a_vec.push_back(
cutlass::make_cute_packed_stride(StrideA{}, {M_g, K, 1}));
stride_b_vec.push_back(
cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}));
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The packed stride for B appears to use {N, K, 1}, but throughout this function B is treated as a contiguous matrix with logical shape K×N (e.g., pointer arithmetic uses g * K * N and the 2D×2D path offsets by k_start * N, implying rows are K and columns are N). Using {N, K, 1} will cause the kernel to interpret the memory layout incorrectly when K != N. Update the stride construction to match the actual logical shape of B (e.g., {K, N, 1} / {K_g, N, 1} / {K, N_g, 1} as appropriate).

Suggested change
cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}));
cutlass::make_cute_packed_stride(StrideB{}, {K, N, 1}));

Copilot uses AI. Check for mistakes.
Comment on lines +260 to +300
// Read offs tensor to host
std::vector<int32_t> offs_host;
if (offs.has_value()) {
auto offs_cpu = offs->cpu().contiguous();
const int32_t* p = offs_cpu.data_ptr<int32_t>();
offs_host.assign(p, p + offs_cpu.numel());
}

if (!a_is_2d && !b_is_2d) {
// 3D x 3D: regular batched MM
group_count = mat_a.size(0);
int M = mat_a.size(1);
int N = mat_b.size(2);
int K = mat_a.size(2);

for (int g = 0; g < group_count; ++g) {
problem_sizes.push_back({M, N, K});
ptr_a_vec.push_back(base_a + g * M * K);
ptr_b_vec.push_back(base_b + g * K * N);
ptr_d_vec.push_back(base_d + g * M * N);
stride_a_vec.push_back(
cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}));
stride_b_vec.push_back(
cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}));
stride_d_vec.push_back(
cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1}));
}
} else if (a_is_2d && !b_is_2d) {
// 2D x 3D: ragged A (MoE pattern)
group_count = mat_b.size(0);
int K = mat_a.size(1);
int N = mat_b.size(2);
int64_t out_stride_row = out.size(1);

int32_t row_start = 0;
for (int g = 0; g < group_count; ++g) {
int32_t row_end = offs_host[g];
int M_g = row_end - row_start;

problem_sizes.push_back({M_g, N, K});
ptr_a_vec.push_back(base_a + row_start * K);
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Several code paths index offs_host[g] (e.g., 2D×3D, 3D×2D, 2D×2D) without validating that offs is present and has the expected length/type. If offs is omitted or has fewer than group_count entries, this will read out of bounds and can crash. Add explicit TORCH_CHECK(offs.has_value(), ...), validate offs dtype is int32, and check offs_host.size() matches the required group count before using it.

Copilot uses AI. Check for mistakes.
Comment on lines +234 to +243
void bf16bf16_grouped_mm(
at::Tensor mat_a,
at::Tensor mat_b,
std::optional<at::Tensor> offs,
std::optional<at::Tensor> bias,
at::Tensor& out) {
// Ensure contiguous
mat_a = mat_a.contiguous();
mat_b = mat_b.contiguous();

Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bias argument is currently ignored. If callers can pass a bias tensor to this kernel, the result will be silently incorrect. Either implement bias support in the epilogue (preferred if the API expects it) or explicitly reject it with a TORCH_CHECK(!bias.has_value(), ...) to avoid silent misbehavior.

Copilot uses AI. Check for mistakes.
Comment on lines +221 to +227
status = gemm_op.run();
if (status != cutlass::Status::kSuccess) {
return status;
}

compat::wait();
return cutlass::Status::kSuccess;
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compat::wait() forces a device-wide synchronization after every grouped GEMM launch, which can severely limit throughput (especially if grouped_mm is called repeatedly in a model). If the wait is only there to ensure temporary allocations (workspace / device arrays) stay alive until completion, consider integrating with PyTorch’s XPU stream/queue semantics (e.g., enqueue on the current queue and tie allocation lifetimes to an event) so the call can remain asynchronous.

Copilot uses AI. Check for mistakes.
Comment on lines +75 to +87
a = torch.randn(m * n_groups, k, device=device, dtype=dtype)
b = torch.randn(n_groups, k, n, device=device, dtype=dtype)

a.requires_grad_(True)
b.requires_grad_(True)

offs = torch.arange(
m, n_groups * m + 1, m, device=device, dtype=torch.int32
)

f = F.grouped_mm
out = f(a, b.transpose(-2, -1), offs=offs, out_dtype=dtype)
gO = torch.rand_like(out)
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this test, b is created with shape (n_groups, k, n) and then transposed before calling F.grouped_mm, but grouped_mm_helper() assumes b iterates into per-group matrices shaped (n, k) (so that b.t() becomes (k, n)). As written, iterating b yields (k, n) and b.t() becomes (n, k), which makes torch.mm(a, b.t()) dimensionally inconsistent and should fail. Adjust the test to either (a) construct b as (n_groups, n, k) and keep the transpose in the grouped_mm call, or (b) keep b as (n_groups, k, n) and remove the transpose + update the helper reference accordingly.

Copilot uses AI. Check for mistakes.
Comment on lines +106 to +114
a = torch.randn(n_groups, m, k, device=device, dtype=dtype)
b = torch.randn(n_groups, k, n, device=device, dtype=dtype)

a.requires_grad_(True)
b.requires_grad_(True)

f = F.grouped_mm
out = f(a, b.transpose(-2, -1), out_dtype=dtype)
gO = torch.rand_like(out)
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue as the 2D×3D test: b is created as (n_groups, k, n) but passed to grouped_mm_helper() which expects per-group b slices shaped (n, k) (so it can use b.t() as (k, n)). With the current shapes/transposes, the reference torch.mm(a, b.t()) will have incompatible dimensions. Align b’s construction and the transpose passed into F.grouped_mm with the helper’s expected layout.

Copilot uses AI. Check for mistakes.
Copilot AI review requested due to automatic review settings March 24, 2026 03:26
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 5 out of 5 changed files in this pull request and generated 14 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +320 to +322
stride_b_vec.push_back(
cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}));
stride_d_vec.push_back(
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the 2D×3D path, the packed stride for B is built with {N, K, 1} even though the input B (after transpose(-2, -1) in the reference tests) is K×N per group. This will misinterpret B’s memory layout. Please build the B stride with the correct (K, N) extents/order for LayoutB.

Copilot uses AI. Check for mistakes.
stride_a_vec.push_back(
cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}));
stride_b_vec.push_back(
cutlass::make_cute_packed_stride(StrideB{}, {N_g, K, 1}));
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the 3D×2D (ragged B) path, b_slice has shape K×N_g (sliced from a K×(sum N_g) tensor), but the packed stride for B is constructed with {N_g, K, 1}. This swapped stride will interpret the contiguous K×N_g buffer incorrectly. Please build the stride with (K, N_g) extents/order to match the actual layout.

Suggested change
cutlass::make_cute_packed_stride(StrideB{}, {N_g, K, 1}));
cutlass::make_cute_packed_stride(StrideB{}, {K, N_g, 1}));

Copilot uses AI. Check for mistakes.
stride_a_vec.push_back(
cutlass::make_cute_packed_stride(StrideA{}, {M, K_g, 1}));
stride_b_vec.push_back(
cutlass::make_cute_packed_stride(StrideB{}, {N, K_g, 1}));
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the 2D×2D (ragged K) path, B points into a contiguous K_total×N buffer (base_b + k_start * N), but the packed stride for B is built with {N, K_g, 1}. This swaps dimensions and will read B incorrectly. Please construct the B stride with (K_g, N) extents/order for the chosen layout.

Suggested change
cutlass::make_cute_packed_stride(StrideB{}, {N, K_g, 1}));
cutlass::make_cute_packed_stride(StrideB{}, {K_g, N, 1}));

Copilot uses AI. Check for mistakes.
Comment on lines +336 to +340
int32_t col_start = 0;
for (int g = 0; g < group_count; ++g) {
int32_t col_end = offs_host[g];
int N_g = col_end - col_start;

Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 3D×2D path indexes offs_host[g] but does not check that offs is present and sized for group_count. Please add a TORCH_CHECK to validate offs (presence and length) before this loop to avoid out-of-bounds reads.

Copilot uses AI. Check for mistakes.
Comment on lines +392 to +396
int32_t k_start = 0;
for (int g = 0; g < group_count; ++g) {
int32_t k_end = offs_host[g];
int K_g = k_end - k_start;

Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 2D×2D (ragged K) path uses offs_host[g] but doesn’t validate that offs was provided and has the expected number of elements. Please add a TORCH_CHECK for offs.has_value() and offs_host.size() == group_count (or appropriate invariant) before looping.

Copilot uses AI. Check for mistakes.
Comment on lines +17 to +19
TEST_XPU = torch.xpu.is_available()


Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TEST_XPU = torch.xpu.is_available() is defined but not used anywhere in this test module. Please remove it (or use it for conditional skips) to avoid dead code.

Suggested change
TEST_XPU = torch.xpu.is_available()

Copilot uses AI. Check for mistakes.
Comment on lines +249 to +255
void bf16bf16_grouped_mm(
at::Tensor mat_a,
at::Tensor mat_b,
std::optional<at::Tensor> offs,
std::optional<at::Tensor> bias,
at::Tensor& out) {
// Ensure contiguous
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This new entrypoint (at::xpu::detail::bf16bf16_grouped_mm) does not appear to be referenced anywhere else in the repository (a repo-wide search only finds its declaration/definition). Unless it’s wired up via additional changes not shown here, the new kernel won’t be invoked by torch.nn.functional.grouped_mm and the added tests may end up exercising a fallback path instead. Please add the XPU dispatch/registration hook that calls into this function (or otherwise ensure the operator routes here) so the implementation is actually used.

Copilot uses AI. Check for mistakes.
Comment on lines +156 to +161
// C matrix (not used, pass same as D with beta=0)
cutlass::DeviceAllocation<const ElementAccumulator*> ptr_C_device;
ptr_C_device.reset(group_count);
std::vector<const ElementAccumulator*> ptr_c_host(group_count);
for (int i = 0; i < group_count; ++i) {
ptr_c_host[i] = reinterpret_cast<const ElementAccumulator*>(ptr_d_host[i]);
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ptr_C_device is populated by reinterpret-casting the bf16 output pointers (ptr_d_host) to const ElementAccumulator* (float). Even with beta=0, CUTLASS epilogues may still read C, and the type/align mismatch can lead to incorrect reads or misaligned access. Please pass a correctly-typed C buffer (or use an epilogue/kernel configuration that does not require C, or set ElementC to match D when C is unused).

Suggested change
// C matrix (not used, pass same as D with beta=0)
cutlass::DeviceAllocation<const ElementAccumulator*> ptr_C_device;
ptr_C_device.reset(group_count);
std::vector<const ElementAccumulator*> ptr_c_host(group_count);
for (int i = 0; i < group_count; ++i) {
ptr_c_host[i] = reinterpret_cast<const ElementAccumulator*>(ptr_d_host[i]);
// C matrix (not used, pass same as D with beta=0). Use ElementOutput to match D's element type.
cutlass::DeviceAllocation<const ElementOutput*> ptr_C_device;
ptr_C_device.reset(group_count);
std::vector<const ElementOutput*> ptr_c_host(group_count);
for (int i = 0; i < group_count; ++i) {
ptr_c_host[i] = ptr_d_host[i];

Copilot uses AI. Check for mistakes.
#include <cutlass/util/device_memory.h>
#include <cutlass/util/packed_stride.hpp>

#include <iostream>
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#include <iostream> is not used in this file; please remove it to keep dependencies minimal.

Suggested change
#include <iostream>

Copilot uses AI. Check for mistakes.
Comment on lines +70 to +71
offs = torch.arange(m, n_groups * m + 1, m, device=device, dtype=torch.int32)

Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These XPU grouped_mm tests don’t cover the zero-size group cases (repeated offsets) that the CUDA grouped_gemm tests exercise (e.g., setting offs[0] = offs[1]). Since ragged inputs are a primary use case, please add coverage for zero-sized groups (and ensure backward is handled/skipped accordingly) for the offs-based modes.

Copilot uses AI. Check for mistakes.
xuhancn added 2 commits March 24, 2026 12:05
- Add allow_xpu=True to instantiate_device_type_tests so tests are
  discovered
- Remove incorrect b.transpose() in 2d_3d and 3d_3d tests since
  grouped_mm expects B in (K,N) form, not transposed
- Add transpose_b parameter to helper to handle both B layouts
Copilot AI review requested due to automatic review settings March 24, 2026 05:54
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +276 to +280
std::vector<int32_t> offs_host;
if (offs.has_value()) {
auto offs_cpu = offs->cpu().contiguous();
const int32_t* p = offs_cpu.data_ptr<int32_t>();
offs_host.assign(p, p + offs_cpu.numel());
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

offs_host is populated only when offs.has_value(), but later code assumes it contains at least group_count elements (e.g., offs_host[g] in ragged branches). Add explicit validation (e.g., require offs for ragged modes and check offs_host.size()/dtype) before any indexing to avoid out-of-bounds reads.

Copilot uses AI. Check for mistakes.
Comment on lines +17 to +21
void bf16bf16_grouped_mm(
at::Tensor mat_a,
at::Tensor mat_b,
std::optional<at::Tensor> offs,
std::optional<at::Tensor> bias,
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This header uses std::optional in the public declaration but doesn’t include <optional>. Relying on transitive includes from ATen/ATen.h can be brittle; include <optional> explicitly here to ensure the header is self-contained.

Copilot uses AI. Check for mistakes.
xuhancn added 2 commits March 24, 2026 15:08
The sycltla kernel is built as a separate shared library only when
USE_SYCLTLA is ON. Add a wrapper in regular native sources that guards
the sycltla call, matching the flash_attention pattern. This prevents
undefined symbol errors when sycltla is not available.
Copilot AI review requested due to automatic review settings March 24, 2026 07:26
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 7 out of 7 changed files in this pull request and generated 7 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +277 to +281
if (offs.has_value()) {
auto offs_cpu = offs->cpu().contiguous();
const int32_t* p = offs_cpu.data_ptr<int32_t>();
offs_host.assign(p, p + offs_cpu.numel());
}
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

offs is copied to CPU into offs_host, but the code later indexes offs_host[g] in the ragged modes without validating that offs was provided and has the expected length. Please add TORCH_CHECK(offs.has_value(), ...), TORCH_CHECK(offs_host.size() == group_count, ...), and validate monotonicity / bounds to prevent out-of-bounds reads and negative group sizes.

Copilot uses AI. Check for mistakes.
Comment on lines +156 to +162
// C matrix (not used, pass same as D with beta=0)
cutlass::DeviceAllocation<const ElementAccumulator*> ptr_C_device;
ptr_C_device.reset(group_count);
std::vector<const ElementAccumulator*> ptr_c_host(group_count);
for (int i = 0; i < group_count; ++i) {
ptr_c_host[i] = reinterpret_cast<const ElementAccumulator*>(ptr_d_host[i]);
}
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ptr_c_host[i] is built by reinterpret_cast-ing a BF16 D pointer to const float* for C. This is type-unsafe and can lead to misaligned/invalid reads if the epilogue ever reads C (even when beta==0 at runtime). Prefer passing a correctly typed C buffer (float), or use an epilogue configuration that does not require C when beta=0 / supports C=nullptr.

Suggested change
// C matrix (not used, pass same as D with beta=0)
cutlass::DeviceAllocation<const ElementAccumulator*> ptr_C_device;
ptr_C_device.reset(group_count);
std::vector<const ElementAccumulator*> ptr_c_host(group_count);
for (int i = 0; i < group_count; ++i) {
ptr_c_host[i] = reinterpret_cast<const ElementAccumulator*>(ptr_d_host[i]);
}
// C matrix (not used when beta = 0). Pass nullptrs to avoid type-unsafe casts.
cutlass::DeviceAllocation<const ElementAccumulator*> ptr_C_device;
ptr_C_device.reset(group_count);
std::vector<const ElementAccumulator*> ptr_c_host(group_count, nullptr);

Copilot uses AI. Check for mistakes.
Comment on lines +143 to +147
cutlass::DeviceAllocation<typename ProblemShape::UnderlyingProblemShape>
problem_sizes_device;
problem_sizes_device.reset(group_count);
problem_sizes_device.copy_from_host(problem_sizes_host.data());

Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

run_grouped_gemm allocates and uploads multiple device buffers on every call (problem sizes / pointer arrays / strides). For workloads that call grouped_mm frequently, this per-call allocation/copy overhead can dominate. Consider caching/reusing these allocations (or using a workspace allocator) where possible.

Copilot uses AI. Check for mistakes.
#include <cutlass/util/device_memory.h>
#include <cutlass/util/packed_stride.hpp>

#include <iostream>
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#include <iostream> appears unused in this translation unit; consider removing it to reduce compile time and avoid unnecessary dependencies.

Suggested change
#include <iostream>

Copilot uses AI. Check for mistakes.
Comment on lines +90 to +94
offs = torch.arange(m, n_groups * m + 1, m, device=device, dtype=torch.int32)

f = F.grouped_mm
out = f(a, b, offs=offs, out_dtype=dtype)
gO = torch.rand_like(out)
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CUDA grouped_mm tests cover the edge case where offs creates zero-size groups (e.g., offs[0] = offs[1]) and verify behavior/grad handling. This new XPU test suite doesn’t exercise that scenario for the offs-based tests, so regressions around empty groups could slip through. Please add at least one zero-size group case for the ragged tests (2D×3D / 3D×2D, and optionally 2D×2D).

Copilot uses AI. Check for mistakes.
Comment on lines +271 to +274
auto* base_a = reinterpret_cast<const ElementA*>(mat_a.data_ptr());
auto* base_b = reinterpret_cast<const ElementB*>(mat_b.data_ptr());
auto* base_d = reinterpret_cast<ElementOutput*>(out.data_ptr());

Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implementation uses out.data_ptr() and manual pointer arithmetic in several branches (e.g., 3D×3D, 2D×3D, 2D×2D), which assumes out is contiguous with the expected layout. Add an explicit TORCH_CHECK(out.is_contiguous(), ...) (or write into a contiguous temp and copy_ back) so non-contiguous out tensors don’t produce incorrect writes.

Copilot uses AI. Check for mistakes.
Comment on lines +27 to +29
TEST_XPU = torch.xpu.is_available()


Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TEST_XPU is defined but never used. Please remove it to avoid dead code in the test module.

Suggested change
TEST_XPU = torch.xpu.is_available()

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants