Skip to content

guard fuser grad checks on non-leaf nodes#2919

Open
CarlosGomes98 wants to merge 8 commits intoNVIDIA:mainfrom
CarlosGomes98:cgomes/fuser_non_leaf_nodes
Open

guard fuser grad checks on non-leaf nodes#2919
CarlosGomes98 wants to merge 8 commits intoNVIDIA:mainfrom
CarlosGomes98:cgomes/fuser_non_leaf_nodes

Conversation

@CarlosGomes98
Copy link
Copy Markdown
Contributor

@CarlosGomes98 CarlosGomes98 commented Apr 23, 2026

Description

Pass an explicit flag that controls whether the fuser forward pass sets requires_grad on outputs. This is required so that in no_grad mode we dont try to mutate this information on non-leaf nodes (which is not allowed by torch)

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this is really addressing the root cause of the issue. Two problems:

  • We aren't actually protecting against setting requires_grad on non-leaf nodes. We're just skipping requires_grad logic when torch.is_grad_enabled() == True.
  • Do we even want to skip setting requires_grad on non-leaf nodes? The backward expects grads from each of the outputs, so we need requires_grad for autograd to do the right thing.

I think the right solution is smarter logic when setting requires_grad_. Maybe something like:

x_requires_grad = fuser.first_op_requiring_backward < fuser._num_basic_ops
if x_requires_grad != x.requires_grad:
    x = x.detach()
    if x_requires_grad:
        x.requires_grad_()

# Or maybe only detach if x is a non-leaf node?
# Need to check if the CPU overhead of checking
# is worth saving the CPU overhead of detaching.

...

return x

Another approach would be changing our ops to always return leaf nodes. For example, here is the forward pass of MakeExtraOutput:


This would be changed to:

out = input_.detach()
return out, [(out,)] 

Comment thread transformer_engine/pytorch/ops/fuser.py Outdated
for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs):
for y in ys:
y.requires_grad_(idx >= fuser.first_op_requiring_backward)
if func_ctx is not None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic is not intuitive. func_ctx is None when torch.is_grad_enabled() == False:

if is_grad_enabled:
forward_func = _OperationFuserAutogradFunction.apply
args = []
else:
forward_func = _OperationFuserAutogradFunction.forward
args = [None]

It would be better to pass in is_grad_enabled as an arg so that we can be explicit and not rely on secret contracts.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I understand it, the real issue is that when the forward_func is .apply, we are free to set requires_grad_ on returned tensors. But when it is .forward, we cannot mutate this state on non-leaf tensors.

When torch.is_grad_enabled() is false, we bypass .apply and call .forward directly with no func_ctx. In that path there is no OperationFuserAutogradFunction node registered, so no fuser backward will run. So I think this

The backward expects grads from each of the outputs, so we need requires_grad for autograd to do the right thing

is not true because we cannot run backward() through it.

I think it makes sense to pass this as an explicit argument as you say, instead of relying on the func_ctx being None. But I think the current logic is correct

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, the apply/forward distinction makes sense to me. We should include a comment making that clear.

CarlosGomes98 and others added 6 commits April 29, 2026 15:19
Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
for more information, see https://pre-commit.ci

Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
@CarlosGomes98 CarlosGomes98 force-pushed the cgomes/fuser_non_leaf_nodes branch from 55f1e55 to a67b89c Compare April 29, 2026 13:19
@CarlosGomes98 CarlosGomes98 marked this pull request as ready for review April 29, 2026 14:53
@CarlosGomes98 CarlosGomes98 requested a review from timmoon10 April 29, 2026 14:54
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 29, 2026

Greptile Summary

This bug fix guards the two requires_grad_() mutation calls inside _OperationFuserAutogradFunction.forward behind a new set_output_requires_grad boolean flag. The flag is passed as is_grad_enabled from OperationFuser.__call__, so when PyTorch's grad mode is off (e.g. torch.no_grad(), torch.inference_mode()), the fuser calls forward directly without autograd wrapping and the inner ops may produce non-leaf tensors — mutating their requires_grad attribute is an error in that state. The change is minimal, correct, and backward-compatible.

Confidence Score: 5/5

Safe to merge — targeted, minimal bug fix with no regressions identified.

The change is a small, well-scoped fix that correctly guards the two requires_grad_() mutation sites behind a flag derived from is_grad_enabled. The backward return tuple is updated consistently, the comment explains the intent, and both the apply and direct-forward paths are handled correctly. No existing behavior changes when grad is enabled.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/fuser.py Adds set_output_requires_grad flag to _OperationFuserAutogradFunction.forward and threads it through the call site, correctly guarding both requires_grad_() mutation sites; backward return tuple updated to include the corresponding None.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["OperationFuser.__call__()"] --> B{"is_grad_enabled?"}
    B -- "True" --> C["args = (..., set_output_requires_grad=True, ...)"] 
    B -- "False" --> D["args = (..., set_output_requires_grad=False, ...)"] 
    C --> E["_OperationFuserAutogradFunction.apply(*args)"]
    D --> F["_OperationFuserAutogradFunction.forward(None, *args)"]
    E --> G["forward() — tensors detached by autograd machinery\nset_output_requires_grad=True\nrequires_grad_() calls are SAFE"]
    F --> H["forward() — inner ops may produce non-leaf tensors\nset_output_requires_grad=False\nrequires_grad_() calls are SKIPPED"]
    G --> I["return output tensor(s)"]
    H --> I
Loading

Reviews (2): Last reviewed commit: "Revert cudnn-frontend submodule bump" | Re-trigger Greptile

input,
self,
basic_op_kwargs,
is_grad_enabled, # set_output_requires_grad
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Implicit coupling between is_grad_enabled and set_output_requires_grad

is_grad_enabled is passed directly as set_output_requires_grad, hardcoding an equivalence between the two concepts. The guard in forward is really about "are the output tensors leaf nodes (safe to mutate requires_grad_) or not", which happens to correlate with grad being enabled today. If a future caller ever needs grad enabled but the outputs are already non-leaf (or vice versa), the coupling breaks silently. A dedicated flag computed from actual tensor leaf-ness, or at minimum a local variable with an explanatory name, would make the intent more resilient:

# Outputs produced by inner ops in the no-grad path may be non-leaf
# tensors; setting requires_grad_ on non-leaf tensors raises a RuntimeError.
set_output_requires_grad = is_grad_enabled

This is already implicit in the existing comment, so at a minimum a short inline note at the call site explaining why is_grad_enabled serves as the proxy would help future maintainers.

Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
Made-with: Cursor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants