Skip to content

Comments

[PyTorch] Introduce semantic quantizer roles#2620

Open
negvet wants to merge 15 commits intoNVIDIA:mainfrom
negvet:semantic_quantizer_roles
Open

[PyTorch] Introduce semantic quantizer roles#2620
negvet wants to merge 15 commits intoNVIDIA:mainfrom
negvet:semantic_quantizer_roles

Conversation

@negvet
Copy link
Collaborator

@negvet negvet commented Jan 23, 2026

Description

Introducing semantic quantizer roles, e.g. linear:input, layernorm_linear:grad_output.
Emitted by module/op and used through RecipeState.create(., roles=..), so that right quantizers can be constructed without relying on index in a list.

Now used only by CustomRecipe, but can be extended to all recipes.
Also extendable to arbitrary operations, e.g. dpa:qkv and dpa:s (scores) for attention.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

negvet and others added 4 commits January 23, 2026 15:14
…ipe state

Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
@negvet negvet requested review from cyanguwa and timmoon10 January 23, 2026 15:32
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 23, 2026

Greptile Summary

This PR introduces semantic quantizer roles via the QuantizerRole dataclass, replacing string-based role identification with structured objects containing module_type, tensor_type, and name fields.

Key changes:

  • Adds QuantizerRole frozen dataclass in quantization.py with helper method is_gemm()
  • Modules/operations implement get_quantizer_roles() to provide semantic role information
  • RecipeState.create() accepts optional roles parameter, passed to CustomRecipeState
  • CustomRecipeState.make_quantizers() uses QuantizerRole objects instead of hardcoded string patterns
  • Factory functions now receive QuantizerRole objects and inspect fields like role.tensor_type
  • All tests and example factories updated to work with the new API
  • Documentation updated to describe the new structure

Benefits:

  • More flexible and extensible than string-based roles
  • Allows factories to inspect multiple dimensions (module type, tensor type, instance name)
  • Paves the way for extending to non-GEMM operations like attention (DPA)
  • Maintains backward compatibility for non-CustomRecipe recipes

Confidence Score: 5/5

  • This PR is safe to merge with no critical issues found.
  • The changes are well-architected, maintain backward compatibility, include comprehensive test coverage, and follow good software engineering practices. The frozen dataclass approach ensures immutability, all existing tests pass, and the migration path is clear.
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/pytorch/quantization.py Introduces the QuantizerRole frozen dataclass with module_type, tensor_type, and name fields. Updates RecipeState.create() to accept roles parameter and CustomRecipeState.make_quantizers() to use roles instead of hardcoded string patterns.
transformer_engine/pytorch/module/base.py Adds get_quantizer_roles() method to TransformerEngineBaseModule base class that returns None by default. Updates recipe state creation to pass roles from modules.
transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py Updates current_scaling_ref_quantizer_factory() to receive QuantizerRole object instead of string. Uses role.tensor_type to determine dtype (E5M2 for grad tensors, E4M3 otherwise).
transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py Updates nvfp4_ref_rht_2d_quantizer_factory() to receive QuantizerRole object. Uses role.is_gemm() and role.tensor_type to select quantization parameters. Simplifies logic by using default return for most cases.
transformer_engine/common/recipe/init.py Updates CustomRecipe documentation to describe the new QuantizerRole parameter instead of string roles, with detailed field descriptions.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TB
    subgraph "Module/Operation Layer"
        Module["Module (Linear, GroupedLinear, etc.)"]
        Module -->|implements| GetRoles["get_quantizer_roles()"]
        GetRoles -->|returns| RoleList["List[QuantizerRole]"]
    end
    
    subgraph "QuantizerRole Dataclass"
        RoleList --> Role["QuantizerRole(frozen=True)"]
        Role --> ModType["module_type: str<br/>(e.g., 'linear', 'grouped_linear')"]
        Role --> TensType["tensor_type: str<br/>(e.g., 'input', 'weight', 'grad_output')"]
        Role --> Name["name: str<br/>(e.g., 'qkv', 'fc1')"]
        Role --> Helper["is_gemm(): bool"]
    end
    
    subgraph "Recipe State Creation"
        RoleList --> RecipeState["RecipeState.create(recipe, roles=...)"]
        RecipeState --> CustomState["CustomRecipeState"]
        CustomState --> MakeQ["make_quantizers()"]
    end
    
    subgraph "Quantizer Factory"
        MakeQ --> Factory["qfactory(role: QuantizerRole)"]
        Factory --> InspectRole["Inspect role.tensor_type,<br/>role.module_type, role.name"]
        InspectRole --> ReturnQ["Return appropriate<br/>Quantizer instance"]
    end
    
    Factory -.->|example| CurrentScaling["current_scaling_ref_quantizer_factory:<br/>E5M2 for grad_*, E4M3 otherwise"]
    Factory -.->|example| NVFP4["nvfp4_ref_rht_2d_quantizer_factory:<br/>16x16 tiles for GEMM weights,<br/>1x16 with RHT otherwise"]
