FIX: TF32 warning (#43012)#43015
Conversation
|
[For maintainers] Suggested jobs to run (before merge) run-slow: afmoe, apertus, arcee, aria, bamba, bitnet, chameleon, csm, cwm, dbrx, deepseek_v3, dia, diffllama, doge, dots1, emu3 |
ArthurZucker
left a comment
There was a problem hiding this comment.
This does not sound bad but at the same time this is very specific, we want users to know that full precision could me necessary
|
I did some checking and the output of |
What does this PR do?
This PR replaces the matrix multiplication operator (@) with broadcasting element-wise multiplication (*) in the RotaryEmbedding implementation for several major models (Llama, Mistral, Mixtral, Qwen2, Gemma, Gemma2).
When compiling a model with torch.compile in bfloat16, the RoPE frequency calculation (which is intentionally kept in float32 for precision) triggers a UserWarning regarding TensorFloat32 (TF32) if it's not enabled.
Since the shapes involved in this specific operation [batch, dim/2, 1] and [batch, 1, seq_len] result in an outer product, using @ is mathematically equivalent to * with broadcasting. However, using * avoids the "matrix multiplication" code path in the compiler, effectively silencing the false-positive warning and potentially offering a minor performance optimization by avoiding a full GEMM call for a simple outer product.
Fixes #43012
Before submitting
Pull Request section?
Who can review?
Anyone in the community is free to review the PR once the tests have passed.