Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions src/megatron/bridge/training/utils/flop_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,62 @@ def transformer_flops():
)
)

# Handle GDN (Gated DeltaNet) hybrid attention variant.
# When experimental_attention_variant is "gated_delta_net", a fraction of the
# layers use GDN instead of standard attention. Override self_attn_term with a
# weighted sum of GDN and standard-attention per-layer costs.
experimental_attention_variant = getattr(cfg.model, "experimental_attention_variant", None)
if experimental_attention_variant == "gated_delta_net":
linear_attention_freq = cfg.model.linear_attention_freq
if linear_attention_freq is None:
raise ValueError(
"linear_attention_freq must be set when experimental_attention_variant='gated_delta_net'"
)
if isinstance(linear_attention_freq, int):
linear_attention_pattern = [
0 if ((i + 1) % linear_attention_freq == 0) else 1 for i in range(num_layers)
]
elif isinstance(linear_attention_freq, list):
if len(linear_attention_pattern) != num_layers:
raise ValueError(
f"Invalid length of linear_attention_pattern: {len(linear_attention_pattern)}, "
f"expected {num_layers}, "
f"current linear_attention_freq: {linear_attention_freq}"
)
else:
raise TypeError(
f"linear_attention_freq must be int or list, got {type(linear_attention_freq).__name__}"
)

num_gdn_layers = sum(linear_attention_pattern)
num_standard_attn_layers = num_layers - num_gdn_layers

standard_self_attn_per_layer = self_attn_term / num_layers if num_layers > 0 else 0

qk_head_dim = cfg.model.linear_key_head_dim
v_head_dim = cfg.model.linear_value_head_dim
num_qk_heads = cfg.model.linear_num_key_heads
num_v_heads = cfg.model.linear_num_value_heads
conv_kernel_dim = cfg.model.linear_conv_kernel_dim

qk_dim = qk_head_dim * num_qk_heads
v_dim = v_head_dim * num_v_heads

gdn_self_attn_per_layer = (
3
* 2
* (
cfg.model.hidden_size * (2 * qk_dim + 2 * v_dim + 2 * num_v_heads)
+ conv_kernel_dim * (2 * qk_dim + v_dim)
+ num_v_heads * (v_head_dim**2) * 4
+ cfg.model.hidden_size * v_dim
)
)

self_attn_term = (
gdn_self_attn_per_layer * num_gdn_layers + standard_self_attn_per_layer * num_standard_attn_layers
)

padded_vocab_size = calculate_padded_vocab_size(
cfg.model.vocab_size,
cfg.model.make_vocab_size_divisible_by,
Expand Down
221 changes: 221 additions & 0 deletions tests/unit_tests/training/utils/test_flop_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ class MockModelConfig:
group_query_attention: bool = True
gated_linear_unit: bool = True
activation_func: object = field(default=None)
# GDN (Gated DeltaNet) settings
experimental_attention_variant: str | None = None
linear_attention_freq: int | list | None = None
linear_conv_kernel_dim: int = 4
linear_key_head_dim: int = 128
linear_value_head_dim: int = 128
linear_num_key_heads: int = 16
linear_num_value_heads: int = 48

def __post_init__(self):
import torch.nn.functional as F
Expand Down Expand Up @@ -380,6 +388,219 @@ def test_swiglu_scaling_factor(self):
assert flops_swiglu == expected_swiglu, f"SwiGLU: expected {expected_swiglu:.2e} but got {flops_swiglu:.2e}"


@pytest.mark.unit
class TestGDNLayerFlops:
"""Tests for Gated DeltaNet (GDN) FLOPs calculation in transformer_flops path."""

def _qwen35_27b_config(self, **overrides):
"""Return a MockModelConfig resembling Qwen3.5-27B (dense, 64 layers, freq=4)."""
defaults = dict(
num_layers=64,
hidden_size=5120,
seq_length=4096,
ffn_hidden_size=17408,
num_attention_heads=24,
num_query_groups=4,
kv_channels=256,
vocab_size=248320,
make_vocab_size_divisible_by=128,
tensor_model_parallel_size=1,
gated_linear_unit=True,
experimental_attention_variant="gated_delta_net",
linear_attention_freq=4,
linear_conv_kernel_dim=4,
linear_key_head_dim=128,
linear_value_head_dim=128,
linear_num_key_heads=16,
linear_num_value_heads=48,
)
defaults.update(overrides)
return MockModelConfig(**defaults)

def test_gdn_flops_differ_from_pure_attention(self):
"""GDN-enabled config should produce different FLOPs than pure-attention baseline."""
batch_size = 1
gdn_cfg = MockConfigContainer(model=self._qwen35_27b_config())
baseline_cfg = MockConfigContainer(model=self._qwen35_27b_config(experimental_attention_variant=None))
gdn_flops = num_floating_point_operations(gdn_cfg, batch_size=batch_size)
baseline_flops = num_floating_point_operations(baseline_cfg, batch_size=batch_size)
assert gdn_flops != baseline_flops, "GDN FLOPs should differ from pure-attention FLOPs"
assert gdn_flops > 0

def test_gdn_only_layers(self):
"""With linear_attention_freq=1 (no standard attn), self_attn_term should be pure GDN."""
batch_size = 1
num_layers = 4
hidden_size = 1024
seq_length = 512
vocab_size = 32000
qk_head_dim = 64
v_head_dim = 64
num_qk_heads = 8
num_v_heads = 16
conv_kernel_dim = 4

model_cfg = MockModelConfig(
num_layers=num_layers,
hidden_size=hidden_size,
seq_length=seq_length,
ffn_hidden_size=4096,
num_attention_heads=8,
num_query_groups=8,
kv_channels=128,
vocab_size=vocab_size,
make_vocab_size_divisible_by=128,
tensor_model_parallel_size=1,
gated_linear_unit=False,
experimental_attention_variant="gated_delta_net",
linear_attention_freq=1,
linear_conv_kernel_dim=conv_kernel_dim,
linear_key_head_dim=qk_head_dim,
linear_value_head_dim=v_head_dim,
linear_num_key_heads=num_qk_heads,
linear_num_value_heads=num_v_heads,
)
cfg = MockConfigContainer(model=model_cfg)
actual_flops = num_floating_point_operations(cfg, batch_size=batch_size)

# freq=1: pattern = [0 if (i+1)%1==0 else 1 for i in range(4)] = [0,0,0,0]
# All layers are standard attention, 0 GDN layers.
# This is because freq=1 means every layer is standard attention.
# So actual_flops should equal baseline (no GDN).
baseline_cfg = MockConfigContainer(
model=MockModelConfig(
num_layers=num_layers,
hidden_size=hidden_size,
seq_length=seq_length,
ffn_hidden_size=4096,
num_attention_heads=8,
num_query_groups=8,
kv_channels=128,
vocab_size=vocab_size,
make_vocab_size_divisible_by=128,
tensor_model_parallel_size=1,
gated_linear_unit=False,
)
)
baseline_flops = num_floating_point_operations(baseline_cfg, batch_size=batch_size)
assert actual_flops == baseline_flops, (
"freq=1 means every layer is standard attention, so FLOPs should match baseline"
)

def test_gdn_layer_freq_list(self):
"""Test GDN with linear_attention_freq as a list pattern (6 GDN, 2 standard)."""
batch_size = 1
freq_list = [1, 1, 0, 1, 1, 0, 1, 1] # 6 GDN, 2 standard
assert sum(freq_list) == 6
model_cfg = self._qwen35_27b_config(
num_layers=8,
linear_attention_freq=freq_list,
)
cfg = MockConfigContainer(model=model_cfg)
flops = num_floating_point_operations(cfg, batch_size=batch_size)
assert flops > 0

# Verify the mask is actually applied: must differ from pure standard attention.
baseline_cfg = MockConfigContainer(
model=self._qwen35_27b_config(num_layers=8, experimental_attention_variant=None)
)
baseline_flops = num_floating_point_operations(baseline_cfg, batch_size=batch_size)
assert flops != baseline_flops, "List-based GDN mask should differ from pure standard attention"

# freq_list [1,1,0,1,1,0,1,1] is identical to the pattern generated by int freq=3.
int_freq_cfg = MockConfigContainer(model=self._qwen35_27b_config(num_layers=8, linear_attention_freq=3))
int_freq_flops = num_floating_point_operations(int_freq_cfg, batch_size=batch_size)
assert flops == int_freq_flops, (
"List [1,1,0,1,1,0,1,1] should produce the same FLOPs as int freq=3 (equivalent 6/2 split)"
)

def test_gdn_exact_self_attn_term(self):
"""Verify the GDN self_attn_term matches the expected formula from Megatron-LM."""
batch_size = 1
num_layers = 4
hidden_size = 1024
seq_length = 256
vocab_size = 32000
qk_head_dim = 64
v_head_dim = 64
num_qk_heads = 8
num_v_heads = 16
conv_kernel_dim = 4
ffn_hidden_size = 4096

qk_dim = qk_head_dim * num_qk_heads
v_dim = v_head_dim * num_v_heads

# freq=2: layers 0,2 are GDN (pattern[i]=1), layers 1,3 are standard (pattern[i]=0)
model_cfg = MockModelConfig(
num_layers=num_layers,
hidden_size=hidden_size,
seq_length=seq_length,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=8,
num_query_groups=8,
kv_channels=128,
vocab_size=vocab_size,
make_vocab_size_divisible_by=128,
tensor_model_parallel_size=1,
gated_linear_unit=False,
experimental_attention_variant="gated_delta_net",
linear_attention_freq=2,
linear_conv_kernel_dim=conv_kernel_dim,
linear_key_head_dim=qk_head_dim,
linear_value_head_dim=v_head_dim,
linear_num_key_heads=num_qk_heads,
linear_num_value_heads=num_v_heads,
)
cfg = MockConfigContainer(model=model_cfg)
gdn_flops = num_floating_point_operations(cfg, batch_size=batch_size)

# Compute expected manually
expansion_factor = 3 * 2 * 2 # 12
# Standard attention per-layer (MHA, num_query_groups==num_attention_heads so ratio=1)
kv_channels = 128
query_projection_size = kv_channels * 8
query_projection_to_hidden_size_ratio = query_projection_size / hidden_size
standard_attn_per_layer = (
expansion_factor
* hidden_size
* hidden_size
* ((1 + 8 / 8 + seq_length / hidden_size / 2) * query_projection_to_hidden_size_ratio)
)
# GDN per-layer
gdn_per_layer = (
3
* 2
* (
hidden_size * (2 * qk_dim + 2 * v_dim + 2 * num_v_heads)
+ conv_kernel_dim * (2 * qk_dim + v_dim)
+ num_v_heads * (v_head_dim**2) * 4
+ hidden_size * v_dim
)
)
# freq=2: pattern = [1, 0, 1, 0] -> 2 GDN, 2 standard
expected_self_attn = gdn_per_layer * 2 + standard_attn_per_layer * 2
# MLP: gated_linear_unit=False -> gated_linear_multiplier=1
expected_mlp = expansion_factor * num_layers * hidden_size * ffn_hidden_size * 1
# Logit
padded_vocab = vocab_size # 32000 is already divisible by 128
expected_logit = 3 * 2 * hidden_size * padded_vocab * 1
expected_total = batch_size * seq_length * (expected_mlp + expected_self_attn + expected_logit)

assert gdn_flops == expected_total, f"Expected {expected_total:.6e} but got {gdn_flops:.6e}"

def test_gdn_more_gdn_layers_changes_flops(self):
"""Increasing GDN layer ratio (higher freq) should change FLOPs."""
batch_size = 1
# freq=4: 3/4 GDN, 1/4 standard
cfg_freq4 = MockConfigContainer(model=self._qwen35_27b_config(num_layers=8, linear_attention_freq=4))
# freq=8: 7/8 GDN, 1/8 standard
cfg_freq8 = MockConfigContainer(model=self._qwen35_27b_config(num_layers=8, linear_attention_freq=8))
flops_freq4 = num_floating_point_operations(cfg_freq4, batch_size=batch_size)
flops_freq8 = num_floating_point_operations(cfg_freq8, batch_size=batch_size)
assert flops_freq4 != flops_freq8, "Different GDN ratios should produce different FLOPs"


class TestHybridMtpPatternParsing:
"""Tests for hybrid/MTP pattern parsing in FLOPs accounting."""

Expand Down
Loading