-
Notifications
You must be signed in to change notification settings - Fork 718
Open
Description
Bug description
Description
I am trying to run DeepSeekV3 with SimpleFSDP and torch.compile enabled, but the compilation fails. The error seems to be related to dynamic tensor shapes (the number of tokens per expert) in the MoE model.
Error log
File "/home/axel/ctml/torchtitan/torchtitan/train.py", line 696, in train
self.train_step(data_iterator)
File "/home/axel/ctml/torchtitan/torchtitan/train.py", line 588, in train_step
loss = self.forward_backward_step(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/axel/ctml/torchtitan/torchtitan/train.py", line 536, in forward_backward_step
pred = model_parts[0](inputs, **extra_inputs, **extra_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 469, in __call__
return super().__call__(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1006, in compile_wrapper
raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 2575, in _call_user_compiler
raise BackendCompilerFailed(
File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 2550, in _call_user_compiler
compiled_fn = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_dynamo/repro/after_dynamo.py", line 156, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/__init__.py", line 2530, in __call__
return self.compiler_fn(model_, inputs_, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/axel/ctml/torchtitan/torchtitan/experiments/simple_fsdp/backend.py", line 153, in simple_fsdp_custom_pass
return backend(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_dynamo/backends/inductor.py", line 31, in inductor
return compile_fx(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 2564, in compile_fx
return _maybe_wrap_and_compile_fx_main(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 2641, in _maybe_wrap_and_compile_fx_main
return _compile_fx_main(
^^^^^^^^^^^^^^^^^
File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 2836, in _compile_fx_main
return aot_autograd(
^^^^^^^^^^^^^
File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_dynamo/backends/common.py", line 124, in __call__
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1149, in aot_module_simplified
aot_state = create_aot_state(
^^^^^^^^^^^^^^^^^
File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 583, in create_aot_state
fw_metadata = run_functionalized_fw_and_collect_metadata(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 221, in inner
flat_f_outs = f(*flat_f_args)
^^^^^^^^^^^^^^^
File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1414, in functional_call
out = PropagateUnbackedSymInts(mod).run(*args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/fx/interpreter.py", line 200, in run
self.env[node] = self.run_node(node)
^^^^^^^^^^^^^^^^^^^
File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7957, in run_node
rebind_unbacked(fake_mode.shape_env, n, result)
File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 673, in rebind_unbacked
assert not raw_u1.free_symbols, (
^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='simple_fsdp_custom_pass' raised:
AssertionError: should have been constant, but got u10 + u11 + u12 + u13 + u14 + u15 + u16 + u9
While executing %getitem_83 : [num_users=2] = call_function[target=operator.getitem](args = (%routed_input_1, slice(None, sym_sum_11, None)), kwargs = {})
Original traceback:
File "/home/axel/ctml/torchtitan/torchtitan/models/deepseek_v3/model/model.py", line 502, in forward
h = layer(h, self.freqs_cis, attention_masks, positions)
File "/home/axel/ctml/torchtitan/torchtitan/models/deepseek_v3/model/model.py", line 391, in forward
x = x + self.moe(self.ffn_norm(x))
File "/home/axel/ctml/torchtitan/torchtitan/models/moe/moe.py", line 518, in forward
routed_output = self.experts(routed_input, num_tokens_per_expert)
File "/home/axel/ctml/torchtitan/torchtitan/models/moe/moe.py", line 177, in forward
return _run_experts_for_loop(w1, w2, w3, x, num_tokens_per_expert)
File "/home/axel/ctml/torchtitan/torchtitan/models/moe/moe.py", line 91, in _run_experts_for_loop
x[: sum(num_tokens_per_expert_list)],
Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
Versions
Environment
GPU: Nvidia TITAN V
Python 3.12.12
torchtitan version: 02661e8
torch version: 2.11.0.dev20260202+cu126
Reproduce
TORCH_COMPILE_DEBUG=1 TORCH_LOGS="+dynamo,+inductor,graph_code" \
TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCHDYNAMO_DISABLE_CACHE=1 TORCHINDUCTOR_CACHE_DISABLE=1 \
torchrun --nproc_per_node=1 -m torchtitan.train --model.name simple_fsdp.deepseek_v3 \
--job.config_file=torchtitan/models/deepseek_v3/train_configs/debug_model.toml \
--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config \
--compile.enable \
--parallelism.data_parallel_shard_degree 1 \
--parallelism.tensor_parallel_degree 1 \
--parallelism.expert_parallel_degree=1 \
--training.mixed_precision_param=float32 \
--training.local_batch_size 1 \
--training.seq_len 256 \
--activation_checkpoint.mode "none"
Reactions are currently unavailable
Metadata
Metadata
Assignees
Type
Projects
Status
Todo