Skip to content

Commit 7e7b06f

Browse files
committed
[torchtitan][replicate] experimenting new replicate integration with torchtitan
ghstack-source-id: 69101fb Pull Request resolved: #1714
1 parent e38d7ab commit 7e7b06f

File tree

9 files changed

+83
-68
lines changed

9 files changed

+83
-68
lines changed

torchtitan/distributed/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,14 +243,14 @@ def context():
243243
def maybe_enable_amp(
244244
parallel_dims: ParallelDims, mixed_precision_param: str, device_type: str
245245
) -> contextlib.AbstractContextManager[None] | torch.autocast:
246-
if parallel_dims.fsdp_enabled:
246+
if parallel_dims.fsdp_enabled or parallel_dims.dp_replicate_enabled:
247247
# FSDP handles mixed precision internally
248-
logger.info("Mixed precision training is handled by fully_shard")
248+
logger.info("Mixed precision training is handled by fully_shard or replicate")
249249
return contextlib.nullcontext()
250250
else:
251251
if parallel_dims.tp_enabled or parallel_dims.pp_enabled:
252252
logger.warning(
253-
"Mixed precision training with TP or PP is only supported when FSDP/HSDP/CP is enabled."
253+
"Mixed precision training with TP or PP is only supported when FSDP/HSDP/CP/replicate is enabled."
254254
)
255255
logger.info("Mixed precision training is disabled")
256256
return contextlib.nullcontext()

torchtitan/experiments/rl/unified/actors/trainer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,13 @@ def __init__(
6262

6363
# apply PT-D Parallelism
6464
# TODO: right now it only works for qwen3 model, need to formalize this to use parallize_fn from train_spec
65-
from torchtitan.models.llama3.infra.parallelize import apply_ddp
65+
from torchtitan.models.llama3.infra.parallelize import apply_replicate
6666

67-
apply_ddp(
67+
apply_replicate(
6868
self.model,
6969
self.parallel_dims.get_mesh("dp_replicate"),
70-
enable_compile=False,
70+
param_dtype=torch.bfloat16,
71+
reduce_dtype=torch.float32,
7172
)
7273

7374
self.model = self.model.to(device)

torchtitan/experiments/transformers_modeling_backend/infra/parallelize.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,12 @@
1818
)
1919
from torchtitan.config import TORCH_DTYPE_MAP
2020
from torchtitan.distributed import NoParallel, ParallelDims
21-
2221
from torchtitan.distributed.activation_checkpoint import apply_ac
23-
2422
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp
2523
from torchtitan.experiments.transformers_modeling_backend.job_config import JobConfig
2624
from torchtitan.models.llama3.infra.parallelize import (
2725
apply_compile,
28-
apply_ddp,
26+
apply_replicate,
2927
disable_fsdp_gradient_division,
3028
)
3129
from torchtitan.tools.logging import logger
@@ -113,13 +111,13 @@ def parallelize_hf_transformers(
113111
if job_config.training.enable_cpu_offload:
114112
logger.info("Applied CPU Offloading to the model")
115113
elif parallel_dims.dp_replicate_enabled:
116-
dp_replicate_mesh = parallel_dims.get_mesh("dp_replicate")
117-
if parallel_dims.world_size != dp_replicate_mesh.size():
118-
raise RuntimeError("DDP has not supported > 1D parallelism")
119-
apply_ddp(
114+
dp_mesh_names = ["dp_replicate"]
115+
dp_mesh = parallel_dims.get_mesh(dp_mesh_names)
116+
apply_replicate(
120117
model,
121-
dp_replicate_mesh,
122-
enable_compile=model_compile_enabled,
118+
dp_mesh,
119+
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
120+
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
123121
)
124122

125123
return model

torchtitan/experiments/vlm/infra/parallelize.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,15 @@
99

1010
import torch
1111
import torch.nn as nn
12-
1312
from torch.distributed.device_mesh import DeviceMesh
1413
from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy
15-
1614
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
1715
from torchtitan.distributed import ParallelDims
1816
from torchtitan.distributed.activation_checkpoint import apply_ac
19-
2017
from torchtitan.models.llama3.infra.parallelize import (
2118
_op_sac_save_list,
2219
apply_compile,
23-
apply_ddp,
20+
apply_replicate,
2421
disable_fsdp_gradient_division,
2522
)
2623
from torchtitan.tools.logging import logger
@@ -99,13 +96,13 @@ def parallelize_vlm(
9996
if job_config.training.enable_cpu_offload:
10097
logger.info("Applied CPU Offloading to the model")
10198
elif parallel_dims.dp_replicate_enabled:
102-
dp_mesh = parallel_dims.get_mesh("dp_replicate")
103-
if dp_mesh is not None and dp_mesh.ndim > 1:
104-
raise RuntimeError("DDP has not supported > 1D parallelism")
105-
apply_ddp(
99+
dp_mesh_names = ["dp_replicate"]
100+
dp_mesh = parallel_dims.get_mesh(dp_mesh_names)
101+
apply_replicate(
106102
model,
107103
dp_mesh,
108-
enable_compile=job_config.compile.enable,
104+
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
105+
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
109106
)
110107

111108
return model

torchtitan/models/deepseek_v3/infra/parallelize.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from torchtitan.distributed.context_parallel import apply_cp_to_attention_module
2222
from torchtitan.distributed.dual_pipe_v import get_dual_pipe_v_flag
2323
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp
24-
from torchtitan.models.llama3.infra.parallelize import apply_ddp
24+
from torchtitan.models.llama3.infra.parallelize import apply_replicate
2525
from torchtitan.models.llama4.infra.parallelize import (
2626
apply_compile,
2727
apply_fsdp,
@@ -195,13 +195,13 @@ def parallelize_deepseekv3(
195195
if job_config.training.enable_cpu_offload:
196196
logger.info("Applied CPU Offloading to the model")
197197
elif parallel_dims.dp_replicate_enabled:
198-
dp_mesh = parallel_dims.get_mesh("dp_replicate")
199-
if dp_mesh.ndim > 1:
200-
raise RuntimeError("DDP has not supported > 1D parallelism")
201-
apply_ddp(
198+
dp_mesh_names = ["dp_replicate"]
199+
dp_mesh = parallel_dims.get_mesh(dp_mesh_names)
200+
apply_replicate(
202201
model,
203202
dp_mesh,
204-
enable_compile=model_compile_enabled,
203+
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
204+
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
205205
)
206206

207207
return model

torchtitan/models/gpt_oss/infra/parallelize.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import torch._inductor.config
99
import torch.nn as nn
1010
from torch.distributed.device_mesh import DeviceMesh
11-
1211
from torch.distributed.tensor import distribute_tensor, Partial, Replicate, Shard
1312
from torch.distributed.tensor.parallel import (
1413
ColwiseParallel,
@@ -31,7 +30,7 @@
3130
ExpertParallel,
3231
ReordererSequenceParallel,
3332
)
34-
from torchtitan.models.llama3.infra.parallelize import apply_ddp
33+
from torchtitan.models.llama3.infra.parallelize import apply_replicate
3534
from torchtitan.models.llama4.infra.parallelize import apply_fsdp
3635
from torchtitan.tools.logging import logger
3736

@@ -160,13 +159,13 @@ def parallelize_gptoss(
160159
if job_config.training.enable_cpu_offload:
161160
logger.info("Applied CPU Offloading to the model")
162161
elif parallel_dims.dp_replicate_enabled:
163-
dp_mesh = parallel_dims.get_mesh("dp_replicate")
164-
if dp_mesh is not None and dp_mesh.ndim > 1:
165-
raise RuntimeError("DDP has not supported > 1D parallelism")
166-
apply_ddp(
162+
dp_mesh_names = ["dp_replicate"]
163+
dp_mesh = parallel_dims.get_mesh(dp_mesh_names)
164+
apply_replicate(
167165
model,
168166
dp_mesh,
169-
enable_compile=model_compile_enabled,
167+
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
168+
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
170169
)
171170

172171
return model

torchtitan/models/llama3/infra/parallelize.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
import torch
1111
import torch.nn as nn
1212
from torch.distributed._composable.fsdp import FSDPModule
13-
from torch.distributed._composable.replicate import replicate
14-
13+
from torch.distributed._composable.replicate_with_fsdp import replicate
1514
from torch.distributed.device_mesh import DeviceMesh
1615
from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy
1716
from torch.distributed.tensor import Replicate, Shard
@@ -22,7 +21,6 @@
2221
RowwiseParallel,
2322
SequenceParallel,
2423
)
25-
2624
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
2725
from torchtitan.config.job_config import Compile as CompileConfig
2826
from torchtitan.distributed import ParallelDims
@@ -146,13 +144,13 @@ def parallelize_llama(
146144
if job_config.training.enable_cpu_offload:
147145
logger.info("Applied CPU Offloading to the model")
148146
elif parallel_dims.dp_replicate_enabled:
149-
dp_replicate_mesh = parallel_dims.get_mesh("dp_replicate")
150-
if parallel_dims.world_size != dp_replicate_mesh.size():
151-
raise RuntimeError("DDP has not supported > 1D parallelism")
152-
apply_ddp(
147+
names = ["dp_replicate"]
148+
dp_mesh = parallel_dims.get_mesh(names)
149+
apply_replicate(
153150
model,
154-
dp_replicate_mesh,
155-
enable_compile=model_compile_enabled,
151+
dp_mesh,
152+
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
153+
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
156154
)
157155

158156
return model
@@ -276,8 +274,10 @@ def disable_fsdp_gradient_division(model: nn.Module) -> None:
276274
Set gradient_divide_factor=1.0 to disable FSDP's automatic gradient division.
277275
We handle gradient scaling ourselves in the training loop with global token count.
278276
277+
Note: This also works for ReplicateModule since it inherits from FSDPModule.
278+
279279
Args:
280-
model: The model containing FSDP-wrapped modules
280+
model: The model containing FSDP-wrapped or Replicate-wrapped modules
281281
"""
282282
for module in model.modules():
283283
if isinstance(module, FSDPModule):
@@ -360,15 +360,37 @@ def apply_fsdp(
360360
disable_fsdp_gradient_division(model)
361361

362362

363-
def apply_ddp(
363+
def apply_replicate(
364364
model: nn.Module,
365365
dp_mesh: DeviceMesh,
366-
enable_compile: bool,
366+
param_dtype: torch.dtype,
367+
reduce_dtype: torch.dtype,
367368
):
368-
if enable_compile:
369-
torch._dynamo.config.optimize_ddp = "ddp_optimizer"
369+
mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
370+
replicate_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
370371

371-
# pyrefly: ignore [invalid-param-spec]
372-
replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
372+
if model.tok_embeddings is not None:
373+
# pyrefly: ignore [no-matching-overload]
374+
replicate(
375+
model.tok_embeddings,
376+
**replicate_config,
377+
)
378+
# pyrefly: ignore [missing-attribute]
379+
for layer_id, transformer_block in model.layers.items():
380+
replicate(
381+
transformer_block,
382+
**replicate_config,
383+
)
384+
385+
if model.norm is not None and model.output is not None:
386+
# pyrefly: ignore [no-matching-overload]
387+
replicate(
388+
[model.norm, model.output],
389+
**replicate_config,
390+
)
391+
replicate(model, **replicate_config)
392+
393+
# Disable Replicate's automatic gradient division (ReplicateModule inherits from FSDPModule)
394+
disable_fsdp_gradient_division(model)
373395

374-
logger.info("Applied DDP to the model")
396+
logger.info("Applied replicate to the model")

torchtitan/models/llama4/infra/parallelize.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
)
4242
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp
4343
from torchtitan.models.llama3.infra.parallelize import (
44-
apply_ddp,
44+
apply_replicate,
4545
disable_fsdp_gradient_division,
4646
)
4747
from torchtitan.models.moe import moe as moe_module
@@ -214,13 +214,13 @@ def parallelize_llama(
214214
if job_config.training.enable_cpu_offload:
215215
logger.info("Applied CPU Offloading to the model")
216216
elif parallel_dims.dp_replicate_enabled:
217-
dp_mesh = parallel_dims.get_mesh("dp_replicate")
218-
if parallel_dims.world_size != dp_mesh.size():
219-
raise RuntimeError("DDP has not supported > 1D parallelism")
220-
apply_ddp(
217+
dp_mesh_names = ["dp_replicate"]
218+
dp_mesh = parallel_dims.get_mesh(dp_mesh_names)
219+
apply_replicate(
221220
model,
222221
dp_mesh,
223-
enable_compile=model_compile_enabled,
222+
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
223+
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
224224
)
225225

226226
return model

torchtitan/models/qwen3/infra/parallelize.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import torch
1111
import torch._inductor.config
1212
import torch.nn as nn
13-
1413
from torch.distributed.device_mesh import DeviceMesh
1514
from torch.distributed.tensor import Replicate, Shard
1615
from torch.distributed.tensor.parallel import (
@@ -20,13 +19,12 @@
2019
RowwiseParallel,
2120
SequenceParallel,
2221
)
23-
2422
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
2523
from torchtitan.distributed import ParallelDims
2624
from torchtitan.distributed.activation_checkpoint import apply_ac
2725
from torchtitan.distributed.context_parallel import apply_cp_to_attention_module
2826
from torchtitan.distributed.dual_pipe_v import get_dual_pipe_v_flag
29-
from torchtitan.models.llama3.infra.parallelize import apply_ddp
27+
from torchtitan.models.llama3.infra.parallelize import apply_replicate
3028
from torchtitan.models.llama4.infra.parallelize import (
3129
apply_compile,
3230
apply_fsdp,
@@ -175,13 +173,13 @@ def parallelize_qwen3(
175173
if job_config.training.enable_cpu_offload:
176174
logger.info("Applied CPU Offloading to the model")
177175
elif parallel_dims.dp_replicate_enabled:
178-
dp_mesh = parallel_dims.get_mesh("dp_replicate")
179-
if dp_mesh.ndim > 1:
180-
raise RuntimeError("DDP has not supported > 1D parallelism")
181-
apply_ddp(
176+
dp_mesh_names = ["dp_replicate"]
177+
dp_mesh = parallel_dims.get_mesh(dp_mesh_names)
178+
apply_replicate(
182179
model,
183180
dp_mesh,
184-
enable_compile=model_compile_enabled,
181+
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
182+
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
185183
)
186184

187185
return model

0 commit comments

Comments
 (0)