-
Notifications
You must be signed in to change notification settings - Fork 616
[Common] Fuse pre-swizzling into grouped MXFP8 quantization kernel #2630
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Common] Fuse pre-swizzling into grouped MXFP8 quantization kernel #2630
Conversation
Signed-off-by: Oleg Goncharov <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Oleg Goncharov <[email protected]>
Signed-off-by: Oleg Goncharov <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Oleg Goncharov <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Oleg Goncharov <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Oleg Goncharov <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Oleg Goncharov <[email protected]>
Signed-off-by: Oleg Goncharov <[email protected]>
Signed-off-by: Oleg Goncharov <[email protected]>
Signed-off-by: Oleg Goncharov <[email protected]>
for more information, see https://pre-commit.ci
Greptile OverviewGreptile SummaryThis PR extends the grouped MXFP8 quantization kernel to support pre-swizzled scaling factors by adding a Key Changes
Implementation DetailsThe kernel uses the
The implementation maintains feature parity with the base kernel, supporting activations (GeLU, SiLU, ReLU), activation derivatives, and dbias computation. Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant API as nvte_group_quantize
participant Dispatch as group_quantize_fwd_helper
participant Kernel as group_quantize_mxfp8_kernel
participant Swizzle as gemm_swizzled_scale_idx
User->>API: Call nvte_group_quantize(input, output, stream)
API->>Dispatch: group_quantize_fwd_helper<IS_ACT, OP>()
Dispatch->>Dispatch: Check scaling_mode (MXFP8_1D_SCALING)
Dispatch->>Kernel: mxfp8::group_quantize(input, output, ...)
Kernel->>Kernel: Read with_gemm_swizzled_scales from output->with_gemm_swizzled_scales
Kernel->>Kernel: Instantiate kernel with WITH_GEMM_SWIZZLED_SCALES template parameter
alt Multiple tensors (not single tensor)
Kernel->>Kernel: Launch update_tma_descriptors kernel
Kernel->>Kernel: Update tensor map descriptors per tensor
end
Kernel->>Kernel: Launch group_quantize_mxfp8_kernel<<<grid, block>>>
loop For each tile in tensor
Kernel->>Kernel: Load data via TMA
Kernel->>Kernel: Compute activations (if IS_ACT or IS_DACT)
alt Colwise Scaling
Kernel->>Kernel: Compute column-wise amax
Kernel->>Kernel: Convert to E8M0 scaling factor
alt WITH_GEMM_SWIZZLED_SCALES
Kernel->>Swizzle: gemm_swizzled_scale_idx(x, y, num_tiles)
Swizzle-->>Kernel: Return swizzled index
else No swizzling
Kernel->>Kernel: Use compact index (y * stride + x)
end
Kernel->>Kernel: Store scale at computed index
Kernel->>Kernel: Apply scale and quantize to MXFP8
end
alt Rowwise Scaling
Kernel->>Kernel: Compute row-wise amax
Kernel->>Kernel: Convert to E8M0 scaling factor
alt WITH_GEMM_SWIZZLED_SCALES
Kernel->>Swizzle: gemm_swizzled_scale_idx(y, x, num_tiles)
Swizzle-->>Kernel: Return swizzled index
else No swizzling
Kernel->>Kernel: Use compact index (y * stride + x)
end
Kernel->>Kernel: Store scale at computed index
Kernel->>Kernel: Apply scale and quantize to MXFP8
end
Kernel->>Kernel: Store quantized data via TMA
end
alt IS_DBIAS
Kernel->>Kernel: Reduce dbias along columns
end
Kernel-->>User: Return quantized output with swizzled scales
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, no comments
Description
This PR fuses pre-swizzling into the grouped MXFP8 quantization kernel so that scaling factors are stored in the format expected by GEMM. It builds on PR#2586: [Common] MXFP8 kernel for grouped tensors and can be merged after that PR lands.
Type of change
Changes
GroupedTensorChecklist: