Skip to content

Commit a2c11f8

Browse files
authored
Refactor the optimizer override function so that users can swap in their own (#2010)
Signed-off-by: John St. John <jstjohn@nvidia.com>
1 parent a0bffa9 commit a2c11f8

File tree

3 files changed

+98
-52
lines changed

3 files changed

+98
-52
lines changed

src/megatron/bridge/training/config.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,12 @@
2424
from megatron.core.datasets.gpt_dataset import GPTDatasetConfig as MCoreGPTDatasetConfig
2525
from megatron.core.distributed import DistributedDataParallelConfig as MCoreDistributedDataParallelConfig
2626
from megatron.core.optimizer import OptimizerConfig as MCoreOptimizerConfig
27+
from megatron.core.optimizer import (
28+
ParamGroupOverride,
29+
ParamKey,
30+
)
2731
from megatron.core.transformer.enums import AttnBackend
32+
from megatron.core.transformer.module import MegatronModule
2833

2934
from megatron.bridge.data.datasets.packed_sequence import PackedSequenceSpecs
3035
from megatron.bridge.models import GPTModelProvider, T5ModelProvider
@@ -258,6 +263,81 @@ class DatasetBuildContext:
258263
tokenizer: Optional[MegatronTokenizer] = None
259264

260265

266+
@dataclass(frozen=True)
267+
class OptimizerConfigOverrideProviderContext:
268+
"""Context for providing config overrides."""
269+
270+
scheduler_config: "SchedulerConfig"
271+
optimizer_config: OptimizerConfig
272+
model: Union[MegatronModule, list[MegatronModule]]
273+
274+
275+
@dataclass
276+
class OptimizerConfigOverrideProvider:
277+
"""Abstract base class for providing config overrides."""
278+
279+
def build_config_overrides(
280+
self, context: OptimizerConfigOverrideProviderContext
281+
) -> dict[ParamKey, ParamGroupOverride] | None:
282+
"""Build config overrides for weight decay based on scheduler configuration.
283+
284+
This function creates parameter-specific overrides for weight decay behavior.
285+
By default, weight decay is skipped for bias parameters and 1D parameters.
286+
For Qwen3-Next models, weight decay is applied to q_layernorm and k_layernorm.
287+
288+
Args:
289+
context: OptimizerConfigOverrideProviderContext which packages the scheduler
290+
configuration, optimizer configuration, and model.
291+
292+
Returns:
293+
Dictionary of ParamKey to ParamGroupOverride for the optimizer
294+
"""
295+
model = context.model
296+
scheduler_config = context.scheduler_config
297+
optimizer_config = context.optimizer_config
298+
299+
config_overrides: dict[ParamKey, ParamGroupOverride] = {}
300+
301+
# Collect param names that should skip weight decay
302+
# NOTE: this can be simplified once https://github.com/NVIDIA/Megatron-LM/pull/2753
303+
# is merged into dev. Then we can re-use megatron's apply_wd_to_qk_layernorm option
304+
# and call megatron.core.optimizer.get_standard_config_overrides(optimizer_config)
305+
# directly for standard settings, replacing the custom logic below for qwen3-next.
306+
no_wd_names: list[str] = []
307+
is_qwen3_next = scheduler_config.no_weight_decay_cond_type == "qwen3_next"
308+
309+
model_list = model if isinstance(model, list) else [model]
310+
for model_chunk in model_list:
311+
for name, param in model_chunk.named_parameters():
312+
# Skip weight decay for bias parameters
313+
if name.endswith(".bias"):
314+
no_wd_names.append(name)
315+
continue
316+
317+
# Skip weight decay for 1D parameters
318+
if len(param.shape) == 1:
319+
if is_qwen3_next:
320+
# Qwen3-Next: apply weight decay to qk layernorm (don't add to skip list)
321+
if "q_layernorm" in name or "k_layernorm" in name:
322+
continue
323+
no_wd_names.append(name)
324+
325+
# Create a single ParamKey with all names that should skip weight decay
326+
if no_wd_names:
327+
no_wd_key = ParamKey(name=tuple(no_wd_names))
328+
config_overrides[no_wd_key] = ParamGroupOverride(wd_mult=0.0)
329+
330+
# Now handle decoupled LR:
331+
if optimizer_config.decoupled_lr is not None:
332+
decoupled_lr_config: ParamGroupOverride = {"max_lr": optimizer_config.decoupled_lr}
333+
decoupled_param_key = ParamKey(attr="is_embedding_or_output_parameter")
334+
if optimizer_config.decoupled_min_lr is not None:
335+
decoupled_lr_config["min_lr"] = optimizer_config.decoupled_min_lr
336+
config_overrides[decoupled_param_key] = decoupled_lr_config
337+
338+
return config_overrides if config_overrides else None
339+
340+
261341
@dataclass
262342
class DatasetProvider(DataloaderConfig, ABC):
263343
"""Abstract base class for custom dataset configurations.
@@ -1163,6 +1243,9 @@ class ConfigContainer(Container):
11631243
train: TrainingConfig
11641244
model: GPTModelProvider | T5ModelProvider | MambaModelProvider
11651245
optimizer: OptimizerConfig
1246+
optimizer_config_override_provider: OptimizerConfigOverrideProvider = field(
1247+
default_factory=OptimizerConfigOverrideProvider
1248+
)
11661249
ddp: DistributedDataParallelConfig = field(default_factory=DistributedDataParallelConfig)
11671250
scheduler: SchedulerConfig
11681251
dataset: GPTDatasetConfig | FinetuningDatasetConfig | DatasetProvider

src/megatron/bridge/training/optim.py

Lines changed: 14 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -12,73 +12,30 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Dict, Optional, Union
15+
from typing import Optional, Union
1616

1717
from megatron.core.optimizer import (
1818
MegatronOptimizer,
1919
OptimizerConfig,
20-
ParamKey,
2120
get_megatron_optimizer,
2221
)
2322
from megatron.core.optimizer.muon import get_megatron_muon_optimizer
24-
from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler, ParamGroupOverride
23+
from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler
2524
from megatron.core.transformer.module import MegatronModule
2625

27-
from megatron.bridge.training.config import SchedulerConfig
28-
29-
30-
def _build_config_overrides(
31-
scheduler_config: SchedulerConfig,
32-
model: Union[MegatronModule, list[MegatronModule]],
33-
) -> Optional[Dict[ParamKey, ParamGroupOverride]]:
34-
"""Build config overrides for weight decay based on scheduler configuration.
35-
36-
This function creates parameter-specific overrides for weight decay behavior.
37-
By default, weight decay is skipped for bias parameters and 1D parameters.
38-
For Qwen3-Next models, weight decay is applied to q_layernorm and k_layernorm.
39-
40-
Args:
41-
scheduler_config: Scheduler configuration containing weight decay settings
42-
model: The model or list of model chunks to collect parameter names from
43-
44-
Returns:
45-
Dictionary of ParamKey to ParamGroupOverride for the optimizer
46-
"""
47-
config_overrides: Dict[ParamKey, ParamGroupOverride] = {}
48-
49-
# Collect param names that should skip weight decay
50-
no_wd_names: list[str] = []
51-
is_qwen3_next = scheduler_config.no_weight_decay_cond_type == "qwen3_next"
52-
53-
model_list = model if isinstance(model, list) else [model]
54-
for model_chunk in model_list:
55-
for name, param in model_chunk.named_parameters():
56-
# Skip weight decay for bias parameters
57-
if name.endswith(".bias"):
58-
no_wd_names.append(name)
59-
continue
60-
61-
# Skip weight decay for 1D parameters
62-
if len(param.shape) == 1:
63-
if is_qwen3_next:
64-
# Qwen3-Next: apply weight decay to qk layernorm (don't add to skip list)
65-
if "q_layernorm" in name or "k_layernorm" in name:
66-
continue
67-
no_wd_names.append(name)
68-
69-
# Create a single ParamKey with all names that should skip weight decay
70-
if no_wd_names:
71-
no_wd_key = ParamKey(name=tuple(no_wd_names))
72-
config_overrides[no_wd_key] = ParamGroupOverride(wd_mult=0.0)
73-
74-
return config_overrides if config_overrides else None
26+
from megatron.bridge.training.config import (
27+
OptimizerConfigOverrideProvider,
28+
OptimizerConfigOverrideProviderContext,
29+
SchedulerConfig,
30+
)
7531

7632

7733
def setup_optimizer(
7834
optimizer_config: OptimizerConfig,
7935
scheduler_config: SchedulerConfig,
8036
model: Union[MegatronModule, list[MegatronModule]],
8137
use_gloo_process_groups: bool = False,
38+
optimizer_config_override_provider: Optional[OptimizerConfigOverrideProvider] = None,
8239
) -> tuple[MegatronOptimizer, OptimizerParamScheduler]:
8340
"""Set up the optimizer and scheduler.
8441
@@ -91,8 +48,13 @@ def setup_optimizer(
9148
Returns:
9249
tuple containing the optimizer and scheduler
9350
"""
51+
if optimizer_config_override_provider is None:
52+
optimizer_config_override_provider = OptimizerConfigOverrideProvider()
53+
9454
# Build config overrides for weight decay based on scheduler config and model params
95-
config_overrides = _build_config_overrides(scheduler_config, model)
55+
config_overrides = optimizer_config_override_provider.build_config_overrides(
56+
OptimizerConfigOverrideProviderContext(scheduler_config, optimizer_config, model)
57+
)
9658

9759
if "muon" not in optimizer_config.optimizer and "soap" not in optimizer_config.optimizer:
9860
optimizer = get_megatron_optimizer(

src/megatron/bridge/training/setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def modelopt_pre_wrap_hook(model):
226226
scheduler_config=cfg.scheduler,
227227
model=model,
228228
use_gloo_process_groups=cfg.dist.use_gloo_process_groups,
229+
optimizer_config_override_provider=cfg.optimizer_config_override_provider,
229230
)
230231
timers("model-and-optimizer-setup").stop()
231232
barrier_and_log("after model, optimizer, and learning rate scheduler are built")

0 commit comments

Comments
 (0)