Skip to content

Conversation

@Oleg-Goncharov
Copy link
Collaborator

Description

This PR fuses pre-swizzling into the grouped MXFP8 quantization kernel so that scaling factors are stored in the format expected by GEMM. It builds on PR#2586: [Common] MXFP8 kernel for grouped tensors and can be merged after that PR lands.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Added a template parameter to the kernel to control the scaling-factor format.
  • Added a new member to GroupedTensor

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Oleg-Goncharov and others added 20 commits January 21, 2026 16:50
Signed-off-by: Oleg Goncharov <[email protected]>
Signed-off-by: Oleg Goncharov <[email protected]>
Signed-off-by: Oleg Goncharov <[email protected]>
Signed-off-by: Oleg Goncharov <[email protected]>
Signed-off-by: Oleg Goncharov <[email protected]>
Signed-off-by: Oleg Goncharov <[email protected]>
Signed-off-by: Oleg Goncharov <[email protected]>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 28, 2026

Greptile Overview

Greptile Summary

This PR extends the grouped MXFP8 quantization kernel to support pre-swizzled scaling factors by adding a WITH_GEMM_SWIZZLED_SCALES template parameter. The swizzling enables the scaling factors to be stored in the format expected by GEMM operations, eliminating the need for post-processing.

Key Changes

  • New kernel implementation (group_quantize_mxfp8.cuh): 983-line CUDA kernel supporting grouped tensor quantization with optional scale swizzling via template parameter
  • API extension (common.h): Added with_gemm_swizzled_scales boolean member to GroupedTensor struct
  • Public API additions: New C functions for grouped quantization (nvte_group_quantize, nvte_group_quantize_dbias) and grouped activation functions (GeLU, SiLU, ReLU variants)
  • Dispatch layer: Extended quantize.cuh with group_quantize_fwd_helper and group_quantize_bwd_helper template functions
  • Comprehensive testing: 777-line test file covering various shape representations and activation functions

Implementation Details

The kernel uses the gemm_swizzled_scale_idx function to compute swizzled indices when WITH_GEMM_SWIZZLED_SCALES=true. The swizzling is applied consistently for both rowwise and columnwise scaling:

  • Colwise: gemm_swizzled_scale_idx(X, Y, rows/128)
  • Rowwise: gemm_swizzled_scale_idx(Y, X, cols/128)

The implementation maintains feature parity with the base kernel, supporting activations (GeLU, SiLU, ReLU), activation derivatives, and dbias computation.

Confidence Score: 4/5

  • This PR is safe to merge after verifying test coverage for the swizzled scale path
  • The implementation follows established patterns from existing kernels (quantize_mxfp8.cuh) and correctly applies the swizzling logic. However, the PR description mentions tests are not yet added (checklist item unchecked), though test files are present in the changeset
  • Check that tests/cpp/operator/test_cast_mxfp8_grouped.cu includes test cases specifically for with_gemm_swizzled_scales=true

Important Files Changed

Filename Overview
transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh New file: implements CUDA kernel for grouped MXFP8 quantization with fused pre-swizzling controlled by WITH_GEMM_SWIZZLED_SCALES template parameter
transformer_engine/common/common.h Added with_gemm_swizzled_scales boolean member to GroupedTensor struct to control scaling factor format
transformer_engine/common/cast/dispatch/quantize.cuh Added dispatch helpers group_quantize_fwd_helper and group_quantize_bwd_helper for grouped tensors, routing to mxfp8::group_quantize
tests/cpp/operator/test_cast_mxfp8_grouped.cu New test file with comprehensive tests for grouped MXFP8 quantization including various shape representations and activation functions

Sequence Diagram

sequenceDiagram
    participant User
    participant API as nvte_group_quantize
    participant Dispatch as group_quantize_fwd_helper
    participant Kernel as group_quantize_mxfp8_kernel
    participant Swizzle as gemm_swizzled_scale_idx

    User->>API: Call nvte_group_quantize(input, output, stream)
    API->>Dispatch: group_quantize_fwd_helper<IS_ACT, OP>()
    Dispatch->>Dispatch: Check scaling_mode (MXFP8_1D_SCALING)
    Dispatch->>Kernel: mxfp8::group_quantize(input, output, ...)
    
    Kernel->>Kernel: Read with_gemm_swizzled_scales from output->with_gemm_swizzled_scales
    Kernel->>Kernel: Instantiate kernel with WITH_GEMM_SWIZZLED_SCALES template parameter
    
    alt Multiple tensors (not single tensor)
        Kernel->>Kernel: Launch update_tma_descriptors kernel
        Kernel->>Kernel: Update tensor map descriptors per tensor
    end
    
    Kernel->>Kernel: Launch group_quantize_mxfp8_kernel<<<grid, block>>>
    
    loop For each tile in tensor
        Kernel->>Kernel: Load data via TMA
        Kernel->>Kernel: Compute activations (if IS_ACT or IS_DACT)
        
        alt Colwise Scaling
            Kernel->>Kernel: Compute column-wise amax
            Kernel->>Kernel: Convert to E8M0 scaling factor
            
            alt WITH_GEMM_SWIZZLED_SCALES
                Kernel->>Swizzle: gemm_swizzled_scale_idx(x, y, num_tiles)
                Swizzle-->>Kernel: Return swizzled index
            else No swizzling
                Kernel->>Kernel: Use compact index (y * stride + x)
            end
            
            Kernel->>Kernel: Store scale at computed index
            Kernel->>Kernel: Apply scale and quantize to MXFP8
        end
        
        alt Rowwise Scaling
            Kernel->>Kernel: Compute row-wise amax
            Kernel->>Kernel: Convert to E8M0 scaling factor
            
            alt WITH_GEMM_SWIZZLED_SCALES
                Kernel->>Swizzle: gemm_swizzled_scale_idx(y, x, num_tiles)
                Swizzle-->>Kernel: Return swizzled index
            else No swizzling
                Kernel->>Kernel: Use compact index (y * stride + x)
            end
            
            Kernel->>Kernel: Store scale at computed index
            Kernel->>Kernel: Apply scale and quantize to MXFP8
        end
        
        Kernel->>Kernel: Store quantized data via TMA
    end
    
    alt IS_DBIAS
        Kernel->>Kernel: Reduce dbias along columns
    end
    
    Kernel-->>User: Return quantized output with swizzled scales
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants