Skip to content

fix undefined values for tail elements in act quant kernels#4186

Merged
iamzainhuda merged 1 commit intomainfrom
fix-tail-elements-act-quant
Mar 27, 2026
Merged

fix undefined values for tail elements in act quant kernels#4186
iamzainhuda merged 1 commit intomainfrom
fix-tail-elements-act-quant

Conversation

@iamzainhuda
Copy link
Copy Markdown
Contributor

@iamzainhuda iamzainhuda commented Mar 26, 2026

Summary

Fixed a tail-handling bug in the blockwise FP8 activation quantization kernels that was corrupting reciprocal scale tensors and causing test_triton_fp8_gemm_1x128_128x128 to fail on ragged shapes.

The GEMM consumes a_s / b_s scale tensors produced by the quantization kernels, and those scale tensors could be corrupted at the tensor edges. In the LHS activation quant kernel, masked tail lanes were still storing scales into a compact column-major as_strided buffer. Because that buffer has no padding, logically invalid row writes could alias valid scale entries from the next column. The GEMM then used those corrupted scales and SQNR collapsed on small or ragged M.

Updated RHS and LHS blockwise quant kernels with tl.load(..., other=0.0) for masked tail loads. And masked reciprocral scale tl.store calls so invalid lanes don't write to scale buffers.

Failures before fix:
image

Testing

  • pytest -q test/prototype/blockwise_fp8_training/test_blockwise_kernels.py -k test_triton_fp8_gemm_1x128_128x128
  • pytest -q test/prototype/blockwise_fp8_training/test_blockwise_kernels.py
    • now passes entirely (42 passed)

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 26, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4186

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 1d57301 with merge base 96a9cdf (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 26, 2026
@iamzainhuda iamzainhuda added topic: bug fix Use this tag for PRs that fix bugs module: training quantize_ api training flow labels Mar 26, 2026
Copy link
Copy Markdown
Contributor

@danielvegamyhre danielvegamyhre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lgtm, thanks for fixing this!

@danielvegamyhre
Copy link
Copy Markdown
Contributor

(this fix is indeed needed, just note that the triton fp8 blockwise gemms were experimental and not performant, not actually used. we use torch._scaled_mm which dispatches to cublas)

@iamzainhuda iamzainhuda merged commit a6be48f into main Mar 27, 2026
23 of 26 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: training quantize_ api training flow topic: bug fix Use this tag for PRs that fix bugs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants