Skip to content

Comments

[TorchOnnxToTorch] Add pad+scatter KV cache and mask to GQA #4471

Draft
IanWood1 wants to merge 2 commits intollvm:mainfrom
IanWood1:gqa-pt2-kv-cache-mask
Draft

[TorchOnnxToTorch] Add pad+scatter KV cache and mask to GQA #4471
IanWood1 wants to merge 2 commits intollvm:mainfrom
IanWood1:gqa-pt2-kv-cache-mask

Conversation

@IanWood1
Copy link
Contributor

Replace cat-based KV cache construction with pad+scatter.

  • Pad past K/V with zeros to total_sequence_length
  • Compute scatter indices from seqlens_k (past positions)
  • Use torch.aten.scatter.src to place current K/V at correct positions

Add causal attention mask generation:

  • Construct a 3D boolean mask: k_positions <= past_len + q_positions
  • Convert to float mask with 0.0 (attend) and -inf (masked)
  • Reshape to 4D [batch, 1, seq, kv_seq] for SDPA broadcasting

Also:

  • Add seqlens_k int32-to-int64 dtype cast that preserves tensor shape
  • Add tests for KV cache, causal mask, prefill, position IDs, f16
    dtype, variable seqlens_k, and scatter.src dim=2

Depends on #4470

Refactor the GroupQueryAttention ONNX-to-Torch lowering:

- Fix Q/K/V reshape to use reshape+transpose instead of direct reshape.
  The original code reshaped [batch, seq, hidden] directly to
  [batch, heads, seq, head_size], which incorrectly reinterprets the
  data layout. The correct approach reshapes to [batch, seq, heads,
  head_size] then transposes dims 1 and 2.

- Support dynamic batch and sequence dimensions by using aten.size.int
  instead of hardcoded constants.

- Simplify rotary embedding position ID calculation. Remove the complex
  isSubsequentPrompt conditional logic and replace with a direct
  arange + past_seqlens offset computation.

- Fix key rotary embedding to use kInput.getType() (kv_num_heads)
  instead of query.getType() (num_heads) for GQA configurations.

- Fix seqlens_k test values (0 for empty past, not 1).

- Add output transpose+reshape to correctly convert attention output
  from [batch, heads, seq, head_size] back to [batch, seq, hidden].

- Minor fixes: reject smoothSoftmax > 0, suppress unused
  totalSequenceLength warning, use loc variable consistently.
Replace cat-based KV cache construction with pad+scatter for correct
shared-buffer semantics. The previous approach used torch.aten.cat to
concatenate past and current K/V tensors, but this doesn't support
the ONNX GQA spec's shared-buffer model where seqlens_k controls
valid positions within a pre-allocated buffer.

New approach:
- Pad past K/V with zeros to total_sequence_length
- Compute scatter indices from seqlens_k (past positions)
- Use torch.aten.scatter.src to place current K/V at correct positions

Add causal attention mask generation:
- Construct a 3D boolean mask: k_positions <= past_len + q_positions
- Convert to float mask with 0.0 (attend) and -inf (masked)
- Reshape to 4D [batch, 1, seq, kv_seq] for SDPA broadcasting

Also:
- Add seqlens_k int32-to-int64 dtype cast that preserves tensor shape
- Add tests for KV cache, causal mask, prefill, position IDs, f16
  dtype, variable seqlens_k, and scatter.src dim=2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant