Skip to content

[Common] Persistent Grouped NVFP4 quantization kernel#2743

Open
Oleg-Goncharov wants to merge 49 commits intoNVIDIA:mainfrom
Oleg-Goncharov:pr_persistent_grouped_nvfp4_kernel
Open

[Common] Persistent Grouped NVFP4 quantization kernel#2743
Oleg-Goncharov wants to merge 49 commits intoNVIDIA:mainfrom
Oleg-Goncharov:pr_persistent_grouped_nvfp4_kernel

Conversation

@Oleg-Goncharov
Copy link
Collaborator

@Oleg-Goncharov Oleg-Goncharov commented Mar 6, 2026

Description

This PR adds a persistent grouped NVFP4 quantization + transpose kernel with static scheduling.
It is built on top of the PR#2738 [Common] Persistent Grouped MXFP8 quantization kernel

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Added persistent grouped kernel
  • Added test suite

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Oleg-Goncharov and others added 30 commits February 27, 2026 15:53
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
for more information, see https://pre-commit.ci

Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
for more information, see https://pre-commit.ci

Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
@NVIDIA NVIDIA deleted a comment from greptile-apps bot Mar 9, 2026
@NVIDIA NVIDIA deleted a comment from greptile-apps bot Mar 9, 2026
@NVIDIA NVIDIA deleted a comment from greptile-apps bot Mar 9, 2026
Comment on lines +419 to +434
const size_t tensor_id, const ShapeRepresentation shape_rep, const size_t last_logical_dim,
const int64_t *const __restrict__ last_dims_ptr) {
size_t cols_num = 0;
switch (shape_rep) {
case ShapeRepresentation::SAME_BOTH_DIMS:
case ShapeRepresentation::VARYING_FIRST_DIM:
cols_num = last_logical_dim;
break;
case ShapeRepresentation::VARYING_LAST_DIM:
case ShapeRepresentation::VARYING_BOTH_DIMS:
cols_num = static_cast<size_t>(last_dims_ptr[tensor_id]);
if (cols_num % 128 != 0) {
NVTE_DEVICE_ERROR(
"For non-single tensors, the last dimension of each tensor in a group "
"must be divisible by 128.");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_job_valid skips intra-tensor boundary check for non-SAME_BOTH_DIMS shapes

In the NVFP4 is_job_valid, once block_global_offset < tensor_end_offset is verified, the function returns true without checking whether the block's Y/X coordinates actually fall within [0, rows) and [0, cols):

const size_t tensor_end_offset = static_cast<size_t>(offsets_ptr[job.tensor_id + 1]);
if (job.block_global_offset >= tensor_end_offset) {
    return false;
}
return true;

The corresponding check in the MXFP8 version (group_quantize_mxfp8.cuh) also validates:

const size_t tensor_offset_from_start = job.block_global_offset - tensor_start_offset;
const size_t block_offset_Y_in_tensor = tensor_offset_from_start / job.cols;
const size_t block_offset_X_in_tensor = tensor_offset_from_start % job.cols;
if (block_offset_Y_in_tensor >= job.rows || block_offset_X_in_tensor >= job.cols) {
    return false;
}

For VARYING_LAST_DIM and VARYING_BOTH_DIMS shapes, omitting this check could allow stale or padding blocks (that are within tensor_end_offset but beyond the actual rows × cols footprint) to issue TMA loads from out-of-bounds addresses. Please consider adding the equivalent bounds check.

Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
std::vector<fp4e2m1> out_data_rowwise_h(total_elts / 2);
std::vector<fp4e2m1> out_data_colwise_h(total_elts / 2);
std::vector<fp8e4m3> out_scales_rowwise_h(rowwise_scales_num);
std::vector<fp8e4m3> out_scales_colwise_h(colwise_scales_num);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrong variable used in "more mismatches" condition

mismatch_messages is only appended while total_mismatches <= max_mismatches_to_print (3), so its size() can never exceed max_mismatches_to_print. The condition is therefore always false and the "... and X more mismatches" line is dead code — even inside the print_detailed_summary branch. The comparison should use total_mismatches:

Suggested change
std::vector<fp8e4m3> out_scales_colwise_h(colwise_scales_num);
if (total_mismatches > max_mismatches_to_print) {

Comment on lines +388 to +396
cudaMemcpy(last_dims_d, last_dims_h.data(), num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice);
cudaMemcpy(offsets_d, offsets_h.data(), (num_tensors + 1) * sizeof(int64_t), cudaMemcpyHostToDevice);

cudaMemset(out_data_rowwise_d, 0, out_data_size);
cudaMemset(out_data_colwise_d, 0, out_data_size);
cudaMemset(out_scales_rowwise_d, 0, rowwise_scales_size);
cudaMemset(out_scales_colwise_d, 0, colwise_scales_size);

NVTEShape logical_shape_ = nvte_make_shape(logical_shape.data(), logical_shape.size());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CUDA API return values are not checked

All cudaMalloc, cudaMemcpy, and cudaMemset calls in performTest silently ignore their return values. A failed allocation would leave the pointer uninitialized (or null) and the test would proceed, likely crashing or producing a spurious cudaGetLastError failure that obscures the real problem.

Consider wrapping the calls with a helper that asserts success, e.g.:

ASSERT_EQ(cudaMalloc((void**)&in_data_d, in_data_size), cudaSuccess);
ASSERT_EQ(cudaMemcpy(in_data_d, grouped_input.data(), in_data_size, cudaMemcpyHostToDevice), cudaSuccess);

This pattern applies to all CUDA API calls from the allocation block down through the cudaMemset calls (lines 388–410).

Comment on lines 163 to 164
* \param[in] stream CUDA stream used for the operation.
*/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Binary-incompatible API change without a breaking-change marker

nvte_group_quantize_dbias (and the five related nvte_group_quantize_dbias_d* functions) previously accepted NVTETensor dbias; this PR changes the parameter to NVTEGroupedTensor dbias. Any existing C/C++ caller that was compiled against the old header will silently pass the wrong type at runtime. Even though both types are opaque pointers at the ABI level, callers that stored the dbias as NVTETensor will need to migrate.

The PR description marks this as "New feature (non-breaking change)", but this signature change will break downstream callers (Python bindings, external C++ users) that previously compiled against NVTETensor dbias. It is worth auditing all internal call sites (Python pybind layer, etc.) and explicitly documenting the migration in the PR / changelog.

Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
@Oleg-Goncharov Oleg-Goncharov force-pushed the pr_persistent_grouped_nvfp4_kernel branch from 97ec071 to fef9220 Compare March 9, 2026 11:55
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
@Oleg-Goncharov Oleg-Goncharov force-pushed the pr_persistent_grouped_nvfp4_kernel branch from 2e289c9 to 9e37b4c Compare March 9, 2026 14:51
Comment on lines +104 to +115

const size_t rows = tensor_rows / chunk_dim_Y;
const size_t cols = last_logical_dim;

const size_t dbias_in_offset_Y =
(shape_rep == ShapeRepresentation::SAME_BOTH_DIMS)
? (tensor_id * (tensor_rows / chunk_dim_Y))
: (static_cast<size_t>(offsets_ptr[tensor_id]) / cols / chunk_dim_Y);

const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;

if (thread_id * nvec >= cols) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

group_reduce_dbias_kernel uses last_logical_dim as cols for all shape representations

cols is unconditionally set to last_logical_dim:

const size_t cols = last_logical_dim;

For VARYING_LAST_DIM and VARYING_BOTH_DIMS shapes, each tensor has a different last dimension. Using the scalar last_logical_dim for all tensors will produce incorrect partial-dbias strides and wrong output write offsets (tensor_id * cols assumes uniform column counts). The same issue affects the dbias_in_offset_Y calculation for those shape representations.

cast.h documents that "Grouped dbias is not yet supported for grouped tensors with a varying last dimension," but there is no runtime guard in grouped_reduce_dbias or this kernel to enforce that. If called with such shapes the function silently corrupts memory. Consider adding an explicit NVTE_CHECK(shape_rep != ShapeRepresentation::VARYING_LAST_DIM && shape_rep != ShapeRepresentation::VARYING_BOTH_DIMS, ...) guard in grouped_reduce_dbias before the kernel launch.

Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
@Oleg-Goncharov Oleg-Goncharov force-pushed the pr_persistent_grouped_nvfp4_kernel branch from f6b5928 to 6a7409d Compare March 9, 2026 15:05
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
@Oleg-Goncharov Oleg-Goncharov force-pushed the pr_persistent_grouped_nvfp4_kernel branch from 50a4921 to 6c5cc7f Compare March 9, 2026 16:05
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
@Oleg-Goncharov Oleg-Goncharov force-pushed the pr_persistent_grouped_nvfp4_kernel branch from c83b558 to f5e2ba0 Compare March 9, 2026 16:06
Comment on lines 175 to +176
void nvte_group_quantize_dbias(const NVTEGroupedTensor input, NVTEGroupedTensor output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream);
NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The six nvte_group_quantize_dbias* functions (this one and nvte_group_quantize_dbias_dgelu, nvte_group_quantize_dbias_dsilu, nvte_group_quantize_dbias_drelu) now take NVTEGroupedTensor dbias instead of NVTETensor dbias. This is a signature change that may affect existing C/C++ callers compiled against the old header. While this appears intentional as part of the grouped tensor API consolidation, consider auditing Python bindings and any external C++ code to ensure compatibility, and explicitly document the migration path in the changelog.

Comment on lines +1044 to +1055

const bool use_single_work_grid = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS ||
shape_rep == ShapeRepresentation::VARYING_FIRST_DIM);

const size_t first_logical_dim = input->logical_shape.data[0];
const size_t last_logical_dim = input->logical_shape.data[1];
const size_t elts_total = first_logical_dim * last_logical_dim;
const size_t num_tensors = input->num_tensors;

NVTE_CHECK(num_tensors <= MAX_SUPPORTED_TENSOR_DESCRIPTORS,
"Number of tensors in a group is larger than the MAX number of supported "
"descriptors (64).");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the SAME_BOTH_DIMS case, the kernel computes per-tensor row counts via integer division (first_logical_dim / num_tensors), which silently truncates if first_logical_dim is not exactly divisible by num_tensors. This causes incorrect base offsets and may skip or overwrite the last few rows. Add a host-side check:

if (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) {
    NVTE_CHECK(first_logical_dim % num_tensors == 0,
               "For SAME_BOTH_DIMS, first_logical_dim (", first_logical_dim,
               ") must be divisible by num_tensors (", num_tensors, ").");
}

Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
@Oleg-Goncharov Oleg-Goncharov force-pushed the pr_persistent_grouped_nvfp4_kernel branch from 811a146 to eace4a6 Compare March 9, 2026 17:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant