add torch.compile to blockwise quantized kernel unit tests#4187
add torch.compile to blockwise quantized kernel unit tests#4187iamzainhuda wants to merge 1 commit intomainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4187
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 25c1971 with merge base 96a9cdf ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@iamzainhuda i think there was a miscommunication, we don't want to directly wrap an individual triton custom op in torch.compile and test - we want to compile a full blockwise linear layer (this test) and ensure there are no graph breaks (fullgraph=True) and that numerics of outputs/grads undergo the pass with the same threshold testing as the eager mode tests. |
JiwaniZakir
left a comment
There was a problem hiding this comment.
The _maybe_compile helper in test_blockwise_kernels.py is clean and avoids repetition, but wrapping a freshly-created lambda each call means torch.compile receives a new function object on every test invocation with no opportunity for cache reuse across parametrized runs — this is fine for correctness but adds compile overhead. More importantly, there are no torch._dynamo.reset() calls between tests: compiled artifacts from one parametrized combination (e.g., block_size=128) can leak into the next (block_size=256), which can mask failures or produce misleading error messages in CI. It would be worth adding a torch._dynamo.reset() at the start of each test body when use_compile=True, or using a pytest fixture/autouse teardown for that. Also, the use_compile parametrize decorator is placed closer to the function than block_size in all cases, which means use_compile varies as the inner loop in the test matrix — this is a minor point but being consistent across all five tests about decorator ordering is good hygiene and it is consistent here.
Summary
The quantization tests in
test/prototype/blockwise_fp8_training/test_blockwise_kernels.pyare now parameterized withuse_compile, and a_maybe_compile(..., fullgraph=True)helper is used to compile the Triton quantization entry points when requested.Testing
pytest test/prototype/blockwise_fp8_training/test_blockwise_kernels.py