Commit dfebee0
committed
Update on "[torchtitan][replicate] experimenting new replicate integration with torchtitan"
**Summary:** During this experiment to integrate the new replicate function into torchtitan, I used pytorch/pytorch#162021, which has not been landed. However, since this is more about making replicate more efficient rather than changing replicate's core code, pytorch/pytorch#160135, which has landed, should be fine. pytorch/pytorch#160133 is the last time replicate_with_fsdp.py and its replicate api were touched.
In order to enable the new replicate, which uses a 2D device mesh (since it is a specialized version of HSDP), I changed the parallelism code to include dp_shard dim = 1 only if dp_replicate > 1, and created device mesh that I pass down in apply_ddp.
The numeric tests for tp + replicate and pp + replicate can be seen below. In order to ensure that they worked, I also compared them with HSDP (n, 1) (replicate, shard).
<img width="950" height="485" alt="image" src="https://github.com/user-attachments/assets/a7bede55-54af-43f4-9fa0-4430f1992d73" />
https://fburl.com/mlhub/5k9v43w3
**Test Case**
1. CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh (set replicate to 8)
Expected output of this experiment should be something like:
[rank0]:[titan] 2025-09-15 17:38:26,676 - root - INFO - Starting job: Llama 3 debug training
[rank0]:[titan] 2025-09-15 17:38:29,094 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
**[rank0]:[titan] 2025-09-15 17:38:29,097 - root - INFO - Building 2-D device mesh with ['dp_replicate', 'dp_shard'], [8, 1]**
[rank0]:[titan] 2025-09-15 17:38:29,104 - root - INFO - [GC] Initial GC collection 0.00 seconds
[rank0]:NCCL version 2.27.5+cuda12.6
[rank0]:[titan] 2025-09-15 17:38:35,439 - root - INFO - Loading tokenizer from tokenizer.json
[rank0]:[titan] 2025-09-15 17:38:35,441 - root - INFO - Preparing c4_test dataset from tests/assets/c4_test
[rank0]:[titan] 2025-09-15 17:38:35,894 - root - INFO - Building llama3 debugmodel with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=256, n_layers=6, n_heads=16, n_kv_heads=None, vocab_size=2000, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-05, rope_theta=500000, max_seq_len=2048, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=0)
[rank0]:[titan] 2025-09-15 17:38:35,931 - root - INFO - CUDA capacity: NVIDIA H100 with 94.99GiB memory
[rank0]:[titan] 2025-09-15 17:38:35,950 - root - INFO - Model llama3 debugmodel size: 6,139,136 total parameters
[rank0]:[titan] 2025-09-15 17:38:35,951 - root - INFO - Applied selective activation checkpointing to the model
**[rank0]:[titan] 2025-09-15 17:38:35,972 - root - INFO - Applied DDP to the model**
[rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14
[rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - CUDA memory usage for model: 0.04GiB(0.04%)
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - WARNING - model.safetensors.index.json not found at hf_assets_path: ./tests/assets/tokenizer/model.safetensors.index.json. Defaulting to saving a single safetensors file if checkpoint is saved in HF format
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Mixed precision training is handled by AMP
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Trainer is initialized with local batch size 8, global batch size 64, gradient accumulation steps 1, sequence length 2048, total steps 10 (warmup 2)
[ghstack-poisoned]2 files changed
+5
-7
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
250 | 250 | | |
251 | 251 | | |
252 | 252 | | |
253 | | - | |
| 253 | + | |
254 | 254 | | |
255 | 255 | | |
256 | 256 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
641 | 641 | | |
642 | 642 | | |
643 | 643 | | |
644 | | - | |
645 | | - | |
646 | | - | |
647 | | - | |
648 | | - | |
649 | | - | |
| 644 | + | |
| 645 | + | |
| 646 | + | |
| 647 | + | |
650 | 648 | | |
651 | 649 | | |
652 | 650 | | |
| |||
0 commit comments