Skip to content

Add multi-precision training support to FSDP script#2662

Merged
vthumbe1503 merged 72 commits intoNVIDIA:mainfrom
aagallo:extend-precision
Mar 5, 2026
Merged

Add multi-precision training support to FSDP script#2662
vthumbe1503 merged 72 commits intoNVIDIA:mainfrom
aagallo:extend-precision

Conversation

@aagallo
Copy link
Contributor

@aagallo aagallo commented Feb 9, 2026

Description

This PR adds comprehensive precision parameter support to the FSDP training script, enabling users to configure training with multiple precision formats (FP32, FP16, FP8, MXFP8, NVFP4) via command-line argument. The implementation includes automatic configuration of appropriate dtypes and format-specific recipes for each precision type.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Added precision() type validator function supporting fp32, fp16, fp8, mxfp8, and nvfp4 formats
  • Added --precision command-line argument to parse_fsdp_args() with default value "fp8"
  • Implemented match statement in train() function to configure precision-based training parameters
  • Configured format-specific recipes for each precision type:
    • FP32/FP16: Uses standard PyTorch dtypes with FP8 disabled
    • FP8: Uses DelayedScaling recipe with HYBRID format
    • MXFP8: Uses MXFP8BlockScaling recipe with E4M3 format
    • NVFP4: Uses NVFP4BlockScaling recipe with bfloat16 dtype
  • Set appropriate no_fp8 flags based on precision selection
  • Updated layer_kwargs["params_dtype"] to use precision-determined dtype
  • Imported required recipe classes: MXFP8BlockScaling and NVFP4BlockScaling

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Please reach out to Santosh Bhavani (sbhavani@nvidia.com) for additional context on the work

aagallo and others added 3 commits February 9, 2026 16:28
Enable configurable precision training with support for FP32, FP16, FP8,
MXFP8, and NVFP4 formats. Added precision argument parser and match statement
to configure appropriate dtype and recipe based on selected precision.

- Add precision() type validator function
- Implement precision-based configuration in train()
- Support FP32, FP16, FP8, MXFP8, and NVFP4 formats
- Configure format-specific recipes (DelayedScaling, MXFP8BlockScaling, NVFP4BlockScaling)
- Set appropriate no_fp8 flags based on precision selection

Signed-off-by: aagallo <aagallo@amazon.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 9, 2026

Greptile Summary

This PR adds comprehensive multi-precision training support to the FSDP example script, allowing users to select between fp32, fp16, fp8, mxfp8, and nvfp4 presets via a new --precision flag. It introduces get_precision_preset() to encapsulate dtype/recipe/no_fp8 configuration per format, a StoreExplicitAction to track whether --dtype was explicitly set on the CLI, and proper amax_reduction_group routing (DelayedScaling needs the all_gpus group, while MXFP8BlockScaling/NVFP4BlockScaling use local block scaling and do not).

The implementation correctly handles all five precision presets, properly routes amax reduction groups, avoids logging crashes by computing amax_group before the fallback recipe substitution, and preserves backward-compatible default behavior. Two minor style issues remain:

  • Line 300 contains a misleading dead initialization of preset_dtype that is always overwritten before use, with a comment suggesting false defensive programming
  • The dtype-incompatibility warning (line 323) only checks for torch.float16 but should also warn about torch.float32, which is equally suboptimal for mxfp8/nvfp4 formats

Confidence Score: 4/5

  • PR is safe to merge; all critical logic issues have been resolved, leaving only minor style concerns that do not affect runtime behavior.
  • The implementation correctly addresses all previously identified critical and logic bugs, properly handles all five precision presets, and avoids the logging and recipe selection issues that were flagged in earlier reviews. The two remaining findings are style-level (dead code initialization and incomplete warning logic) that improve code maintainability and completeness but do not affect correctness or functionality.
  • examples/pytorch/fsdp/fsdp.py - two minor style improvements suggested

Comments Outside Diff (2)

  1. examples/pytorch/fsdp/fsdp.py, line 300-304 (link)

    Misleading dead initialization of preset_dtype

    preset_dtype is initialized to opts.dtype on line 300 as a "sensible fallback," but this initial value is always overwritten before it is ever read:

    • When opts.precision is not None (line 303), preset_dtype is immediately reassigned by get_precision_preset() on line 304.
    • The variable is used only once on line 305 to set dtype, then never referenced again.

    The # sensible fallback comment creates a false impression of defensive programming. Since preset_dtype is guaranteed to be set by get_precision_preset() before use (or never used at all in the else branch), the initialization and type annotation can be removed:

  2. examples/pytorch/fsdp/fsdp.py, line 323-327 (link)

    Incomplete dtype compatibility warning

    The incompatibility warning only checks for torch.float16, but using torch.float32 as the parameter dtype with --precision mxfp8 or --precision nvfp4 is equally unusual — these block-scaling formats are designed for bfloat16 accumulation, so float32 parameters largely defeats the performance intent without a warning.

    Consider broadening the condition to also cover torch.float32:

