-
Notifications
You must be signed in to change notification settings - Fork 603
[NVFP4][Dense/MoE] Integrate Cutlass NVFP4 Row-Cast-Col-RHT-Transpose-Cast Fusion Kernel #2555
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[NVFP4][Dense/MoE] Integrate Cutlass NVFP4 Row-Cast-Col-RHT-Transpose-Cast Fusion Kernel #2555
Conversation
Greptile OverviewGreptile SummaryIntegrates a Cutlass-based CUDA fusion kernel for NVFP4 quantization that combines rowwise casting and columnwise RHT-transpose-cast operations into a single kernel for dense linear layers and shared experts. The optimization reduces memory bandwidth by reading high-precision input once instead of twice, with eligibility detection based on BF16 input and 64×128 dimension alignment. The NVTE_USE_FAST_MATH environment variable controls fast-math optimizations for RHT kernels. Confidence Score: 2/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant Input as BF16 Input Tensor
participant Quantizer as NVFP4Quantizer
participant Fusion as Fused Kernel
participant Unfused as Unfused Path
participant Output as NVFP4 Output
Input->>Quantizer: "quantize_impl()"
Quantizer->>Quantizer: "Check: BF16 & rows%64==0 & cols%128==0"
alt Fusion Eligible
Quantizer->>Fusion: "nvte_hadamard_transform_cast_fusion()"
Fusion->>Fusion: "Single kernel: Rowwise + RHT + Colwise"
Fusion->>Output: "Both quantizations complete"
else Unfused Path
alt Rowwise enabled
Quantizer->>Unfused: "nvte_quantize_v2(rowwise)"
Unfused->>Output: "Rowwise quantization"
end
alt Columnwise enabled
Quantizer->>Unfused: "nvte_hadamard_transform()"
Unfused->>Unfused: "Apply RHT"
Unfused->>Unfused: "nvte_quantize_v2(columnwise)"
Unfused->>Output: "Columnwise quantization"
end
end
|
c80932f to
fc42825
Compare
|
/te-ci arm L1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (2)
8 files reviewed, 2 comments
Signed-off-by: Zhongbo Zhu <[email protected]>
Signed-off-by: Zhongbo Zhu <[email protected]>
Signed-off-by: Zhongbo Zhu <[email protected]>
Signed-off-by: Zhongbo Zhu <[email protected]>
Signed-off-by: Zhongbo Zhu <[email protected]>
Signed-off-by: Zhongbo Zhu <[email protected]>
2bc695e to
6ea9dab
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR integrates a Cutlass-based fusion kernel that combines row-wise quantization and column-wise RHT (Random Hadamard Transform) + quantization + transpose operations for NVFP4 dense linear layers and shared experts. The key optimization reduces memory bandwidth by reading high-precision input data once instead of twice.
Key Changes
New Fusion Kernel (row_cast_col_hadamard_transform_cast_fusion.cu):
- Implements
nvte_hadamard_transform_cast_fusionAPI that performs both rowwise and columnwise quantization in a single pass - Uses MMA hardware for efficient Hadamard transform computation
- Eligible when input is BF16 with dimensions divisible by 64×128
- Reads pre-computed amax values to calculate FP8 scaling factors
- Supports stochastic rounding and fast math optimization flags
Refactored Quantizer Logic (quantizer.cpp):
- Moved unfused RHT path into
quantize_with_rht_unfused_helpermethod for cleaner code organization - Improved RNG state handling: single RNG state when fusion is used, separate states for rowwise/columnwise when unfused
- Added
NVTE_USE_FAST_MATHenvironment variable support for accelerating high-precision math operations - Eligibility check moved before RNG state generation to avoid unnecessary work
Extended Test Coverage (test_nvfp4_rht_quantize_exact.py):
- Added "columnwise-only" quantization mode testing alongside existing "quantize" and "quantize_transpose" modes
- Tests now validate rowwise/columnwise results conditionally based on the quantization mode
Grouped Quantization Support (cast.cpp):
- Split-quantize path now uses fused kernel when all tensors have 128-aligned dimensions
- Bulk RNG state generation for grouped kernels (single state shared across splits)
- Fast math flag propagation to all quantization configs
Architecture Notes
The fusion provides optimal performance when:
- Input dtype is BF16
- Rows are divisible by 64 (MMA tile requirement)
- Columns are divisible by 128 (MMA tile requirement)
When these conditions aren't met, the code gracefully falls back to the unfused path with separate kernel launches for rowwise and columnwise quantization.
Confidence Score: 4/5
- This PR is safe to merge with minimal risk after addressing documentation and TODO items mentioned in the PR description
- Score of 4 reflects a well-engineered feature with thorough implementation. The code demonstrates good software practices: clean refactoring with extracted helper methods, proper error handling, graceful fallback paths, and comprehensive test coverage including the new columnwise-only mode. The fusion kernel follows established patterns from the grouped quantization PR #2411. Deducted 1 point due to: (1) PR author notes cutlass deprecation warnings need addressing, (2) TODOs remain about potentially defaulting fast math on, and (3) the ~1400 line CUDA kernel file has limited inline documentation for complex template logic
- The main CUDA kernel file (row_cast_col_hadamard_transform_cast_fusion.cu) would benefit from additional inline comments explaining the template parameter switches and MMA computation flow, but no files have critical issues requiring immediate attention
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/csrc/quantizer.cpp | 4/5 | Refactored NVFP4 quantize_impl to use new fused RHT cast kernel, extracted unfused helper, improved RNG state handling for fused vs unfused paths |
| transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu | 4/5 | New CUDA kernel implementing fused row-cast and column-RHT-transpose-cast using Cutlass MMA hardware for BF16 inputs with 64x128 alignment |
| transformer_engine/common/include/transformer_engine/hadamard_transform.h | 5/5 | Added new API function nvte_hadamard_transform_cast_fusion for dense layer quantization, marked old columnwise function for future deprecation |
| transformer_engine/pytorch/csrc/extensions/cast.cpp | 4/5 | Added NVTE_USE_FAST_MATH env var support in split_quantize for grouped NVFP4 kernels, improved RNG state setup with bulk generation flag |
| tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py | 5/5 | Extended test coverage to support columnwise-only quantization mode, added return_identity parameter to test all three modes |
Sequence Diagram
sequenceDiagram
participant User as User Code
participant Quantizer as NVFP4Quantizer
participant API as nvte_hadamard_transform_cast_fusion
participant Kernel as row_col_rht_gemm_ntt_w_sfc
participant AmaxKernel as nvte_hadamard_transform_amax
User->>Quantizer: quantize(input, output)
Quantizer->>Quantizer: Check eligibility (BF16, rows%64==0, cols%128==0)
alt With RHT and eligible for fusion
Quantizer->>AmaxKernel: Compute rowwise & columnwise amax
AmaxKernel-->>Quantizer: amax values populated
alt Stochastic rounding enabled
Quantizer->>Quantizer: Generate RNG state
end
alt Fast math enabled (NVTE_USE_FAST_MATH)
Quantizer->>Quantizer: Set use_fast_math flag
end
Quantizer->>API: Call with input, output, hadamard_matrix, quant_config
API->>Kernel: Launch fused kernel
Kernel->>Kernel: Read amax values
Kernel->>Kernel: Perform rowwise quantization to FP4
Kernel->>Kernel: Compute RHT using MMA hardware
Kernel->>Kernel: Transpose and quantize to FP4
Kernel->>Kernel: Write FP8 scales
Kernel-->>API: Complete
API-->>Quantizer: Return
else Not eligible for fusion
Quantizer->>AmaxKernel: Compute amax
AmaxKernel-->>Quantizer: amax values
alt Rowwise usage
Quantizer->>Quantizer: Call nvte_quantize_v2 for rowwise
end
alt Columnwise usage
Quantizer->>Quantizer: Call nvte_hadamard_transform for RHT
Quantizer->>Quantizer: Call nvte_quantize_v2 for columnwise
end
end
Quantizer-->>User: Quantized output
Signed-off-by: Zhongbo Zhu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
Signed-off-by: Zhongbo Zhu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
7 files reviewed, 7 comments
| bool eligible_for_rht_cast_fusion = | ||
| input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fusion eligibility check uses rows % 64 == 0 && cols % 128 == 0 but this doesn't match the alignment check in cast.cpp line 815 which uses split_section % 128 == 0. The mismatched alignment requirements could cause the fusion kernel to be called with incompatible dimensions. Should this be rows % 128 == 0 && cols % 128 == 0 to match the grouped kernel requirements?
| const bool need_separate_columnwise_rng = this->stochastic_rounding && this->with_rht && | ||
| this->columnwise_usage && | ||
| (!eligible_for_rht_cast_fusion); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for need_separate_columnwise_rng checks !eligible_for_rht_cast_fusion, but the actual kernel path selection happens later at line 1673. If the fusion path is taken, both rowwise and columnwise quantization use the same RNG state (from quant_config), but if unfused, they use separate states. This could cause non-deterministic behavior when conditions change slightly (e.g., input shape changes by a few elements). Consider generating separate RNG states unconditionally when both rowwise and columnwise are enabled, then have the fusion kernel use both appropriately.
| if (eligible_for_rht_cast_fusion) { | ||
| // fusion kernel requires passing in RHT matrix directly for maximum performance | ||
| auto rht_matrix_nvte = makeTransformerEngineTensor(this->rht_matrix); | ||
| // Fusion kernel that does the following: | ||
| // 1. Rowwise quantization | ||
| // 2. RHT followed by columnwise quantization & transpose | ||
| NVTE_SCOPED_GIL_RELEASE({ | ||
| nvte_hadamard_transform_cast_fusion(input.data(), out.data(), rht_matrix_nvte.data(), | ||
| quant_config, stream); | ||
| }); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing validation that input dimensions meet fusion kernel requirements before calling nvte_hadamard_transform_cast_fusion. The kernel expects dimensions aligned to specific tile sizes (128×128 based on the kernel code), but there's no check that rows and cols satisfy these constraints. If dimensions are misaligned, the kernel could produce incorrect results or crash.
| // Assert checks problem size should be multiple of 64 | ||
| assert(M % 64 == 0); | ||
| assert(N % 64 == 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assertions assert(M % 64 == 0) and assert(N % 64 == 0) in the kernel launch function will cause program termination in release builds where asserts are disabled. For production code handling user inputs, these should be replaced with proper error checking using NVTE_CHECK that works in all build configurations.
|
|
||
| // Allocate SMEMork |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo in comment: "SMEMork" should be "SMEM work" or "SMEM fork"
Additional Comments (2)
|
Description
Note: #2558 reported a bug in #2411. Fix is here #2564: make sure you cherry-pick this one too before it's in main.
Previously, similar optimization has been applied for MOE grouped quantize with RHT in #2411. This PR targets the dense linear layers & shared experts when being quantized to NVFP4. Having this fusion means high precision input only needs to be read once while without this fusion, it needs to be read twice.
Similarly, we have env var NVTE_USE_FAST_MATH to control the numerical behavior of RHT quant fusion kernel to accelerate it further. The fast math is only applied to the high precision math so it will have minimal impact of the training convergence.
What fast-math toggle controls:
Therefore, I DO recommend turn it on since it will significantly improve the RHT kernel performance.
The only reason why it's still not default open is because we want ZERO TOLERNACE test between our CUDA quantize kernels and our pytorch-based emulated quantize references. With fast math toggle turned on, it's hard to pass test with zero tolerance without further investigation of how to relax the test conditions while still providing high confidence of the test case.
TODO items:
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: