Skip to content

Commit 7536553

Browse files
committed
resolve comments
Signed-off-by: Robin Zhang <robinz@nvidia.com>
1 parent 7f095c5 commit 7536553

File tree

1 file changed

+72
-71
lines changed

1 file changed

+72
-71
lines changed

transformer_engine/pytorch/graph.py

Lines changed: 72 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,12 @@ def _make_graphed_callables(
360360
+ "for each callable must contain only Tensors. Other types are not allowed."
361361
)
362362

363+
if capture_time_hooks is not None and len(capture_time_hooks) != len(callables):
364+
raise ValueError(
365+
f"capture_time_hooks has {len(capture_time_hooks)} entries but there are "
366+
f"{len(callables)} callables"
367+
)
368+
363369
# If a callable is an nn.Module, its graph's full input surface is the args the user explicitly
364370
# passes to forward (ie, its sample_args) AND the module's parameter attributes.
365371
# Note: These per_callable_* variables are not actually
@@ -449,7 +455,7 @@ def _make_graphed_callables(
449455
visited_te_modules = {}
450456
need_bwd_dw_graph = {}
451457

452-
def _run_warmup_forward(func_idx, func):
458+
def _run_warmup_forward(func_idx, func, callable_idx):
453459
"""Run forward for one callable during warmup; returns flattened outputs."""
454460
args = sample_args[func_idx]
455461
kwargs = sample_kwargs[func_idx]
@@ -479,42 +485,44 @@ def hook_fn(module, inputs, outputs, func_idx=func_idx): # pylint: disable=unus
479485

480486
if (
481487
capture_time_hooks is not None
482-
and func_idx < len(capture_time_hooks)
483-
and capture_time_hooks[func_idx] is not None
484-
and "forward_pre" in capture_time_hooks[func_idx]
488+
and capture_time_hooks[callable_idx] is not None
489+
and "pre_forward" in capture_time_hooks[callable_idx]
485490
):
486-
for hook in capture_time_hooks[func_idx]["forward_pre"].values():
487-
hook(func, args, kwargs)
491+
for hook in capture_time_hooks[callable_idx]["pre_forward"].values():
492+
result = hook(func, args, kwargs)
493+
if result is not None:
494+
args, kwargs = result
488495

489496
hooks = []
490497
for module in func.modules():
491498
hooks.append(module.register_forward_hook(hook_fn))
492-
outputs, _ = _tree_flatten(func(*args, **kwargs))
499+
outputs = func(*args, **kwargs)
493500
for hook in hooks:
494501
hook.remove()
495502

496503
if (
497504
capture_time_hooks is not None
498-
and func_idx < len(capture_time_hooks)
499-
and capture_time_hooks[func_idx] is not None
500-
and "forward" in capture_time_hooks[func_idx]
505+
and capture_time_hooks[callable_idx] is not None
506+
and "forward" in capture_time_hooks[callable_idx]
501507
):
502-
for hook in capture_time_hooks[func_idx]["forward"].values():
503-
hook(func, args, outputs)
508+
for hook in capture_time_hooks[callable_idx]["forward"].values():
509+
result = hook(func, args, outputs)
510+
if result is not None:
511+
outputs = result
504512

513+
outputs, _ = _tree_flatten(outputs)
505514
return outputs
506515

507-
def _run_warmup_backward(func_idx, func, outputs, warmup_iter):
516+
def _run_warmup_backward(func_idx, func, outputs, warmup_iter, callable_idx):
508517
"""Run dgrad backward for one callable during warmup."""
509518
static_input_surface = per_callable_static_input_surfaces[func_idx]
510519

511520
if (
512521
capture_time_hooks is not None
513-
and func_idx < len(capture_time_hooks)
514-
and capture_time_hooks[func_idx] is not None
515-
and "pre_backward" in capture_time_hooks[func_idx]
522+
and capture_time_hooks[callable_idx] is not None
523+
and "pre_backward" in capture_time_hooks[callable_idx]
516524
):
517-
for hook in capture_time_hooks[func_idx]["pre_backward"].values():
525+
for hook in capture_time_hooks[callable_idx]["pre_backward"].values():
518526
hook(func)
519527

520528
inputs = tuple(i for i in static_input_surface if i.requires_grad)
@@ -528,11 +536,10 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter):
528536

529537
if (
530538
capture_time_hooks is not None
531-
and func_idx < len(capture_time_hooks)
532-
and capture_time_hooks[func_idx] is not None
533-
and "backward" in capture_time_hooks[func_idx]
539+
and capture_time_hooks[callable_idx] is not None
540+
and "backward" in capture_time_hooks[callable_idx]
534541
):
535-
for hook in capture_time_hooks[func_idx]["backward"].values():
542+
for hook in capture_time_hooks[callable_idx]["backward"].values():
536543
hook(func)
537544

538545
# Filter module params that get None grad from grad_inputs and remove them
@@ -589,11 +596,11 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter):
589596
# All forwards in order, then all backwards in reverse order.
590597
warmup_outputs = []
591598
for func_idx, func in zip(warmup_func_idx, warmup_func):
592-
outputs = _run_warmup_forward(func_idx, func)
599+
outputs = _run_warmup_forward(func_idx, func, func_idx)
593600
warmup_outputs.append((func_idx, func, outputs))
594601
if is_training:
595602
for func_idx, func, outputs in reversed(warmup_outputs):
596-
_run_warmup_backward(func_idx, func, outputs, warmup_iter)
603+
_run_warmup_backward(func_idx, func, outputs, warmup_iter, func_idx)
597604
else:
598605
# Follow _order exactly, mirroring the capture phase.
599606
per_fwd_outputs = {} # per_callable_fwd_idx -> flattened outputs
@@ -604,25 +611,27 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter):
604611
# Forward pass for chunk c_id.
605612
m_chunk = c_id - 1
606613
for l_no in range(_num_layers_per_chunk[m_chunk]):
614+
callable_idx = _prefix_num_layers[m_chunk] + l_no
607615
per_callable_fwd_idx = (
608616
_prefix_num_layers[m_chunk] * num_microbatches
609617
) + (fwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no)
610-
func = callables[_prefix_num_layers[m_chunk] + l_no]
611-
outputs = _run_warmup_forward(per_callable_fwd_idx, func)
618+
func = callables[callable_idx]
619+
outputs = _run_warmup_forward(per_callable_fwd_idx, func, callable_idx)
612620
per_fwd_outputs[per_callable_fwd_idx] = outputs
613621
fwd_idx[m_chunk] += 1
614622
elif ceil(c_id) == c_id:
615623
# Backward pass for chunk -c_id.
616624
if is_training:
617625
m_chunk = -c_id - 1
618626
for l_no in reversed(range(_num_layers_per_chunk[m_chunk])):
627+
callable_idx = _prefix_num_layers[m_chunk] + l_no
619628
per_callable_bwd_idx = (
620629
_prefix_num_layers[m_chunk] * num_microbatches
621630
) + (bwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no)
622-
func = callables[_prefix_num_layers[m_chunk] + l_no]
631+
func = callables[callable_idx]
623632
outputs = per_fwd_outputs[per_callable_bwd_idx]
624633
_run_warmup_backward(
625-
per_callable_bwd_idx, func, outputs, warmup_iter
634+
per_callable_bwd_idx, func, outputs, warmup_iter, callable_idx
626635
)
627636
bwd_idx[m_chunk] += 1
628637

@@ -653,42 +662,39 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter):
653662
# Capture forward graph for model chunk c_id, microbatch fwd_idx[c_id-1]
654663
m_chunk = c_id - 1
655664
for l_no in range(_num_layers_per_chunk[m_chunk]):
656-
func = callables[_prefix_num_layers[m_chunk] + l_no]
665+
callable_idx = _prefix_num_layers[m_chunk] + l_no
666+
func = callables[callable_idx]
657667
per_callable_fwd_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + (
658668
fwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no
659669
)
660670
args = sample_args[per_callable_fwd_idx]
661671
kwargs = sample_kwargs[per_callable_fwd_idx]
662672
fwd_graph = fwd_graphs[per_callable_fwd_idx]
663673

664-
# Call forward_pre hooks before forward graph capture (outside capture context)
674+
# Call pre_forward hooks before forward graph capture (outside capture context)
665675
if (
666676
capture_time_hooks is not None
667-
and per_callable_fwd_idx < len(capture_time_hooks)
668-
and capture_time_hooks[per_callable_fwd_idx] is not None
669-
and "forward_pre" in capture_time_hooks[per_callable_fwd_idx]
677+
and capture_time_hooks[callable_idx] is not None
678+
and "pre_forward" in capture_time_hooks[callable_idx]
670679
):
671-
for hook in capture_time_hooks[per_callable_fwd_idx][
672-
"forward_pre"
673-
].values():
674-
hook(
675-
func, args, kwargs
676-
) # forward_pre hook signature: (module, args, kwargs)
680+
for hook in capture_time_hooks[callable_idx]["pre_forward"].values():
681+
result = hook(func, args, kwargs)
682+
if result is not None:
683+
args, kwargs = result
677684

678685
with _graph_context_wrapper(fwd_graph, pool=mempool):
679686
outputs = func(*args, **kwargs)
680687

681688
# Call forward hooks after forward graph capture (outside capture context)
682689
if (
683690
capture_time_hooks is not None
684-
and per_callable_fwd_idx < len(capture_time_hooks)
685-
and capture_time_hooks[per_callable_fwd_idx] is not None
686-
and "forward" in capture_time_hooks[per_callable_fwd_idx]
691+
and capture_time_hooks[callable_idx] is not None
692+
and "forward" in capture_time_hooks[callable_idx]
687693
):
688-
for hook in capture_time_hooks[per_callable_fwd_idx]["forward"].values():
689-
hook(
690-
func, args, outputs
691-
) # forward hook signature: (module, inputs, output)
694+
for hook in capture_time_hooks[callable_idx]["forward"].values():
695+
result = hook(func, args, outputs)
696+
if result is not None:
697+
outputs = result
692698

693699
flatten_outputs, spec = _tree_flatten(outputs)
694700
per_callable_static_outputs[per_callable_fwd_idx] = tuple(flatten_outputs)
@@ -700,6 +706,7 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter):
700706
m_chunk = -ceil(c_id) - 1
701707
previous_per_callable_bwd_idx = None
702708
for l_no in list(reversed(range(_num_layers_per_chunk[m_chunk]))):
709+
callable_idx = _prefix_num_layers[m_chunk] + l_no
703710
per_callable_bwd_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + (
704711
bwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no
705712
)
@@ -789,17 +796,14 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter):
789796
# Call pre_backward hooks before backward graph capture (outside capture context)
790797
if (
791798
capture_time_hooks is not None
792-
and per_callable_bwd_idx < len(capture_time_hooks)
793-
and capture_time_hooks[per_callable_bwd_idx] is not None
794-
and "pre_backward" in capture_time_hooks[per_callable_bwd_idx]
799+
and capture_time_hooks[callable_idx] is not None
800+
and "pre_backward" in capture_time_hooks[callable_idx]
795801
):
796802
# Get the callable module for this backward index
797803
callable_module = graph_callables[per_callable_bwd_idx]
798-
for hook in capture_time_hooks[per_callable_bwd_idx][
804+
for hook in capture_time_hooks[callable_idx][
799805
"pre_backward"
800806
].values():
801-
# During capture, call with the actual module (not None)
802-
# FSDP hooks need to access module attributes
803807
hook(callable_module)
804808

805809
inputs = tuple(i for i in static_input_surface if i.requires_grad)
@@ -818,17 +822,12 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter):
818822
# Call backward hooks after backward graph capture (outside capture context)
819823
if (
820824
capture_time_hooks is not None
821-
and per_callable_bwd_idx < len(capture_time_hooks)
822-
and capture_time_hooks[per_callable_bwd_idx] is not None
823-
and "backward" in capture_time_hooks[per_callable_bwd_idx]
825+
and capture_time_hooks[callable_idx] is not None
826+
and "backward" in capture_time_hooks[callable_idx]
824827
):
825828
# Get the callable module for this backward index
826829
callable_module = graph_callables[per_callable_bwd_idx]
827-
for hook in capture_time_hooks[per_callable_bwd_idx][
828-
"backward"
829-
].values():
830-
# During capture, call with the actual module (not None)
831-
# FSDP hooks need to access module attributes
830+
for hook in capture_time_hooks[callable_idx]["backward"].values():
832831
hook(callable_module)
833832

834833
# Constructs a tuple suitable for returning from Graphed.backward:
@@ -888,15 +887,16 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter):
888887
for func_idx, (func, args, kwargs, fwd_graph) in enumerate(
889888
zip(callables, sample_args, sample_kwargs, fwd_graphs)
890889
):
891-
# Call forward_pre hooks before forward graph capture (outside capture context)
890+
# Call pre_forward hooks before forward graph capture (outside capture context)
892891
if (
893892
capture_time_hooks is not None
894-
and func_idx < len(capture_time_hooks)
895893
and capture_time_hooks[func_idx] is not None
896-
and "forward_pre" in capture_time_hooks[func_idx]
894+
and "pre_forward" in capture_time_hooks[func_idx]
897895
):
898-
for hook in capture_time_hooks[func_idx]["forward_pre"].values():
899-
hook(func, args, kwargs)
896+
for hook in capture_time_hooks[func_idx]["pre_forward"].values():
897+
result = hook(func, args, kwargs)
898+
if result is not None:
899+
args, kwargs = result
900900

901901
with _graph_context_wrapper(fwd_graph, pool=mempool):
902902
outputs = func(*args, **kwargs)
@@ -905,12 +905,13 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter):
905905
# Call forward hooks after forward graph capture (outside capture context)
906906
if (
907907
capture_time_hooks is not None
908-
and func_idx < len(capture_time_hooks)
909908
and capture_time_hooks[func_idx] is not None
910909
and "forward" in capture_time_hooks[func_idx]
911910
):
912911
for hook in capture_time_hooks[func_idx]["forward"].values():
913-
hook(func, args, outputs)
912+
result = hook(func, args, outputs)
913+
if result is not None:
914+
outputs = result
914915

915916
flatten_outputs, spec = _tree_flatten(outputs)
916917
per_callable_static_outputs.append(tuple(flatten_outputs))
@@ -935,7 +936,6 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter):
935936
# Call pre_backward hooks before backward graph capture (outside capture context)
936937
if (
937938
capture_time_hooks is not None
938-
and bwd_idx < len(capture_time_hooks)
939939
and capture_time_hooks[bwd_idx] is not None
940940
and "pre_backward" in capture_time_hooks[bwd_idx]
941941
):
@@ -957,7 +957,6 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter):
957957
# Call backward hooks after backward graph capture (outside capture context)
958958
if (
959959
capture_time_hooks is not None
960-
and bwd_idx < len(capture_time_hooks)
961960
and capture_time_hooks[bwd_idx] is not None
962961
and "backward" in capture_time_hooks[bwd_idx]
963962
):
@@ -1362,9 +1361,11 @@ def make_graphed_callables(
13621361
when `_order` is provided. All callables in `modules` are assumed to have
13631362
inputs and outputs with the same dtype and shape.
13641363
pre_warmup_hook: callable, default = None
1365-
A hook function that will be called before the warmup iterations.
1364+
A hook function that will be called once before all warmup iterations
1365+
(not once per callable).
13661366
post_warmup_hook: callable, default = None
1367-
A hook function that will be called after the warmup iterations.
1367+
A hook function that will be called once after all warmup iterations
1368+
(not once per callable).
13681369
capture_time_hooks: list of dict, optional
13691370
Per-callable hooks invoked at capture time (during warmup iterations and
13701371
graph capture), but intentionally executed **outside** the CUDA graph
@@ -1374,7 +1375,7 @@ def make_graphed_callables(
13741375
CPU-side state updates.
13751376
Each element corresponds to one callable and is a dict with any subset
13761377
of the following keys:
1377-
- ``"forward_pre"``: dict of hooks called *before* the forward pass.
1378+
- ``"pre_forward"``: dict of hooks called *before* the forward pass.
13781379
Hook signature: ``hook(module, args, kwargs)``.
13791380
- ``"forward"``: dict of hooks called *after* the forward pass.
13801381
Hook signature: ``hook(module, args, output)``.

0 commit comments

Comments
 (0)