Skip to content

fa4 sdpa#1947

Open
Jackmin801 wants to merge 1 commit intomainfrom
feat-fa4
Open

fa4 sdpa#1947
Jackmin801 wants to merge 1 commit intomainfrom
feat-fa4

Conversation

@Jackmin801
Copy link
Member

@Jackmin801 Jackmin801 commented Mar 4, 2026

Note

Cursor Bugbot is generating a summary for commit cae5038. Configure here.

Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 2 potential issues.

Bugbot Autofix prepared fixes for both issues found in the latest run.

  • ✅ Fixed: Missing CHANGELOG entry for new config option
    • Added CHANGELOG entry documenting sdpa_fa4 as a new attention implementation option.
  • ✅ Fixed: Unregistered attention implementation passed to HuggingFace transformers
    • Modified get_model() to map sdpa_fa4 to sdpa when passing to AutoConfig.from_pretrained() since FA4 backend activation is handled separately.

Create PR

Or push these changes by commenting:

@cursor push e2f556bda3
Preview (e2f556bda3)
diff --git a/CHANGELOG.md b/CHANGELOG.md
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -50,6 +50,7 @@
 - **`orchestrator.tasks_per_minute`**: Added optional rate limiting for sandbox tasks per environment worker. Uses token bucket algorithm. If None (default), no rate limiting is applied (2026-02-02)
 - **`model.cp`**: When `cp>1` with `attn="flash_attention_3"`, require `model.impl="custom"` (FA3 ring-attention kernel only in custom path) (2026-02-06)
 - **`model.attn`**: Added `fa4` as an attention implementation option. Flash attention 4 is only supported with the custom implementation (#1726, 2026-02-06)
+- **`model.attn`**: Added `sdpa_fa4` as an attention implementation option. Uses PyTorch SDPA with FA4 backend (2026-03-04)
 - **`inference.model.enable_prefix_caching`**: Added flag to enable prefix caching in vLLM. Passed to vLLM as `--enable-prefix-caching` (default: None) (2026-02-08)
 - **`orchestrator.env.address`**: Added address field on `EnvConfig`. If set, connect to an environment server at this address; if None, spawn a server in a subprocess (2026-02-06)
 - **`orchestrator.env.extra_env_kwargs`**: Added on `EnvConfig`. Extra kwargs passed to the env (e.g. seq_len, interleaved_rollouts, score_rollouts). Auto-populated by the orchestrator for training envs; generally not recommended for user override. Main use case is to match these kwargs when running an env in an isolated environment server (default: {}) (2026-02-06)

diff --git a/src/prime_rl/trainer/model.py b/src/prime_rl/trainer/model.py
--- a/src/prime_rl/trainer/model.py
+++ b/src/prime_rl/trainer/model.py
@@ -165,10 +165,11 @@
     if is_vlm:
         logger.info(f"Detected vision-language model: {config.name}")
 
+    attn_for_hf = "sdpa" if config.attn == "sdpa_fa4" else config.attn
     model_config = cast(
         PretrainedConfig,
         AutoConfig.from_pretrained(
-            config.name, attn_implementation=config.attn, trust_remote_code=config.trust_remote_code
+            config.name, attn_implementation=attn_for_hf, trust_remote_code=config.trust_remote_code
         ),
     )
     model_config.use_cache = False
This Bugbot Autofix run was free. To enable autofix for future PRs, go to the Cursor dashboard.

# -- Shared trainer configs (used by both SFT and RL trainers) --

AttnImplementation: TypeAlias = Literal["sdpa", "flash_attention_2", "flash_attention_3", "fa4"]
AttnImplementation: TypeAlias = Literal["sdpa", "sdpa_fa4", "flash_attention_2", "flash_attention_3", "fa4"]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing CHANGELOG entry for new config option

Low Severity

Adding "sdpa_fa4" to the AttnImplementation type alias introduces a new valid value for the model.attn config field, but CHANGELOG.md is not updated. A precedent exists at line 52 of the changelog for the earlier fa4 addition. The project rule requires a changelog entry when configuration structures or usage patterns are modified.

Fix in Cursor Fix in Web

Triggered by project rule: BugBot Instructions

_register_fa4_attention_interface()

if config.attn == "sdpa_fa4":
_activate_sdpa_fa4_backend()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unregistered attention implementation passed to HuggingFace transformers

High Severity

When config.attn is "sdpa_fa4", the string "sdpa_fa4" is passed as attn_implementation to AutoConfig.from_pretrained, but HuggingFace transformers doesn't recognize this value. Unlike "fa4", which has _register_fa4_attention_interface() to register a dummy and flash_attention_4_only_with_custom_impl to restrict to custom impl, "sdpa_fa4" has neither a registration nor an impl restriction. This will cause model loading to fail.

Additional Locations (1)

Fix in Cursor Fix in Web

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant