Skip to content

Gate FA2#370

Merged
bputzeys merged 2 commits into
releasefrom
main
Apr 21, 2026
Merged

Gate FA2#370
bputzeys merged 2 commits into
releasefrom
main

Conversation

@bputzeys
Copy link
Copy Markdown
Collaborator

No description provided.

dmiv-helical and others added 2 commits April 20, 2026 22:16
Rationale
---------
select_attn_backend previously returned "flash_attention_2" whenever
flash_attn was installed and the device was CUDA, without checking
whether the target model class actually declares FA2 support via HF's
dispatcher. For BertForMaskedLM (Geneformer) that silently routed the
model down a code path transformers can't actually dispatch, so the
"Loading ... in bfloat16 for flash_attention_2 compatibility" warning
wasn't just cosmetic noise — it flagged a branch that couldn't work.
The helical integration-tests job doesn't install flash_attn, so this
gap was invisible in CI.

Plan
----
* Add a supports_fa2 parameter to select_attn_backend. Only models
  whose class declares _supports_flash_attn / _supports_flash_attn_2
  can take the FA2 branch; others (Geneformer) fall back to sdpa.
* Pass supports_fa2=True from HelixmRNA. Leave Geneformer on the
  default (False) and annotate the call site so callers who want FA2
  for BertForMaskedLM know they have to wire flash_attn directly.
* Drop the now-unreachable bfloat16-for-FA2 warnings from Geneformer;
  the sdpa fallback path never triggers them.
* Add a flash-attn-integration CI job that installs flash_attn and
  smoke-tests both paths: Geneformer (regression guard — must still
  load on sdpa even with flash_attn present) and HelixmRNA (must
  actually run on the FA2 branch).
@bputzeys bputzeys merged commit 47bc7b8 into release Apr 21, 2026
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants