Skip to content

Commit dfebee0

Browse files
committed
Update on "[torchtitan][replicate] experimenting new replicate integration with torchtitan"
**Summary:** During this experiment to integrate the new replicate function into torchtitan, I used pytorch/pytorch#162021, which has not been landed. However, since this is more about making replicate more efficient rather than changing replicate's core code, pytorch/pytorch#160135, which has landed, should be fine. pytorch/pytorch#160133 is the last time replicate_with_fsdp.py and its replicate api were touched. In order to enable the new replicate, which uses a 2D device mesh (since it is a specialized version of HSDP), I changed the parallelism code to include dp_shard dim = 1 only if dp_replicate > 1, and created device mesh that I pass down in apply_ddp. The numeric tests for tp + replicate and pp + replicate can be seen below. In order to ensure that they worked, I also compared them with HSDP (n, 1) (replicate, shard). <img width="950" height="485" alt="image" src="https://github.com/user-attachments/assets/a7bede55-54af-43f4-9fa0-4430f1992d73" /> https://fburl.com/mlhub/5k9v43w3 **Test Case** 1. CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh (set replicate to 8) Expected output of this experiment should be something like: [rank0]:[titan] 2025-09-15 17:38:26,676 - root - INFO - Starting job: Llama 3 debug training [rank0]:[titan] 2025-09-15 17:38:29,094 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config **[rank0]:[titan] 2025-09-15 17:38:29,097 - root - INFO - Building 2-D device mesh with ['dp_replicate', 'dp_shard'], [8, 1]** [rank0]:[titan] 2025-09-15 17:38:29,104 - root - INFO - [GC] Initial GC collection 0.00 seconds [rank0]:NCCL version 2.27.5+cuda12.6 [rank0]:[titan] 2025-09-15 17:38:35,439 - root - INFO - Loading tokenizer from tokenizer.json [rank0]:[titan] 2025-09-15 17:38:35,441 - root - INFO - Preparing c4_test dataset from tests/assets/c4_test [rank0]:[titan] 2025-09-15 17:38:35,894 - root - INFO - Building llama3 debugmodel with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=256, n_layers=6, n_heads=16, n_kv_heads=None, vocab_size=2000, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-05, rope_theta=500000, max_seq_len=2048, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=0) [rank0]:[titan] 2025-09-15 17:38:35,931 - root - INFO - CUDA capacity: NVIDIA H100 with 94.99GiB memory [rank0]:[titan] 2025-09-15 17:38:35,950 - root - INFO - Model llama3 debugmodel size: 6,139,136 total parameters [rank0]:[titan] 2025-09-15 17:38:35,951 - root - INFO - Applied selective activation checkpointing to the model **[rank0]:[titan] 2025-09-15 17:38:35,972 - root - INFO - Applied DDP to the model** [rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14 [rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - CUDA memory usage for model: 0.04GiB(0.04%) [rank0]:[titan] 2025-09-15 17:38:36,154 - root - WARNING - model.safetensors.index.json not found at hf_assets_path: ./tests/assets/tokenizer/model.safetensors.index.json. Defaulting to saving a single safetensors file if checkpoint is saved in HF format [rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Mixed precision training is handled by AMP [rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Trainer is initialized with local batch size 8, global batch size 64, gradient accumulation steps 1, sequence length 2048, total steps 10 (warmup 2) [ghstack-poisoned]
2 parents 096ce87 + dc0f8a6 commit dfebee0

File tree

2 files changed

+5
-7
lines changed

2 files changed

+5
-7
lines changed

torchtitan/distributed/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def maybe_enable_amp(
250250
else:
251251
if parallel_dims.tp_enabled or parallel_dims.pp_enabled:
252252
logger.warning(
253-
"Mixed precision training with TP or PP is only supported when FSDP/HSDP/CP/replicate is enabled."
253+
"Mixed precision training with TP or PP is only supported when FSDP/HSDP/CP/DDP is enabled."
254254
)
255255
logger.info("Mixed precision training is disabled")
256256
return contextlib.nullcontext()

torchtitan/train.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -641,12 +641,10 @@ def train_step(
641641
global_avg_loss = global_max_loss = loss.detach().item()
642642
global_ntokens_seen = self.ntokens_seen
643643

644-
print(global_avg_loss)
645-
646-
extra_metrics = {
647-
"n_tokens_seen": global_ntokens_seen,
648-
"lr": lr,
649-
}
644+
extra_metrics = {
645+
"n_tokens_seen": global_ntokens_seen,
646+
"lr": lr,
647+
}
650648
self.metrics_processor.log(
651649
self.step,
652650
global_avg_loss,

0 commit comments

Comments
 (0)