Skip to content

torch.compile fails with DeepSeekV3 + SimpleFSDP #2312

@axeber01

Description

@axeber01

Bug description

Description

I am trying to run DeepSeekV3 with SimpleFSDP and torch.compile enabled, but the compilation fails. The error seems to be related to dynamic tensor shapes (the number of tokens per expert) in the MoE model.

Error log

    File "/home/axel/ctml/torchtitan/torchtitan/train.py", line 696, in train      
      self.train_step(data_iterator)                                                                                                                                                 
    File "/home/axel/ctml/torchtitan/torchtitan/train.py", line 588, in train_step        
      loss = self.forward_backward_step(                                                  
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                             
    File "/home/axel/ctml/torchtitan/torchtitan/train.py", line 536, in forward_backward_step    
      pred = model_parts[0](inputs, **extra_inputs, **extra_kwargs)                                                                                                                  
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                  
    File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 469, in __call__
      return super().__call__(*args, **kwargs)                                                                                                                                       
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                     
    File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)                                                                                                                                        
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                        
    File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
      return forward_call(*args, **kwargs)                                                                                                                                           
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                           
    File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1006, in compile_wrapper
      raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1                                                                                                          
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                       
    File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 2575, in _call_user_compiler                                         
      raise BackendCompilerFailed(                                                                                                                                                   
    File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 2550, in _call_user_compiler
      compiled_fn = compiler_fn(gm, example_inputs)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_dynamo/repro/after_dynamo.py", line 156, in __call__
      compiled_gm = compiler_fn(gm, example_inputs)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/__init__.py", line 2530, in __call__
      return self.compiler_fn(model_, inputs_, **self.kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/axel/ctml/torchtitan/torchtitan/experiments/simple_fsdp/backend.py", line 153, in simple_fsdp_custom_pass
      return backend(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_dynamo/backends/inductor.py", line 31, in inductor
      return compile_fx(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 2564, in compile_fx
      return _maybe_wrap_and_compile_fx_main( 
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 
    File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 2641, in _maybe_wrap_and_compile_fx_main
      return _compile_fx_main(
             ^^^^^^^^^^^^^^^^^
    File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 2836, in _compile_fx_main
      return aot_autograd(
             ^^^^^^^^^^^^^
    File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_dynamo/backends/common.py", line 124, in __call__
      cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1149, in aot_module_simplified
      aot_state = create_aot_state(
                  ^^^^^^^^^^^^^^^^^
    File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 583, in create_aot_state
      fw_metadata = run_functionalized_fw_and_collect_metadata(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 221, in inner
      flat_f_outs = f(*flat_f_args)
                    ^^^^^^^^^^^^^^^
    File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1414, in functional_call
      out = PropagateUnbackedSymInts(mod).run(*args)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/fx/interpreter.py", line 200, in run
      self.env[node] = self.run_node(node)
                       ^^^^^^^^^^^^^^^^^^^
    File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7957, in run_node
      rebind_unbacked(fake_mode.shape_env, n, result)
    File "/home/axel/condaenv/envs/torchtitan/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 673, in rebind_unbacked
      assert not raw_u1.free_symbols, (
             ^^^^^^^^^^^^^^^^^^^^^^^
  torch._dynamo.exc.BackendCompilerFailed: backend='simple_fsdp_custom_pass' raised:
  AssertionError: should have been constant, but got u10 + u11 + u12 + u13 + u14 + u15 + u16 + u9
   
  While executing %getitem_83 : [num_users=2] = call_function[target=operator.getitem](args = (%routed_input_1, slice(None, sym_sum_11, None)), kwargs = {})
  Original traceback:
    File "/home/axel/ctml/torchtitan/torchtitan/models/deepseek_v3/model/model.py", line 502, in forward
      h = layer(h, self.freqs_cis, attention_masks, positions)
    File "/home/axel/ctml/torchtitan/torchtitan/models/deepseek_v3/model/model.py", line 391, in forward
      x = x + self.moe(self.ffn_norm(x))
    File "/home/axel/ctml/torchtitan/torchtitan/models/moe/moe.py", line 518, in forward
      routed_output = self.experts(routed_input, num_tokens_per_expert)
    File "/home/axel/ctml/torchtitan/torchtitan/models/moe/moe.py", line 177, in forward
      return _run_experts_for_loop(w1, w2, w3, x, num_tokens_per_expert)
    File "/home/axel/ctml/torchtitan/torchtitan/models/moe/moe.py", line 91, in _run_experts_for_loop
      x[: sum(num_tokens_per_expert_list)],
   
  Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)
   
  Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

Versions

Environment

GPU: Nvidia TITAN V
Python 3.12.12
torchtitan version: 02661e8
torch version: 2.11.0.dev20260202+cu126

Reproduce

TORCH_COMPILE_DEBUG=1 TORCH_LOGS="+dynamo,+inductor,graph_code" \
  TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCHDYNAMO_DISABLE_CACHE=1 TORCHINDUCTOR_CACHE_DISABLE=1 \
  torchrun --nproc_per_node=1 -m torchtitan.train --model.name simple_fsdp.deepseek_v3 \
  --job.config_file=torchtitan/models/deepseek_v3/train_configs/debug_model.toml \
  --job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config \
  --compile.enable \
  --parallelism.data_parallel_shard_degree 1 \
  --parallelism.tensor_parallel_degree 1 \
  --parallelism.expert_parallel_degree=1 \
  --training.mixed_precision_param=float32 \
  --training.local_batch_size 1 \
  --training.seq_len 256 \
  --activation_checkpoint.mode "none" 

Metadata

Metadata

Assignees

Type

Projects

Status

Todo

Relationships

None yet

Development

No branches or pull requests

Issue actions