Skip to content
Open
30 changes: 20 additions & 10 deletions transformer_engine/pytorch/ops/fuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def forward(
input_: torch.Tensor,
fuser: OperationFuser,
basic_op_kwargs: list[dict[str, Any]],
set_output_requires_grad: bool,
*params_and_extra_inputs: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""Forward pass
Expand All @@ -79,6 +80,8 @@ def forward(
Container for the pipeline of operations to run
basic_op_kwargs: list of dict
Keyword arguments to BasicOperation
set_output_requires_grad: bool
Whether to set ``requires_grad`` flags on returned tensors
*params_and_extra_inputs: torch.Tensor
Other tensor inputs to include in autograd graph. Consists
of parameter tensors, followed by extra operation inputs.
Expand Down Expand Up @@ -138,7 +141,8 @@ def forward(
)
for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs):
for y in ys:
y.requires_grad_(idx >= fuser.first_op_requiring_backward)
if set_output_requires_grad:
y.requires_grad_(idx >= fuser.first_op_requiring_backward)
extra_outputs[idx] = ys

# Flatten list of extra outputs
Expand Down Expand Up @@ -190,7 +194,8 @@ def forward(
for tensor in [x] + extra_outputs_flat:
tensor._do_not_clear = True

x.requires_grad_(fuser.first_op_requiring_backward < fuser._num_basic_ops)
if set_output_requires_grad:
x.requires_grad_(fuser.first_op_requiring_backward < fuser._num_basic_ops)

if extra_outputs_flat:
return x, *extra_outputs_flat
Expand Down Expand Up @@ -293,6 +298,7 @@ def backward(
dx, # input_
None, # fuser
None, # basic_op_kwargs
None, # set_output_requires_grad
*grad_params_flat,
*grad_extra_inputs_flat,
)
Expand Down Expand Up @@ -501,20 +507,24 @@ def __call__(
op.pre_fuser_forward(requires_grad=idx >= self.first_op_requiring_backward)

# Fuser forward pass
if is_grad_enabled:
forward_func = _OperationFuserAutogradFunction.apply
args = []
else:
forward_func = _OperationFuserAutogradFunction.forward
args = [None]
args += (
# When is_grad_enabled is False, we call forward directly.
# This does not register a PyTorch autograd node, so
# no fuser backward will run. We pass set_output_requires_grad=False
# to avoid setting requires_grad on outputs in
# this path since they may be non-leaf tensors from the inner ops.
args = (
input,
self,
basic_op_kwargs,
is_grad_enabled, # set_output_requires_grad
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Implicit coupling between is_grad_enabled and set_output_requires_grad

is_grad_enabled is passed directly as set_output_requires_grad, hardcoding an equivalence between the two concepts. The guard in forward is really about "are the output tensors leaf nodes (safe to mutate requires_grad_) or not", which happens to correlate with grad being enabled today. If a future caller ever needs grad enabled but the outputs are already non-leaf (or vice versa), the coupling breaks silently. A dedicated flag computed from actual tensor leaf-ness, or at minimum a local variable with an explanatory name, would make the intent more resilient:

# Outputs produced by inner ops in the no-grad path may be non-leaf
# tensors; setting requires_grad_ on non-leaf tensors raises a RuntimeError.
set_output_requires_grad = is_grad_enabled

This is already implicit in the existing comment, so at a minimum a short inline note at the call site explaining why is_grad_enabled serves as the proxy would help future maintainers.

*self._flat_basic_op_params,
*extra_inputs,
)
return forward_func(*args)

if not is_grad_enabled:
return _OperationFuserAutogradFunction.forward(None, *args)

return _OperationFuserAutogradFunction.apply(*args)


def register_forward_fusion(
Expand Down
Loading