Skip to content

Commit 74faf7e

Browse files
[PyTorch Debug] NVFP4 debug stats support (#2296)
* init Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * year update in license Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Pawel Gadzinski <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 29b84c1 commit 74faf7e

File tree

11 files changed

+636
-87
lines changed

11 files changed

+636
-87
lines changed

docs/debug/1_getting_started.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ Transformer Engine provides a set of precision debug tools which allow you to ea
1515
- log the statistics for each of the tensors in every matrix multiply (GEMM) operation,
1616
- run selected GEMMs in higher precision,
1717
- run current scaling - with one scaling factor per tensor - for particular GEMMs,
18-
- test new precisions and integrate them with FP8 training,
18+
- test new precisions and integrate them with quantized training (FP8, NVFP4, etc.),
1919
- ... and many more.
2020

2121
There are 4 things one needs to do to use Transformer Engine debug features:

docs/debug/3_api_features.rst

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@ Debug features
88

99
.. autoapiclass:: transformer_engine.debug.features.log_tensor_stats.LogTensorStats
1010
.. autoapiclass:: transformer_engine.debug.features.log_fp8_tensor_stats.LogFp8TensorStats
11-
.. autoapiclass:: transformer_engine.debug.features.disable_fp8_gemm.DisableFP8GEMM
12-
.. autoapiclass:: transformer_engine.debug.features.disable_fp8_layer.DisableFP8Layer
11+
.. autoapiclass:: transformer_engine.debug.features.log_nvfp4_tensor_stats.LogNvfp4TensorStats
12+
.. autoapiclass:: transformer_engine.debug.features.disable_quantization_gemm.DisableQuantizationGEMM
13+
.. autoapiclass:: transformer_engine.debug.features.disable_quantization_layer.DisableQuantizationLayer
1314
.. autoapiclass:: transformer_engine.debug.features.per_tensor_scaling.PerTensorScaling
1415
.. autoapiclass:: transformer_engine.debug.features.fake_quant.FakeQuant
16+
.. autoapiclass:: transformer_engine.debug.features.disable_fp8_gemm.DisableFP8GEMM
17+
.. autoapiclass:: transformer_engine.debug.features.disable_fp8_layer.DisableFP8Layer

tests/pytorch/debug/test_log.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
is_fp8_available,
1616
is_mxfp8_available,
1717
is_fp8_block_scaling_available,
18+
is_nvfp4_available,
1819
)
1920
from transformer_engine.pytorch.quantization import RecipeState
2021
from transformer_engine.debug.pytorch.debug_state import TEDebugState
@@ -29,6 +30,7 @@
2930
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = is_fp8_block_scaling_available(
3031
return_reason=True
3132
)
33+
nvfp4_available, reason_for_no_nvfp4 = is_nvfp4_available(return_reason=True)
3234

3335
LOG_QUANTIZED_CONFIG_BASE = """
3436
log:
@@ -363,6 +365,124 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs):
363365
TEDebugState._reset()
364366

365367

368+
# NVFP4 tests
369+
LOG_NVFP4_CONFIG_BASE = """
370+
log:
371+
layers:
372+
layer_name_regex_pattern: .*
373+
enabled:
374+
True
375+
transformer_engine:
376+
LogNvfp4TensorStats:
377+
enabled: True
378+
stats: [
379+
{stats}
380+
]
381+
tensors: [activation, gradient, weight]
382+
freq: 2
383+
start_step: 0
384+
end_step: 10
385+
"""
386+
387+
388+
def test_nvfp4_numeric(feature_dirs):
389+
"""Test that NVFP4 underflows% and MSE stats are computed correctly with known values."""
390+
if not nvfp4_available:
391+
pytest.skip(reason_for_no_nvfp4)
392+
393+
log_nvfp4_config = LOG_NVFP4_CONFIG_BASE.format(stats="underflows%, mse")
394+
395+
with debug_session(log_nvfp4_config, feature_dirs) as log_dir:
396+
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
397+
from transformer_engine.pytorch.quantization import RecipeState
398+
399+
recipe_state = RecipeState.create(
400+
recipe.NVFP4BlockScaling(),
401+
mode="forward",
402+
num_quantizers=3,
403+
)
404+
405+
# Create test tensor with known distribution
406+
torch.manual_seed(42)
407+
tensor = torch.randn(128, 128, dtype=torch.bfloat16).cuda()
408+
# Add some small values that should underflow to zero in FP4
409+
tensor[0, :16] = 0.0001
410+
411+
quantizer = recipe_state.make_quantizers()[0]
412+
quantized_tensor = quantizer(tensor)
413+
414+
debug_api.transformer_engine.inspect_tensor(
415+
layer_name="test_layer",
416+
tensor_name="activation",
417+
iteration=0,
418+
tp_group=None,
419+
tensor=tensor,
420+
quantizer=quantizer,
421+
rowwise_quantized_tensor=quantized_tensor,
422+
columnwise_quantized_tensor=quantized_tensor,
423+
)
424+
debug_api.step()
425+
426+
dequantized_tensor = quantized_tensor.dequantize()
427+
output = read_log(log_dir)
428+
429+
# Validate both stats are present
430+
assert "nvfp4_underflows%" in output, "underflows% stat missing"
431+
assert "nvfp4_mse" in output, "mse stat missing"
432+
433+
# Extract values and validate numerics
434+
underflows_value = None
435+
mse_value = None
436+
437+
for line in output.splitlines():
438+
if "nvfp4_underflows%" in line and "value=" in line:
439+
underflows_value = float(line.split("value=")[1].split()[0])
440+
if "nvfp4_mse" in line and "value=" in line:
441+
mse_value = float(line.split("value=")[1].split()[0])
442+
443+
# Compute expected underflows: non-zero elements that became zero after quantization
444+
orig_nonzero_mask = tensor != 0
445+
dequant_zero_mask = dequantized_tensor == 0
446+
expected_underflows = (
447+
(orig_nonzero_mask & dequant_zero_mask).sum().float() / tensor.numel() * 100
448+
)
449+
450+
# Allow some tolerance
451+
assert underflows_value == pytest.approx(expected_underflows.cpu().item(), abs=1e-4)
452+
453+
# Compute expected MSE
454+
expected_mse = torch.nn.functional.mse_loss(
455+
dequantized_tensor.float(), tensor.float(), reduction="mean"
456+
)
457+
458+
assert mse_value == pytest.approx(expected_mse.cpu().item(), abs=1e-4)
459+
460+
461+
def test_fp8_stats_allows_nvfp4_with_recipe_prefix(feature_dirs):
462+
"""Test that LogFp8TensorStats allows recipe-prefixed stats with NVFP4 for what-if analysis."""
463+
if not nvfp4_available:
464+
pytest.skip(reason_for_no_nvfp4)
465+
466+
# Use recipe-prefixed stat with NVFP4 - should work (computes MXFP8 separately)
467+
log_fp8_config = LOG_QUANTIZED_CONFIG_BASE.format(stats="mxfp8_mse")
468+
469+
with debug_session(log_fp8_config, feature_dirs) as log_dir:
470+
model = te.Linear(128, 128, params_dtype=torch.bfloat16)
471+
inp = torch.randn(128, 128, dtype=torch.bfloat16).cuda()
472+
473+
# Should work - recipe-prefixed stats compute MXFP8 separately for comparison
474+
for _ in range(2):
475+
with te.autocast(recipe=recipe.NVFP4BlockScaling()):
476+
output = model(inp)
477+
loss = output.sum()
478+
loss.backward()
479+
debug_api.step()
480+
481+
output = read_log(log_dir)
482+
# Should have logged MXFP8 MSE stat (what-if scenario)
483+
assert "mxfp8_mse" in output
484+
485+
366486
def test_log_grouped_gemm(feature_dirs):
367487
if not fp8_available:
368488
pytest.skip(reason_for_no_fp8)

transformer_engine/debug/features/disable_fp8_gemm.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,28 @@
22
#
33
# See LICENSE for license information.
44

5-
"""DisableFP8GEMM Feature support for nvidia-dlframework-inspect"""
5+
"""DisableFP8GEMM Feature support for nvidia-dlframework-inspect
66
7-
from nvdlfw_inspect.registry import Registry, api_method
8-
from transformer_engine.debug.features.api import TEConfigAPIMapper
7+
DEPRECATED: This is a backward compatibility alias for DisableQuantizationGEMM.
8+
New code should use DisableQuantizationGEMM instead, which works with all quantization formats.
9+
"""
10+
11+
import warnings
12+
13+
from nvdlfw_inspect.registry import Registry
14+
from transformer_engine.debug.features.disable_quantization_gemm import DisableQuantizationGEMM
915

1016

1117
@Registry.register_feature(namespace="transformer_engine")
12-
class DisableFP8GEMM(TEConfigAPIMapper):
18+
class DisableFP8GEMM(DisableQuantizationGEMM):
1319
"""
1420
GEMM operations are executed in higher precision, even when FP8 autocast is enabled.
1521
22+
.. deprecated::
23+
Use :class:`DisableQuantizationGEMM` instead. This class is maintained for
24+
backward compatibility only. DisableQuantizationGEMM works with all quantization
25+
formats (FP8, NVFP4, etc.), not just FP8.
26+
1627
Parameters
1728
----------
1829
@@ -32,22 +43,17 @@ class DisableFP8GEMM(TEConfigAPIMapper):
3243
layers:
3344
layer_types: [fc1]
3445
transformer_engine:
35-
DisableFP8GEMM:
46+
DisableFP8GEMM: # Deprecated: use DisableQuantizationGEMM
3647
enabled: True
3748
gemms: [dgrad, wgrad]
3849
"""
3950

40-
@api_method
41-
def fp8_gemm_enabled(
42-
self, config, layer_name: str, gemm: str, iteration: int
43-
): # pylint: disable=unused-argument
44-
"""API call responsible for choice between high-precision and FP8 GEMM execution."""
45-
46-
for key in config:
47-
if key != "gemm":
48-
raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".')
49-
50-
# If this feature is invoked, then FP8 GEMM is disabled.
51-
# If not, then default behaviour in TransformerEngineAPI
52-
# is that fp8_gemm() API call returns True.
53-
return False, iteration + 1
51+
def __init__(self, *args, **kwargs):
52+
warnings.warn(
53+
"DisableFP8GEMM is deprecated. "
54+
"Use DisableQuantizationGEMM instead, which works with all quantization "
55+
"formats (FP8, NVFP4, etc.).",
56+
DeprecationWarning,
57+
stacklevel=2,
58+
)
59+
super().__init__(*args, **kwargs)

transformer_engine/debug/features/disable_fp8_layer.py

Lines changed: 28 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,54 +2,47 @@
22
#
33
# See LICENSE for license information.
44

5-
"""DisableFP8Layer Feature support for nvidia-dlframework-inspect"""
5+
"""DisableFP8Layer Feature support for nvidia-dlframework-inspect
66
7-
import nvdlfw_inspect.api as debug_api
8-
from nvdlfw_inspect.registry import Registry, api_method
7+
DEPRECATED: This is a backward compatibility alias for DisableQuantizationLayer.
8+
New code should use DisableQuantizationLayer instead, which works with all quantization formats.
9+
"""
10+
11+
import warnings
12+
13+
from nvdlfw_inspect.registry import Registry
14+
from transformer_engine.debug.features.disable_quantization_layer import DisableQuantizationLayer
915

1016

1117
@Registry.register_feature(namespace="transformer_engine")
12-
class DisableFP8Layer:
18+
class DisableFP8Layer(DisableQuantizationLayer):
1319
"""
1420
Disables all FP8 GEMMs in the layer.
1521
22+
.. deprecated::
23+
Use :class:`DisableQuantizationLayer` instead. This class is maintained for
24+
backward compatibility only. DisableQuantizationLayer works with all quantization
25+
formats (FP8, NVFP4, etc.), not just FP8.
1626
1727
Example
1828
-------
1929
.. code-block:: yaml
2030
2131
example_disable_fp8_layer:
2232
enabled: True
23-
layers:
24-
layer_types: [fc1]
25-
transformer_engine:
26-
DisableFP8Layer:
27-
enabled: True
33+
layers:
34+
layer_types: [fc1]
35+
transformer_engine:
36+
DisableFP8Layer: # Deprecated: use DisableQuantizationLayer
37+
enabled: True
2838
"""
2939

30-
@api_method
31-
def fp8_gemm_enabled(
32-
self, config, layer_name: str, gemm: str, iteration: int
33-
): # pylint: disable=unused-argument
34-
"""API call responsible for selecting between high-precision and FP8 GEMM execution."""
35-
for key in config:
36-
if key not in ["enabled", "gemm"]:
37-
raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".')
38-
# If FP8 training, disable FP8 for the selected layers if this feature is enabled in config.
39-
debug_api.log_message("FP8 Disabled", layer_name)
40-
41-
# If this feature is invoked, then FP8 GEMM is disabled.
42-
# If not, then default behavior in TransformerEngineAPI
43-
# is that fp8_gemm() API call returns True.
44-
return False, iteration + 1
45-
46-
def parse_config_and_api(self, config, **_kwargs):
47-
"""Determines whether to run the API
48-
DisableFP8Layer is the only feature provided by the Transformer Engine
49-
which does not inherit from TEConfigAPIMapper - this mapper is primarly responsible for
50-
parsing gemms and tensors fields from the config, which are not needed for this feature.
51-
52-
Explanation of the parse_config_and_api can be found in the
53-
nvidia-dlframework-inspect documentation.
54-
"""
55-
return config["enabled"], None
40+
def __init__(self, *args, **kwargs):
41+
warnings.warn(
42+
"DisableFP8Layer is deprecated. "
43+
"Use DisableQuantizationLayer instead, which works with all quantization "
44+
"formats (FP8, NVFP4, etc.).",
45+
DeprecationWarning,
46+
stacklevel=2,
47+
)
48+
super().__init__(*args, **kwargs)
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# See LICENSE for license information.
4+
5+
"""DisableQuantizationGEMM Feature support for nvidia-dlframework-inspect"""
6+
7+
from nvdlfw_inspect.registry import Registry, api_method
8+
from transformer_engine.debug.features.api import TEConfigAPIMapper
9+
10+
11+
@Registry.register_feature(namespace="transformer_engine")
12+
class DisableQuantizationGEMM(TEConfigAPIMapper):
13+
"""
14+
Disables specific GEMM operations from using quantization, forcing high-precision execution.
15+
16+
Works with any quantization format (FP8, NVFP4, etc.).
17+
18+
Parameters
19+
----------
20+
21+
gemms: List[str]
22+
list of gemms to disable quantization for
23+
24+
- fprop
25+
- dgrad
26+
- wgrad
27+
28+
Example
29+
-------
30+
.. code-block:: yaml
31+
32+
example_disable_quantization_gemm:
33+
enabled: True
34+
layers:
35+
layer_types: [fc1]
36+
transformer_engine:
37+
DisableQuantizationGEMM:
38+
enabled: True
39+
gemms: [dgrad, wgrad]
40+
"""
41+
42+
@api_method
43+
def fp8_gemm_enabled(
44+
self, config, layer_name: str, gemm: str, iteration: int
45+
): # pylint: disable=unused-argument
46+
"""API call responsible for choice between high-precision and quantized GEMM execution.
47+
48+
Note: Method name kept as 'fp8_gemm_enabled' for backward compatibility with the debug API,
49+
but it applies to all quantization formats (FP8, NVFP4, etc.).
50+
"""
51+
52+
for key in config:
53+
if key != "gemm":
54+
raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".')
55+
56+
# If this feature is invoked, then quantized GEMM is disabled (returns to high precision).
57+
# If not, then default behavior in TransformerEngineAPI
58+
# is that fp8_gemm() API call returns True.
59+
return False, iteration + 1

0 commit comments

Comments
 (0)