Skip to content

mask_cache in kv-cache path seems to force attn_mask, preventing flash attention #2220

@sea-with-sakura

Description

@sea-with-sakura

Body:

Hi, thank you for this great implementation. I found a behavior that I am confused about and would like clarification:

  1. In model.py, GPT.set_kv_cache() always does:
if self.mask_cache is None or self.mask_cache.size(3) != max_seq_length:
    # passing `attn_mask` to SDPA disables the flash implementation...
    self.mask_cache = build_mask_cache(max_seq_length, device)
  1. In GPT.forward(), if input_pos is not None:
if self.mask_cache is None:
    raise TypeError("You need to call `gpt.set_kv_cache()`")
mask = batched_index_select(self.mask_cache, 2, input_pos)
  1. In CausalSelfAttention.scaled_dot_product_attention():
y = F.scaled_dot_product_attention(
    q, k, v, attn_mask=mask, ..., is_causal=mask is None
)

So it looks like kv-cache path always uses attn_mask (non-null), which would set is_causal=False in PyTorch SDPA and appear to avoid flash attention path.

  • Is this intentional design?
  • Is there any recommended usage mode where kv-cache and flash-attn fast path are both active?
  • Could this be a misunderstanding on my side?

Thanks for any clarification.

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions