Skip to content

Add scaled_grouped_mm kernel for XPU (FP8 x FP8 -> BF16)#3172

Open
xuhancn wants to merge 10 commits intointel:mainfrom
xuhancn:xpu-scaled-grouped-mm
Open

Add scaled_grouped_mm kernel for XPU (FP8 x FP8 -> BF16)#3172
xuhancn wants to merge 10 commits intointel:mainfrom
xuhancn:xpu-scaled-grouped-mm

Conversation

@xuhancn
Copy link
Copy Markdown
Contributor

@xuhancn xuhancn commented Mar 24, 2026

Add _scaled_grouped_mm support for Intel XPU using sycl-tla. The kernel
dequantizes FP8 inputs to BF16 with rowwise float32 scale application,
then dispatches to the existing BF16 grouped GEMM sycl-tla kernel.

Supports all 4 input modes: 3D×3D (batched), 2D×3D (ragged A / MoE),
3D×2D (ragged B), and 2D×2D (ragged K).

xuhancn added 8 commits March 24, 2026 10:16
Add XPU implementation of _grouped_mm operator using sycl-tla (CUTLASS
for Intel GPUs) as the grouped GEMM backend. Supports BF16 inputs with
FP32 accumulation and BF16 output across all 4 input modes:
- 3D x 3D (batched)
- 2D x 3D (ragged A / MoE pattern)
- 3D x 2D (ragged B)
- 2D x 2D (ragged K)

Changes:
- Upgrade sycl-tla from v0.6 to v0.7
- Add sycltla GroupedMM kernel in src/ATen/native/xpu/sycltla/
- Update CMakeLists.txt to include new sycltla sources
- Add unit tests for all 4 input modes
- Add allow_xpu=True to instantiate_device_type_tests so tests are
  discovered
- Remove incorrect b.transpose() in 2d_3d and 3d_3d tests since
  grouped_mm expects B in (K,N) form, not transposed
- Add transpose_b parameter to helper to handle both B layouts
The sycltla kernel is built as a separate shared library only when
USE_SYCLTLA is ON. Add a wrapper in regular native sources that guards
the sycltla call, matching the flash_attention pattern. This prevents
undefined symbol errors when sycltla is not available.
Add _scaled_grouped_mm support for Intel XPU using sycl-tla. The kernel
dequantizes FP8 inputs to BF16 with rowwise float32 scale application,
then dispatches to the existing BF16 grouped GEMM sycl-tla kernel.

Supports all 4 input modes: 3D×3D (batched), 2D×3D (ragged A / MoE),
3D×2D (ragged B), and 2D×2D (ragged K).

Files:
- ScaledGroupedMM.{h,cpp}: USE_SYCLTLA-guarded wrapper
- sycltla/ScaledGroupedMM.{h,cpp}: sycl-tla kernel implementation
- test_scaled_grouped_mm_xpu.py: accuracy tests for all 4 modes
Copilot AI review requested due to automatic review settings March 24, 2026 23:39
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds SYCLTLA-backed XPU implementations and tests for grouped GEMM and scaled grouped GEMM (FP8 inputs with rowwise scaling), and wires the new SYCLTLA sources into the build.

Changes:

  • Introduce SYCLTLA kernels for grouped_mm (BF16) and scaled_grouped_mm (FP8→BF16 with scales).
  • Add XPU unit tests covering the 4 supported input modes for both ops, plus a larger “stress” case.
  • Update CMake to build/install new native/xpu/sycltla sources and bump sycl-tla to v0.7.

Reviewed changes

Copilot reviewed 12 out of 12 changed files in this pull request and generated 12 comments.