Last reviewed commit: f21d0ac

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +243 to +246
case "fp16":
dtype = torch.bfloat16
no_fp8 = True
case "fp8":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect fp16 dtype
In the case "fp16" branch, the code sets dtype = torch.bfloat16. That contradicts the meaning of fp16 and also diverges from the existing --dtype parsing which supports torch.float16. If a user runs with --precision fp16 expecting fp16 parameters/inputs, they’ll silently get bf16 instead.

Comment on lines +206 to 212
parser.add_argument(
"--precision",
type=precision,
default="fp8",
help="Precision to apply to model training (FP32, FP16, FP8, MXFP8, NVFP4)",
)
return parser.parse_args()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conflicting CLI flags
--precision and --dtype/--no-fp8 now overlap: train() overrides dtype and no_fp8 based on --precision, but --dtype/--no-fp8 are still accepted and used as defaults. As written, --precision fp8 will force no_fp8=False even if the user explicitly passed --no-fp8, and --precision fp32 will ignore an explicit --dtype fp16/bf16. This makes the CLI behavior surprising and hard to reason about; either make --precision derive defaults only when the user didn’t specify --dtype/--no-fp8, or document/enforce precedence (e.g., error on incompatible combinations).

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 9, 2026

Additional Comments (1)

examples/pytorch/fsdp/fsdp.py
FSDP mixed_precision mismatch
layer_kwargs["params_dtype"] and the input tensor dtype are overridden by --precision, but FSDP is still configured with mixed_precision=MixedPrecision(param_dtype=opts.dtype, ...) (and opts.dtype no longer matches the model param dtype when --precision is used). This will cause inconsistent param casting/communication behavior under FSDP for e.g. --precision fp32 (params are fp32 but FSDP thinks they’re bf16) and --precision fp16 (currently sets dtype=torch.bfloat16). FSDP param_dtype should be driven by the same dtype selected in the precision switch, or the precision switch should not override param dtype when FSDP mixed precision is enabled.

Correct FP16 precision to use torch.float16 instead of torch.bfloat16,
and add precedence logic where --dtype and --no-fp8 flags override
--precision when explicitly set, with warnings issued for conflicts.

- Fix case fp16 to use torch.float16 instead of torch.bfloat16
- Add flag precedence detection by comparing against default values
- Implement warning messages when --dtype or --no-fp8 override --precision
- Update argument parser help text to document precedence behavior
- Ensure --dtype and --no-fp8 take precedence over --precision presets

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
@ptrendx ptrendx requested a review from vthumbe1503 February 10, 2026 20:09
aagallo and others added 5 commits February 10, 2026 15:31
Add informative log messages and enhanced help text to clarify precision
configuration behavior and flag precedence for better user transparency.

- Add log message showing which precision preset is being used
- Add warning logs when --dtype or --no-fp8 override --precision
- Add final training configuration log (dtype, FP8 status, recipe)
- Enhance argument parser help text with precedence examples
- Add inline code comments explaining precedence logic

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
Add recipe initialization for fp32 and fp16 precision cases to prevent
undefined variable errors, even though recipe is not used when no_fp8
is set to True.

- Add DelayedScaling recipe setup for fp32 case with no_fp8=True
- Add DelayedScaling recipe setup for fp16 case with no_fp8=True
- Add inline comments explaining recipe is set up but not used by autocast
- Ensure recipe variable is defined in all precision branches for consistency

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
Update flag precedence detection to use sys.argv for checking if --dtype
was explicitly set, ensuring dtype always overrides precision regardless
of whether it matches the default value.

- Add sys import for command-line argument detection
- Change dtype_explicitly_set check to use '--dtype' in sys.argv
- Change no_fp8_explicitly_set check to use '--no-fp8' in sys.argv
- Ensure --dtype bf16 correctly overrides --precision even when matching default
- Maintain warning messages when explicit flags override precision presets

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +296 to +302
case _:
dtype = torch.float16
precision_format = Format.HYBRID
recipe = DelayedScaling(
fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max"
)
no_fp8 = opts.no_fp8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recipe variable not initialized in default case. If precision validator allows an unexpected value, this will cause UnboundLocalError at line 403 when passed to te.autocast().

Suggested change
case _:
dtype = torch.float16
precision_format = Format.HYBRID
recipe = DelayedScaling(
fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max"
)
no_fp8 = opts.no_fp8
case _:
dtype = torch.float16
precision_format = Format.HYBRID
recipe = DelayedScaling(
fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max"
)
no_fp8 = False

Comment on lines +242 to +243
dtype_explicitly_set = "--dtype" in sys.argv
no_fp8_explicitly_set = "--no-fp8" in sys.argv
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parsing sys.argv directly is fragile - will break if args are passed via config file, environment variables, or if arg uses = syntax (--dtype=fp16). Use parser.parse_known_args() or track which args were explicitly set via custom action class.

Example with custom action:

class StoreExplicitAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        setattr(namespace, self.dest, values)
        setattr(namespace, f'{self.dest}_explicitly_set', True)

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 10, 2026

Additional Comments (1)

examples/pytorch/fsdp/fsdp.py
FSDP MixedPrecision uses opts.dtype instead of local dtype variable. When --precision is used without explicit --dtype, FSDP will use the default bfloat16 instead of the precision-determined dtype (e.g., float16 for fp8/mxfp8 or float32 for fp32).

        mixed_precision=MixedPrecision(
            param_dtype=dtype,
            reduce_dtype=torch.float32,
        ),

aagallo and others added 3 commits February 10, 2026 16:15
Replace fragile sys.argv parsing with robust custom argparse action class
to track explicitly set arguments, and fix default precision case to
explicitly set no_fp8 to False for consistent FP8-enabled behavior.

- Add StoreExplicitAction custom action class for tracking explicit arguments
- Update --dtype argument to use StoreExplicitAction
- Replace sys.argv check with getattr for dtype_explicitly_set attribute
- Remove sys import from train() function
- Fix default case to set no_fp8 = False instead of opts.no_fp8
- Ensure recipe variable is properly initialized in all code paths
- Support all argument passing methods including config files and = syntax

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 5 comments

Edit Code Review Agent Settings | Greptile

def train(opts):
# Check which flags were explicitly set
dtype_explicitly_set = getattr(opts, "dtype_explicitly_set", False)
no_fp8_explicitly_set = opts.no_fp8 != False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Detection of explicit --no-fp8 is incorrect - opts.no_fp8 != False is True when the flag is absent (defaults to False). User running with defaults will trigger else branch instead of precision preset logic.

Suggested change
no_fp8_explicitly_set = opts.no_fp8 != False
no_fp8_explicitly_set = "--no-fp8" in sys.argv

@@ -173,7 +193,11 @@ def parse_fsdp_args():
"--no-fp8",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing action=StoreExplicitAction to track explicit --no-fp8 usage - required for precedence logic at line 250

Suggested change
"--no-fp8",
action=StoreExplicitAction,

Comment on lines +332 to +334
dist_print(f"Training configuration: dtype={dtype}, FP8={'disabled' if no_fp8 else 'enabled'}")
if not no_fp8:
dist_print(f"Using FP8 recipe: {type(recipe).__name__}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logging FP8 recipe when no_fp8=True and recipe=None (line 322) will cause AttributeError on type(recipe).__name__

parser.add_argument(
"--precision",
type=precision,
default="fp8",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default changed from implicit FP8 disabled (via --dtype bf16 + --no-fp8 False) to explicit --precision fp8 enabled - existing scripts relying on default behavior will now use FP8 unexpectedly

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 10, 2026

Additional Comments (1)

examples/pytorch/fsdp/fsdp.py
param_dtype set to opts.dtype instead of local dtype variable - ignores precision preset

            param_dtype=dtype,

aagallo and others added 3 commits February 10, 2026 16:47
Remove params_dtype initialization from get_layer_args() and update FSDP
MixedPrecision to use computed dtype variable instead of raw opts.dtype,
ensuring precision presets are properly applied throughout the model.

- Remove params_dtype from get_layer_args() layer_kwargs initialization
- Update FSDP MixedPrecision param_dtype to use computed dtype variable
- Ensure precision preset logic is respected in both layer initialization and FSDP
- Maintain backward compatibility with original FP8-enabled default behavior

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 10, 2026

Additional Comments (1)

examples/pytorch/fsdp/fsdp.py
Broken --dtype parsing

--dtype uses a custom action=StoreExplicitAction, but StoreExplicitAction.__call__() just assigns values directly and never applies the type=torch_dtype conversion. With a custom Action, argparse won’t run the type converter for you, so passing --dtype bf16 will make opts.dtype a raw string (e.g., 'bf16') and later calls like torch.rand(..., dtype=dtype) / MixedPrecision(param_dtype=dtype) will fail.

One fix is to perform conversion inside the action:

    def __call__(self, parser, namespace, values, option_string=None):
        setattr(namespace, self.dest, torch_dtype(values))
        setattr(namespace, f"{self.dest}_explicitly_set", True)

(or alternatively, implement a custom action that subclasses _StoreAction and preserves normal type handling).

aagallo and others added 3 commits February 10, 2026 17:05
Add type converter application in StoreExplicitAction custom action to
ensure --dtype values are properly converted from strings to torch dtype
objects, preventing runtime errors in torch operations.

- Store type converter in StoreExplicitAction.__init__
- Apply type conversion in __call__ before setting attribute value
- Add error handling for invalid type conversions
- Ensure opts.dtype contains torch dtype object, not raw string
- Fix runtime errors in torch.rand() and MixedPrecision() calls

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +340 to +346
if not no_fp8:
precision_format = Format.HYBRID
recipe = DelayedScaling(
fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max"
)
else:
recipe = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CRITICAL BUG: Incorrect recipe selection when --dtype or --no-fp8 is explicitly set

When users explicitly set --dtype but also specify --precision mxfp8 or --precision nvfp4, the code ignores the precision preset and falls back to DelayedScaling recipe instead of the expected MXFP8BlockScaling or NVFP4BlockScaling recipe.

Reproduction:

# User expects MXFP8BlockScaling recipe
./fsdp.py --precision mxfp8 --dtype bf16
# But gets DelayedScaling recipe instead!

Root cause:
Lines 340-344 unconditionally use DelayedScaling when not no_fp8 is true, completely ignoring the opts.precision value.

Expected behavior:
The code should respect the precision preset even when dtype is explicitly set. The logic should check opts.precision and select the appropriate recipe (MXFP8BlockScaling for "mxfp8", NVFP4BlockScaling for "nvfp4", etc.).

Suggested fix:

# Set up recipe based on precision preset
if not no_fp8:
    match opts.precision:
        case "mxfp8":
            recipe = MXFP8BlockScaling(fp8_format=Format.E4M3)
        case "nvfp4":
            recipe = NVFP4BlockScaling()
        case _:
            precision_format = Format.HYBRID
            recipe = DelayedScaling(
                fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max"
            )
else:
    recipe = None

This is a merge-blocking bug as it breaks the advertised functionality of the --precision flag.

Comment on lines +334 to +346
else:
# dtype and/or no_fp8 were explicitly set - they take precedence
dtype = opts.dtype
no_fp8 = opts.no_fp8

# Set up default recipe for FP8 cases
if not no_fp8:
precision_format = Format.HYBRID
recipe = DelayedScaling(
fp8_format=precision_format, amax_history_len=32, amax_compute_algo="max"
)
else:
recipe = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Precision preset ignored

In the else: branch (when --dtype and/or --no-fp8 were explicitly set), the code unconditionally falls back to DelayedScaling whenever FP8 is enabled:

if not no_fp8:
    recipe = DelayedScaling(...)

This ignores opts.precision entirely, so --precision mxfp8 or --precision nvfp4 will silently use DelayedScaling if the user also sets --dtype/--no-fp8 (e.g. --precision mxfp8 --dtype bf16). That breaks the advertised presets and applies the wrong quantization recipe.

Fix: in the explicit-flags path, either (a) still select recipe based on opts.precision when FP8 is enabled, or (b) explicitly error/warn and force opts.precision back to fp8 if you’re going to always use DelayedScaling.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Additional Comments (1)

examples/pytorch/fsdp/fsdp.py, line 439
New DelayedScaling() allocated every iteration

recipe or DelayedScaling() is evaluated on each training iteration. When no_fp8=True, recipe is None, so a fresh DelayedScaling() object is constructed on every call even though enabled=False makes it unused. The original code created the recipe object once outside the loop; the current pattern regresses that behaviour.

Move the fallback outside the loop or replace the or with an upfront assignment so each iteration reuses the same (no-op) object:

        with te.autocast(
            enabled=not no_fp8, recipe=recipe, amax_reduction_group=all_gpus
        ):

And ensure recipe is always non-None before the loop (e.g. recipe = recipe or DelayedScaling() after the configuration block).

aagallo and others added 2 commits March 4, 2026 15:49
…tion

Move 'recipe or DelayedScaling()' fallback to a one-time assignment
before the training loop to avoid allocating a new DelayedScaling()
object on every iteration when FP8 is disabled.

- Add 'if recipe is None: recipe = DelayedScaling()' after the
  configuration block and before the training loop so the fallback
  object is created once and reused across all iterations
- Restore clean 'recipe=recipe' in te.autocast call, matching the
  original code pattern
- Add comment explaining why recipe is always set to a concrete
  object even when FP8 is disabled

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Additional Comments (1)

examples/pytorch/fsdp/fsdp.py, line 441
Recipe fallback initialization belongs outside the training loop

The if recipe is None: recipe = DelayedScaling() guard is placed inside the for i in range(opts.num_iters): loop, but it only has any effect on the very first iteration — after that recipe is no longer None. Keeping it here has two downsides:

  1. The is None check runs on every iteration, even though it only matters once.
  2. It mutates a variable that was established before the loop, making the control-flow confusing for future readers.

Additionally, note that when this fallback fires (i.e. recipe was None), it creates a plain DelayedScaling() with default constructor arguments rather than the custom parameters (fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max") used throughout the rest of the file. While this is harmless at runtime because te.autocast(enabled=False, ...) ignores the recipe, it is inconsistent and could cause silent behavioral differences if no_fp8 logic ever changes.

Move the guard to just after the if opts.precision is not None / else block, before the loop begins:

    # Ensure recipe is always a concrete object before passing to te.autocast.
    # When FP8 is disabled, te.autocast ignores the recipe, but some TE versions
    # perform attribute access on it regardless of the enabled flag.
    if recipe is None:
        recipe = DelayedScaling(fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max")

    for i in range(opts.num_iters):
        x = torch.rand(...)
        with te.autocast(enabled=not no_fp8, recipe=recipe, amax_reduction_group=all_gpus):
            ...

aagallo and others added 2 commits March 4, 2026 15:55
…arameters

Move 'if recipe is None: recipe = DelayedScaling(...)' guard to
just before the training loop instead of inside it to avoid
redundant is-None checks on every iteration and variable mutation
inside the loop.

- Use consistent DelayedScaling parameters (fp8_format=Format.HYBRID,
  amax_history_len=32, amax_compute_algo='max') matching the rest of
  the file, rather than plain DelayedScaling() with default args
- Guard runs once before the loop; recipe is stable for all iterations
- Restores clean 'recipe=recipe' in te.autocast call with no inline
  fallback expression

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Additional Comments (2)

examples/pytorch/fsdp/fsdp.py, line 444
amax_reduction_group incompatible with block-scaling recipes

amax_reduction_group is an FP8-specific parameter used by DelayedScaling to aggregate per-tensor amax statistics across FSDP ranks. Both MXFP8BlockScaling and NVFP4BlockScaling use block-level scaling that does not require (and may not accept) this parameter.

Passing amax_reduction_group=all_gpus unconditionally when using --precision mxfp8 or --precision nvfp4 will likely trigger a TypeError or silently create unnecessary inter-rank communication, depending on which TE version is running.

The amax_reduction_group should only be passed when the active recipe uses per-tensor amax (i.e., DelayedScaling):

# autocast needs to be given the FSDP process group for amax reductions
amax_group = all_gpus if isinstance(recipe, DelayedScaling) else None
with te.autocast(enabled=not no_fp8, recipe=recipe, amax_reduction_group=amax_group):
    y = te_model(x)
    loss = y.sum()

examples/pytorch/fsdp/fsdp.py, line 258
Overly verbose help strings

The help text for --no-fp8, --dtype, and --precision embeds structured prose with PRECEDENCE:, BEHAVIOR:, RATIONALE:, EXAMPLES: sections. This content will be displayed verbatim by argparse --help and will be very hard to read in a terminal (no line-break control, wrapping at 80 columns, etc.).

Concise, one-sentence help strings for the CLI and a README or docstring for the extended documentation would be significantly more user-friendly. For example:

parser.add_argument(
    "--no-fp8",
    action="store_true",
    default=False,
    help=(
        "Disable te.autocast() FP8 context. Incompatible with --precision fp8/mxfp8/nvfp4."
        " Default: False."
    ),
)

The detailed interaction rules can live in a module-level docstring or in the script's README.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

aagallo and others added 2 commits March 4, 2026 16:06
…lp text

Pass amax_reduction_group conditionally based on recipe type and
replace verbose multi-paragraph help strings with concise one-liners.

- Compute amax_group = all_gpus if isinstance(recipe, DelayedScaling)
  else None and pass amax_group to te.autocast; amax_reduction_group
  is a DelayedScaling-specific parameter for per-tensor amax
  aggregation and is not accepted by MXFP8BlockScaling or
  NVFP4BlockScaling which use block-level scaling
- Replace multi-paragraph help strings for --no-fp8, --dtype, and
  --precision (with PRECEDENCE/BEHAVIOR/RATIONALE/EXAMPLES sections)
  with concise one-liner synopses suitable for terminal --help output
- Move detailed precedence rules to module-level docstring

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Additional Comments (1)

examples/pytorch/fsdp/fsdp.py, line 409
amax_group is incorrectly set to all_gpus when FP8 is disabled.

When no_fp8=True, recipe starts as None (from fp32/fp16 presets or --no-fp8 flag), gets replaced with DelayedScaling() on line 408, and then isinstance(recipe, DelayedScaling) on line 409 evaluates to True — causing amax_group to be set to all_gpus even though FP8 is disabled.

This results in te.autocast(enabled=False, ..., amax_reduction_group=all_gpus) at line 422, which could trigger unnecessary distributed all-reduce operations if TE processes the group regardless of the enabled flag.

The guard condition should also check not no_fp8:

    if recipe is None:
        recipe = DelayedScaling()
    amax_group = all_gpus if (not no_fp8 and isinstance(recipe, DelayedScaling)) else None

aagallo and others added 2 commits March 4, 2026 16:18
… comms

Add 'not no_fp8' condition to amax_group assignment to prevent
amax_reduction_group=all_gpus being passed to te.autocast when
FP8 is disabled.

- When no_fp8=True and recipe was None, the DelayedScaling fallback
  causes isinstance(recipe, DelayedScaling) to return True, which
  incorrectly set amax_group=all_gpus even though enabled=False
- Add 'not no_fp8' guard so amax_group is only set to all_gpus when
  FP8 is active AND the recipe is DelayedScaling (per-tensor amax);
  all other cases (FP8 disabled, block-scaling recipes) use None

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Additional Comments (2)

examples/pytorch/fsdp/fsdp.py, line 287
--no-fp8 silently ignored with fp32/fp16 precision presets

The validation guard on line 281 only rejects --no-fp8 when the precision preset is FP8-based (fp8, mxfp8, nvfp4). When --precision fp32 or --precision fp16 is combined with --no-fp8, no warning or error is raised and opts.no_fp8 is never actually read anywhere in the if opts.precision is not None: branch — no_fp8 comes entirely from the preset (which already returns True). A user who explicitly passes --no-fp8 --precision fp32 receives no feedback that the flag is redundant/ignored, which can be confusing.

Consider emitting a warning for this case:

if opts.precision in ["fp32", "fp16"] and opts.no_fp8:
    dist_print(
        f"Warning: --no-fp8 is redundant when using --precision {opts.precision} "
        "(FP8 is already disabled by this preset). The flag will be ignored."
    )

examples/pytorch/fsdp/fsdp.py, line 410
amax_reduction_group implicitly assumed None for MXFP8/NVFP4 — add a comment

The expression silently passes amax_reduction_group=None to te.autocast for MXFP8BlockScaling and NVFP4BlockScaling. While block-scaling recipes typically compute scales locally (no global amax reduction), this is a non-obvious correctness assumption. If a future TE version or variant of these recipes does require a process group, the call will silently fall back to local-only scaling without any diagnostic.

Please add an inline comment explaining why None is intentional, e.g.:

# MXFP8 and NVFP4 use local block scaling — no distributed amax reduction group needed.
# amax_reduction_group is only required for DelayedScaling (global AMAX allreduce).
amax_group = all_gpus if (not no_fp8 and isinstance(recipe, DelayedScaling)) else None

…p=None

Emit a warning when --no-fp8 is combined with a non-FP8 precision
preset and add an inline comment explaining why amax_reduction_group
is None for block-scaling recipes.

- Add warning when opts.precision in ['fp32', 'fp16'] and opts.no_fp8
  is set; FP8 is already disabled by these presets so the flag is
  redundant and silently ignored without this feedback
- Add inline comment on amax_group assignment explaining that
  MXFP8BlockScaling and NVFP4BlockScaling use local block scaling
  and do not require a distributed amax reduction group, and that
  None is also correct when FP8 is disabled to avoid unnecessary
  distributed communication

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Additional Comments (1)

examples/pytorch/fsdp/fsdp.py, line 332
preset_dtype / preset_recipe potentially unbound from static analysis perspective

preset_dtype and preset_recipe (used at lines 319 and 328) are only assigned inside the if opts.precision is not None: block (line 300–303). The outer guard at line 317 includes the same opts.precision is not None condition, so they are always logically initialized when accessed — but Python's static analyzers (mypy, pylint, pyflakes) will still raise an "possibly undefined" / "unbound" warning because the else branch (line 304) never assigns these names.

This can silently break if the condition is ever refactored without updating both blocks, and it will trigger CI lint failures in projects that treat unbound-variable warnings as errors.

Suggested fix — initialize the trio before the if-else so the later access is always safe:

preset_dtype: torch.dtype = opts.dtype   # sensible fallback
preset_recipe = None

if opts.precision is not None:
    preset_dtype, preset_no_fp8, preset_recipe = get_precision_preset(opts.precision)
    ...

aagallo and others added 2 commits March 4, 2026 16:36
Initialize preset_dtype and preset_recipe with fallback values before
the 'if opts.precision is not None' block to prevent static analyzer
warnings about potentially unbound variables.

- Assign preset_dtype = opts.dtype and preset_recipe = None as
  sensible fallbacks before the if-else block; these are overwritten
  by get_precision_preset() when opts.precision is not None and are
  never accessed in the else branch
- Satisfies mypy, pylint, and pyflakes 'possibly undefined' / 'unbound'
  warnings that would otherwise trigger CI lint failures in projects
  treating unbound-variable warnings as errors
- No behavioral change; the if-else logic is unchanged

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Additional Comments (1)

examples/pytorch/fsdp/fsdp.py, line 416
amax_group check races with recipe mutation

After the recipe = DelayedScaling() fallback on line 416, recipe is now a DelayedScaling instance even when no_fp8=True. The amax_group computation on line 420 then evaluates isinstance(recipe, DelayedScaling) as True, relying solely on not no_fp8 to produce None. This is technically correct today, but the invariant is fragile: if the guard condition is ever reordered or the fallback recipe type changes, the FSDP process group could silently be passed to a recipe that doesn't expect it.

A safer pattern is to compute amax_group before the fallback assignment, or guard it only on no_fp8:

# Compute amax_group before the defensive recipe fallback
amax_group = all_gpus if (not no_fp8 and isinstance(recipe, DelayedScaling)) else None

# Ensure recipe is always a concrete object before passing to te.autocast.
if recipe is None:
    recipe = DelayedScaling()

aagallo and others added 2 commits March 4, 2026 16:55
…ce race

Move amax_group computation before the 'if recipe is None' fallback
assignment so isinstance(recipe, DelayedScaling) reflects the actual
user-selected recipe rather than the defensive fallback object.

- When recipe is None (non-FP8 presets or --no-fp8), isinstance
  correctly returns False and amax_group is set to None before the
  fallback substitutes a DelayedScaling instance
- Prevents the fragile ordering dependency where not no_fp8 was the
  sole guard against passing all_gpus to a recipe that doesn't need it
- Add inline comment explaining why amax_group must be computed before
  the recipe fallback to preserve the invariant

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Additional Comments (1)

examples/pytorch/fsdp/fsdp.py, line 338
No warning when overriding dtype to float16 with MXFP8/NVFP4 presets

When a user explicitly passes --dtype fp16 alongside --precision mxfp8 or --precision nvfp4, the code silently overrides the preset's bfloat16 dtype with float16. Both MXFP8 and NVFP4 are designed to accumulate into bfloat16; pairing them with float16 parameters may not be supported by TE (or may produce incorrect/undefined results at runtime). The PR description even states "NVFP4: Uses NVFP4BlockScaling recipe with bfloat16 dtype", which implies bfloat16 is required for that preset.

Consider adding an explicit warning (or an error) in the override block when the overridden dtype is incompatible with the selected precision recipe:

if dtype_explicitly_set and opts.precision is not None:
    new_dtype = opts.dtype
    if new_dtype != preset_dtype:
        # Warn if overriding to float16 for FP8-family presets that require bfloat16
        if opts.precision in ["fp8", "mxfp8", "nvfp4"] and new_dtype == torch.float16:
            dist_print(
                f"Warning: --dtype float16 may be incompatible with --precision"
                f" {opts.precision}, which expects bfloat16 accumulation."
            )
        dtype = new_dtype
        ...

aagallo and others added 3 commits March 4, 2026 17:24
… presets

Add explicit warning when --dtype float16 is combined with
--precision fp8, mxfp8, or nvfp4, which expect bfloat16 accumulation.

- Emit compatibility warning before applying the dtype override when
  opts.precision is in ['fp8', 'mxfp8', 'nvfp4'] and new_dtype is
  torch.float16; these presets are designed for bfloat16 accumulation
  and pairing with float16 may produce incorrect or undefined results
- Warning is emitted in addition to the existing dtype override warning
  so users see both the compatibility concern and the override confirmation
- Override is still applied (not blocked) to preserve user control;
  users who know their TE version supports float16 accumulation can
  proceed with awareness of the risk

Signed-off-by: Andrea Gallo <aagallo@amazon.com>
@vthumbe1503
Copy link
Collaborator

@aagallo changes look good to me. Are there any other changes you need to make before merging the PR?

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Additional Comments (3)

examples/pytorch/fsdp/fsdp.py, line 204
--no-fp8 not tracked with StoreExplicitAction

--no-fp8 still uses action="store_true" (no explicit-set tracking), so opts.no_fp8 = False is ambiguous: it could mean the user never passed it OR the user wrote --no-fp8 False (unsupported syntax for store_true). More importantly, when --precision fp32 or --precision fp16 is combined with the default --no-fp8=False, the code at line 288 will never print the redundancy warning (correct), but if a user somehow produces no_fp8=False via environment trickery the code silently ignores it.

More critically: --no-fp8 is the only flag in the combined --precision/--dtype/--no-fp8 trio that cannot distinguish "user explicitly passed it" from "left at default". For completeness and defensive programming, consider using StoreExplicitAction (or BooleanOptionalAction) on --no-fp8 the same way --dtype does, so the precedence logic is symmetrical.

    parser.add_argument(
        "--no-fp8",
        action=StoreExplicitAction,
        nargs=0,
        default=False,
        help=(
            "Disable te.autocast() FP8 context. Incompatible with --precision fp8/mxfp8/nvfp4."
            " Default: False."
        ),
    )

(Alternatively, keep store_true but document that --no-fp8 is only meaningful when --precision is absent or set to fp32/fp16.)


examples/pytorch/fsdp/fsdp.py, line 431
amax_group computed before the None-recipe fallback, but comment mixes concerns

The comment at line 421–422 says "Compute amax_group BEFORE the recipe fallback so isinstance() reflects the actual recipe". This is correct and intentional, but there is a subtle contract violation in the surrounding code:

When no_fp8=True (FP8 disabled), recipe is None (set in the else branch at line 314). Line 423 evaluates isinstance(recipe, DelayedScaling) on None, which Python handles safely (False), so amax_group=None as expected. Then lines 428–431 replace None with a concrete DelayedScaling object "for TE versions that access recipe attributes even when disabled". This defensive instantiation is fine, but it means the recipe object passed to te.autocast when no_fp8=True is a different object than what amax_group was computed from. While this does not cause a bug today (because amax_group=None in that path), a future reader might be confused by the two-phase setup. Consider unifying by making recipe non-None earlier (in the preset/else branches) so the fallback here is unnecessary, or add a one-line comment clarifying the invariant:

# amax_group is already fixed; the fallback below only affects te.autocast's recipe
# attribute accesses, not the group selection.
amax_group = all_gpus if (not no_fp8 and isinstance(recipe, DelayedScaling)) else None

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!


examples/pytorch/fsdp/fsdp.py, line 305
Dead initialisation of preset_dtype and preset_recipe

Lines 300–301 initialise preset_dtype = opts.dtype and preset_recipe = None as "sensible fallbacks". However:

  • preset_dtype is only read at line 322 (if new_dtype != preset_dtype), which is inside the if dtype_explicitly_set and opts.precision is not None block — a block that requires opts.precision is not None. Since that condition is true, line 304 has already overwritten preset_dtype with the real preset value before line 322 is reached. The initialisation on line 300 is therefore never the value actually used in the comparison.
  • preset_recipe is overwritten on line 304 immediately if opts.precision is not None, and never read in the else branch.

These initialisations add noise and could mislead a reader into thinking opts.dtype is being used as a fallback. Consider removing them and relying solely on the assignment in the if branch:

# preset_dtype / preset_recipe declared in the if-block below; no fallback needed.
if opts.precision is not None:
    preset_dtype, preset_no_fp8, preset_recipe = get_precision_preset(opts.precision)
    dtype, no_fp8, recipe = preset_dtype, preset_no_fp8, preset_recipe
    ...

@aagallo
Copy link
Contributor Author

aagallo commented Mar 4, 2026

@vthumbe1503 I am mostly battling with the feedback from greptile. If you are good, I am good as well to merge! Thanks for the support

@vthumbe1503 vthumbe1503 merged commit 145e88c into NVIDIA:main Mar 5, 2026
9 of 12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants