Body:
Hi, thank you for this great implementation. I found a behavior that I am confused about and would like clarification:
- In model.py,
GPT.set_kv_cache() always does:
if self.mask_cache is None or self.mask_cache.size(3) != max_seq_length:
# passing `attn_mask` to SDPA disables the flash implementation...
self.mask_cache = build_mask_cache(max_seq_length, device)
- In
GPT.forward(), if input_pos is not None:
if self.mask_cache is None:
raise TypeError("You need to call `gpt.set_kv_cache()`")
mask = batched_index_select(self.mask_cache, 2, input_pos)
- In
CausalSelfAttention.scaled_dot_product_attention():
y = F.scaled_dot_product_attention(
q, k, v, attn_mask=mask, ..., is_causal=mask is None
)
So it looks like kv-cache path always uses attn_mask (non-null), which would set is_causal=False in PyTorch SDPA and appear to avoid flash attention path.
- Is this intentional design?
- Is there any recommended usage mode where kv-cache and flash-attn fast path are both active?
- Could this be a misunderstanding on my side?
Thanks for any clarification.
Body:
Hi, thank you for this great implementation. I found a behavior that I am confused about and would like clarification:
GPT.set_kv_cache()always does:GPT.forward(), ifinput_pos is not None:CausalSelfAttention.scaled_dot_product_attention():So it looks like kv-cache path always uses
attn_mask(non-null), which would setis_causal=Falsein PyTorch SDPA and appear to avoid flash attention path.Thanks for any clarification.