Skip to content

Commit 924ff91

Browse files
Merge branch 'main' into pr_persistent_grouped_mxfp8_kernel
2 parents 219e925 + d40b9de commit 924ff91

File tree

13 files changed

+703
-51
lines changed

13 files changed

+703
-51
lines changed

tests/cpp/operator/test_grouped_gemm.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,10 @@ void run_grouped_gemm_case(const TestParams& params) {
123123

124124
for (size_t i = 0; i < num_gemms; ++i) {
125125
const auto [M, N, K] = shapes[i];
126-
const std::vector<size_t> a_shape = params.transa ? std::vector<size_t>{M, K}
127-
: std::vector<size_t>{K, M};
128-
const std::vector<size_t> b_shape = params.transb ? std::vector<size_t>{K, N}
129-
: std::vector<size_t>{N, K};
126+
const std::vector<size_t> a_shape = params.transa ? std::vector<size_t>{N, K}
127+
: std::vector<size_t>{K, N};
128+
const std::vector<size_t> b_shape = params.transb ? std::vector<size_t>{K, M}
129+
: std::vector<size_t>{M, K};
130130
switch (params.input_case) {
131131
case InputCase::kFP8Current: {
132132
A_tensors.emplace_back(make_fp8_operand("A" + std::to_string(i), a_shape));

transformer_engine/common/gemm/cublaslt_grouped_gemm.cu

Lines changed: 61 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,29 @@ inline cublasLtMatmulAlgo_t select_grouped_gemm_algo(cublasLtHandle_t handle,
440440
return heuristicResult.algo;
441441
}
442442

443+
// Device helper: compute the element offset for tensor `idx` given shape metadata.
444+
// Three cases:
445+
// 1. Explicit per-tensor offset array provided → use it directly.
446+
// 2. Per-tensor first/last dims provided but no offsets → cumulative sum of (first*last) products.
447+
// 3. Fully uniform shapes → idx * uniform_first * uniform_last.
448+
__forceinline__ __device__ int64_t compute_grouped_tensor_offset(const TensorShapeInfo &meta,
449+
size_t idx) {
450+
if (meta.offsets) {
451+
return meta.offsets[idx];
452+
} else if (meta.first_dims != nullptr || meta.last_dims != nullptr) {
453+
// offset[i] = sum_{j < i} (first_dims[j] * last_dims[j])
454+
int64_t cumsum = 0;
455+
for (size_t i = 0; i < idx; i++) {
456+
int64_t f = meta.first_dims ? meta.first_dims[i] : meta.uniform_first;
457+
int64_t l = meta.last_dims ? meta.last_dims[i] : meta.uniform_last;
458+
cumsum += f * l;
459+
}
460+
return cumsum;
461+
} else {
462+
return static_cast<int64_t>(idx) * meta.uniform_first * meta.uniform_last;
463+
}
464+
}
465+
443466
// Single kernel that sets up all GEMM parameters.
444467
// Rationale: cuBLASLt grouped matmul API needs flat arrays of pointers and per-matrix dimensions,
445468
// but NVTEGroupedTensor stores a single contiguous buffer + optional per-tensor offsets/shapes.
@@ -464,15 +487,11 @@ __global__ void setup_grouped_gemm_kernel(
464487
int64_t d_first = D_meta.first_dims ? D_meta.first_dims[idx] : D_meta.uniform_first;
465488
int64_t d_last = D_meta.last_dims ? D_meta.last_dims[idx] : D_meta.uniform_last;
466489

467-
// Compute offsets (from array or compute from uniform dims)
468-
int64_t a_offset =
469-
A_meta.offsets ? A_meta.offsets[idx] : (idx * A_meta.uniform_first * A_meta.uniform_last);
470-
int64_t b_offset =
471-
B_meta.offsets ? B_meta.offsets[idx] : (idx * B_meta.uniform_first * B_meta.uniform_last);
472-
int64_t c_offset =
473-
C_meta.offsets ? C_meta.offsets[idx] : (idx * C_meta.uniform_first * C_meta.uniform_last);
474-
int64_t d_offset =
475-
D_meta.offsets ? D_meta.offsets[idx] : (idx * D_meta.uniform_first * D_meta.uniform_last);
490+
// Compute offsets (from explicit array, cumulative from per-tensor dims, or uniform)
491+
int64_t a_offset = compute_grouped_tensor_offset(A_meta, idx);
492+
int64_t b_offset = compute_grouped_tensor_offset(B_meta, idx);
493+
int64_t c_offset = compute_grouped_tensor_offset(C_meta, idx);
494+
int64_t d_offset = compute_grouped_tensor_offset(D_meta, idx);
476495

477496
// Compute data pointers
478497
A_ptrs[idx] = a_base + a_offset * a_elem_size;
@@ -487,9 +506,8 @@ __global__ void setup_grouped_gemm_kernel(
487506
a_cols[idx] = static_cast<int>(a_first);
488507
b_rows[idx] = static_cast<int>(b_last);
489508
b_cols[idx] = static_cast<int>(b_first);
490-
// For OUTPUTS (D, C): cuBLAS writes in column-major, so rows=first (M), cols=last (N).
491-
d_rows[idx] = static_cast<int>(d_first);
492-
d_cols[idx] = static_cast<int>(d_last);
509+
d_rows[idx] = static_cast<int>(d_last);
510+
d_cols[idx] = static_cast<int>(d_first);
493511

494512
// Fill alpha/beta pointers (per-matrix)
495513
alpha_ptrs[idx] = alpha_ptr + idx;
@@ -535,6 +553,11 @@ inline size_t grouped_gemm_setup_workspace_size(size_t num_tensors) {
535553

536554
} // namespace
537555

556+
size_t nvte_get_grouped_gemm_setup_workspace_size(size_t num_tensors) {
557+
NVTE_API_CALL(nvte_get_grouped_gemm_setup_workspace_size);
558+
return grouped_gemm_setup_workspace_size(num_tensors);
559+
}
560+
538561
void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb,
539562
const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha,
540563
const NVTETensor beta, NVTETensor workspace_setup,
@@ -642,4 +665,30 @@ void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedT
642665
CUBLAS_VERSION, ". Please upgrade to CUDA 13.1 or newer.");
643666
}
644667

668+
size_t nvte_get_grouped_gemm_setup_workspace_size(size_t num_tensors) {
669+
NVTE_ERROR(
670+
"nvte_get_grouped_gemm_setup_workspace_size requires cuBLAS 13.2+, but compile-time cuBLAS "
671+
"version is ",
672+
CUBLAS_VERSION, ". Please upgrade to CUDA 13.1 or newer.");
673+
return 0;
674+
}
675+
645676
#endif // CUBLAS_VERSION >= 130200
677+
678+
namespace {
679+
680+
__global__ void convert_int32_to_int64_kernel(const int32_t *src, int64_t *dst, size_t n) {
681+
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
682+
if (idx < n) dst[idx] = static_cast<int64_t>(src[idx]);
683+
}
684+
685+
} // namespace
686+
687+
void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, cudaStream_t stream) {
688+
NVTE_API_CALL(nvte_convert_int32_to_int64);
689+
if (n == 0) return;
690+
const int threads = 256;
691+
const int blocks = static_cast<int>((n + threads - 1) / threads);
692+
convert_int32_to_int64_kernel<<<blocks, threads, 0, stream>>>(src, dst, n);
693+
NVTE_CHECK_CUDA(cudaGetLastError());
694+
}

transformer_engine/common/include/transformer_engine/gemm.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,31 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor
329329
* - Shape compatibility: if transa=false, transb=false:
330330
* - A[i]: (M[i], K[i]), B[i]: (K[i], N[i]), D[i]: (M[i], N[i])
331331
*/
332+
/*! \brief Return the required size in bytes for the setup workspace of grouped GEMM.
333+
*
334+
* The setup workspace stores pointer arrays and per-matrix dimension arrays used
335+
* by the grouped GEMM kernel. Its size depends only on the number of tensors (GEMMs)
336+
* in the group and is independent of matrix dimensions.
337+
*
338+
* Pass the result as the size of the workspace_setup tensor in nvte_grouped_gemm.
339+
*
340+
* \param[in] num_tensors Number of tensors (GEMMs) in the group.
341+
* \return Required size in bytes for workspace_setup.
342+
*/
343+
size_t nvte_get_grouped_gemm_setup_workspace_size(size_t num_tensors);
344+
345+
/*! \brief Convert a device array of int32 values to int64 values.
346+
*
347+
* Useful for preparing group_sizes for nvte_grouped_gemm when the caller
348+
* holds int32 sizes and needs int64 values on the device.
349+
*
350+
* \param[in] src Device pointer to source int32 array.
351+
* \param[out] dst Device pointer to destination int64 array.
352+
* \param[in] n Number of elements.
353+
* \param[in] stream CUDA stream.
354+
*/
355+
void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, cudaStream_t stream);
356+
332357
void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb,
333358
const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha,
334359
const NVTETensor beta, NVTETensor workspace_setup,

transformer_engine/common/triton/permutation.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,10 @@ def _sort_chunks_by_map_kernel(
599599
input_ptr,
600600
row_id_map_ptr,
601601
probs_ptr,
602+
# Pre-allocated output buffer for JAX input_output_aliases.
603+
# Aliased to output_ptr in JAX so they point to the same memory.
604+
# In PyTorch, pass the same tensor as output_ptr.
605+
output_buf_ptr, # pylint: disable=unused-argument
602606
# strides
603607
stride_input_token,
604608
stride_input_hidden,

0 commit comments

Comments
 (0)