Skip to content

Commit 8e4bb4f

Browse files
feginwwwjn
authored andcommitted
[CP] Refactor Context Parallel to use new PyTorch CP APIs (#2144)
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.14.0) (oldest at bottom): * #2145 * __->__ #2144 **Summary** 1. Refactored CP Dispatching: - New apply_cp() function uses PyTorch's _ContextParallel parallelization plan to dispatch attention call. - Enables CP dispatcher for SDPA attention type inside apply_cp() 2. New CP Data Sharding Approach: - Added a cp_shard() helper function that wraps PyTorch's _context_parallel_shard API - Uses _HeadTailLoadBalancer for SDPA attention load balancing - FlexAttention CP support deferred to a future PR - CP sharding now happens explicitly in post_dataloading_process() where inputs, labels, and positions are sharded - The new positions argument allows us to not shard the freqs_cis. Note that this PR require pytorch/pytorch#170200 **Test** ``` -> % python3 scripts/loss_compare.py . chienchin/loss_compare --baseline-options="--parallelism.context_parallel_degree=8" --test-options="--parallelism.context_parallel_degree=8" --steps=100 --assert-equal pick 5903566a Improve the loss_compare.sh logic [LOSS_COMPARE] [LOSS_COMPARE] Asserting losses are equal... [LOSS_COMPARE] Baseline log: /tmp/baseline_training.log [LOSS_COMPARE] Test log: /tmp/test_training.log [LOSS_COMPARE] Extracted 100 steps from baseline log [LOSS_COMPARE] Extracted 100 steps from test log test_losses_equal (__main__.assert_losses_equal.<locals>.LossEqualityTest.test_losses_equal) ... ok ---------------------------------------------------------------------- Ran 1 test in 0.000s OK [LOSS_COMPARE] All losses are equal. Assertion passed! [LOSS_COMPARE] ========================================== [LOSS_COMPARE] LOSS COMPARISON ANALYSIS [LOSS_COMPARE] ========================================== [LOSS_COMPARE] Step-by-step loss comparison: [LOSS_COMPARE] Step Baseline Loss Test Loss Difference [LOSS_COMPARE] ---- ------------- --------- ---------- [LOSS_COMPARE] 1 8.1309 8.1309 0.000000 [LOSS_COMPARE] 2 7.8268 7.8268 0.000000 [LOSS_COMPARE] 3 7.2284 7.2284 0.000000 [LOSS_COMPARE] 4 6.4669 6.4669 0.000000 [LOSS_COMPARE] 5 5.4017 5.4017 0.000000 [LOSS_COMPARE] 6 4.7656 4.7656 0.000000 [LOSS_COMPARE] 7 4.3587 4.3587 0.000000 [LOSS_COMPARE] 8 4.0938 4.0938 0.000000 [LOSS_COMPARE] 9 4.4019 4.4019 0.000000 [LOSS_COMPARE] 10 3.7451 3.7451 0.000000 .... [LOSS_COMPARE] 90 2.802 2.802 0.000000 [LOSS_COMPARE] 91 2.7207 2.7207 0.000000 [LOSS_COMPARE] 92 2.7454 2.7454 0.000000 [LOSS_COMPARE] 93 2.6992 2.6992 0.000000 [LOSS_COMPARE] 94 2.743 2.743 0.000000 [LOSS_COMPARE] 95 2.7534 2.7534 0.000000 [LOSS_COMPARE] 96 2.8403 2.8403 0.000000 [LOSS_COMPARE] 97 2.783 2.783 0.000000 [LOSS_COMPARE] 98 3.0892 3.0892 0.000000 [LOSS_COMPARE] 99 2.7905 2.7905 0.000000 [LOSS_COMPARE] 100 2.733 2.733 0.000000 [LOSS_COMPARE] [LOSS_COMPARE] Summary statistics: [LOSS_COMPARE] Average baseline loss: 3.1414940000000002 [LOSS_COMPARE] Average test loss: 3.1414940000000002 [LOSS_COMPARE] Average difference: 0.000000 [LOSS_COMPARE] [LOSS_COMPARE] Loss comparison complete. No results saved (no output folder specified). ``` **TODO** - This PR will invalidate torch.compile + CP due to pytorch/pytorch#170110. We will have to wait for Dynamo to fix the issue or refactor nn.Module core logic to avoid check hook_id.
1 parent c5fd490 commit 8e4bb4f

File tree

20 files changed

+515
-169
lines changed

20 files changed

+515
-169
lines changed

torchtitan/components/validate.py

Lines changed: 85 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from collections.abc import Callable
88
from contextlib import AbstractContextManager
9-
from typing import TypeAlias
9+
from typing import Any, TypeAlias
1010

1111
import torch
1212
import torch.nn as nn
@@ -17,14 +17,12 @@
1717
from torchtitan.components.tokenizer import BaseTokenizer
1818
from torchtitan.config import JobConfig
1919
from torchtitan.distributed import ParallelDims, utils as dist_utils
20+
from torchtitan.distributed.context_parallel import prepare_context_parallel_input
2021
from torchtitan.hf_datasets.text_datasets import build_text_validation_dataloader
2122
from torchtitan.tools import utils
2223
from torchtitan.tools.logging import logger
2324

24-
ValidationContext: TypeAlias = Callable[
25-
[AbstractContextManager[None] | None],
26-
AbstractContextManager[None],
27-
]
25+
ValidationContext: TypeAlias = Callable[[], AbstractContextManager[None]]
2826

2927

3028
class BaseValidator:
@@ -67,6 +65,7 @@ def __init__(
6765
pp_has_last_stage: bool | None = None,
6866
):
6967
self.job_config = job_config
68+
self.tokenizer = tokenizer
7069
self.parallel_dims = parallel_dims
7170
self.loss_fn = loss_fn
7271
self.validation_dataloader = build_text_validation_dataloader(
@@ -89,6 +88,70 @@ def __init__(
8988
"unequal sample counts across ranks when dataset is exhausted."
9089
)
9190

91+
def post_dataloading_process(
92+
self,
93+
input_dict: dict[str, torch.Tensor],
94+
labels: torch.Tensor,
95+
model_parts: list[nn.Module],
96+
) -> tuple[torch.Tensor, torch.Tensor, dict[str, torch.Tensor], dict[str, Any]]:
97+
"""
98+
Post-processing hook after data loading and before model forward pass.
99+
100+
This method processes the raw data from the dataloader and prepares it for
101+
the model's forward pass. It separates the main input tensor from auxiliary
102+
inputs and constructs additional keyword arguments (e.g., attention masks).
103+
104+
Args:
105+
input_dict: Dictionary containing tensors from the dataloader. Must
106+
contain an "input" key with the main input tensor. May contain
107+
additional keys for auxiliary inputs (e.g., position ids).
108+
labels: Target labels for the batch.
109+
model_parts: List of model parts for accessing model methods.
110+
111+
Returns:
112+
A tuple of (inputs, labels, extra_inputs, extra_kwargs) where:
113+
- inputs: Main input tensor extracted from input_dict["input"].
114+
- labels: Target labels (potentially modified by CP sharding).
115+
- extra_inputs: Dict of auxiliary input tensors (all keys except
116+
"input" from input_dict). These are passed to the model forward
117+
but are NOT forwarded across pipeline parallel stages.
118+
- extra_kwargs: Dict of additional keyword arguments for model forward.
119+
These ARE forwarded across pipeline parallel stages. Contains
120+
attention_masks if flex attention is enabled.
121+
122+
Note:
123+
The distinction between extra_inputs and extra_kwargs is important for
124+
pipeline parallelism: extra_kwargs are forwarded to all pipeline stages,
125+
while extra_inputs are only available to the first stage.
126+
"""
127+
inputs = input_dict["input"]
128+
extra_inputs = {k: v for k, v in input_dict.items() if k != "input"}
129+
# For arguments, like attention_masks, we have to put them in a separate
130+
# dict as extra_inputs are not forwarded to other stages in PP, but
131+
# extra_kwargs are.
132+
extra_kwargs: dict[str, Any] = {}
133+
134+
try:
135+
# pyrefly: ignore [not-callable]
136+
extra_kwargs["attention_masks"] = model_parts[0].get_attention_masks(
137+
input_batch=inputs,
138+
tokenizer=self.tokenizer,
139+
extra_inputs=extra_inputs,
140+
)
141+
except TypeError:
142+
pass
143+
144+
if self.parallel_dims.cp_enabled:
145+
inputs, labels, extra_kwargs = prepare_context_parallel_input(
146+
inputs,
147+
labels,
148+
extra_kwargs,
149+
self.parallel_dims.get_mesh("cp"),
150+
inputs.device,
151+
)
152+
153+
return inputs, labels, extra_inputs, extra_kwargs
154+
92155
@torch.no_grad()
93156
# pyrefly: ignore [bad-override]
94157
def validate(
@@ -117,9 +180,13 @@ def validate(
117180
self.metrics_processor.ntokens_since_last_log += labels.numel()
118181
for k, v in input_dict.items():
119182
input_dict[k] = v.to(device_type)
120-
inputs = input_dict["input"]
121183
labels = labels.to(device_type)
122184

185+
# Process data (extract inputs, handle attention masks, CP sharding)
186+
inputs, labels, extra_inputs, extra_kwargs = self.post_dataloading_process(
187+
input_dict, labels, model_parts
188+
)
189+
123190
# Count valid tokens for this batch
124191
local_valid_tokens = torch.tensor(0, dtype=torch.int64, device=device_type)
125192
# pyrefly: ignore [missing-attribute]
@@ -150,18 +217,24 @@ def validate(
150217
assert self.pp_has_first_stage is not None
151218
assert self.pp_has_last_stage is not None
152219
# Pipeline Parallel forward inside eval() call
153-
with self.validation_context(optional_context_parallel_ctx):
220+
with self.validation_context():
154221
targets, losses = (
155222
(labels, []) if self.pp_has_last_stage else (None, None)
156223
)
157224
if self.pp_has_first_stage:
158225
self.pp_schedule.eval(
159226
inputs,
227+
**extra_inputs,
228+
**extra_kwargs,
160229
target=targets,
161230
losses=losses,
162231
)
163232
else:
164-
self.pp_schedule.eval(target=targets, losses=losses)
233+
self.pp_schedule.eval(
234+
**extra_kwargs,
235+
target=targets,
236+
losses=losses,
237+
)
165238

166239
# accumulate losses across pipeline microbatches
167240
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU
@@ -172,10 +245,12 @@ def validate(
172245
else torch.tensor([-1.0], device=device_type)
173246
)
174247
else:
175-
with self.validation_context(optional_context_parallel_ctx):
248+
with self.validation_context():
176249
assert len(model_parts) == 1
177250
with self.maybe_enable_amp:
178-
predictions = model_parts[0](inputs)
251+
predictions = model_parts[0](
252+
inputs, **extra_inputs, **extra_kwargs
253+
)
179254
loss_sum = self.loss_fn(predictions, labels)
180255

181256
accumulated_losses.append(loss_sum.detach() / global_valid_tokens)
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from collections.abc import Sequence
8+
from typing import Any, cast
9+
10+
import torch
11+
import torch.nn as nn
12+
from torch.distributed.device_mesh import DeviceMesh
13+
from torch.distributed.tensor.experimental._attention import (
14+
_context_parallel_shard,
15+
_ContextParallel,
16+
_enable_context_parallel_dispatcher,
17+
_HeadTailLoadBalancer,
18+
)
19+
from torch.distributed.tensor.parallel import parallelize_module
20+
21+
from torchtitan.protocols.model import AttentionMasksType
22+
from torchtitan.tools.logging import logger
23+
24+
25+
def apply_cp_to_attention_module(
26+
attention_modules: Sequence[nn.Module],
27+
cp_mesh: DeviceMesh,
28+
attention_type: str,
29+
) -> None:
30+
"""
31+
Apply context parallelism to attention modules.
32+
33+
CP splits the sequence dimension across devices to enable training with
34+
longer sequences. This function applies CP to the provided attention
35+
modules.
36+
37+
Args:
38+
attention_modules: Sequence of attention modules to apply CP to
39+
cp_mesh: Device mesh for context parallel dimension
40+
attention_type: Type of attention mechanism. Must be one of:
41+
- "sdpa": scaled_dot_product_attention()
42+
- "flex": flex_attention()
43+
- "varlen": varlen_attn() (not yet implemented)
44+
45+
Raises:
46+
NotImplementedError: If attention_type is "varlen"
47+
"""
48+
# Apply context parallelism to every attention module
49+
# TODO: make seq_dim configurable once the implementation doesn't assume 2
50+
# internally.
51+
match attention_type:
52+
case "flex":
53+
cp_plan = _ContextParallel(
54+
seq_dim=2, attention_type=_ContextParallel.AttentionType.FLEX
55+
)
56+
case "sdpa":
57+
# Enable the DTensor dispatcher to route SDPA operations to the
58+
# Context Parallel implementation. This is required for CP to work
59+
# with SDPA (but not FlexAttention).
60+
# Note: Use _disable_context_parallel_dispatcher() if you need to
61+
# turn this off. In TorchTitan, we currently don't disable the CP
62+
# dispatcher.
63+
_enable_context_parallel_dispatcher()
64+
cp_plan = _ContextParallel(
65+
seq_dim=2, attention_type=_ContextParallel.AttentionType.SDPA
66+
)
67+
case "varlen":
68+
raise NotImplementedError(
69+
"Variable-length attention CP is not yet supported"
70+
)
71+
case _:
72+
raise ValueError(
73+
f"Invalid attention_type '{attention_type}'. "
74+
f"Must be one of: 'sdpa', 'flex', 'varlen'"
75+
)
76+
77+
for attention_module in attention_modules:
78+
parallelize_module(
79+
module=attention_module,
80+
device_mesh=cp_mesh,
81+
parallelize_plan=cp_plan,
82+
)
83+
84+
logger.info("Applied Context Parallel to the model")
85+
86+
87+
def prepare_context_parallel_input(
88+
inputs: torch.Tensor,
89+
labels: torch.Tensor,
90+
extra_kwargs: dict[str, Any],
91+
cp_mesh: DeviceMesh,
92+
device: torch.device,
93+
) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]:
94+
"""
95+
Prepare inputs, labels, and attention masks for Context Parallel forward pass.
96+
97+
This function prepares tensors for context parallel by:
98+
1. Creating position indices based on input sequence length
99+
2. Sharding inputs, labels, and positions across the CP mesh
100+
3. Sharding attention masks if present
101+
102+
Args:
103+
inputs: Input tensor of shape [batch_size, seq_len]
104+
labels: Label tensor of shape [batch_size, seq_len]
105+
extra_kwargs: Dictionary that may contain 'attention_masks' to be sharded
106+
cp_mesh: Device mesh for context parallel dimension
107+
device: Device to create position tensor on
108+
109+
Returns:
110+
Tuple of (sharded_inputs, sharded_labels, updated_extra_kwargs) where:
111+
- sharded_inputs: Inputs sharded along sequence dimension
112+
- sharded_labels: Labels sharded along sequence dimension
113+
- updated_extra_kwargs: Dict with sharded 'positions' and optionally
114+
sharded 'attention_masks'
115+
"""
116+
attention_masks = extra_kwargs.get("attention_masks", None)
117+
positions = torch.arange(
118+
0, inputs.shape[1], dtype=torch.int32, device=device
119+
).expand(inputs.shape)
120+
(inputs, labels, positions), attention_masks = cp_shard(
121+
cp_mesh,
122+
(inputs, labels, positions),
123+
attention_masks,
124+
)
125+
extra_kwargs["positions"] = positions
126+
if attention_masks is not None:
127+
extra_kwargs["attention_masks"] = attention_masks
128+
129+
return inputs, labels, extra_kwargs
130+
131+
132+
def cp_shard(
133+
cp_mesh: DeviceMesh,
134+
inputs: tuple[torch.Tensor, ...],
135+
attention_masks: AttentionMasksType | None,
136+
disable_load_balancer: bool = False,
137+
input_seq_dim: int = 1,
138+
) -> tuple[tuple[torch.Tensor, ...], AttentionMasksType | None]:
139+
"""
140+
Shard inputs and attention masks across the context parallel mesh.
141+
142+
This function distributes input tensors across devices in the CP mesh
143+
along the sequence dimension. It optionally uses a load balancer to
144+
handle uneven computation workload. Currently, HeadTailLoadBalancer is
145+
used for SDPA + CP, which is the only supported configuration.
146+
147+
Args:
148+
cp_mesh: Device mesh for context parallel dimension
149+
inputs: Tuple of input tensors to be sharded along the sequence
150+
dimension
151+
attention_masks: Attention masks to be sharded (currently raises
152+
error as FlexAttention CP is not yet supported)
153+
disable_load_balancer: If True, disables load balancing. If False
154+
(default), uses HeadTailLoadBalancer for SDPA to handle uneven
155+
computation workload.
156+
input_seq_dim: Sequence dimension index for sharding. Defaults to 1,
157+
which covers most use cases where tensors have shape
158+
[batch_size, seq_len, ...]. Can be changed by passing a
159+
different value if your tensors use a different sequence
160+
dimension layout.
161+
162+
Returns:
163+
Tuple of (sharded_inputs, attention_masks) where:
164+
- sharded_inputs: Tuple of input tensors sharded along the
165+
sequence dimension
166+
- attention_masks: Attention masks (currently unchanged/None)
167+
"""
168+
seq_len = inputs[0].size(input_seq_dim)
169+
cp_world_size = cp_mesh.size(0)
170+
if attention_masks is not None:
171+
raise ValueError(
172+
"FlexAttention CP is not supported yet. Will come in the next PR."
173+
)
174+
else:
175+
# For SDPA, we use the _HeadTailLoadBalancer.
176+
load_balancer = (
177+
None
178+
if disable_load_balancer
179+
else _HeadTailLoadBalancer(seq_len, cp_world_size, cp_mesh.device_type)
180+
)
181+
182+
inputs = cast(
183+
tuple[torch.Tensor, ...],
184+
_context_parallel_shard(
185+
mesh=cp_mesh,
186+
buffers=inputs,
187+
seq_dims=tuple(input_seq_dim for _ in inputs),
188+
load_balancer=load_balancer,
189+
),
190+
)
191+
192+
return inputs, attention_masks

torchtitan/distributed/utils.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -224,23 +224,17 @@ def create_context_parallel_ctx(
224224

225225
class TrainContext(Protocol):
226226
@abstractmethod
227-
def __call__(
228-
self,
229-
cp_context: contextlib.AbstractContextManager[None] | None = None,
230-
) -> contextlib.AbstractContextManager[None]:
227+
def __call__(self) -> contextlib.AbstractContextManager[None]:
231228
pass
232229

233230

234231
def get_train_context(enable_loss_parallel: bool) -> TrainContext:
235232
@contextlib.contextmanager
236-
def context(cp_context: contextlib.AbstractContextManager[None] | None = None):
233+
def context():
237234
with contextlib.ExitStack() as stack:
238235
if enable_loss_parallel:
239236
stack.enter_context(torch.distributed.tensor.parallel.loss_parallel())
240237

241-
if cp_context:
242-
stack.enter_context(cp_context)
243-
244238
yield
245239

246240
return context

0 commit comments

Comments
 (0)