Skip to content

Commit d4a7e9b

Browse files
committed
rebase
1 parent 0e2b303 commit d4a7e9b

File tree

4 files changed

+78
-39
lines changed

4 files changed

+78
-39
lines changed

torchtitan/distributed/dual_pipe_v.py

Lines changed: 75 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,11 @@ def get_dual_pipe_v_flag(job_config, parallel_dims) -> bool:
4343
)
4444

4545
if dual_pipe_v and job_config.activation_checkpoint.mode != "none":
46-
raise NotImplementedError(
47-
"Expert Parallel with DualPipeV and Activation Checkpointing "
48-
"cannot be used together. Please disable one of them."
49-
)
46+
pass
47+
# raise NotImplementedError(
48+
# "Expert Parallel with DualPipeV and Activation Checkpointing "
49+
# "cannot be used together. Please disable one of them."
50+
# )
5051

5152
return dual_pipe_v
5253

@@ -96,6 +97,11 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
9697
)
9798

9899

100+
# Thread-local flag to track if we're in the backward thread
101+
# Any SyncHook.forward call from the backward thread is checkpoint recomputation
102+
_backward_thread_flag = threading.local()
103+
104+
99105
class HookCoordinator:
100106
def __init__(self):
101107
# Barrier for 2 threads (forward and backward) to synchronize
@@ -139,6 +145,16 @@ def is_coordination_enabled(self):
139145
return self._coordination_enabled
140146

141147

148+
def _is_in_backward_thread() -> bool:
149+
"""Check if current thread is the backward thread."""
150+
return getattr(_backward_thread_flag, 'value', False)
151+
152+
153+
def _set_backward_thread_flag(value: bool):
154+
"""Set the backward thread flag for current thread."""
155+
_backward_thread_flag.value = value
156+
157+
142158
# Global coordinator
143159
_hook_coordinator = HookCoordinator()
144160

@@ -147,6 +163,16 @@ class SyncHook(torch.autograd.Function):
147163
@staticmethod
148164
def forward(ctx, x, hook_name=""):
149165
ctx.hook_name = hook_name
166+
167+
# Skip barrier if we're in the backward thread - this means we're being called
168+
# during checkpoint recomputation (the forward thread never sets this flag)
169+
if _is_in_backward_thread():
170+
print("skipping backward barrier", flush=True)
171+
ctx.skip_backward_barrier = True
172+
return x
173+
174+
ctx.skip_backward_barrier = False
175+
150176
# handle edge case for transformer level boundary
151177
if _hook_coordinator._coordination_enabled and hook_name == "D":
152178
_hook_coordinator._cycle_count += 1
@@ -161,6 +187,13 @@ def forward(ctx, x, hook_name=""):
161187
def backward(ctx, grad_output):
162188
hook_name = ctx.hook_name
163189

190+
# Skip barrier if this backward corresponds to a checkpoint recompute forward
191+
# These are "extra" backward nodes created by checkpoint that don't have
192+
# corresponding partners in the other thread
193+
if ctx.skip_backward_barrier:
194+
print("skipping backward barrier", flush=True)
195+
return grad_output, None
196+
164197
# Edge case, skip initial barrier, all subsequent backward hooks will acquire
165198
if hook_name == "D" and _hook_coordinator._cycle_count == 0:
166199
return grad_output, None
@@ -262,32 +295,44 @@ def overlap_callback(action: _Action, ctx: _PipelineContext):
262295
main_stream = torch.accelerator.current_stream(device_module)
263296

264297
# Shared container for exception from backward thread
265-
def run_backward():
266-
schedule._assert_unsharded(backward_stage)
267-
# Set the backward thread to use the same stream as forward
268-
device_module.set_stream(main_stream)
269-
with record_function(
270-
f"backward_stage_{backward_stage_index}_mb_{backward_mb_index}"
271-
):
272-
loss = schedule._maybe_get_loss(backward_stage, backward_mb_index)
273-
schedule.backward_counter[backward_stage_index] += 1
274-
last_backward = (
275-
schedule.backward_counter[backward_stage_index]
276-
== schedule._n_microbatches
277-
)
278-
backward_stage.backward_one_chunk(
279-
backward_mb_index,
280-
loss=loss,
281-
full_backward=True,
282-
last_backward=last_backward,
283-
)
298+
backward_exception: list[BaseException] = []
284299

285-
if backward_is_prev_stage_on_this_rank:
286-
stage_index_to_stage[backward_stage_index - 1].set_local_bwd_input(
287-
backward_stage.get_local_bwd_output(backward_mb_index),
300+
def run_backward():
301+
# Mark this thread as the backward thread so SyncHook.forward
302+
# can detect checkpoint recomputation (forward called from backward thread)
303+
_set_backward_thread_flag(True)
304+
try:
305+
schedule._assert_unsharded(backward_stage)
306+
# Set the backward thread to use the same stream as forward
307+
device_module.set_stream(main_stream)
308+
with record_function(
309+
f"backward_stage_{backward_stage_index}_mb_{backward_mb_index}"
310+
):
311+
loss = schedule._maybe_get_loss(backward_stage, backward_mb_index)
312+
schedule.backward_counter[backward_stage_index] += 1
313+
last_backward = (
314+
schedule.backward_counter[backward_stage_index]
315+
== schedule._n_microbatches
316+
)
317+
backward_stage.backward_one_chunk(
288318
backward_mb_index,
319+
loss=loss,
320+
full_backward=True,
321+
last_backward=last_backward,
289322
)
290323

324+
if backward_is_prev_stage_on_this_rank:
325+
stage_index_to_stage[backward_stage_index - 1].set_local_bwd_input(
326+
backward_stage.get_local_bwd_output(backward_mb_index),
327+
backward_mb_index,
328+
)
329+
except BaseException as e:
330+
backward_exception.append(e)
331+
# Abort barrier to unblock forward thread if it's waiting
332+
_hook_coordinator.disable_coordination()
333+
finally:
334+
_set_backward_thread_flag(False)
335+
291336
def run_forward():
292337
schedule._assert_unsharded(forward_stage)
293338
output = forward_stage.forward_one_chunk(
@@ -315,3 +360,7 @@ def run_forward():
315360
thread.join()
316361

317362
_hook_coordinator.disable_coordination()
363+
364+
# Re-raise exception from backward thread with full traceback
365+
if backward_exception:
366+
raise backward_exception[0]

torchtitan/models/deepseek_v3/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
dim=256,
3333
inter_dim=1024,
3434
moe_inter_dim=256,
35-
n_layers=6,
35+
n_layers=24,
3636
n_dense_layers=1,
3737
n_heads=16,
3838
moe_args=MoEArgs(

torchtitan/models/deepseek_v3/train_configs/debug_model.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ description = "DeepSeek-V3 debug training"
44
print_config = false
55

66
[profiling]
7-
enable_profiling = true
7+
enable_profiling = false
88
save_traces_folder = "profile_trace"
99
profile_freq = 1
1010
profiler_warmup = 0

torchtitan/models/llama4/infra/parallelize.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,10 @@
2424
from torchtitan.config.job_config import Compile as CompileConfig
2525
from torchtitan.distributed import NoParallel, ParallelDims
2626
from torchtitan.distributed.activation_checkpoint import apply_ac
27-
<<<<<<< HEAD
2827
from torchtitan.distributed.dual_pipe_v import (
2928
DualPipeExpertParallel,
3029
get_dual_pipe_v_flag,
3130
)
32-
=======
33-
from torchtitan.distributed.dual_pipe_v import DualPipeExpertParallel
34-
>>>>>>> f3e551fb (Enable PP and EP overlap for MoE)
35-
3631
from torchtitan.distributed.expert_parallel import (
3732
BaseExpertParallel,
3833
ExpertParallel,
@@ -123,12 +118,7 @@ def parallelize_llama(
123118
else None
124119
),
125120
etp_enabled=parallel_dims.etp_enabled,
126-
<<<<<<< HEAD
127121
dual_pipe_v=dual_pipe_v,
128-
=======
129-
dual_pipe_v=job_config.parallelism.pipeline_parallel_expert_parallel_overlap
130-
and job_config.parallelism.pipeline_parallel_schedule == "dualpipev",
131-
>>>>>>> f3e551fb (Enable PP and EP overlap for MoE)
132122
)
133123

134124
model_compile_enabled = (
@@ -513,7 +503,7 @@ def apply_moe_ep_tp(
513503
elif tp_mesh is None or not etp_enabled:
514504
experts_mesh = ep_mesh
515505
# input / output sharding on the batch / tokens dim
516-
experts_plan = DualPipeExpertParallel() if dual_pipe_v else ExpertParallel()
506+
experts_plan = ExpertParallel()
517507
else:
518508
experts_mesh = ep_tp_mesh
519509
experts_plan = ExpertTensorParallel()

0 commit comments

Comments
 (0)