Loading

Last reviewed commit: a86fdad

@greptile-apps

This comment was marked as off-topic.

Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
greptile-apps[bot]

This comment was marked as outdated.

Signed-off-by: Evgeny <etsykunov@nvidia.com>
greptile-apps[bot]

This comment was marked as resolved.

Signed-off-by: Evgeny <etsykunov@nvidia.com>
greptile-apps[bot]

This comment was marked as resolved.

Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this design is quite clean and generalizable.

position : str
Module-internal sub-slot. For modules that fuse multiple sequential operations,
e.g. `LayerNormMLP` has `"fc1"` and `"fc2"` sub-slots.
Empty string for simple modules.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel name and position are redundant. I see how position is basically just there to accommodate LayerNormMLP, but I'm uneasy about designing just for that (especially since it's not used publicly in Megatron-LM or Megatron-Bridge).

Instead of contorting QuantizerRole to work with LayerNormMLP, how about we contort LayerNormMLP? Instead of the module having a single name, it could have fc1_name and fc2_name.

Comment on lines +1320 to +1329
base = [
QuantizerRole(module_type="linear", tensor_type="input", name=name),
QuantizerRole(module_type="linear", tensor_type="weight", name=name),
QuantizerRole(module_type="linear", tensor_type="output", name=name),
]
else:
base = [
QuantizerRole(module_type="linear", tensor_type="grad_output", name=name),
QuantizerRole(module_type="linear", tensor_type="grad_input", name=name),
]
Copy link
Collaborator

@timmoon10 timmoon10 Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"output" and "grad_input" roles don't make sense. In reality, we are implicitly assuming that the tensor will be consumed by another linear-like layer.

Suggested change
base = [
QuantizerRole(module_type="linear", tensor_type="input", name=name),
QuantizerRole(module_type="linear", tensor_type="weight", name=name),
QuantizerRole(module_type="linear", tensor_type="output", name=name),
]
else:
base = [
QuantizerRole(module_type="linear", tensor_type="grad_output", name=name),
QuantizerRole(module_type="linear", tensor_type="grad_input", name=name),
]
base = [
QuantizerRole(module_type="linear", tensor_type="input", name=name),
QuantizerRole(module_type="linear", tensor_type="weight", name=name),
QuantizerRole(module_type="linear", tensor_type="input", name=name),
]
else:
base = [
QuantizerRole(module_type="linear", tensor_type="grad_output", name=name),
QuantizerRole(module_type="linear", tensor_type="grad_output", name=name),
]

Alternatively, if we want to use the output in FP8 DPA, the right role would be module_type="dpa" and module_type="input". We should probably make this configurable. I kind of like that this design is exposing the hidden assumptions we've been making.

Comment on lines +310 to +314
assert counts["input"] == 1
assert counts["weight"] == 1
assert counts["output"] == 1
assert counts["grad_output"] == 1
assert counts["grad_input"] == 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert counts["input"] == 1
assert counts["weight"] == 1
assert counts["output"] == 1
assert counts["grad_output"] == 1
assert counts["grad_input"] == 1
assert counts["input"] == 2
assert counts["weight"] == 1
assert counts["output"] == 0
assert counts["grad_output"] == 2
assert counts["grad_input"] == 0

negvet and others added 2 commits February 20, 2026 14:31
Signed-off-by: Evgeny Tsykunov <etsykunov@nvidia.com>
greptile-apps[bot]

This comment was marked as resolved.

negvet and others added 5 commits February 20, 2026 15:05
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

15 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Comment on lines +85 to +88
def is_gemm(self) -> bool:
"""Whether this role belongs to a GEMM-based module."""
return self.module_type in self.GEMM_MODULE_TYPES

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is baking in assumptions about what formats are similar (our recent experiences with grouped tensors makes me wonder if the requirements for "linear" and "grouped_linear" will diverge in the future), and it's also not giving us that much convenience.

Suggested change
def is_gemm(self) -> bool:
"""Whether this role belongs to a GEMM-based module."""
return self.module_type in self.GEMM_MODULE_TYPES

@negvet negvet mentioned this pull request Feb 23, 2026
13 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants