Skip to content

Commit 145e88c

Browse files
aagalloaagallopre-commit-ci[bot]vthumbe1503
authored
Add multi-precision training support to FSDP script (#2662)
* Add precision parameter support for multiple training formats Enable configurable precision training with support for FP32, FP16, FP8, MXFP8, and NVFP4 formats. Added precision argument parser and match statement to configure appropriate dtype and recipe based on selected precision. - Add precision() type validator function - Implement precision-based configuration in train() - Support FP32, FP16, FP8, MXFP8, and NVFP4 formats - Configure format-specific recipes (DelayedScaling, MXFP8BlockScaling, NVFP4BlockScaling) - Set appropriate no_fp8 flags based on precision selection Signed-off-by: aagallo <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix FP16 dtype mapping and implement CLI flag precedence Correct FP16 precision to use torch.float16 instead of torch.bfloat16, and add precedence logic where --dtype and --no-fp8 flags override --precision when explicitly set, with warnings issued for conflicts. - Fix case fp16 to use torch.float16 instead of torch.bfloat16 - Add flag precedence detection by comparing against default values - Implement warning messages when --dtype or --no-fp8 override --precision - Update argument parser help text to document precedence behavior - Ensure --dtype and --no-fp8 take precedence over --precision presets Signed-off-by: Andrea Gallo <[email protected]> * Add logging and documentation for precision configuration Add informative log messages and enhanced help text to clarify precision configuration behavior and flag precedence for better user transparency. - Add log message showing which precision preset is being used - Add warning logs when --dtype or --no-fp8 override --precision - Add final training configuration log (dtype, FP8 status, recipe) - Enhance argument parser help text with precedence examples - Add inline code comments explaining precedence logic Signed-off-by: Andrea Gallo <[email protected]> * Initialize recipe variable in all precision cases Add recipe initialization for fp32 and fp16 precision cases to prevent undefined variable errors, even though recipe is not used when no_fp8 is set to True. - Add DelayedScaling recipe setup for fp32 case with no_fp8=True - Add DelayedScaling recipe setup for fp16 case with no_fp8=True - Add inline comments explaining recipe is set up but not used by autocast - Ensure recipe variable is defined in all precision branches for consistency Signed-off-by: Andrea Gallo <[email protected]> * Fix dtype flag detection to support explicit override behavior Update flag precedence detection to use sys.argv for checking if --dtype was explicitly set, ensuring dtype always overrides precision regardless of whether it matches the default value. - Add sys import for command-line argument detection - Change dtype_explicitly_set check to use '--dtype' in sys.argv - Change no_fp8_explicitly_set check to use '--no-fp8' in sys.argv - Ensure --dtype bf16 correctly overrides --precision even when matching default - Maintain warning messages when explicit flags override precision presets Signed-off-by: Andrea Gallo <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Replace sys.argv parsing with custom action and fix default case Replace fragile sys.argv parsing with robust custom argparse action class to track explicitly set arguments, and fix default precision case to explicitly set no_fp8 to False for consistent FP8-enabled behavior. - Add StoreExplicitAction custom action class for tracking explicit arguments - Update --dtype argument to use StoreExplicitAction - Replace sys.argv check with getattr for dtype_explicitly_set attribute - Remove sys import from train() function - Fix default case to set no_fp8 = False instead of opts.no_fp8 - Ensure recipe variable is properly initialized in all code paths - Support all argument passing methods including config files and = syntax Signed-off-by: Andrea Gallo <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix params_dtype to use computed dtype from precision logic Remove params_dtype initialization from get_layer_args() and update FSDP MixedPrecision to use computed dtype variable instead of raw opts.dtype, ensuring precision presets are properly applied throughout the model. - Remove params_dtype from get_layer_args() layer_kwargs initialization - Update FSDP MixedPrecision param_dtype to use computed dtype variable - Ensure precision preset logic is respected in both layer initialization and FSDP - Maintain backward compatibility with original FP8-enabled default behavior Signed-off-by: Andrea Gallo <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix type conversion in StoreExplicitAction for --dtype argument Add type converter application in StoreExplicitAction custom action to ensure --dtype values are properly converted from strings to torch dtype objects, preventing runtime errors in torch operations. - Store type converter in StoreExplicitAction.__init__ - Apply type conversion in __call__ before setting attribute value - Add error handling for invalid type conversions - Ensure opts.dtype contains torch dtype object, not raw string - Fix runtime errors in torch.rand() and MixedPrecision() calls Signed-off-by: Andrea Gallo <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix precision preset recipe selection and add incompatibility validation Address critical bugs where FP8 recipes were incorrectly selected when explicit flags were set, and add validation to prevent incompatible flag combinations that would silently disable FP8 training. - Remove default value from --precision parameter (set to None for backward compatibility) - Add get_precision_preset() and get_recipe_for_precision() helper functions - Implement two-path configuration logic: backward compatibility mode vs. precision preset mode - Add incompatibility validation: raise ValueError when --no-fp8 used with fp8/mxfp8/nvfp4 presets - Preserve FP8 recipe selection when --dtype explicitly overrides precision preset dtype - Fix fp16 case to correctly map to torch.float16 instead of torch.bfloat16 - Update parameter help text with precedence rules and usage examples - Ensure backward compatibility: scripts without --precision work identically to original version Signed-off-by: Andrea Gallo <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix unreachable default case and redundant recipe recreation Remove dead code in get_precision_preset() default case and eliminate redundant recipe recreation when dtype is explicitly overridden, ensuring cleaner logic flow and preventing duplicate recipe instantiation. - Remove unreachable case _: branch from get_precision_preset() function - Delete redundant recipe recreation when dtype_explicitly_set is true - Preserve existing recipe from preset when dtype override occurs - Ensure dtype override only affects parameter storage, not FP8 recipe selection Signed-off-by: Andrea Gallo <[email protected]> * Add explicit error handling for invalid precision presets Prevent silent failures when precision validation is bypassed or new presets are added without updating get_precision_preset() function by adding explicit ValueError for unhandled cases. - Add case _: branch to get_precision_preset() that raises ValueError - Ensure invalid precision values fail loudly with clear error message - Prevent TypeError on tuple unpacking if function returns None - Improve maintainability when adding new precision presets Signed-off-by: Andrea Gallo <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: address argparse robustness and cleanup issues in fsdp.py Resolve three code review issues in examples/pytorch/fsdp/fsdp.py: dead commented-out code, unhelpful TypeError in precision(), and rigid __init__ signature in StoreTrueExplicitAction. - Remove commented-out layer_kwargs["params_dtype"] = dtype at line 106; dead code after params_dtype was moved to train() - Replace bare raise TypeError in precision() with argparse.ArgumentTypeError and explicit list of supported values (fp32, fp16, fp8, mxfp8, nvfp4) for a meaningful error message - Add **kwargs to StoreTrueExplicitAction.__init__ and forward to super().__init__(); aligns with StoreExplicitAction for robustness Signed-off-by: Andrea Gallo <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: suppress spurious dtype override warning when value matches preset Guard the --dtype override warning and dtype reassignment behind an actual value change check to avoid a false positive when the user explicitly passes --dtype with the same value the precision preset would have selected. - Add new_dtype != preset_dtype guard inside the dtype_explicitly_set branch so warning and dtype reassignment only trigger on a real override - Suppress redundant recipe re-creation when dtype matches preset default; recipe is already correctly set from preset_recipe above - No behavioral change when --dtype differs from preset default Signed-off-by: Andrea Gallo <[email protected]> * fix: add type conversion with error handling in StoreExplicitAction.__call__ Apply type_converter inside __call__ with proper exception handling to ensure --dtype values are converted and validated at parse time rather than silently passing raw strings through. - Wrap type_converter call in try/except catching ValueError, TypeError, and argparse.ArgumentTypeError to surface conversion failures via parser.error() with a descriptive message - Guard conversion behind if self.type_converter is not None check for cases where no converter is registered - Ensures --dtype argument is correctly converted and validated consistently with standard argparse type= behavior Signed-off-by: Andrea Gallo <[email protected]> * fix: remove redundant condition, deduplicate recipe logic, guard re-instantiation Address three code review issues in examples/pytorch/fsdp/fsdp.py: redundant opts.no_fp8 check, duplicated recipe construction, and unnecessary recipe re-instantiation when dtype matches preset. - Remove redundant 'and opts.no_fp8' from no_fp8_explicitly_set guard at line 373; StoreTrueExplicitAction always sets opts.no_fp8 to True when it fires, making the extra check always True - Refactor get_recipe_for_precision() to delegate to get_precision_preset() and extract the recipe, eliminating duplicated recipe construction logic and silent drift hazard when recipe parameters are tuned in one place but not the other - Guard recipe re-creation inside new_dtype != preset_dtype branch to avoid unnecessary re-instantiation when dtype_explicitly_set but the value matches the preset default Signed-off-by: Andrea Gallo <[email protected]> * fix: validate flags before dist.init_process_group and remove redundant arg Move incompatible-flags check before dist.init_process_group() to avoid leaving the NCCL process group partially initialized, and remove redundant fp8_format=Format.E4M3 from MXFP8BlockScaling(). - Move no_fp8_explicitly_set + precision conflict check to the top of train() before dist.init_process_group() to prevent deadlocks or 'Address already in use' errors on other ranks that are still waiting inside init_process_group when rank 0 raises ValueError - Remove explicit fp8_format=Format.E4M3 from MXFP8BlockScaling() call; Format.E4M3 is already the dataclass default and passing it explicitly adds noise without adding clarity Signed-off-by: Andrea Gallo <[email protected]> * fix: simplify StoreExplicitAction and improve training config log Delegate type conversion to argparse in StoreExplicitAction and include active FP8 recipe type in the training configuration log. - Remove self.type_converter field and manual try/except block from StoreExplicitAction.__call__; forward type= kwarg to super().__init__() so argparse handles conversion natively before __call__ is invoked, restoring standard error messages and %(type)s help interpolation - Simplify StoreExplicitAction.__init__ to use **kwargs passthrough, removing the now-unnecessary type= interception logic - Include active recipe type in training configuration log output using type(recipe).__name__ so log emits messages like 'FP8=enabled (MXFP8BlockScaling)' or 'FP8=disabled', making it easier to verify the intended quantization scheme is in use Signed-off-by: Andrea Gallo <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: forward kwargs in StoreTrueExplicitAction, improve dtype log, document rank assumption Address three code review issues in examples/pytorch/fsdp/fsdp.py: silent kwargs drop in StoreTrueExplicitAction, missing confirmation log when dtype matches preset, and undocumented torchrun assumption. - Forward **kwargs to super().__init__() in StoreTrueExplicitAction to prevent silent discard of unexpected keyword arguments (e.g. metavar, choices) if argument registration is ever extended - Add info log when dtype_explicitly_set but new_dtype == preset_dtype so user receives confirmation their --dtype flag was acknowledged even when it matches the preset default and no override is needed - Add comment above no_fp8_explicitly_set validation documenting that raising ValueError before dist.init_process_group is safe because torchrun guarantees all ranks receive identical CLI arguments Signed-off-by: Andrea Gallo <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixing typo in code documentation Signed-off-by: Andrea Gallo <[email protected]> * fix: format dtype in log messages and document recipe=None intent Strip 'torch.' prefix from dtype in user-facing log messages and add a comment documenting the intentional recipe=None behavior when FP8 is disabled. - Replace raw dtype formatting with str(dtype).replace('torch.', '') in both Warning and Info log messages so users see 'float32' or 'bfloat16' instead of 'torch.float32' or 'torch.bfloat16' - Add inline comment on recipe=None assignment explaining that te.autocast safely substitutes get_default_fp8_recipe() internally when recipe is None, and skips check_recipe_support when enabled=False, so the assignment is intentional and safe despite populating global FP8 state with a default recipe Signed-off-by: Andrea Gallo <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor: remove redundant __init__ override in StoreExplicitAction Remove the __init__ override from StoreExplicitAction since it only calls super().__init__() with the same arguments, which Python does automatically. The class now consists solely of __call__, eliminating dead code without any behavioral change. Signed-off-by: Andrea Gallo <[email protected]> * fix: use 'quantization' label in log and remove redundant recipe re-instantiation Replace misleading 'FP8' label with 'quantization' in training configuration log and remove redundant recipe re-instantiation in the dtype_explicitly_set path. - Replace 'FP8=enabled/disabled' with 'quantization=enabled/disabled' in dist_print configuration log to accurately cover all TE precision modes including NVFP4 which is 4-bit, not FP8 - Remove get_recipe_for_precision() call inside dtype_explicitly_set block; recipe is already correctly assigned from preset_recipe above and re-instantiating it is wasteful and creates a second object discarding the first - Add inline comment clarifying that recipe requires no update in the dtype_explicitly_set path since it is determined by opts.precision, not dtype Signed-off-by: Andrea Gallo <[email protected]> * fix: remove redundant recipe re-instantiation in equal-dtype path Remove unnecessary get_recipe_for_precision() call in the else branch of the dtype_explicitly_set block where new_dtype == preset_dtype. - recipe is already correctly assigned from preset_recipe before the dtype_explicitly_set block; no re-instantiation is needed in either branch since recipe is determined by opts.precision, not dtype - Previous else branch was re-creating the recipe (wasteful) while the if branch was not, inverting the logic implied by the comment - Replace with a comment clarifying that recipe requires no update in the dtype_explicitly_set path Signed-off-by: Andrea Gallo <[email protected]> * fix: define dtype_name unconditionally and guard dtype override warning Fix two bugs in train() precision configuration block. - Define dtype_name unconditionally before the dtype_explicitly_set block to prevent NameError in the config log when dtype_explicitly_set is False (the common case when --dtype is not explicitly passed) - Guard dtype override warning behind 'dtype_explicitly_set and opts.precision is not None' to prevent spurious warning when user passes --dtype without --precision (original behavior path) Signed-off-by: Andrea Gallo <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: recompute dtype_name after override and restore DelayedScaling default Fix two bugs in train() precision configuration block: stale dtype_name in log messages after dtype override, and behavioral regression where recipe=None was passed to te.autocast when FP8 was enabled in backward-compatible mode. - Recompute dtype_name immediately after dtype = new_dtype in the dtype override branch so warning and config log reflect the effective dtype rather than the stale preset dtype - Restore original default behavior in opts.precision is None path: when no_fp8 is False (FP8 enabled), supply DelayedScaling recipe to preserve the original te.autocast behavior instead of passing recipe=None which changed the implicit fallback behavior Signed-off-by: Andrea Gallo <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor: simplify opts.dtype None check and remove redundant no_fp8 assignment Remove two redundant lines in train() precision configuration block. - Remove 'if opts.dtype is not None' guard in opts.precision is None branch; --dtype has default=torch.bfloat16 so opts.dtype is never None and the condition is always True - Remove redundant 'no_fp8 = preset_no_fp8' assignment in the else branch; no_fp8 is already assigned from preset_no_fp8 at the tuple unpack above and reassigning it in the else branch adds noise without changing behavior Signed-off-by: Andrea Gallo <[email protected]> * fix: guard recipe=None, remove dead case None, shorten help text Address three code review issues in examples/pytorch/fsdp/fsdp.py: recipe=None passed to te.autocast, dead case None in get_precision_preset, and excessively verbose help strings for --no-fp8 and --dtype. - Use 'recipe or DelayedScaling()' fallback at te.autocast call site to preserve original defensive pattern of always passing a concrete recipe instance, even when enabled=False - Remove case None from get_precision_preset() and guard call site in train() with 'if opts.precision is not None' to eliminate dead-code path whose return values were immediately overridden by the caller - Replace multi-paragraph help strings for --no-fp8 and --dtype with concise one-liner synopses; move detailed precedence rules to module-level docstring or README Signed-off-by: Andrea Gallo <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: resolve recipe=None fallback before training loop, not per-iteration Move 'recipe or DelayedScaling()' fallback to a one-time assignment before the training loop to avoid allocating a new DelayedScaling() object on every iteration when FP8 is disabled. - Add 'if recipe is None: recipe = DelayedScaling()' after the configuration block and before the training loop so the fallback object is created once and reused across all iterations - Restore clean 'recipe=recipe' in te.autocast call, matching the original code pattern - Add comment explaining why recipe is always set to a concrete object even when FP8 is disabled Signed-off-by: Andrea Gallo <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: move recipe=None fallback before training loop with consistent parameters Move 'if recipe is None: recipe = DelayedScaling(...)' guard to just before the training loop instead of inside it to avoid redundant is-None checks on every iteration and variable mutation inside the loop. - Use consistent DelayedScaling parameters (fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo='max') matching the rest of the file, rather than plain DelayedScaling() with default args - Guard runs once before the loop; recipe is stable for all iterations - Restores clean 'recipe=recipe' in te.autocast call with no inline fallback expression Signed-off-by: Andrea Gallo <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: pass amax_reduction_group only for DelayedScaling and shorten help text Pass amax_reduction_group conditionally based on recipe type and replace verbose multi-paragraph help strings with concise one-liners. - Compute amax_group = all_gpus if isinstance(recipe, DelayedScaling) else None and pass amax_group to te.autocast; amax_reduction_group is a DelayedScaling-specific parameter for per-tensor amax aggregation and is not accepted by MXFP8BlockScaling or NVFP4BlockScaling which use block-level scaling - Replace multi-paragraph help strings for --no-fp8, --dtype, and --precision (with PRECEDENCE/BEHAVIOR/RATIONALE/EXAMPLES sections) with concise one-liner synopses suitable for terminal --help output - Move detailed precedence rules to module-level docstring Signed-off-by: Andrea Gallo <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: guard amax_group with not no_fp8 to prevent spurious distributed comms Add 'not no_fp8' condition to amax_group assignment to prevent amax_reduction_group=all_gpus being passed to te.autocast when FP8 is disabled. - When no_fp8=True and recipe was None, the DelayedScaling fallback causes isinstance(recipe, DelayedScaling) to return True, which incorrectly set amax_group=all_gpus even though enabled=False - Add 'not no_fp8' guard so amax_group is only set to all_gpus when FP8 is active AND the recipe is DelayedScaling (per-tensor amax); all other cases (FP8 disabled, block-scaling recipes) use None Signed-off-by: Andrea Gallo <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: warn on redundant --no-fp8 with fp32/fp16 and document amax_group=None Emit a warning when --no-fp8 is combined with a non-FP8 precision preset and add an inline comment explaining why amax_reduction_group is None for block-scaling recipes. - Add warning when opts.precision in ['fp32', 'fp16'] and opts.no_fp8 is set; FP8 is already disabled by these presets so the flag is redundant and silently ignored without this feedback - Add inline comment on amax_group assignment explaining that MXFP8BlockScaling and NVFP4BlockScaling use local block scaling and do not require a distributed amax reduction group, and that None is also correct when FP8 is disabled to avoid unnecessary distributed communication Signed-off-by: Andrea Gallo <[email protected]> * fix: initialize preset_dtype and preset_recipe before conditional block Initialize preset_dtype and preset_recipe with fallback values before the 'if opts.precision is not None' block to prevent static analyzer warnings about potentially unbound variables. - Assign preset_dtype = opts.dtype and preset_recipe = None as sensible fallbacks before the if-else block; these are overwritten by get_precision_preset() when opts.precision is not None and are never accessed in the else branch - Satisfies mypy, pylint, and pyflakes 'possibly undefined' / 'unbound' warnings that would otherwise trigger CI lint failures in projects treating unbound-variable warnings as errors - No behavioral change; the if-else logic is unchanged Signed-off-by: Andrea Gallo <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: compute amax_group before recipe=None fallback to avoid isinstance race Move amax_group computation before the 'if recipe is None' fallback assignment so isinstance(recipe, DelayedScaling) reflects the actual user-selected recipe rather than the defensive fallback object. - When recipe is None (non-FP8 presets or --no-fp8), isinstance correctly returns False and amax_group is set to None before the fallback substitutes a DelayedScaling instance - Prevents the fragile ordering dependency where not no_fp8 was the sole guard against passing all_gpus to a recipe that doesn't need it - Add inline comment explaining why amax_group must be computed before the recipe fallback to preserve the invariant Signed-off-by: Andrea Gallo <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: warn on potentially incompatible --dtype float16 with FP8-family presets Add explicit warning when --dtype float16 is combined with --precision fp8, mxfp8, or nvfp4, which expect bfloat16 accumulation. - Emit compatibility warning before applying the dtype override when opts.precision is in ['fp8', 'mxfp8', 'nvfp4'] and new_dtype is torch.float16; these presets are designed for bfloat16 accumulation and pairing with float16 may produce incorrect or undefined results - Warning is emitted in addition to the existing dtype override warning so users see both the compatibility concern and the override confirmation - Override is still applied (not blocked) to preserve user control; users who know their TE version supports float16 accumulation can proceed with awareness of the risk Signed-off-by: Andrea Gallo <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: aagallo <[email protected]> Signed-off-by: Andrea Gallo <[email protected]> Co-authored-by: aagallo <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: vthumbe1503 <[email protected]>
1 parent d2e4755 commit 145e88c

File tree

1 file changed

+168
-12
lines changed

1 file changed

+168
-12
lines changed

examples/pytorch/fsdp/fsdp.py

Lines changed: 168 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,12 @@
1818
)
1919

2020
import transformer_engine.pytorch as te
21-
from transformer_engine.common.recipe import Format, DelayedScaling
21+
from transformer_engine.common.recipe import (
22+
Format,
23+
DelayedScaling,
24+
MXFP8BlockScaling,
25+
NVFP4BlockScaling,
26+
)
2227
from transformer_engine.pytorch.distributed import prepare_te_modules_for_fsdp
2328

2429
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
@@ -64,10 +69,21 @@ def torch_dtype(d):
6469
"bfloat16": torch.bfloat16,
6570
}
6671
if lowercase(d) not in typemap.keys():
67-
raise TypeError
72+
raise argparse.ArgumentTypeError(
73+
f"invalid dtype '{d}'. Supported values: fp32/float32, fp16/float16, bf16/bfloat16"
74+
)
6875
return typemap[lowercase(d)]
6976

7077

78+
def precision(d):
79+
typemap = ["fp32", "fp16", "fp8", "mxfp8", "nvfp4"]
80+
if lowercase(d) not in typemap:
81+
raise argparse.ArgumentTypeError(
82+
f"invalid precision '{d}'. Supported values: {', '.join(typemap)}"
83+
)
84+
return lowercase(d)
85+
86+
7187
te_layer_map = {
7288
"linear": te.Linear,
7389
"layernorm": te.LayerNorm,
@@ -91,7 +107,6 @@ def get_layer_args(opts):
91107
hidden_size = opts.num_heads * opts.head_dim
92108
layer_args = (hidden_size,)
93109
layer_kwargs = {
94-
"params_dtype": opts.dtype,
95110
"device": "cuda" if opts.no_defer_init else "meta",
96111
"get_rng_state_tracker": get_cuda_rng_tracker,
97112
}
@@ -112,6 +127,15 @@ def get_layer_args(opts):
112127
return layer_args, layer_kwargs
113128

114129

130+
class StoreExplicitAction(argparse.Action):
131+
"""Custom action that tracks whether an argument was explicitly set."""
132+
133+
def __call__(self, parser, namespace, values, option_string=None):
134+
# values already converted by argparse via action.type
135+
setattr(namespace, self.dest, values)
136+
setattr(namespace, f"{self.dest}_explicitly_set", True)
137+
138+
115139
def parse_fsdp_args():
116140
parser = argparse.ArgumentParser(
117141
description="Run Transformer Engine modules with the "
@@ -173,7 +197,10 @@ def parse_fsdp_args():
173197
"--no-fp8",
174198
action="store_true",
175199
default=False,
176-
help="Disables the te.autocast() context.",
200+
help=(
201+
"Disable te.autocast() FP8 context. Incompatible with --precision fp8/mxfp8/nvfp4."
202+
" Default: False."
203+
),
177204
)
178205
parser.add_argument(
179206
"--no-defer-init",
@@ -189,7 +216,21 @@ def parse_fsdp_args():
189216
"--dtype",
190217
type=torch_dtype,
191218
default=torch.bfloat16,
192-
help="Data type for input tensor and Transformer Engine module parameters.",
219+
action=StoreExplicitAction,
220+
help=(
221+
"Parameter dtype: fp32/float32, fp16/float16, bf16/bfloat16. Overrides --precision"
222+
" dtype when explicitly set. Default: bfloat16."
223+
),
224+
)
225+
parser.add_argument(
226+
"--precision",
227+
type=precision,
228+
default=None,
229+
help=(
230+
"Precision preset: fp32, fp16, fp8, mxfp8, nvfp4. Configures dtype and FP8 recipe"
231+
" automatically. Overridden by explicit --dtype. Default: None (use --dtype and"
232+
" --no-fp8 directly)."
233+
),
193234
)
194235
return parser.parse_args()
195236

@@ -200,15 +241,118 @@ def dist_print(text, all_ranks=False, no_new_line=False):
200241
print(f"[GPU-{LOCAL_RANK}] " + text, end=end)
201242

202243

244+
def get_precision_preset(precision_value):
245+
"""Get dtype, no_fp8, and recipe based on precision preset.
246+
247+
Returns:
248+
tuple: (dtype, no_fp8, recipe)
249+
"""
250+
match precision_value:
251+
case "fp32":
252+
return torch.float32, True, None
253+
case "fp16":
254+
return torch.float16, True, None
255+
case "fp8":
256+
recipe = DelayedScaling(
257+
fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max"
258+
)
259+
return torch.bfloat16, False, recipe
260+
case "mxfp8":
261+
recipe = MXFP8BlockScaling()
262+
return torch.bfloat16, False, recipe
263+
case "nvfp4":
264+
recipe = NVFP4BlockScaling()
265+
return torch.bfloat16, False, recipe
266+
case _:
267+
raise ValueError(
268+
f"Invalid precision preset: {precision_value}. "
269+
"Supported values: fp32, fp16, fp8, mxfp8, nvfp4"
270+
)
271+
272+
203273
def train(opts):
274+
# Check which flags were explicitly set
275+
dtype_explicitly_set = getattr(opts, "dtype_explicitly_set", False)
276+
277+
# Validate flag combinations before touching distributed state.
278+
# Error if user requests FP8-based precision but also sets --no-fp8
279+
# Safe to raise here because torchrun guarantees all ranks receive
280+
# identical CLI arguments; all ranks will raise simultaneously.
281+
if opts.precision in ["fp8", "mxfp8", "nvfp4"] and opts.no_fp8:
282+
raise ValueError(
283+
f"Cannot use --no-fp8 with --precision {opts.precision}. "
284+
"These flags are incompatible. "
285+
f"Either remove --no-fp8 to use {opts.precision} training, "
286+
"or use --precision fp32/fp16 for non-FP8 training."
287+
)
288+
if opts.precision in ["fp32", "fp16"] and opts.no_fp8:
289+
dist_print(
290+
f"Warning: --no-fp8 is redundant when using --precision {opts.precision} "
291+
"(FP8 is already disabled by this preset). The flag will be ignored."
292+
)
293+
204294
# Initialize torch.distributed global process group
205295
dist.init_process_group(backend="nccl")
206296
torch.cuda.set_device(LOCAL_RANK)
207297
dist_print(f"WORLD_SIZE = {WORLD_SIZE}")
208298
torch.manual_seed(opts.seed)
209299

300+
preset_dtype: torch.dtype = opts.dtype # sensible fallback
301+
preset_recipe = None
302+
303+
if opts.precision is not None:
304+
preset_dtype, preset_no_fp8, preset_recipe = get_precision_preset(opts.precision)
305+
dtype, no_fp8, recipe = preset_dtype, preset_no_fp8, preset_recipe
306+
dist_print(f"Using precision preset: {opts.precision}")
307+
else:
308+
# Original behavior: --dtype and --no-fp8 control training directly
309+
dtype = opts.dtype
310+
no_fp8 = opts.no_fp8
311+
recipe = (
312+
DelayedScaling(fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max")
313+
if not no_fp8
314+
else None
315+
)
316+
317+
dtype_name = str(dtype).replace("torch.", "")
318+
319+
# Apply explicit dtype override with warning
320+
if dtype_explicitly_set and opts.precision is not None:
321+
new_dtype = opts.dtype
322+
if new_dtype != preset_dtype:
323+
if opts.precision in ["fp8", "mxfp8", "nvfp4"] and new_dtype == torch.float16:
324+
dist_print(
325+
"Warning: --dtype float16 may be incompatible with --precision"
326+
f" {opts.precision}, which expects bfloat16 accumulation."
327+
)
328+
329+
dtype = new_dtype
330+
dtype_name = str(dtype).replace("torch.", "")
331+
332+
dist_print(
333+
f"Warning: --dtype {dtype_name} overrides --precision {opts.precision} dtype"
334+
" setting"
335+
)
336+
else:
337+
new_dtype_name = str(new_dtype).replace("torch.", "")
338+
dist_print(
339+
f"Info: --dtype {new_dtype_name} matches --precision {opts.precision} preset"
340+
" default, no override needed"
341+
)
342+
343+
# recipe is already set correctly from preset_recipe above;
344+
# dtype only affects parameter storage, not the quantization recipe
345+
346+
# Always log the final configuration being used
347+
dist_print(
348+
f"Training configuration: dtype={dtype_name}, "
349+
f"quantization={'disabled' if no_fp8 else f'enabled ({type(recipe).__name__})'}"
350+
)
351+
210352
# Construct a simple homogeneous model (only one layer type) with NO PARALLELISM
211353
layer_args, layer_kwargs = get_layer_args(opts)
354+
layer_kwargs["params_dtype"] = dtype
355+
212356
if opts.num_layers > 1:
213357
te_layer_list = []
214358
for i in range(opts.num_layers):
@@ -239,7 +383,7 @@ def train(opts):
239383
process_group=all_gpus,
240384
use_orig_params=True,
241385
mixed_precision=MixedPrecision(
242-
param_dtype=opts.dtype,
386+
param_dtype=dtype,
243387
reduce_dtype=torch.float32,
244388
),
245389
auto_wrap_policy=fsdp_wrap_policy,
@@ -258,10 +402,6 @@ def train(opts):
258402
dist_print(f"Post-FSDP memory use = {post_mem_use}MiB")
259403
dist_print(f"FSDP-Wrapped + Checkpointed TE Model:\n{te_model}")
260404

261-
# Fp8 setup for TE
262-
fp8_format = Format.HYBRID
263-
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")
264-
265405
# Optimizer must be created after the model is wrapped in FSDP and the parameters are sharded
266406
optim = torch.optim.Adam(te_model.parameters(), lr=0.0001)
267407

@@ -275,17 +415,33 @@ def train(opts):
275415
torch.cuda.synchronize()
276416
start.record()
277417

418+
# MXFP8 and NVFP4 use local block scaling — no distributed amax reduction group needed.
419+
# amax_reduction_group is only required for DelayedScaling (global AMAX allreduce).
420+
# Also skip when FP8 is disabled to avoid unnecessary distributed communication.
421+
# Compute amax_group BEFORE the recipe fallback so isinstance() reflects the actual
422+
# recipe, not the defensive DelayedScaling() substituted for None.
423+
amax_group = all_gpus if (not no_fp8 and isinstance(recipe, DelayedScaling)) else None
424+
425+
# Ensure recipe is always a concrete object before passing to te.autocast.
426+
# When FP8 is disabled, te.autocast ignores the recipe, but some TE versions
427+
# perform attribute access on it regardless of the enabled flag.
428+
if recipe is None:
429+
recipe = DelayedScaling(
430+
fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max"
431+
)
432+
278433
for i in range(opts.num_iters):
279434
# Generate a random input batch
280435
x = torch.rand(
281436
opts.seq_length,
282437
opts.batch_size,
283438
opts.num_heads * opts.head_dim,
284-
dtype=opts.dtype,
439+
dtype=dtype,
285440
device="cuda",
286441
)
442+
287443
# autocast needs to be given the FSDP process group for amax reductions
288-
with te.autocast(enabled=not opts.no_fp8, recipe=fp8_recipe, amax_reduction_group=all_gpus):
444+
with te.autocast(enabled=not no_fp8, recipe=recipe, amax_reduction_group=amax_group):
289445
y = te_model(x)
290446
loss = y.sum()
291447
# calculate gradient and take training step outside the autocast context

0 commit comments

Comments
 (0)