[perf] feat: add GDN (Gated DeltaNet) FLOPs calculator#2925
[perf] feat: add GDN (Gated DeltaNet) FLOPs calculator#2925
Conversation
Port the GDN FLOPs formula from Megatron-LM (PR #1989) into Bridge's flop_utils.py so that Qwen3.5 VL and Qwen3-Next models report accurate throughput numbers instead of treating all layers as standard attention. When experimental_attention_variant="gated_delta_net" is set on the model config, transformer_flops() now: - Parses linear_attention_freq to split layers into GDN vs standard attention - Computes per-layer GDN cost (in_proj + conv1d + gated delta rule + out_proj) - Produces a weighted self_attn_term over both layer types Signed-off-by: Chen Cui <chcui@nvidia.com>
📝 WalkthroughWalkthroughThis change adds support for computing floating-point operations for Gated DeltaNet (GDN) attention layers in mixed-attention configurations. The implementation conditionally recalculates self-attention FLOP costs based on frequency patterns specified in configuration, allowing different layers to use either GDN or standard attention mechanisms. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment Tip You can enable review details to help with troubleshooting, context usage and more.Enable the |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
tests/unit_tests/training/utils/test_flop_utils.py (1)
489-500: Make the explicit-list test verify the actual mask behavior.
assert flops > 0still passes if the list mask is ignored or its semantics flip. Please assert the expected 6-GDN/2-standard split, or compare against a hand-computed total, so the list branch is actually protected.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unit_tests/training/utils/test_flop_utils.py` around lines 489 - 500, The test must verify the list mask is actually applied: after building model_cfg in test_gdn_layer_freq_list, assert that model_cfg.linear_attention_freq (or the config field that stores the per-layer mask) equals the freq_list ([1,1,0,1,1,0,1,1]) and additionally assert sum(model_cfg.linear_attention_freq) == 6 to ensure a 6-GDN/2-standard split before calling num_floating_point_operations; keep the existing flops check if desired.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/megatron/bridge/training/utils/flop_utils.py`:
- Around line 389-411: When experimental_attention_variant == "gated_delta_net",
make the GDN config strict: if cfg.model.linear_attention_freq is None raise a
ValueError; if it's not an int or list raise a TypeError (do not silently fall
back to [1]*num_layers); when it's a list assert len(linear_attention_pattern)
== num_layers; and stop using default fallbacks for GDN dims—replace
getattr(cfg.model, "linear_key_head_dim", 128), "linear_value_head_dim",
"linear_num_key_heads", "linear_num_value_heads", and "linear_conv_kernel_dim"
with direct attribute access (e.g., cfg.model.linear_key_head_dim) so missing
fields raise immediately.
In `@tests/unit_tests/training/utils/test_flop_utils.py`:
- Around line 391-392: The test suite class TestGDNLayerFlops is not marked as a
unit test; add a pytest marker by importing pytest (if not already present) and
placing `@pytest.mark.unit` directly above the TestGDNLayerFlops class definition
so the whole class is selected by pytest's unit marker filtering.
---
Nitpick comments:
In `@tests/unit_tests/training/utils/test_flop_utils.py`:
- Around line 489-500: The test must verify the list mask is actually applied:
after building model_cfg in test_gdn_layer_freq_list, assert that
model_cfg.linear_attention_freq (or the config field that stores the per-layer
mask) equals the freq_list ([1,1,0,1,1,0,1,1]) and additionally assert
sum(model_cfg.linear_attention_freq) == 6 to ensure a 6-GDN/2-standard split
before calling num_floating_point_operations; keep the existing flops check if
desired.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: f3722b61-a992-43a5-85e3-4b7d34034089
📒 Files selected for processing (2)
src/megatron/bridge/training/utils/flop_utils.pytests/unit_tests/training/utils/test_flop_utils.py
- Make GDN config strict: raise ValueError when linear_attention_freq is None, raise TypeError on invalid types, assert list length matches num_layers (mirrors existing MoE validation pattern) - Replace getattr fallbacks for GDN dims with direct attribute access so missing fields fail explicitly - Add @pytest.mark.unit to TestGDNLayerFlops class - Strengthen test_gdn_layer_freq_list to verify 6/2 split by comparing against equivalent int freq=3 and against pure-standard baseline Signed-off-by: Chen Cui <chcui@nvidia.com>
|
/claude review |
Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com> Signed-off-by: Chen Cui <cxcui@alumni.cmu.edu>
…to adding-model-support skill - Step 4 (Discovery): Check for quantized weights (FP8/FP4) that silently break models without dequantization. Documents standalone script and in-bridge hook approaches. - Phase 2: Update FLOPs calculator when new architectural blocks (GDN, MTP, Mamba) differ from standard attention/MLP. References PR #2925 as example. Signed-off-by: Chen Cui <chcui@nvidia.com>
|
/ok to test a0c7541 |
What does this PR do?
Port the GDN (Gated DeltaNet) FLOPs formula from Megatron-LM into Bridge's
flop_utils.pyso that Qwen3.5 VL and Qwen3-Next models report accurate throughput numbers.Changelog
transformer_flops()whenexperimental_attention_variant="gated_delta_net"is setlinear_attention_freq(int or list) to split layers into GDN vs standard attention, matching Megatron-LM's conventionself_attn_termcombining GDN and standard-attention per-layer costsReference
Ported from Megatron-LM PR #1989 — "feat(moe): Support gated delta net for Qwen3-Next"
GitHub Actions CI
See the CI section in the Contributing doc for how to trigger the CI.
A Nvidia developer will need to approve and trigger the CI for external contributors.
Before your PR is "Ready for review"
Pre checks:
If you haven't finished some of the above items you can still open "Draft" PR.
Additional Information
Previously, Bridge's
flop_utils.pyhad no GDN support — Qwen3.5 VL and Qwen3-Next training reported FLOPs as if all layers were standard attention. This PR fixes that by implementing the same formula used in Megatron-LM'straining.py(lines 488-514).Summary by CodeRabbit
New Features
gated_delta_netattention variant with configurable layer mixing.linear_attention_freqconfiguration to designate which layers use the new attention type.Tests