|
24 | 24 | from megatron.core.datasets.gpt_dataset import GPTDatasetConfig as MCoreGPTDatasetConfig |
25 | 25 | from megatron.core.distributed import DistributedDataParallelConfig as MCoreDistributedDataParallelConfig |
26 | 26 | from megatron.core.optimizer import OptimizerConfig as MCoreOptimizerConfig |
| 27 | +from megatron.core.optimizer import ( |
| 28 | + ParamGroupOverride, |
| 29 | + ParamKey, |
| 30 | +) |
27 | 31 | from megatron.core.transformer.enums import AttnBackend |
| 32 | +from megatron.core.transformer.module import MegatronModule |
28 | 33 |
|
29 | 34 | from megatron.bridge.data.datasets.packed_sequence import PackedSequenceSpecs |
30 | 35 | from megatron.bridge.models import GPTModelProvider, T5ModelProvider |
@@ -258,6 +263,81 @@ class DatasetBuildContext: |
258 | 263 | tokenizer: Optional[MegatronTokenizer] = None |
259 | 264 |
|
260 | 265 |
|
| 266 | +@dataclass(frozen=True) |
| 267 | +class OptimizerConfigOverrideProviderContext: |
| 268 | + """Context for providing config overrides.""" |
| 269 | + |
| 270 | + scheduler_config: "SchedulerConfig" |
| 271 | + optimizer_config: OptimizerConfig |
| 272 | + model: Union[MegatronModule, list[MegatronModule]] |
| 273 | + |
| 274 | + |
| 275 | +@dataclass |
| 276 | +class OptimizerConfigOverrideProvider: |
| 277 | + """Abstract base class for providing config overrides.""" |
| 278 | + |
| 279 | + def build_config_overrides( |
| 280 | + self, context: OptimizerConfigOverrideProviderContext |
| 281 | + ) -> dict[ParamKey, ParamGroupOverride] | None: |
| 282 | + """Build config overrides for weight decay based on scheduler configuration. |
| 283 | +
|
| 284 | + This function creates parameter-specific overrides for weight decay behavior. |
| 285 | + By default, weight decay is skipped for bias parameters and 1D parameters. |
| 286 | + For Qwen3-Next models, weight decay is applied to q_layernorm and k_layernorm. |
| 287 | +
|
| 288 | + Args: |
| 289 | + context: OptimizerConfigOverrideProviderContext which packages the scheduler |
| 290 | + configuration, optimizer configuration, and model. |
| 291 | +
|
| 292 | + Returns: |
| 293 | + Dictionary of ParamKey to ParamGroupOverride for the optimizer |
| 294 | + """ |
| 295 | + model = context.model |
| 296 | + scheduler_config = context.scheduler_config |
| 297 | + optimizer_config = context.optimizer_config |
| 298 | + |
| 299 | + config_overrides: dict[ParamKey, ParamGroupOverride] = {} |
| 300 | + |
| 301 | + # Collect param names that should skip weight decay |
| 302 | + # NOTE: this can be simplified once https://github.com/NVIDIA/Megatron-LM/pull/2753 |
| 303 | + # is merged into dev. Then we can re-use megatron's apply_wd_to_qk_layernorm option |
| 304 | + # and call megatron.core.optimizer.get_standard_config_overrides(optimizer_config) |
| 305 | + # directly for standard settings, replacing the custom logic below for qwen3-next. |
| 306 | + no_wd_names: list[str] = [] |
| 307 | + is_qwen3_next = scheduler_config.no_weight_decay_cond_type == "qwen3_next" |
| 308 | + |
| 309 | + model_list = model if isinstance(model, list) else [model] |
| 310 | + for model_chunk in model_list: |
| 311 | + for name, param in model_chunk.named_parameters(): |
| 312 | + # Skip weight decay for bias parameters |
| 313 | + if name.endswith(".bias"): |
| 314 | + no_wd_names.append(name) |
| 315 | + continue |
| 316 | + |
| 317 | + # Skip weight decay for 1D parameters |
| 318 | + if len(param.shape) == 1: |
| 319 | + if is_qwen3_next: |
| 320 | + # Qwen3-Next: apply weight decay to qk layernorm (don't add to skip list) |
| 321 | + if "q_layernorm" in name or "k_layernorm" in name: |
| 322 | + continue |
| 323 | + no_wd_names.append(name) |
| 324 | + |
| 325 | + # Create a single ParamKey with all names that should skip weight decay |
| 326 | + if no_wd_names: |
| 327 | + no_wd_key = ParamKey(name=tuple(no_wd_names)) |
| 328 | + config_overrides[no_wd_key] = ParamGroupOverride(wd_mult=0.0) |
| 329 | + |
| 330 | + # Now handle decoupled LR: |
| 331 | + if optimizer_config.decoupled_lr is not None: |
| 332 | + decoupled_lr_config: ParamGroupOverride = {"max_lr": optimizer_config.decoupled_lr} |
| 333 | + decoupled_param_key = ParamKey(attr="is_embedding_or_output_parameter") |
| 334 | + if optimizer_config.decoupled_min_lr is not None: |
| 335 | + decoupled_lr_config["min_lr"] = optimizer_config.decoupled_min_lr |
| 336 | + config_overrides[decoupled_param_key] = decoupled_lr_config |
| 337 | + |
| 338 | + return config_overrides if config_overrides else None |
| 339 | + |
| 340 | + |
261 | 341 | @dataclass |
262 | 342 | class DatasetProvider(DataloaderConfig, ABC): |
263 | 343 | """Abstract base class for custom dataset configurations. |
@@ -1163,6 +1243,9 @@ class ConfigContainer(Container): |
1163 | 1243 | train: TrainingConfig |
1164 | 1244 | model: GPTModelProvider | T5ModelProvider | MambaModelProvider |
1165 | 1245 | optimizer: OptimizerConfig |
| 1246 | + optimizer_config_override_provider: OptimizerConfigOverrideProvider = field( |
| 1247 | + default_factory=OptimizerConfigOverrideProvider |
| 1248 | + ) |
1166 | 1249 | ddp: DistributedDataParallelConfig = field(default_factory=DistributedDataParallelConfig) |
1167 | 1250 | scheduler: SchedulerConfig |
1168 | 1251 | dataset: GPTDatasetConfig | FinetuningDatasetConfig | DatasetProvider |
|
0 commit comments