guard fuser grad checks on non-leaf nodes#2919
guard fuser grad checks on non-leaf nodes#2919CarlosGomes98 wants to merge 8 commits intoNVIDIA:mainfrom
Conversation
There was a problem hiding this comment.
I'm not sure if this is really addressing the root cause of the issue. Two problems:
- We aren't actually protecting against setting
requires_gradon non-leaf nodes. We're just skippingrequires_gradlogic whentorch.is_grad_enabled() == True. - Do we even want to skip setting
requires_gradon non-leaf nodes? The backward expects grads from each of the outputs, so we needrequires_gradfor autograd to do the right thing.
I think the right solution is smarter logic when setting requires_grad_. Maybe something like:
x_requires_grad = fuser.first_op_requiring_backward < fuser._num_basic_ops
if x_requires_grad != x.requires_grad:
x = x.detach()
if x_requires_grad:
x.requires_grad_()
# Or maybe only detach if x is a non-leaf node?
# Need to check if the CPU overhead of checking
# is worth saving the CPU overhead of detaching.
...
return xAnother approach would be changing our ops to always return leaf nodes. For example, here is the forward pass of MakeExtraOutput:
This would be changed to:
out = input_.detach()
return out, [(out,)] | for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs): | ||
| for y in ys: | ||
| y.requires_grad_(idx >= fuser.first_op_requiring_backward) | ||
| if func_ctx is not None: |
There was a problem hiding this comment.
This logic is not intuitive. func_ctx is None when torch.is_grad_enabled() == False:
TransformerEngine/transformer_engine/pytorch/ops/fuser.py
Lines 504 to 509 in 0c2e7b0
It would be better to pass in is_grad_enabled as an arg so that we can be explicit and not rely on secret contracts.
There was a problem hiding this comment.
As I understand it, the real issue is that when the forward_func is .apply, we are free to set requires_grad_ on returned tensors. But when it is .forward, we cannot mutate this state on non-leaf tensors.
When torch.is_grad_enabled() is false, we bypass .apply and call .forward directly with no func_ctx. In that path there is no OperationFuserAutogradFunction node registered, so no fuser backward will run. So I think this
The backward expects grads from each of the outputs, so we need requires_grad for autograd to do the right thing
is not true because we cannot run backward() through it.
I think it makes sense to pass this as an explicit argument as you say, instead of relying on the func_ctx being None. But I think the current logic is correct
There was a problem hiding this comment.
I see, the apply/forward distinction makes sense to me. We should include a comment making that clear.
Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
for more information, see https://pre-commit.ci Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
55f1e55 to
a67b89c
Compare
for more information, see https://pre-commit.ci
Greptile SummaryThis bug fix guards the two Confidence Score: 5/5Safe to merge — targeted, minimal bug fix with no regressions identified. The change is a small, well-scoped fix that correctly guards the two No files require special attention. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["OperationFuser.__call__()"] --> B{"is_grad_enabled?"}
B -- "True" --> C["args = (..., set_output_requires_grad=True, ...)"]
B -- "False" --> D["args = (..., set_output_requires_grad=False, ...)"]
C --> E["_OperationFuserAutogradFunction.apply(*args)"]
D --> F["_OperationFuserAutogradFunction.forward(None, *args)"]
E --> G["forward() — tensors detached by autograd machinery\nset_output_requires_grad=True\nrequires_grad_() calls are SAFE"]
F --> H["forward() — inner ops may produce non-leaf tensors\nset_output_requires_grad=False\nrequires_grad_() calls are SKIPPED"]
G --> I["return output tensor(s)"]
H --> I
Reviews (2): Last reviewed commit: "Revert cudnn-frontend submodule bump" | Re-trigger Greptile |
| input, | ||
| self, | ||
| basic_op_kwargs, | ||
| is_grad_enabled, # set_output_requires_grad |
There was a problem hiding this comment.
Implicit coupling between
is_grad_enabled and set_output_requires_grad
is_grad_enabled is passed directly as set_output_requires_grad, hardcoding an equivalence between the two concepts. The guard in forward is really about "are the output tensors leaf nodes (safe to mutate requires_grad_) or not", which happens to correlate with grad being enabled today. If a future caller ever needs grad enabled but the outputs are already non-leaf (or vice versa), the coupling breaks silently. A dedicated flag computed from actual tensor leaf-ness, or at minimum a local variable with an explanatory name, would make the intent more resilient:
# Outputs produced by inner ops in the no-grad path may be non-leaf
# tensors; setting requires_grad_ on non-leaf tensors raises a RuntimeError.
set_output_requires_grad = is_grad_enabledThis is already implicit in the existing comment, so at a minimum a short inline note at the call site explaining why is_grad_enabled serves as the proxy would help future maintainers.
Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt> Made-with: Cursor
Description
Pass an explicit flag that controls whether the fuser forward pass sets requires_grad on outputs. This is required so that in no_grad mode we dont try to mutate this information on non-leaf nodes (which is not allowed by torch)
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: