@@ -360,6 +360,12 @@ def _make_graphed_callables(
360360 + "for each callable must contain only Tensors. Other types are not allowed."
361361 )
362362
363+ if capture_time_hooks is not None and len (capture_time_hooks ) != len (callables ):
364+ raise ValueError (
365+ f"capture_time_hooks has { len (capture_time_hooks )} entries but there are "
366+ f"{ len (callables )} callables"
367+ )
368+
363369 # If a callable is an nn.Module, its graph's full input surface is the args the user explicitly
364370 # passes to forward (ie, its sample_args) AND the module's parameter attributes.
365371 # Note: These per_callable_* variables are not actually
@@ -449,7 +455,7 @@ def _make_graphed_callables(
449455 visited_te_modules = {}
450456 need_bwd_dw_graph = {}
451457
452- def _run_warmup_forward (func_idx , func ):
458+ def _run_warmup_forward (func_idx , func , callable_idx ):
453459 """Run forward for one callable during warmup; returns flattened outputs."""
454460 args = sample_args [func_idx ]
455461 kwargs = sample_kwargs [func_idx ]
@@ -479,42 +485,44 @@ def hook_fn(module, inputs, outputs, func_idx=func_idx): # pylint: disable=unus
479485
480486 if (
481487 capture_time_hooks is not None
482- and func_idx < len (capture_time_hooks )
483- and capture_time_hooks [func_idx ] is not None
484- and "forward_pre" in capture_time_hooks [func_idx ]
488+ and capture_time_hooks [callable_idx ] is not None
489+ and "pre_forward" in capture_time_hooks [callable_idx ]
485490 ):
486- for hook in capture_time_hooks [func_idx ]["forward_pre" ].values ():
487- hook (func , args , kwargs )
491+ for hook in capture_time_hooks [callable_idx ]["pre_forward" ].values ():
492+ result = hook (func , args , kwargs )
493+ if result is not None :
494+ args , kwargs = result
488495
489496 hooks = []
490497 for module in func .modules ():
491498 hooks .append (module .register_forward_hook (hook_fn ))
492- outputs , _ = _tree_flatten ( func (* args , ** kwargs ) )
499+ outputs = func (* args , ** kwargs )
493500 for hook in hooks :
494501 hook .remove ()
495502
496503 if (
497504 capture_time_hooks is not None
498- and func_idx < len (capture_time_hooks )
499- and capture_time_hooks [func_idx ] is not None
500- and "forward" in capture_time_hooks [func_idx ]
505+ and capture_time_hooks [callable_idx ] is not None
506+ and "forward" in capture_time_hooks [callable_idx ]
501507 ):
502- for hook in capture_time_hooks [func_idx ]["forward" ].values ():
503- hook (func , args , outputs )
508+ for hook in capture_time_hooks [callable_idx ]["forward" ].values ():
509+ result = hook (func , args , outputs )
510+ if result is not None :
511+ outputs = result
504512
513+ outputs , _ = _tree_flatten (outputs )
505514 return outputs
506515
507- def _run_warmup_backward (func_idx , func , outputs , warmup_iter ):
516+ def _run_warmup_backward (func_idx , func , outputs , warmup_iter , callable_idx ):
508517 """Run dgrad backward for one callable during warmup."""
509518 static_input_surface = per_callable_static_input_surfaces [func_idx ]
510519
511520 if (
512521 capture_time_hooks is not None
513- and func_idx < len (capture_time_hooks )
514- and capture_time_hooks [func_idx ] is not None
515- and "pre_backward" in capture_time_hooks [func_idx ]
522+ and capture_time_hooks [callable_idx ] is not None
523+ and "pre_backward" in capture_time_hooks [callable_idx ]
516524 ):
517- for hook in capture_time_hooks [func_idx ]["pre_backward" ].values ():
525+ for hook in capture_time_hooks [callable_idx ]["pre_backward" ].values ():
518526 hook (func )
519527
520528 inputs = tuple (i for i in static_input_surface if i .requires_grad )
@@ -528,11 +536,10 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter):
528536
529537 if (
530538 capture_time_hooks is not None
531- and func_idx < len (capture_time_hooks )
532- and capture_time_hooks [func_idx ] is not None
533- and "backward" in capture_time_hooks [func_idx ]
539+ and capture_time_hooks [callable_idx ] is not None
540+ and "backward" in capture_time_hooks [callable_idx ]
534541 ):
535- for hook in capture_time_hooks [func_idx ]["backward" ].values ():
542+ for hook in capture_time_hooks [callable_idx ]["backward" ].values ():
536543 hook (func )
537544
538545 # Filter module params that get None grad from grad_inputs and remove them
@@ -589,11 +596,11 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter):
589596 # All forwards in order, then all backwards in reverse order.
590597 warmup_outputs = []
591598 for func_idx , func in zip (warmup_func_idx , warmup_func ):
592- outputs = _run_warmup_forward (func_idx , func )
599+ outputs = _run_warmup_forward (func_idx , func , func_idx )
593600 warmup_outputs .append ((func_idx , func , outputs ))
594601 if is_training :
595602 for func_idx , func , outputs in reversed (warmup_outputs ):
596- _run_warmup_backward (func_idx , func , outputs , warmup_iter )
603+ _run_warmup_backward (func_idx , func , outputs , warmup_iter , func_idx )
597604 else :
598605 # Follow _order exactly, mirroring the capture phase.
599606 per_fwd_outputs = {} # per_callable_fwd_idx -> flattened outputs
@@ -604,25 +611,27 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter):
604611 # Forward pass for chunk c_id.
605612 m_chunk = c_id - 1
606613 for l_no in range (_num_layers_per_chunk [m_chunk ]):
614+ callable_idx = _prefix_num_layers [m_chunk ] + l_no
607615 per_callable_fwd_idx = (
608616 _prefix_num_layers [m_chunk ] * num_microbatches
609617 ) + (fwd_idx [m_chunk ] * _num_layers_per_chunk [m_chunk ] + l_no )
610- func = callables [_prefix_num_layers [ m_chunk ] + l_no ]
611- outputs = _run_warmup_forward (per_callable_fwd_idx , func )
618+ func = callables [callable_idx ]
619+ outputs = _run_warmup_forward (per_callable_fwd_idx , func , callable_idx )
612620 per_fwd_outputs [per_callable_fwd_idx ] = outputs
613621 fwd_idx [m_chunk ] += 1
614622 elif ceil (c_id ) == c_id :
615623 # Backward pass for chunk -c_id.
616624 if is_training :
617625 m_chunk = - c_id - 1
618626 for l_no in reversed (range (_num_layers_per_chunk [m_chunk ])):
627+ callable_idx = _prefix_num_layers [m_chunk ] + l_no
619628 per_callable_bwd_idx = (
620629 _prefix_num_layers [m_chunk ] * num_microbatches
621630 ) + (bwd_idx [m_chunk ] * _num_layers_per_chunk [m_chunk ] + l_no )
622- func = callables [_prefix_num_layers [ m_chunk ] + l_no ]
631+ func = callables [callable_idx ]
623632 outputs = per_fwd_outputs [per_callable_bwd_idx ]
624633 _run_warmup_backward (
625- per_callable_bwd_idx , func , outputs , warmup_iter
634+ per_callable_bwd_idx , func , outputs , warmup_iter , callable_idx
626635 )
627636 bwd_idx [m_chunk ] += 1
628637
@@ -653,42 +662,39 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter):
653662 # Capture forward graph for model chunk c_id, microbatch fwd_idx[c_id-1]
654663 m_chunk = c_id - 1
655664 for l_no in range (_num_layers_per_chunk [m_chunk ]):
656- func = callables [_prefix_num_layers [m_chunk ] + l_no ]
665+ callable_idx = _prefix_num_layers [m_chunk ] + l_no
666+ func = callables [callable_idx ]
657667 per_callable_fwd_idx = (_prefix_num_layers [m_chunk ] * num_microbatches ) + (
658668 fwd_idx [m_chunk ] * _num_layers_per_chunk [m_chunk ] + l_no
659669 )
660670 args = sample_args [per_callable_fwd_idx ]
661671 kwargs = sample_kwargs [per_callable_fwd_idx ]
662672 fwd_graph = fwd_graphs [per_callable_fwd_idx ]
663673
664- # Call forward_pre hooks before forward graph capture (outside capture context)
674+ # Call pre_forward hooks before forward graph capture (outside capture context)
665675 if (
666676 capture_time_hooks is not None
667- and per_callable_fwd_idx < len (capture_time_hooks )
668- and capture_time_hooks [per_callable_fwd_idx ] is not None
669- and "forward_pre" in capture_time_hooks [per_callable_fwd_idx ]
677+ and capture_time_hooks [callable_idx ] is not None
678+ and "pre_forward" in capture_time_hooks [callable_idx ]
670679 ):
671- for hook in capture_time_hooks [per_callable_fwd_idx ][
672- "forward_pre"
673- ].values ():
674- hook (
675- func , args , kwargs
676- ) # forward_pre hook signature: (module, args, kwargs)
680+ for hook in capture_time_hooks [callable_idx ]["pre_forward" ].values ():
681+ result = hook (func , args , kwargs )
682+ if result is not None :
683+ args , kwargs = result
677684
678685 with _graph_context_wrapper (fwd_graph , pool = mempool ):
679686 outputs = func (* args , ** kwargs )
680687
681688 # Call forward hooks after forward graph capture (outside capture context)
682689 if (
683690 capture_time_hooks is not None
684- and per_callable_fwd_idx < len (capture_time_hooks )
685- and capture_time_hooks [per_callable_fwd_idx ] is not None
686- and "forward" in capture_time_hooks [per_callable_fwd_idx ]
691+ and capture_time_hooks [callable_idx ] is not None
692+ and "forward" in capture_time_hooks [callable_idx ]
687693 ):
688- for hook in capture_time_hooks [per_callable_fwd_idx ]["forward" ].values ():
689- hook (
690- func , args , outputs
691- ) # forward hook signature: (module, inputs, output)
694+ for hook in capture_time_hooks [callable_idx ]["forward" ].values ():
695+ result = hook (func , args , outputs )
696+ if result is not None :
697+ outputs = result
692698
693699 flatten_outputs , spec = _tree_flatten (outputs )
694700 per_callable_static_outputs [per_callable_fwd_idx ] = tuple (flatten_outputs )
@@ -700,6 +706,7 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter):
700706 m_chunk = - ceil (c_id ) - 1
701707 previous_per_callable_bwd_idx = None
702708 for l_no in list (reversed (range (_num_layers_per_chunk [m_chunk ]))):
709+ callable_idx = _prefix_num_layers [m_chunk ] + l_no
703710 per_callable_bwd_idx = (_prefix_num_layers [m_chunk ] * num_microbatches ) + (
704711 bwd_idx [m_chunk ] * _num_layers_per_chunk [m_chunk ] + l_no
705712 )
@@ -789,17 +796,14 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter):
789796 # Call pre_backward hooks before backward graph capture (outside capture context)
790797 if (
791798 capture_time_hooks is not None
792- and per_callable_bwd_idx < len (capture_time_hooks )
793- and capture_time_hooks [per_callable_bwd_idx ] is not None
794- and "pre_backward" in capture_time_hooks [per_callable_bwd_idx ]
799+ and capture_time_hooks [callable_idx ] is not None
800+ and "pre_backward" in capture_time_hooks [callable_idx ]
795801 ):
796802 # Get the callable module for this backward index
797803 callable_module = graph_callables [per_callable_bwd_idx ]
798- for hook in capture_time_hooks [per_callable_bwd_idx ][
804+ for hook in capture_time_hooks [callable_idx ][
799805 "pre_backward"
800806 ].values ():
801- # During capture, call with the actual module (not None)
802- # FSDP hooks need to access module attributes
803807 hook (callable_module )
804808
805809 inputs = tuple (i for i in static_input_surface if i .requires_grad )
@@ -818,17 +822,12 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter):
818822 # Call backward hooks after backward graph capture (outside capture context)
819823 if (
820824 capture_time_hooks is not None
821- and per_callable_bwd_idx < len (capture_time_hooks )
822- and capture_time_hooks [per_callable_bwd_idx ] is not None
823- and "backward" in capture_time_hooks [per_callable_bwd_idx ]
825+ and capture_time_hooks [callable_idx ] is not None
826+ and "backward" in capture_time_hooks [callable_idx ]
824827 ):
825828 # Get the callable module for this backward index
826829 callable_module = graph_callables [per_callable_bwd_idx ]
827- for hook in capture_time_hooks [per_callable_bwd_idx ][
828- "backward"
829- ].values ():
830- # During capture, call with the actual module (not None)
831- # FSDP hooks need to access module attributes
830+ for hook in capture_time_hooks [callable_idx ]["backward" ].values ():
832831 hook (callable_module )
833832
834833 # Constructs a tuple suitable for returning from Graphed.backward:
@@ -888,15 +887,16 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter):
888887 for func_idx , (func , args , kwargs , fwd_graph ) in enumerate (
889888 zip (callables , sample_args , sample_kwargs , fwd_graphs )
890889 ):
891- # Call forward_pre hooks before forward graph capture (outside capture context)
890+ # Call pre_forward hooks before forward graph capture (outside capture context)
892891 if (
893892 capture_time_hooks is not None
894- and func_idx < len (capture_time_hooks )
895893 and capture_time_hooks [func_idx ] is not None
896- and "forward_pre " in capture_time_hooks [func_idx ]
894+ and "pre_forward " in capture_time_hooks [func_idx ]
897895 ):
898- for hook in capture_time_hooks [func_idx ]["forward_pre" ].values ():
899- hook (func , args , kwargs )
896+ for hook in capture_time_hooks [func_idx ]["pre_forward" ].values ():
897+ result = hook (func , args , kwargs )
898+ if result is not None :
899+ args , kwargs = result
900900
901901 with _graph_context_wrapper (fwd_graph , pool = mempool ):
902902 outputs = func (* args , ** kwargs )
@@ -905,12 +905,13 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter):
905905 # Call forward hooks after forward graph capture (outside capture context)
906906 if (
907907 capture_time_hooks is not None
908- and func_idx < len (capture_time_hooks )
909908 and capture_time_hooks [func_idx ] is not None
910909 and "forward" in capture_time_hooks [func_idx ]
911910 ):
912911 for hook in capture_time_hooks [func_idx ]["forward" ].values ():
913- hook (func , args , outputs )
912+ result = hook (func , args , outputs )
913+ if result is not None :
914+ outputs = result
914915
915916 flatten_outputs , spec = _tree_flatten (outputs )
916917 per_callable_static_outputs .append (tuple (flatten_outputs ))
@@ -935,7 +936,6 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter):
935936 # Call pre_backward hooks before backward graph capture (outside capture context)
936937 if (
937938 capture_time_hooks is not None
938- and bwd_idx < len (capture_time_hooks )
939939 and capture_time_hooks [bwd_idx ] is not None
940940 and "pre_backward" in capture_time_hooks [bwd_idx ]
941941 ):
@@ -957,7 +957,6 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter):
957957 # Call backward hooks after backward graph capture (outside capture context)
958958 if (
959959 capture_time_hooks is not None
960- and bwd_idx < len (capture_time_hooks )
961960 and capture_time_hooks [bwd_idx ] is not None
962961 and "backward" in capture_time_hooks [bwd_idx ]
963962 ):
@@ -1362,9 +1361,11 @@ def make_graphed_callables(
13621361 when `_order` is provided. All callables in `modules` are assumed to have
13631362 inputs and outputs with the same dtype and shape.
13641363 pre_warmup_hook: callable, default = None
1365- A hook function that will be called before the warmup iterations.
1364+ A hook function that will be called once before all warmup iterations
1365+ (not once per callable).
13661366 post_warmup_hook: callable, default = None
1367- A hook function that will be called after the warmup iterations.
1367+ A hook function that will be called once after all warmup iterations
1368+ (not once per callable).
13681369 capture_time_hooks: list of dict, optional
13691370 Per-callable hooks invoked at capture time (during warmup iterations and
13701371 graph capture), but intentionally executed **outside** the CUDA graph
@@ -1374,7 +1375,7 @@ def make_graphed_callables(
13741375 CPU-side state updates.
13751376 Each element corresponds to one callable and is a dict with any subset
13761377 of the following keys:
1377- - ``"forward_pre "``: dict of hooks called *before* the forward pass.
1378+ - ``"pre_forward "``: dict of hooks called *before* the forward pass.
13781379 Hook signature: ``hook(module, args, kwargs)``.
13791380 - ``"forward"``: dict of hooks called *after* the forward pass.
13801381 Hook signature: ``hook(module, args, output)``.
0 commit comments