|
10 | 10 | import torch |
11 | 11 | import torch.nn as nn |
12 | 12 | from torch.distributed._composable.fsdp import FSDPModule |
13 | | -from torch.distributed._composable.replicate import replicate |
14 | | - |
| 13 | +from torch.distributed._composable.replicate_with_fsdp import replicate |
15 | 14 | from torch.distributed.device_mesh import DeviceMesh |
16 | 15 | from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy |
17 | 16 | from torch.distributed.tensor import Replicate, Shard |
|
22 | 21 | RowwiseParallel, |
23 | 22 | SequenceParallel, |
24 | 23 | ) |
25 | | - |
26 | 24 | from torchtitan.config import JobConfig, TORCH_DTYPE_MAP |
27 | 25 | from torchtitan.config.job_config import Compile as CompileConfig |
28 | 26 | from torchtitan.distributed import ParallelDims |
@@ -146,13 +144,13 @@ def parallelize_llama( |
146 | 144 | if job_config.training.enable_cpu_offload: |
147 | 145 | logger.info("Applied CPU Offloading to the model") |
148 | 146 | elif parallel_dims.dp_replicate_enabled: |
149 | | - dp_replicate_mesh = parallel_dims.get_mesh("dp_replicate") |
150 | | - if parallel_dims.world_size != dp_replicate_mesh.size(): |
151 | | - raise RuntimeError("DDP has not supported > 1D parallelism") |
152 | | - apply_ddp( |
| 147 | + names = ["dp_replicate"] |
| 148 | + dp_mesh = parallel_dims.get_mesh(names) |
| 149 | + apply_replicate( |
153 | 150 | model, |
154 | | - dp_replicate_mesh, |
155 | | - enable_compile=model_compile_enabled, |
| 151 | + dp_mesh, |
| 152 | + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], |
| 153 | + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], |
156 | 154 | ) |
157 | 155 |
|
158 | 156 | return model |
@@ -276,8 +274,10 @@ def disable_fsdp_gradient_division(model: nn.Module) -> None: |
276 | 274 | Set gradient_divide_factor=1.0 to disable FSDP's automatic gradient division. |
277 | 275 | We handle gradient scaling ourselves in the training loop with global token count. |
278 | 276 |
|
| 277 | + Note: This also works for ReplicateModule since it inherits from FSDPModule. |
| 278 | +
|
279 | 279 | Args: |
280 | | - model: The model containing FSDP-wrapped modules |
| 280 | + model: The model containing FSDP-wrapped or Replicate-wrapped modules |
281 | 281 | """ |
282 | 282 | for module in model.modules(): |
283 | 283 | if isinstance(module, FSDPModule): |
@@ -360,15 +360,37 @@ def apply_fsdp( |
360 | 360 | disable_fsdp_gradient_division(model) |
361 | 361 |
|
362 | 362 |
|
363 | | -def apply_ddp( |
| 363 | +def apply_replicate( |
364 | 364 | model: nn.Module, |
365 | 365 | dp_mesh: DeviceMesh, |
366 | | - enable_compile: bool, |
| 366 | + param_dtype: torch.dtype, |
| 367 | + reduce_dtype: torch.dtype, |
367 | 368 | ): |
368 | | - if enable_compile: |
369 | | - torch._dynamo.config.optimize_ddp = "ddp_optimizer" |
| 369 | + mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) |
| 370 | + replicate_config = {"mesh": dp_mesh, "mp_policy": mp_policy} |
370 | 371 |
|
371 | | - # pyrefly: ignore [invalid-param-spec] |
372 | | - replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) |
| 372 | + if model.tok_embeddings is not None: |
| 373 | + # pyrefly: ignore [no-matching-overload] |
| 374 | + replicate( |
| 375 | + model.tok_embeddings, |
| 376 | + **replicate_config, |
| 377 | + ) |
| 378 | + # pyrefly: ignore [missing-attribute] |
| 379 | + for layer_id, transformer_block in model.layers.items(): |
| 380 | + replicate( |
| 381 | + transformer_block, |
| 382 | + **replicate_config, |
| 383 | + ) |
| 384 | + |
| 385 | + if model.norm is not None and model.output is not None: |
| 386 | + # pyrefly: ignore [no-matching-overload] |
| 387 | + replicate( |
| 388 | + [model.norm, model.output], |
| 389 | + **replicate_config, |
| 390 | + ) |
| 391 | + replicate(model, **replicate_config) |
| 392 | + |
| 393 | + # Disable Replicate's automatic gradient division (ReplicateModule inherits from FSDPModule) |
| 394 | + disable_fsdp_gradient_division(model) |
373 | 395 |
|
374 | | - logger.info("Applied DDP to the model") |
| 396 | + logger.info("Applied replicate to the model") |
0 commit comments