Show a summary per file
File Description
test/xpu/test_scaled_grouped_mm_xpu.py Adds correctness tests for XPU torch._scaled_grouped_mm across 2D/3D input modes.
test/xpu/test_grouped_mm_xpu.py Adds grouped GEMM tests for XPU using F.grouped_mm, including autograd checks.
src/ATen/native/xpu/sycltla/ScaledGroupedMM.h Declares SYCLTLA scaled grouped MM entrypoint.
src/ATen/native/xpu/sycltla/ScaledGroupedMM.cpp Implements FP8 scaled grouped GEMM via FP8→BF16 dequant + SYCLTLA grouped GEMM.
src/ATen/native/xpu/sycltla/GroupedMM.h Declares SYCLTLA grouped MM entrypoint.
src/ATen/native/xpu/sycltla/GroupedMM.cpp Implements BF16 grouped GEMM via SYCLTLA grouped GEMM.
src/ATen/native/xpu/ScaledGroupedMM.h Exposes native XPU scaled grouped MM API and availability query.
src/ATen/native/xpu/ScaledGroupedMM.cpp Dispatches to SYCLTLA implementation (or errors if not enabled).
src/ATen/native/xpu/GroupedMM.h Exposes native XPU grouped MM API and availability query.
src/ATen/native/xpu/GroupedMM.cpp Dispatches to SYCLTLA implementation (or errors if not enabled).
src/ATen/CMakeLists.txt Adds native/xpu/sycltla/*.cpp to build and installs its headers.
cmake/SYCLTLA.cmake Bumps sycl-tla dependency from v0.6 to v0.7.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +167 to +169
std::vector<const ElementAccumulator*> ptr_c_host(group_count);
for (int i = 0; i < group_count; ++i) {
ptr_c_host[i] = reinterpret_cast<const ElementAccumulator*>(ptr_d_host[i]);
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This passes the D buffer (BF16) as the C pointer by reinterpreting it as const ElementAccumulator* (float). If the epilogue/kernel ever dereferences C (even incidentally), this is type/size-mismatch UB (misaligned reads, out-of-bounds). Prefer one of: (1) change the epilogue/C element type to match ElementOutput so C can alias D safely, (2) allocate a real float C buffer (or a single dummy float buffer) and point all C pointers at valid float memory, or (3) if CUTLASS supports it for this epilogue, pass a null C pointer explicitly when beta==0 and ensure the epilogue does not read it.

Suggested change
std::vector<const ElementAccumulator*> ptr_c_host(group_count);
for (int i = 0; i < group_count; ++i) {
ptr_c_host[i] = reinterpret_cast<const ElementAccumulator*>(ptr_d_host[i]);
// Allocate a small valid ElementAccumulator buffer for C to avoid
// reinterpreting the BF16 D buffer as float (undefined behavior).
cutlass::DeviceAllocation<ElementAccumulator> dummy_C_device;
dummy_C_device.reset(1);
{
ElementAccumulator zero = ElementAccumulator(0);
dummy_C_device.copy_from_host(&zero);
}
std::vector<const ElementAccumulator*> ptr_c_host(group_count);
for (int i = 0; i < group_count; ++i) {
ptr_c_host[i] = dummy_C_device.get();

Copilot uses AI. Check for mistakes.
Comment on lines +156 to +162
// C matrix (not used, pass same as D with beta=0)
cutlass::DeviceAllocation<const ElementAccumulator*> ptr_C_device;
ptr_C_device.reset(group_count);
std::vector<const ElementAccumulator*> ptr_c_host(group_count);
for (int i = 0; i < group_count; ++i) {
ptr_c_host[i] = reinterpret_cast<const ElementAccumulator*>(ptr_d_host[i]);
}
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue as in ScaledGroupedMM.cpp: ptr_C_device is typed as const float* (ElementAccumulator*) but is populated with BF16 output pointers via reinterpret_cast. This is undefined behavior if C is accessed. Fix by providing valid float storage for C (or adjusting the epilogue/C type so C can legally alias D, or passing null C if supported when beta=0).

Suggested change
// C matrix (not used, pass same as D with beta=0)
cutlass::DeviceAllocation<const ElementAccumulator*> ptr_C_device;
ptr_C_device.reset(group_count);
std::vector<const ElementAccumulator*> ptr_c_host(group_count);
for (int i = 0; i < group_count; ++i) {
ptr_c_host[i] = reinterpret_cast<const ElementAccumulator*>(ptr_d_host[i]);
}
// C matrix (not used, pass null pointers with beta=0 to avoid undefined behavior)
cutlass::DeviceAllocation<const ElementAccumulator*> ptr_C_device;
ptr_C_device.reset(group_count);
std::vector<const ElementAccumulator*> ptr_c_host(group_count, nullptr);

Copilot uses AI. Check for mistakes.
const int32_t* p = offs_cpu.data_ptr<int32_t>();
offs_host.assign(p, p + offs_cpu.numel());
}

Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Several branches index offs_host[g] unconditionally, but offs is an std::optional and offs_host remains empty when offs is not provided. This will crash (out-of-bounds) for the ragged modes (2D×3D, 3D×2D, 2D×2D). Add explicit validation in the ragged branches, e.g. TORCH_CHECK(offs.has_value(), ...) and TORCH_CHECK(offs_host.size() == group_count, ...) (or the expected size), before indexing.

Suggested change
// For ragged modes (2D×3D, 3D×2D, 2D×2D), an offs tensor is required.
if (a_is_2d || b_is_2d) {
TORCH_CHECK(
offs.has_value(),
"ScaledGroupedMM ragged modes (2Dx3D, 3Dx2D, 2Dx2D) require a valid 'offs' tensor.");
}

Copilot uses AI. Check for mistakes.
int K = a_bf16.size(1);
int N = b_bf16.size(2);
int64_t out_stride_row = out.size(1);

Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Several branches index offs_host[g] unconditionally, but offs is an std::optional and offs_host remains empty when offs is not provided. This will crash (out-of-bounds) for the ragged modes (2D×3D, 3D×2D, 2D×2D). Add explicit validation in the ragged branches, e.g. TORCH_CHECK(offs.has_value(), ...) and TORCH_CHECK(offs_host.size() == group_count, ...) (or the expected size), before indexing.

Suggested change
TORCH_CHECK(
offs.has_value(),
"ScaledGroupedMM: ragged 2D x 3D mode requires `offs` to be provided.");
TORCH_CHECK(
static_cast<int64_t>(offs_host.size()) >= group_count,
"ScaledGroupedMM: ragged 2D x 3D mode expects offs_host.size() >= group_count, but got offs_host.size() = ",
static_cast<int64_t>(offs_host.size()),
" and group_count = ",
static_cast<int64_t>(group_count),
".");

Copilot uses AI. Check for mistakes.
const int32_t* p = offs_cpu.data_ptr<int32_t>();
offs_host.assign(p, p + offs_cpu.numel());
}

Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue as in ScaledGroupedMM.cpp: ragged paths index into offs_host without asserting that offs was provided and has the expected length. Please add TORCH_CHECK(offs.has_value(), ...) and size validation prior to indexing in ragged branches.

Suggested change
// For ragged paths (mixed 2D/3D inputs), an offs tensor describing
// segment boundaries is required and must have the expected length.
if (a_is_2d != b_is_2d) {
TORCH_CHECK(
offs.has_value(),
"GroupedMM: 'offs' tensor must be provided for ragged (mixed 2D/3D) "
"batched matrix multiplication.");
// offs is typically a 1D tensor of length batch_size + 1, where
// batch_size is taken from the 3D operand's leading dimension.
int64_t batch_size = (!a_is_2d) ? mat_a.size(0) : mat_b.size(0);
TORCH_CHECK(
static_cast<int64_t>(offs_host.size()) == batch_size + 1,
"GroupedMM: 'offs' tensor must have length batch_size + 1; got ",
offs_host.size(),
" elements for batch_size=",
batch_size,
".");
}

Copilot uses AI. Check for mistakes.
Comment on lines +40 to +44
a_f32 = a_fp8.float() * scale_a.unsqueeze(-1)
b_phys = b_fp8_t.t().contiguous() # (N, K) row-major
b_f32 = b_phys.float() * scale_b.unsqueeze(-1) # (N, K) * (N, 1)
out = a_f32 @ b_f32.t()
return out.to(torch.bfloat16)
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation in ScaledGroupedMM.cpp dequantizes FP8 to BF16 (and casts back to BF16 after applying the float scale) before GEMM, while this reference computes GEMM in float32 and only casts the output to BF16. This mismatch can hide real kernel errors behind wide tolerances. Update the reference to mirror the kernel behavior (apply scale + cast to BF16 prior to matmul), or adjust the kernel comment/tests if the intended reference is float32 GEMM.

Copilot uses AI. Check for mistakes.
Comment on lines +26 to +28
TEST_XPU = torch.xpu.is_available()


Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TEST_XPU is defined but unused in this file, and the dtype test parameter is also unused across these tests. Since this is a standalone new file, consider removing TEST_XPU and either removing the unused dtype parameter/decorator (if supported by the test harness), or using dtype in assertions to avoid dead parameters (e.g., assert it’s torch.bfloat16 and/or pass it where applicable).

Suggested change
TEST_XPU = torch.xpu.is_available()

Copilot uses AI. Check for mistakes.

@onlyXPU
@dtypes(torch.bfloat16)
def test_scaled_grouped_gemm_3d_3d(self, device, dtype):
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TEST_XPU is defined but unused in this file, and the dtype test parameter is also unused across these tests. Since this is a standalone new file, consider removing TEST_XPU and either removing the unused dtype parameter/decorator (if supported by the test harness), or using dtype in assertions to avoid dead parameters (e.g., assert it’s torch.bfloat16 and/or pass it where applicable).

Copilot uses AI. Check for mistakes.
Comment on lines +257 to +271
// Dequantize a 2D FP8 tensor with a 1D rowwise scale.
// Result is BF16 (bf16 * float32 promotes to float32, so cast back).
at::Tensor dequantize_rowwise_2d(
const at::Tensor& fp8_tensor,
const at::Tensor& scale) {
auto bf16 = fp8_tensor.to(at::kBFloat16);
return (bf16 * scale.unsqueeze(-1)).to(at::kBFloat16);
}

// Dequantize a 3D FP8 tensor with a 2D rowwise scale.
at::Tensor dequantize_rowwise_3d(
const at::Tensor& fp8_tensor,
const at::Tensor& scale) {
auto bf16 = fp8_tensor.to(at::kBFloat16);
return (bf16 * scale.unsqueeze(-1)).to(at::kBFloat16);
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dequantize_rowwise_2d and dequantize_rowwise_3d currently have identical bodies (the broadcasting works the same way for both). Consider consolidating into a single helper (or renaming to reflect the shared behavior) to reduce duplication and the chance of future divergence.

Suggested change
// Dequantize a 2D FP8 tensor with a 1D rowwise scale.
// Result is BF16 (bf16 * float32 promotes to float32, so cast back).
at::Tensor dequantize_rowwise_2d(
const at::Tensor& fp8_tensor,
const at::Tensor& scale) {
auto bf16 = fp8_tensor.to(at::kBFloat16);
return (bf16 * scale.unsqueeze(-1)).to(at::kBFloat16);
}
// Dequantize a 3D FP8 tensor with a 2D rowwise scale.
at::Tensor dequantize_rowwise_3d(
const at::Tensor& fp8_tensor,
const at::Tensor& scale) {
auto bf16 = fp8_tensor.to(at::kBFloat16);
return (bf16 * scale.unsqueeze(-1)).to(at::kBFloat16);
// Shared implementation for rowwise dequantization.
// Supports both 2D (tensor, 1D scale) and 3D (tensor, 2D scale) cases
// via broadcasting on the last dimension.
at::Tensor dequantize_rowwise_impl(
const at::Tensor& fp8_tensor,
const at::Tensor& scale) {
auto bf16 = fp8_tensor.to(at::kBFloat16);
return (bf16 * scale.unsqueeze(-1)).to(at::kBFloat16);
}
// Dequantize a 2D FP8 tensor with a 1D rowwise scale.
// Result is BF16 (bf16 * float32 promotes to float32, so cast back).
at::Tensor dequantize_rowwise_2d(
const at::Tensor& fp8_tensor,
const at::Tensor& scale) {
return dequantize_rowwise_impl(fp8_tensor, scale);
}
// Dequantize a 3D FP8 tensor with a 2D rowwise scale.
at::Tensor dequantize_rowwise_3d(
const at::Tensor& fp8_tensor,
const at::Tensor& scale) {
return dequantize_rowwise_impl(fp8_tensor, scale);

Copilot uses AI. Check for mistakes.
#include <cutlass/util/device_memory.h>
#include <cutlass/util/packed_stride.hpp>

#include <iostream>
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

<iostream> is included but not used in this translation unit. Please remove it to reduce compile overhead and keep includes minimal.

Suggested change
#include <iostream>

Copilot uses AI. Check for mistakes.
@xuhancn xuhancn changed the title Xpu scaled grouped mm Add scaled_grouped_mm kernel for XPU (FP8 x FP8 -> BF16) Mar 25, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants