[Pytorch] Add get_backward_dw_params api for TE module#2614
[Pytorch] Add get_backward_dw_params api for TE module#2614Wohox wants to merge 2 commits intoNVIDIA:mainfrom
Conversation
Greptile OverviewGreptile SummaryThis PR fixes a critical bug where gradient reduction hooks were not being triggered after CUDA graph replay in weight gradient computation, which would cause parameters to skip gradient reduction in Megatron-LM. Key Changes:
Technical Context: Confidence Score: 5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant MegatronLM
participant GraphCallable
participant CUDAGraph
participant TEModule
participant Hooks
Note over MegatronLM,Hooks: CUDA Graph Replay Flow
MegatronLM->>GraphCallable: backward_dw()
GraphCallable->>CUDAGraph: replay wgrad graph
Note over CUDAGraph: CUDA operations execute<br/>(Python code skipped)
CUDAGraph-->>GraphCallable: replay complete
Note over GraphCallable,TEModule: New: Trigger hooks after replay
GraphCallable->>TEModule: check need_backward_dw()
TEModule-->>GraphCallable: returns True
GraphCallable->>TEModule: _trigger_wgrad_accumulation_and_reduce_hooks()
TEModule->>Hooks: execute registered hooks
Note over Hooks: Gradient reduction occurs
Hooks-->>TEModule: hooks complete
TEModule-->>GraphCallable: hooks triggered
GraphCallable-->>MegatronLM: backward_dw complete
|
Greptile's behavior is changing!From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section. This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR". |
| Get the parameters for the backward weight gradient computation. | ||
| """ | ||
| params = [] | ||
| params.append(noop_cat(self._get_weight_tensors())) |
There was a problem hiding this comment.
logic: in backward_dw() (line 1520-1522), weight tensors are only accessed when not self.fuse_wgrad_accumulation, but this method unconditionally returns weight parameters. depending on Megatron-LM's usage, this could cause hooks to be registered on parameters that shouldn't have them when fuse_wgrad_accumulation=True
There was a problem hiding this comment.
commit content reverted.
|
/te-ci pytorch L1 |
Signed-off-by: Pingtian Li <pingtianl@nvidia.com>
0906e63 to
5dfe8c1
Compare
Signed-off-by: Pingtian Li <pingtianl@nvidia.com>
|
/te-ci pytorch L1 |
Description
This PR adds
get_backward_dw_paramsfor TE modules, which helps manage the hooks of parameters.For Megatron-LM,
get_backward_dw_paramswill be called once the wgrad cuda graph is executed. Currently the backward_post_hook of wgrad computation is discarded and will cause parameters to skip grad reduce.Type of change
Changes
Please list the changes introduced in this PR:
Checklist: