@@ -43,10 +43,11 @@ def get_dual_pipe_v_flag(job_config, parallel_dims) -> bool:
4343 )
4444
4545 if dual_pipe_v and job_config .activation_checkpoint .mode != "none" :
46- raise NotImplementedError (
47- "Expert Parallel with DualPipeV and Activation Checkpointing "
48- "cannot be used together. Please disable one of them."
49- )
46+ pass
47+ # raise NotImplementedError(
48+ # "Expert Parallel with DualPipeV and Activation Checkpointing "
49+ # "cannot be used together. Please disable one of them."
50+ # )
5051
5152 return dual_pipe_v
5253
@@ -96,6 +97,11 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
9697 )
9798
9899
100+ # Thread-local flag to track if we're in the backward thread
101+ # Any SyncHook.forward call from the backward thread is checkpoint recomputation
102+ _backward_thread_flag = threading .local ()
103+
104+
99105class HookCoordinator :
100106 def __init__ (self ):
101107 # Barrier for 2 threads (forward and backward) to synchronize
@@ -139,6 +145,16 @@ def is_coordination_enabled(self):
139145 return self ._coordination_enabled
140146
141147
148+ def _is_in_backward_thread () -> bool :
149+ """Check if current thread is the backward thread."""
150+ return getattr (_backward_thread_flag , 'value' , False )
151+
152+
153+ def _set_backward_thread_flag (value : bool ):
154+ """Set the backward thread flag for current thread."""
155+ _backward_thread_flag .value = value
156+
157+
142158# Global coordinator
143159_hook_coordinator = HookCoordinator ()
144160
@@ -147,6 +163,16 @@ class SyncHook(torch.autograd.Function):
147163 @staticmethod
148164 def forward (ctx , x , hook_name = "" ):
149165 ctx .hook_name = hook_name
166+
167+ # Skip barrier if we're in the backward thread - this means we're being called
168+ # during checkpoint recomputation (the forward thread never sets this flag)
169+ if _is_in_backward_thread ():
170+ print ("skipping backward barrier" , flush = True )
171+ ctx .skip_backward_barrier = True
172+ return x
173+
174+ ctx .skip_backward_barrier = False
175+
150176 # handle edge case for transformer level boundary
151177 if _hook_coordinator ._coordination_enabled and hook_name == "D" :
152178 _hook_coordinator ._cycle_count += 1
@@ -161,6 +187,13 @@ def forward(ctx, x, hook_name=""):
161187 def backward (ctx , grad_output ):
162188 hook_name = ctx .hook_name
163189
190+ # Skip barrier if this backward corresponds to a checkpoint recompute forward
191+ # These are "extra" backward nodes created by checkpoint that don't have
192+ # corresponding partners in the other thread
193+ if ctx .skip_backward_barrier :
194+ print ("skipping backward barrier" , flush = True )
195+ return grad_output , None
196+
164197 # Edge case, skip initial barrier, all subsequent backward hooks will acquire
165198 if hook_name == "D" and _hook_coordinator ._cycle_count == 0 :
166199 return grad_output , None
@@ -262,32 +295,44 @@ def overlap_callback(action: _Action, ctx: _PipelineContext):
262295 main_stream = torch .accelerator .current_stream (device_module )
263296
264297 # Shared container for exception from backward thread
265- def run_backward ():
266- schedule ._assert_unsharded (backward_stage )
267- # Set the backward thread to use the same stream as forward
268- device_module .set_stream (main_stream )
269- with record_function (
270- f"backward_stage_{ backward_stage_index } _mb_{ backward_mb_index } "
271- ):
272- loss = schedule ._maybe_get_loss (backward_stage , backward_mb_index )
273- schedule .backward_counter [backward_stage_index ] += 1
274- last_backward = (
275- schedule .backward_counter [backward_stage_index ]
276- == schedule ._n_microbatches
277- )
278- backward_stage .backward_one_chunk (
279- backward_mb_index ,
280- loss = loss ,
281- full_backward = True ,
282- last_backward = last_backward ,
283- )
298+ backward_exception : list [BaseException ] = []
284299
285- if backward_is_prev_stage_on_this_rank :
286- stage_index_to_stage [backward_stage_index - 1 ].set_local_bwd_input (
287- backward_stage .get_local_bwd_output (backward_mb_index ),
300+ def run_backward ():
301+ # Mark this thread as the backward thread so SyncHook.forward
302+ # can detect checkpoint recomputation (forward called from backward thread)
303+ _set_backward_thread_flag (True )
304+ try :
305+ schedule ._assert_unsharded (backward_stage )
306+ # Set the backward thread to use the same stream as forward
307+ device_module .set_stream (main_stream )
308+ with record_function (
309+ f"backward_stage_{ backward_stage_index } _mb_{ backward_mb_index } "
310+ ):
311+ loss = schedule ._maybe_get_loss (backward_stage , backward_mb_index )
312+ schedule .backward_counter [backward_stage_index ] += 1
313+ last_backward = (
314+ schedule .backward_counter [backward_stage_index ]
315+ == schedule ._n_microbatches
316+ )
317+ backward_stage .backward_one_chunk (
288318 backward_mb_index ,
319+ loss = loss ,
320+ full_backward = True ,
321+ last_backward = last_backward ,
289322 )
290323
324+ if backward_is_prev_stage_on_this_rank :
325+ stage_index_to_stage [backward_stage_index - 1 ].set_local_bwd_input (
326+ backward_stage .get_local_bwd_output (backward_mb_index ),
327+ backward_mb_index ,
328+ )
329+ except BaseException as e :
330+ backward_exception .append (e )
331+ # Abort barrier to unblock forward thread if it's waiting
332+ _hook_coordinator .disable_coordination ()
333+ finally :
334+ _set_backward_thread_flag (False )
335+
291336 def run_forward ():
292337 schedule ._assert_unsharded (forward_stage )
293338 output = forward_stage .forward_one_chunk (
@@ -315,3 +360,7 @@ def run_forward():
315360 thread .join ()
316361
317362 _hook_coordinator .disable_coordination ()
363+
364+ # Re-raise exception from backward thread with full traceback
365+ if backward_exception :
366+ raise backward_exception [0 ]
0 commit comments