Skip to content

[Pytorch] Add get_backward_dw_params api for TE module#2614

Open
Wohox wants to merge 2 commits intoNVIDIA:mainfrom
Wohox:pingtian/add_linear_wgrad_compute_param_api
Open

[Pytorch] Add get_backward_dw_params api for TE module#2614
Wohox wants to merge 2 commits intoNVIDIA:mainfrom
Wohox:pingtian/add_linear_wgrad_compute_param_api

Conversation

@Wohox
Copy link
Contributor

@Wohox Wohox commented Jan 22, 2026

Description

This PR adds get_backward_dw_params for TE modules, which helps manage the hooks of parameters.

For Megatron-LM, get_backward_dw_params will be called once the wgrad cuda graph is executed. Currently the backward_post_hook of wgrad computation is discarded and will cause parameters to skip grad reduce.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 22, 2026

Greptile Overview

Greptile Summary

This PR fixes a critical bug where gradient reduction hooks were not being triggered after CUDA graph replay in weight gradient computation, which would cause parameters to skip gradient reduction in Megatron-LM.

Key Changes:

  • Refactored hook triggering logic into a reusable _trigger_wgrad_accumulation_and_reduce_hooks() method in TransformerEngineBaseModule
  • Updated all subclasses (GroupedLinear, LayerNormMLP) to use the new method for consistency
  • Added hook triggering in graph.py after wgrad CUDA graph replay to ensure hooks execute even though Python code is skipped during graph replay

Technical Context:
When CUDA graphs are replayed, only the CUDA operations are executed - Python code including hook triggering is not executed. This PR correctly addresses this by manually triggering the wgrad accumulation and reduce hooks after the graph replay completes, ensuring gradient reduction occurs as expected for CUDA-graphed modules.

Confidence Score: 5/5

  • This PR is safe to merge with no issues found
  • The implementation correctly addresses the CUDA graph hook triggering issue with proper type checks and method extraction. All hook triggering locations have been consistently refactored, and the new code follows established patterns in the codebase.
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/base.py Extracted hook triggering logic into reusable _trigger_wgrad_accumulation_and_reduce_hooks() method
transformer_engine/pytorch/graph.py Added hook triggering after wgrad graph replay to ensure gradient reduction occurs for CUDA-graphed modules

Sequence Diagram

sequenceDiagram
    participant MegatronLM
    participant GraphCallable
    participant CUDAGraph
    participant TEModule
    participant Hooks

    Note over MegatronLM,Hooks: CUDA Graph Replay Flow

    MegatronLM->>GraphCallable: backward_dw()
    GraphCallable->>CUDAGraph: replay wgrad graph
    Note over CUDAGraph: CUDA operations execute<br/>(Python code skipped)
    CUDAGraph-->>GraphCallable: replay complete
    
    Note over GraphCallable,TEModule: New: Trigger hooks after replay
    
    GraphCallable->>TEModule: check need_backward_dw()
    TEModule-->>GraphCallable: returns True
    GraphCallable->>TEModule: _trigger_wgrad_accumulation_and_reduce_hooks()
    TEModule->>Hooks: execute registered hooks
    Note over Hooks: Gradient reduction occurs
    Hooks-->>TEModule: hooks complete
    TEModule-->>GraphCallable: hooks triggered
    GraphCallable-->>MegatronLM: backward_dw complete
Loading

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 22, 2026

Greptile's behavior is changing!

From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section.

This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR".

@Wohox
Copy link
Contributor Author

Wohox commented Jan 22, 2026

@buptzyb @lhb8125 Please help review this PR, thanks!

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Get the parameters for the backward weight gradient computation.
"""
params = []
params.append(noop_cat(self._get_weight_tensors()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: in backward_dw() (line 1520-1522), weight tensors are only accessed when not self.fuse_wgrad_accumulation, but this method unconditionally returns weight parameters. depending on Megatron-LM's usage, this could cause hooks to be registered on parameters that shouldn't have them when fuse_wgrad_accumulation=True

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

commit content reverted.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@Wohox
Copy link
Contributor Author

Wohox commented Jan 30, 2026

@ksivaman Can you help review this PR, it's a bug fix for #2376.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@lhb8125
Copy link
Contributor

lhb8125 commented Feb 5, 2026

/te-ci pytorch L1

Signed-off-by: Pingtian Li <pingtianl@nvidia.com>
@Wohox Wohox force-pushed the pingtian/add_linear_wgrad_compute_param_api branch from 0906e63 to 5dfe8c1 Compare February 5, 2026 03:30
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Pingtian Li <pingtianl@nvidia.com>
@Wohox
Copy link
Contributor Author

Wohox commented Feb 5, 2026

/te-ci pytorch L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@lhb8125 lhb8125 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants