Commit 8e4bb4f
[CP] Refactor Context Parallel to use new PyTorch CP APIs (#2144)
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.14.0)
(oldest at bottom):
* #2145
* __->__ #2144
**Summary**
1. Refactored CP Dispatching:
- New apply_cp() function uses PyTorch's _ContextParallel
parallelization plan to dispatch attention call.
- Enables CP dispatcher for SDPA attention type inside apply_cp()
2. New CP Data Sharding Approach:
- Added a cp_shard() helper function that wraps PyTorch's
_context_parallel_shard API
- Uses _HeadTailLoadBalancer for SDPA attention load balancing
- FlexAttention CP support deferred to a future PR
- CP sharding now happens explicitly in post_dataloading_process() where
inputs, labels, and positions are sharded
- The new positions argument allows us to not shard the freqs_cis.
Note that this PR require pytorch/pytorch#170200
**Test**
```
-> % python3 scripts/loss_compare.py . chienchin/loss_compare --baseline-options="--parallelism.context_parallel_degree=8" --test-options="--parallelism.context_parallel_degree=8" --steps=100 --assert-equal
pick 5903566a Improve the loss_compare.sh logic
[LOSS_COMPARE]
[LOSS_COMPARE] Asserting losses are equal...
[LOSS_COMPARE] Baseline log: /tmp/baseline_training.log
[LOSS_COMPARE] Test log: /tmp/test_training.log
[LOSS_COMPARE] Extracted 100 steps from baseline log
[LOSS_COMPARE] Extracted 100 steps from test log
test_losses_equal
(__main__.assert_losses_equal.<locals>.LossEqualityTest.test_losses_equal)
... ok
----------------------------------------------------------------------
Ran 1 test in 0.000s
OK
[LOSS_COMPARE] All losses are equal. Assertion passed!
[LOSS_COMPARE] ==========================================
[LOSS_COMPARE] LOSS COMPARISON ANALYSIS
[LOSS_COMPARE] ==========================================
[LOSS_COMPARE] Step-by-step loss comparison:
[LOSS_COMPARE] Step Baseline Loss Test Loss Difference
[LOSS_COMPARE] ---- ------------- --------- ----------
[LOSS_COMPARE] 1 8.1309 8.1309 0.000000
[LOSS_COMPARE] 2 7.8268 7.8268 0.000000
[LOSS_COMPARE] 3 7.2284 7.2284 0.000000
[LOSS_COMPARE] 4 6.4669 6.4669 0.000000
[LOSS_COMPARE] 5 5.4017 5.4017 0.000000
[LOSS_COMPARE] 6 4.7656 4.7656 0.000000
[LOSS_COMPARE] 7 4.3587 4.3587 0.000000
[LOSS_COMPARE] 8 4.0938 4.0938 0.000000
[LOSS_COMPARE] 9 4.4019 4.4019 0.000000
[LOSS_COMPARE] 10 3.7451 3.7451 0.000000
....
[LOSS_COMPARE] 90 2.802 2.802 0.000000
[LOSS_COMPARE] 91 2.7207 2.7207 0.000000
[LOSS_COMPARE] 92 2.7454 2.7454 0.000000
[LOSS_COMPARE] 93 2.6992 2.6992 0.000000
[LOSS_COMPARE] 94 2.743 2.743 0.000000
[LOSS_COMPARE] 95 2.7534 2.7534 0.000000
[LOSS_COMPARE] 96 2.8403 2.8403 0.000000
[LOSS_COMPARE] 97 2.783 2.783 0.000000
[LOSS_COMPARE] 98 3.0892 3.0892 0.000000
[LOSS_COMPARE] 99 2.7905 2.7905 0.000000
[LOSS_COMPARE] 100 2.733 2.733 0.000000
[LOSS_COMPARE]
[LOSS_COMPARE] Summary statistics:
[LOSS_COMPARE] Average baseline loss: 3.1414940000000002
[LOSS_COMPARE] Average test loss: 3.1414940000000002
[LOSS_COMPARE] Average difference: 0.000000
[LOSS_COMPARE]
[LOSS_COMPARE] Loss comparison complete. No results saved (no output
folder specified).
```
**TODO**
- This PR will invalidate torch.compile + CP due to
pytorch/pytorch#170110. We will have to wait
for Dynamo to fix the issue or refactor nn.Module core logic to avoid
check hook_id.1 parent c5fd490 commit 8e4bb4f
File tree
20 files changed
+515
-169
lines changed- torchtitan
- components
- distributed
- experiments
- forge
- simple_fsdp/deepseek_v3
- models
- deepseek_v3
- infra
- model
- flux
- infra
- model
- llama3
- infra
- model
- llama4
- infra
- model
- qwen3
- infra
- model
20 files changed
+515
-169
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
6 | 6 | | |
7 | 7 | | |
8 | 8 | | |
9 | | - | |
| 9 | + | |
10 | 10 | | |
11 | 11 | | |
12 | 12 | | |
| |||
17 | 17 | | |
18 | 18 | | |
19 | 19 | | |
| 20 | + | |
20 | 21 | | |
21 | 22 | | |
22 | 23 | | |
23 | 24 | | |
24 | | - | |
25 | | - | |
26 | | - | |
27 | | - | |
| 25 | + | |
28 | 26 | | |
29 | 27 | | |
30 | 28 | | |
| |||
67 | 65 | | |
68 | 66 | | |
69 | 67 | | |
| 68 | + | |
70 | 69 | | |
71 | 70 | | |
72 | 71 | | |
| |||
89 | 88 | | |
90 | 89 | | |
91 | 90 | | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
| 153 | + | |
| 154 | + | |
92 | 155 | | |
93 | 156 | | |
94 | 157 | | |
| |||
117 | 180 | | |
118 | 181 | | |
119 | 182 | | |
120 | | - | |
121 | 183 | | |
122 | 184 | | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
123 | 190 | | |
124 | 191 | | |
125 | 192 | | |
| |||
150 | 217 | | |
151 | 218 | | |
152 | 219 | | |
153 | | - | |
| 220 | + | |
154 | 221 | | |
155 | 222 | | |
156 | 223 | | |
157 | 224 | | |
158 | 225 | | |
159 | 226 | | |
| 227 | + | |
| 228 | + | |
160 | 229 | | |
161 | 230 | | |
162 | 231 | | |
163 | 232 | | |
164 | | - | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
165 | 238 | | |
166 | 239 | | |
167 | 240 | | |
| |||
172 | 245 | | |
173 | 246 | | |
174 | 247 | | |
175 | | - | |
| 248 | + | |
176 | 249 | | |
177 | 250 | | |
178 | | - | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
179 | 254 | | |
180 | 255 | | |
181 | 256 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
| 153 | + | |
| 154 | + | |
| 155 | + | |
| 156 | + | |
| 157 | + | |
| 158 | + | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
| 167 | + | |
| 168 | + | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
| 191 | + | |
| 192 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
224 | 224 | | |
225 | 225 | | |
226 | 226 | | |
227 | | - | |
228 | | - | |
229 | | - | |
230 | | - | |
| 227 | + | |
231 | 228 | | |
232 | 229 | | |
233 | 230 | | |
234 | 231 | | |
235 | 232 | | |
236 | | - | |
| 233 | + | |
237 | 234 | | |
238 | 235 | | |
239 | 236 | | |
240 | 237 | | |
241 | | - | |
242 | | - | |
243 | | - | |
244 | 238 | | |
245 | 239 | | |
246 | 240 | | |
| |||
0 commit comments