Skip to content

Commit b9f4013

Browse files
pggPLpre-commit-ci[bot]ptrendx
authored
[common] Add support for cuBLASLt GEMM for GroupedTensor (#2502)
* code drop Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add FP8 scale support and fix alignment for grouped GEMM - Add FP8 scale_inv pointer handling in nvte_grouped_gemm for proper FP8 GEMM - Fix random padding in tests to ensure 16-byte alignment for all dtypes - Reorder GroupedGemmSetupWorkspace members for natural alignment - Remove debug prints Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Grouped GEMM: code cleanup and NULL C support - Remove unused alignment parameter from GroupedGemmSetupWorkspace::from_buffers - Simplify select_grouped_operand by removing dead code branches - Add GroupedOperandSelection.tensor field to avoid passing tensor separately - Extract set_fp8_scale_pointers and init_matrix_layouts helpers - Add safety check for FP8 on Hopper column-wise fallback - Support NULL C tensor when beta=0 (uses D as placeholder) - Remove unused get_scale_inv() from test - Add use_null_c test parameter and test case - Fix documentation: alpha/beta are single element tensors only Signed-off-by: Piotr Gadzinski <pgadzinski@nvidia.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Grouped GEMM: per-matrix alpha/beta support - Change alpha/beta from single values to per-matrix arrays - Validate alpha/beta have exactly num_tensors elements - Update kernel to index alpha_ptr[idx] and beta_ptr[idx] - Move alpha/beta validation to validate_grouped_gemm_inputs - Update tests to use per-matrix alpha/beta arrays - Update documentation Signed-off-by: Piotr Gadzinski <pgadzinski@nvidia.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix alpha/beta numel - use SimpleTensor::numel() Signed-off-by: Piotr Gadzinski <pgadzinski@nvidia.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * Refactor: move grouped GEMM to separate file and cleanup API Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * Require Blackwell (SM100) and cuBLAS 13.1+ for grouped GEMM Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * fixes Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/common/gemm/config.h Co-authored-by: Przemyslaw Tredak <ptrendx@gmail.com> Signed-off-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * changed Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * suggestions Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactored hopper tensor selection Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Signed-off-by: Piotr Gadzinski <pgadzinski@nvidia.com> Signed-off-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak <ptrendx@gmail.com>
1 parent f04b094 commit b9f4013

12 files changed

Lines changed: 1494 additions & 20 deletions

File tree

tests/cpp/operator/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ add_executable(test_operator
3030
test_causal_softmax.cu
3131
test_swizzle.cu
3232
test_swap_first_dims.cu
33+
test_grouped_gemm.cu
3334
../test_common.cu)
3435

3536
# Find required packages
Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,308 @@
1+
/*************************************************************************
2+
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
*
4+
* See LICENSE for license information.
5+
************************************************************************/
6+
7+
#include <cublasLt.h>
8+
#include <cuda_bf16.h>
9+
#include <cuda_runtime.h>
10+
#include <gtest/gtest.h>
11+
12+
#include <algorithm>
13+
#include <memory>
14+
#include <numeric>
15+
#include <optional>
16+
#include <random>
17+
#include <tuple>
18+
#include <vector>
19+
20+
#include <transformer_engine/cast.h>
21+
#include <transformer_engine/gemm.h>
22+
#include <transformer_engine/recipe.h>
23+
#include <transformer_engine/transformer_engine.h>
24+
25+
#include "../test_common.h"
26+
27+
using namespace transformer_engine;
28+
using namespace test;
29+
30+
namespace {
31+
32+
enum class InputCase {
33+
kFP8Current,
34+
kBF16,
35+
};
36+
37+
enum class ShapeCase {
38+
kAllSame,
39+
kSameFirst,
40+
kSameLast,
41+
kAllDifferent,
42+
};
43+
44+
size_t grouped_setup_workspace_size(const size_t num_tensors) {
45+
const size_t ptr_bytes = num_tensors * sizeof(void*);
46+
const size_t int_bytes = num_tensors * sizeof(int);
47+
// Layout: 6 pointer arrays (A, B, C, D, alpha, beta) + 6 int arrays (a_rows, a_cols, b_rows, b_cols, d_rows, d_cols)
48+
size_t size = 6 * ptr_bytes + 6 * int_bytes;
49+
const size_t alignment = 256;
50+
size = ((size + alignment - 1) / alignment) * alignment;
51+
return size;
52+
}
53+
54+
Tensor make_fp8_operand(const std::string& name, const std::vector<size_t>& shape) {
55+
Tensor input_fp32(name + "_fp32", shape, DType::kFloat32);
56+
fillUniform(&input_fp32);
57+
58+
Tensor fp8(name, shape, TypeInfo<fp8e4m3>::dtype, true, true, NVTE_DELAYED_TENSOR_SCALING);
59+
60+
nvte_compute_amax(input_fp32.data(), fp8.data(), 0);
61+
QuantizationConfigWrapper config;
62+
nvte_compute_scale_from_amax(fp8.data(), config, 0);
63+
nvte_quantize(input_fp32.data(), fp8.data(), 0);
64+
return fp8;
65+
}
66+
67+
Tensor make_bf16_operand(const std::string& name, const std::vector<size_t>& shape) {
68+
Tensor t(name, shape, DType::kBFloat16);
69+
const size_t numel = shape[0] * shape[1];
70+
std::vector<__nv_bfloat16> ones(numel, __float2bfloat16(1.0f));
71+
NVTE_CHECK_CUDA(cudaMemcpy(t.rowwise_dptr(), ones.data(),
72+
numel * sizeof(__nv_bfloat16), cudaMemcpyHostToDevice));
73+
return t;
74+
}
75+
76+
struct TestParams {
77+
InputCase input_case;
78+
bool transa;
79+
bool transb;
80+
ShapeCase shape_case;
81+
bool use_null_c = false; // When true, pass nullptr for C (valid when beta=0)
82+
};
83+
84+
// Returns a vector of (M, N, K) tuples for each GEMM in the group.
85+
// M - number of rows in output D
86+
// N - number of columns in output D
87+
// K - reduction dimension shared between A and B
88+
std::vector<std::tuple<size_t, size_t, size_t>> make_shapes(ShapeCase scase) {
89+
switch (scase) {
90+
case ShapeCase::kAllSame:
91+
return {{64, 64, 32}, {64, 64, 32}, {64, 64, 32}};
92+
case ShapeCase::kSameFirst:
93+
// Same M (first dim), varying N and K
94+
return {{64, 80, 32}, {64, 96, 48}, {64, 112, 64}};
95+
case ShapeCase::kSameLast:
96+
// Same N (last dim), varying M and K
97+
return {{64, 80, 32}, {80, 80, 48}, {96, 80, 64}};
98+
case ShapeCase::kAllDifferent:
99+
default:
100+
return {{64, 96, 32}, {80, 112, 48}, {96, 128, 64}};
101+
}
102+
}
103+
104+
void run_grouped_gemm_case(const TestParams& params) {
105+
#if CUBLAS_VERSION < 130100
106+
GTEST_SKIP() << "Grouped GEMM requires cuBLAS 13.1+, but compile-time cuBLAS version is "
107+
<< CUBLAS_VERSION << ".";
108+
#else
109+
if (getDeviceComputeCapability() < blackwellComputeCapability) {
110+
GTEST_SKIP() << "Grouped GEMM requires Blackwell (SM100) or newer.";
111+
}
112+
113+
const std::vector<std::tuple<size_t, size_t, size_t>> shapes = make_shapes(params.shape_case);
114+
115+
const size_t num_gemms = shapes.size();
116+
std::vector<Tensor> A_tensors;
117+
std::vector<Tensor> B_tensors;
118+
std::vector<Tensor> D_multi;
119+
120+
A_tensors.reserve(num_gemms);
121+
B_tensors.reserve(num_gemms);
122+
D_multi.reserve(num_gemms);
123+
124+
for (size_t i = 0; i < num_gemms; ++i) {
125+
const auto [M, N, K] = shapes[i];
126+
const std::vector<size_t> a_shape = params.transa ? std::vector<size_t>{M, K}
127+
: std::vector<size_t>{K, M};
128+
const std::vector<size_t> b_shape = params.transb ? std::vector<size_t>{K, N}
129+
: std::vector<size_t>{N, K};
130+
switch (params.input_case) {
131+
case InputCase::kFP8Current: {
132+
A_tensors.emplace_back(make_fp8_operand("A" + std::to_string(i), a_shape));
133+
B_tensors.emplace_back(make_fp8_operand("B" + std::to_string(i), b_shape));
134+
break;
135+
}
136+
case InputCase::kBF16: {
137+
A_tensors.emplace_back(make_bf16_operand("A" + std::to_string(i), a_shape));
138+
B_tensors.emplace_back(make_bf16_operand("B" + std::to_string(i), b_shape));
139+
break;
140+
}
141+
}
142+
D_multi.emplace_back(Tensor("D_multi" + std::to_string(i),
143+
std::vector<size_t>{M, N},
144+
DType::kBFloat16));
145+
}
146+
147+
std::vector<NVTETensor> A_ptrs(num_gemms);
148+
std::vector<NVTETensor> B_ptrs(num_gemms);
149+
std::vector<NVTETensor> D_ptrs(num_gemms);
150+
std::vector<Tensor> workspaces(num_gemms);
151+
std::vector<NVTETensor> workspace_ptrs(num_gemms, nullptr);
152+
std::vector<Tensor*> A_views;
153+
std::vector<Tensor*> B_views;
154+
A_views.reserve(num_gemms);
155+
B_views.reserve(num_gemms);
156+
157+
// Empty bias/gelu arrays for nvte_multi_tensor_gemm (no epilogues)
158+
std::vector<NVTETensor> bias_ptrs(num_gemms, nullptr);
159+
std::vector<NVTETensor> gelu_ptrs(num_gemms, nullptr);
160+
161+
const size_t cublas_ws_bytes = 32ull * 1024 * 1024;
162+
163+
for (size_t i = 0; i < num_gemms; ++i) {
164+
A_ptrs[i] = A_tensors[i].data();
165+
B_ptrs[i] = B_tensors[i].data();
166+
D_ptrs[i] = D_multi[i].data();
167+
workspaces[i] = Tensor("workspace" + std::to_string(i), std::vector<size_t>{cublas_ws_bytes}, DType::kByte);
168+
workspace_ptrs[i] = workspaces[i].data();
169+
A_views.push_back(&A_tensors[i]);
170+
B_views.push_back(&B_tensors[i]);
171+
}
172+
173+
nvte_multi_tensor_gemm(A_ptrs.data(),
174+
B_ptrs.data(),
175+
D_ptrs.data(),
176+
bias_ptrs.data(),
177+
gelu_ptrs.data(),
178+
static_cast<int>(num_gemms),
179+
params.transa,
180+
params.transb,
181+
false, // grad
182+
workspace_ptrs.data(),
183+
false, // accumulate
184+
false, // use_split_accumulator
185+
0, // sm_count
186+
0);
187+
188+
GroupedBuffers grouped_A = build_grouped_tensor(A_views, A_tensors[0].scaling_mode());
189+
GroupedBuffers grouped_B = build_grouped_tensor(B_views, B_tensors[0].scaling_mode());
190+
191+
std::vector<Tensor> C_tensors;
192+
std::vector<Tensor> D_group_tensors;
193+
C_tensors.reserve(num_gemms);
194+
D_group_tensors.reserve(num_gemms);
195+
for (size_t i = 0; i < num_gemms; ++i) {
196+
const auto [M, N, K] = shapes[i];
197+
(void)K;
198+
if (!params.use_null_c) {
199+
C_tensors.emplace_back(Tensor("C" + std::to_string(i),
200+
std::vector<size_t>{static_cast<size_t>(M), static_cast<size_t>(N)},
201+
DType::kBFloat16));
202+
}
203+
D_group_tensors.emplace_back(Tensor("D_group" + std::to_string(i),
204+
std::vector<size_t>{static_cast<size_t>(M), static_cast<size_t>(N)},
205+
DType::kBFloat16));
206+
NVTE_CHECK_CUDA(cudaMemset(D_group_tensors.back().rowwise_dptr(), 0, bytes(D_group_tensors.back().rowwise_shape(), D_group_tensors.back().dtype())));
207+
}
208+
209+
std::vector<Tensor*> C_views, D_views;
210+
for (size_t i = 0; i < num_gemms; ++i) {
211+
if (!params.use_null_c) {
212+
C_views.push_back(&C_tensors[i]);
213+
}
214+
D_views.push_back(&D_group_tensors[i]);
215+
}
216+
217+
std::optional<GroupedBuffers> grouped_C;
218+
if (!params.use_null_c) {
219+
grouped_C = build_grouped_tensor(C_views, NVTE_DELAYED_TENSOR_SCALING);
220+
}
221+
GroupedBuffers grouped_D = build_grouped_tensor(D_views, NVTE_DELAYED_TENSOR_SCALING);
222+
223+
// Per-matrix alpha/beta (all 1.0 and 0.0 respectively)
224+
Tensor alpha_tensor("alpha", std::vector<size_t>{num_gemms}, DType::kFloat32);
225+
Tensor beta_tensor("beta", std::vector<size_t>{num_gemms}, DType::kFloat32);
226+
std::vector<float> alpha_vals(num_gemms, 1.f);
227+
std::vector<float> beta_vals(num_gemms, 0.f);
228+
NVTE_CHECK_CUDA(cudaMemcpy(alpha_tensor.rowwise_dptr(), alpha_vals.data(),
229+
num_gemms * sizeof(float), cudaMemcpyHostToDevice));
230+
NVTE_CHECK_CUDA(cudaMemcpy(beta_tensor.rowwise_dptr(), beta_vals.data(),
231+
num_gemms * sizeof(float), cudaMemcpyHostToDevice));
232+
233+
const size_t setup_ws_bytes = grouped_setup_workspace_size(num_gemms);
234+
Tensor setup_ws("setup_ws", std::vector<size_t>{setup_ws_bytes}, DType::kByte);
235+
Tensor cublas_ws("cublas_ws", std::vector<size_t>{cublas_ws_bytes}, DType::kByte);
236+
237+
nvte_grouped_gemm(grouped_A.get_handle(),
238+
params.transa,
239+
grouped_B.get_handle(),
240+
params.transb,
241+
params.use_null_c ? nullptr : grouped_C->get_handle(),
242+
grouped_D.get_handle(),
243+
alpha_tensor.data(),
244+
beta_tensor.data(),
245+
setup_ws.data(),
246+
cublas_ws.data(),
247+
nullptr, // config (use defaults)
248+
0);
249+
250+
for (size_t i = 0; i < num_gemms; ++i) {
251+
Tensor grouped_split("grouped_D" + std::to_string(i),
252+
std::vector<size_t>{static_cast<size_t>(std::get<0>(shapes[i])),
253+
static_cast<size_t>(std::get<1>(shapes[i]))},
254+
D_multi[i].dtype());
255+
const size_t offset_bytes = static_cast<size_t>(grouped_D.offsets_host[i]) * grouped_D.elem_size;
256+
NVTE_CHECK_CUDA(cudaMemcpy(grouped_split.rowwise_dptr(),
257+
static_cast<char*>(grouped_D.get_data()) + offset_bytes,
258+
grouped_D.tensor_bytes[i],
259+
cudaMemcpyDeviceToDevice));
260+
grouped_split.to_cpu();
261+
D_multi[i].to_cpu();
262+
auto [atol, rtol] = getTolerances(D_multi[i].dtype());
263+
compareResults("grouped_vs_multi",
264+
grouped_split,
265+
D_multi[i].rowwise_cpu_dptr<bf16>(),
266+
true,
267+
atol,
268+
rtol);
269+
}
270+
#endif // CUBLAS_VERSION >= 130100
271+
}
272+
273+
class GroupedGemmTest : public ::testing::TestWithParam<TestParams> {};
274+
275+
TEST_P(GroupedGemmTest, CompareWithMultiTensorGemm) {
276+
run_grouped_gemm_case(GetParam());
277+
}
278+
279+
std::string MakeGroupedGemmTestName(const testing::TestParamInfo<GroupedGemmTest::ParamType>& info) {
280+
constexpr const char* kInputNames[] = {"FP8Current", "BF16"};
281+
constexpr const char* kShapeNames[] = {"AllSame", "SameM", "SameN", "AllDiff"};
282+
const std::string layout = std::string("ta") + (info.param.transa ? "T" : "N") +
283+
"tb" + (info.param.transb ? "T" : "N");
284+
const std::string null_c = info.param.use_null_c ? "_NullC" : "";
285+
return std::string(kInputNames[static_cast<int>(info.param.input_case)]) + "_" +
286+
kShapeNames[static_cast<int>(info.param.shape_case)] + "_" + layout + null_c;
287+
}
288+
289+
// TestParams: {input_case, transa, transb, shape_case, use_null_c}
290+
const std::vector<TestParams> kTestParams = {
291+
// Basic tests
292+
{InputCase::kFP8Current, true, false, ShapeCase::kAllDifferent, false},
293+
{InputCase::kFP8Current, false, true, ShapeCase::kAllDifferent, false},
294+
{InputCase::kFP8Current, false, false, ShapeCase::kAllSame, false},
295+
{InputCase::kBF16, true, false, ShapeCase::kSameFirst, false},
296+
{InputCase::kBF16, false, true, ShapeCase::kSameLast, false},
297+
{InputCase::kBF16, false, false, ShapeCase::kAllSame, false},
298+
{InputCase::kBF16, true, true, ShapeCase::kAllDifferent, false},
299+
// Test NULL C (valid when beta=0)
300+
{InputCase::kBF16, false, false, ShapeCase::kAllSame, true},
301+
};
302+
303+
INSTANTIATE_TEST_SUITE_P(OperatorTest,
304+
GroupedGemmTest,
305+
::testing::ValuesIn(kTestParams),
306+
MakeGroupedGemmTestName);
307+
308+
} // namespace

0 commit comments

Comments
 (0)