[TorchOnnxToTorch] Add pad+scatter KV cache and mask to GQA #4471
Draft
[TorchOnnxToTorch] Add pad+scatter KV cache and mask to GQA #4471
Conversation
Refactor the GroupQueryAttention ONNX-to-Torch lowering: - Fix Q/K/V reshape to use reshape+transpose instead of direct reshape. The original code reshaped [batch, seq, hidden] directly to [batch, heads, seq, head_size], which incorrectly reinterprets the data layout. The correct approach reshapes to [batch, seq, heads, head_size] then transposes dims 1 and 2. - Support dynamic batch and sequence dimensions by using aten.size.int instead of hardcoded constants. - Simplify rotary embedding position ID calculation. Remove the complex isSubsequentPrompt conditional logic and replace with a direct arange + past_seqlens offset computation. - Fix key rotary embedding to use kInput.getType() (kv_num_heads) instead of query.getType() (num_heads) for GQA configurations. - Fix seqlens_k test values (0 for empty past, not 1). - Add output transpose+reshape to correctly convert attention output from [batch, heads, seq, head_size] back to [batch, seq, hidden]. - Minor fixes: reject smoothSoftmax > 0, suppress unused totalSequenceLength warning, use loc variable consistently.
Replace cat-based KV cache construction with pad+scatter for correct shared-buffer semantics. The previous approach used torch.aten.cat to concatenate past and current K/V tensors, but this doesn't support the ONNX GQA spec's shared-buffer model where seqlens_k controls valid positions within a pre-allocated buffer. New approach: - Pad past K/V with zeros to total_sequence_length - Compute scatter indices from seqlens_k (past positions) - Use torch.aten.scatter.src to place current K/V at correct positions Add causal attention mask generation: - Construct a 3D boolean mask: k_positions <= past_len + q_positions - Convert to float mask with 0.0 (attend) and -inf (masked) - Reshape to 4D [batch, 1, seq, kv_seq] for SDPA broadcasting Also: - Add seqlens_k int32-to-int64 dtype cast that preserves tensor shape - Add tests for KV cache, causal mask, prefill, position IDs, f16 dtype, variable seqlens_k, and scatter.src dim=2
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Replace cat-based KV cache construction with pad+scatter.
Add causal attention mask generation:
Also:
dtype, variable seqlens_k, and scatter.src dim=2
Depends on #4470