@@ -440,6 +440,29 @@ inline cublasLtMatmulAlgo_t select_grouped_gemm_algo(cublasLtHandle_t handle,
440440 return heuristicResult.algo ;
441441}
442442
443+ // Device helper: compute the element offset for tensor `idx` given shape metadata.
444+ // Three cases:
445+ // 1. Explicit per-tensor offset array provided → use it directly.
446+ // 2. Per-tensor first/last dims provided but no offsets → cumulative sum of (first*last) products.
447+ // 3. Fully uniform shapes → idx * uniform_first * uniform_last.
448+ __forceinline__ __device__ int64_t compute_grouped_tensor_offset (const TensorShapeInfo &meta,
449+ size_t idx) {
450+ if (meta.offsets ) {
451+ return meta.offsets [idx];
452+ } else if (meta.first_dims != nullptr || meta.last_dims != nullptr ) {
453+ // offset[i] = sum_{j < i} (first_dims[j] * last_dims[j])
454+ int64_t cumsum = 0 ;
455+ for (size_t i = 0 ; i < idx; i++) {
456+ int64_t f = meta.first_dims ? meta.first_dims [i] : meta.uniform_first ;
457+ int64_t l = meta.last_dims ? meta.last_dims [i] : meta.uniform_last ;
458+ cumsum += f * l;
459+ }
460+ return cumsum;
461+ } else {
462+ return static_cast <int64_t >(idx) * meta.uniform_first * meta.uniform_last ;
463+ }
464+ }
465+
443466// Single kernel that sets up all GEMM parameters.
444467// Rationale: cuBLASLt grouped matmul API needs flat arrays of pointers and per-matrix dimensions,
445468// but NVTEGroupedTensor stores a single contiguous buffer + optional per-tensor offsets/shapes.
@@ -464,15 +487,11 @@ __global__ void setup_grouped_gemm_kernel(
464487 int64_t d_first = D_meta.first_dims ? D_meta.first_dims [idx] : D_meta.uniform_first ;
465488 int64_t d_last = D_meta.last_dims ? D_meta.last_dims [idx] : D_meta.uniform_last ;
466489
467- // Compute offsets (from array or compute from uniform dims)
468- int64_t a_offset =
469- A_meta.offsets ? A_meta.offsets [idx] : (idx * A_meta.uniform_first * A_meta.uniform_last );
470- int64_t b_offset =
471- B_meta.offsets ? B_meta.offsets [idx] : (idx * B_meta.uniform_first * B_meta.uniform_last );
472- int64_t c_offset =
473- C_meta.offsets ? C_meta.offsets [idx] : (idx * C_meta.uniform_first * C_meta.uniform_last );
474- int64_t d_offset =
475- D_meta.offsets ? D_meta.offsets [idx] : (idx * D_meta.uniform_first * D_meta.uniform_last );
490+ // Compute offsets (from explicit array, cumulative from per-tensor dims, or uniform)
491+ int64_t a_offset = compute_grouped_tensor_offset (A_meta, idx);
492+ int64_t b_offset = compute_grouped_tensor_offset (B_meta, idx);
493+ int64_t c_offset = compute_grouped_tensor_offset (C_meta, idx);
494+ int64_t d_offset = compute_grouped_tensor_offset (D_meta, idx);
476495
477496 // Compute data pointers
478497 A_ptrs[idx] = a_base + a_offset * a_elem_size;
@@ -487,9 +506,8 @@ __global__ void setup_grouped_gemm_kernel(
487506 a_cols[idx] = static_cast <int >(a_first);
488507 b_rows[idx] = static_cast <int >(b_last);
489508 b_cols[idx] = static_cast <int >(b_first);
490- // For OUTPUTS (D, C): cuBLAS writes in column-major, so rows=first (M), cols=last (N).
491- d_rows[idx] = static_cast <int >(d_first);
492- d_cols[idx] = static_cast <int >(d_last);
509+ d_rows[idx] = static_cast <int >(d_last);
510+ d_cols[idx] = static_cast <int >(d_first);
493511
494512 // Fill alpha/beta pointers (per-matrix)
495513 alpha_ptrs[idx] = alpha_ptr + idx;
@@ -535,6 +553,11 @@ inline size_t grouped_gemm_setup_workspace_size(size_t num_tensors) {
535553
536554} // namespace
537555
556+ size_t nvte_get_grouped_gemm_setup_workspace_size (size_t num_tensors) {
557+ NVTE_API_CALL (nvte_get_grouped_gemm_setup_workspace_size);
558+ return grouped_gemm_setup_workspace_size (num_tensors);
559+ }
560+
538561void nvte_grouped_gemm (const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb,
539562 const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha,
540563 const NVTETensor beta, NVTETensor workspace_setup,
@@ -642,4 +665,30 @@ void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedT
642665 CUBLAS_VERSION, " . Please upgrade to CUDA 13.1 or newer." );
643666}
644667
668+ size_t nvte_get_grouped_gemm_setup_workspace_size (size_t num_tensors) {
669+ NVTE_ERROR (
670+ " nvte_get_grouped_gemm_setup_workspace_size requires cuBLAS 13.2+, but compile-time cuBLAS "
671+ " version is " ,
672+ CUBLAS_VERSION, " . Please upgrade to CUDA 13.1 or newer." );
673+ return 0 ;
674+ }
675+
645676#endif // CUBLAS_VERSION >= 130200
677+
678+ namespace {
679+
680+ __global__ void convert_int32_to_int64_kernel (const int32_t *src, int64_t *dst, size_t n) {
681+ size_t idx = blockIdx .x * blockDim .x + threadIdx .x ;
682+ if (idx < n) dst[idx] = static_cast <int64_t >(src[idx]);
683+ }
684+
685+ } // namespace
686+
687+ void nvte_convert_int32_to_int64 (const int32_t *src, int64_t *dst, size_t n, cudaStream_t stream) {
688+ NVTE_API_CALL (nvte_convert_int32_to_int64);
689+ if (n == 0 ) return ;
690+ const int threads = 256 ;
691+ const int blocks = static_cast <int >((n + threads - 1 ) / threads);
692+ convert_int32_to_int64_kernel<<<blocks, threads, 0 , stream>>> (src, dst, n);
693+ NVTE_CHECK_CUDA (cudaGetLastError ());
694+ }
0 commit comments