-
Notifications
You must be signed in to change notification settings - Fork 700
[CP] Refactor Context Parallel to use new PyTorch CP APIs #2144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * #2145 * #2144 * __->__ #2143 1. Accept one "." (meaning the current commit) case to simplify the command line. 2. Ignore the untracked files.
…PyTorch CP APIs" **Summary** 1. Refactored CP Dispatching: - New apply_cp() function uses PyTorch's _ContextParallel parallelization plan to dispatch attention call. - Enables CP dispatcher for SDPA attention type inside apply_cp() 2. New CP Data Sharding Approach: - Added a cp_shard() helper function that wraps PyTorch's _context_parallel_shard API - Uses _HeadTailLoadBalancer for SDPA attention load balancing - FlexAttention CP support deferred to a future PR - CP sharding now happens explicitly in post_dataloading_process() where inputs, labels, and positions are sharded - The new positions argument allows us to not shard the freqs_cis. Note that this PR require pytorch/pytorch#170200 **Test** ``` -> % python3 scripts/loss_compare.py . chienchin/loss_compare --baseline-options="--parallelism.context_parallel_degree=8" --test-options="--parallelism.context_parallel_degree=8" --steps=100 --assert-equal pick 5903566a Improve the loss_compare.sh logic [LOSS_COMPARE] [LOSS_COMPARE] Asserting losses are equal... [LOSS_COMPARE] Baseline log: /tmp/baseline_training.log [LOSS_COMPARE] Test log: /tmp/test_training.log [LOSS_COMPARE] Extracted 100 steps from baseline log [LOSS_COMPARE] Extracted 100 steps from test log test_losses_equal (__main__.assert_losses_equal.<locals>.LossEqualityTest.test_losses_equal) ... ok ---------------------------------------------------------------------- Ran 1 test in 0.000s OK [LOSS_COMPARE] All losses are equal. Assertion passed! [LOSS_COMPARE] ========================================== [LOSS_COMPARE] LOSS COMPARISON ANALYSIS [LOSS_COMPARE] ========================================== [LOSS_COMPARE] Step-by-step loss comparison: [LOSS_COMPARE] Step Baseline Loss Test Loss Difference [LOSS_COMPARE] ---- ------------- --------- ---------- [LOSS_COMPARE] 1 8.1309 8.1309 0.000000 [LOSS_COMPARE] 2 7.8268 7.8268 0.000000 [LOSS_COMPARE] 3 7.2284 7.2284 0.000000 [LOSS_COMPARE] 4 6.4669 6.4669 0.000000 [LOSS_COMPARE] 5 5.4017 5.4017 0.000000 [LOSS_COMPARE] 6 4.7656 4.7656 0.000000 [LOSS_COMPARE] 7 4.3587 4.3587 0.000000 [LOSS_COMPARE] 8 4.0938 4.0938 0.000000 [LOSS_COMPARE] 9 4.4019 4.4019 0.000000 [LOSS_COMPARE] 10 3.7451 3.7451 0.000000 .... [LOSS_COMPARE] 90 2.802 2.802 0.000000 [LOSS_COMPARE] 91 2.7207 2.7207 0.000000 [LOSS_COMPARE] 92 2.7454 2.7454 0.000000 [LOSS_COMPARE] 93 2.6992 2.6992 0.000000 [LOSS_COMPARE] 94 2.743 2.743 0.000000 [LOSS_COMPARE] 95 2.7534 2.7534 0.000000 [LOSS_COMPARE] 96 2.8403 2.8403 0.000000 [LOSS_COMPARE] 97 2.783 2.783 0.000000 [LOSS_COMPARE] 98 3.0892 3.0892 0.000000 [LOSS_COMPARE] 99 2.7905 2.7905 0.000000 [LOSS_COMPARE] 100 2.733 2.733 0.000000 [LOSS_COMPARE] [LOSS_COMPARE] Summary statistics: [LOSS_COMPARE] Average baseline loss: 3.1414940000000002 [LOSS_COMPARE] Average test loss: 3.1414940000000002 [LOSS_COMPARE] Average difference: 0.000000 [LOSS_COMPARE] [LOSS_COMPARE] Loss comparison complete. No results saved (no output folder specified). ``` **TODO** - This PR will invalidate torch.compile + CP due to pytorch/pytorch#170110. We will have to wait for Dynamo to fix the issue or refactor nn.Module core logic to avoid check hook_id. [ghstack-poisoned]
**Summary** 1. Refactored CP Dispatching: - New apply_cp() function uses PyTorch's _ContextParallel parallelization plan to dispatch attention call. - Enables CP dispatcher for SDPA attention type inside apply_cp() 2. New CP Data Sharding Approach: - Added a cp_shard() helper function that wraps PyTorch's _context_parallel_shard API - Uses _HeadTailLoadBalancer for SDPA attention load balancing - FlexAttention CP support deferred to a future PR - CP sharding now happens explicitly in post_dataloading_process() where inputs, labels, and positions are sharded - The new positions argument allows us to not shard the freqs_cis. Note that this PR require pytorch/pytorch#170200 **Test** ``` -> % python3 scripts/loss_compare.py . chienchin/loss_compare --baseline-options="--parallelism.context_parallel_degree=8" --test-options="--parallelism.context_parallel_degree=8" --steps=100 --assert-equal pick 5903566a Improve the loss_compare.sh logic [LOSS_COMPARE] [LOSS_COMPARE] Asserting losses are equal... [LOSS_COMPARE] Baseline log: /tmp/baseline_training.log [LOSS_COMPARE] Test log: /tmp/test_training.log [LOSS_COMPARE] Extracted 100 steps from baseline log [LOSS_COMPARE] Extracted 100 steps from test log test_losses_equal (__main__.assert_losses_equal.<locals>.LossEqualityTest.test_losses_equal) ... ok ---------------------------------------------------------------------- Ran 1 test in 0.000s OK [LOSS_COMPARE] All losses are equal. Assertion passed! [LOSS_COMPARE] ========================================== [LOSS_COMPARE] LOSS COMPARISON ANALYSIS [LOSS_COMPARE] ========================================== [LOSS_COMPARE] Step-by-step loss comparison: [LOSS_COMPARE] Step Baseline Loss Test Loss Difference [LOSS_COMPARE] ---- ------------- --------- ---------- [LOSS_COMPARE] 1 8.1309 8.1309 0.000000 [LOSS_COMPARE] 2 7.8268 7.8268 0.000000 [LOSS_COMPARE] 3 7.2284 7.2284 0.000000 [LOSS_COMPARE] 4 6.4669 6.4669 0.000000 [LOSS_COMPARE] 5 5.4017 5.4017 0.000000 [LOSS_COMPARE] 6 4.7656 4.7656 0.000000 [LOSS_COMPARE] 7 4.3587 4.3587 0.000000 [LOSS_COMPARE] 8 4.0938 4.0938 0.000000 [LOSS_COMPARE] 9 4.4019 4.4019 0.000000 [LOSS_COMPARE] 10 3.7451 3.7451 0.000000 .... [LOSS_COMPARE] 90 2.802 2.802 0.000000 [LOSS_COMPARE] 91 2.7207 2.7207 0.000000 [LOSS_COMPARE] 92 2.7454 2.7454 0.000000 [LOSS_COMPARE] 93 2.6992 2.6992 0.000000 [LOSS_COMPARE] 94 2.743 2.743 0.000000 [LOSS_COMPARE] 95 2.7534 2.7534 0.000000 [LOSS_COMPARE] 96 2.8403 2.8403 0.000000 [LOSS_COMPARE] 97 2.783 2.783 0.000000 [LOSS_COMPARE] 98 3.0892 3.0892 0.000000 [LOSS_COMPARE] 99 2.7905 2.7905 0.000000 [LOSS_COMPARE] 100 2.733 2.733 0.000000 [LOSS_COMPARE] [LOSS_COMPARE] Summary statistics: [LOSS_COMPARE] Average baseline loss: 3.1414940000000002 [LOSS_COMPARE] Average test loss: 3.1414940000000002 [LOSS_COMPARE] Average difference: 0.000000 [LOSS_COMPARE] [LOSS_COMPARE] Loss comparison complete. No results saved (no output folder specified). ``` **TODO** - This PR will invalidate torch.compile + CP due to pytorch/pytorch#170110. We will have to wait for Dynamo to fix the issue or refactor nn.Module core logic to avoid check hook_id. [ghstack-poisoned]
…PyTorch CP APIs" **Summary** 1. Refactored CP Dispatching: - New apply_cp() function uses PyTorch's _ContextParallel parallelization plan to dispatch attention call. - Enables CP dispatcher for SDPA attention type inside apply_cp() 2. New CP Data Sharding Approach: - Added a cp_shard() helper function that wraps PyTorch's _context_parallel_shard API - Uses _HeadTailLoadBalancer for SDPA attention load balancing - FlexAttention CP support deferred to a future PR - CP sharding now happens explicitly in post_dataloading_process() where inputs, labels, and positions are sharded - The new positions argument allows us to not shard the freqs_cis. Note that this PR require pytorch/pytorch#170200 **Test** ``` -> % python3 scripts/loss_compare.py . chienchin/loss_compare --baseline-options="--parallelism.context_parallel_degree=8" --test-options="--parallelism.context_parallel_degree=8" --steps=100 --assert-equal pick 5903566a Improve the loss_compare.sh logic [LOSS_COMPARE] [LOSS_COMPARE] Asserting losses are equal... [LOSS_COMPARE] Baseline log: /tmp/baseline_training.log [LOSS_COMPARE] Test log: /tmp/test_training.log [LOSS_COMPARE] Extracted 100 steps from baseline log [LOSS_COMPARE] Extracted 100 steps from test log test_losses_equal (__main__.assert_losses_equal.<locals>.LossEqualityTest.test_losses_equal) ... ok ---------------------------------------------------------------------- Ran 1 test in 0.000s OK [LOSS_COMPARE] All losses are equal. Assertion passed! [LOSS_COMPARE] ========================================== [LOSS_COMPARE] LOSS COMPARISON ANALYSIS [LOSS_COMPARE] ========================================== [LOSS_COMPARE] Step-by-step loss comparison: [LOSS_COMPARE] Step Baseline Loss Test Loss Difference [LOSS_COMPARE] ---- ------------- --------- ---------- [LOSS_COMPARE] 1 8.1309 8.1309 0.000000 [LOSS_COMPARE] 2 7.8268 7.8268 0.000000 [LOSS_COMPARE] 3 7.2284 7.2284 0.000000 [LOSS_COMPARE] 4 6.4669 6.4669 0.000000 [LOSS_COMPARE] 5 5.4017 5.4017 0.000000 [LOSS_COMPARE] 6 4.7656 4.7656 0.000000 [LOSS_COMPARE] 7 4.3587 4.3587 0.000000 [LOSS_COMPARE] 8 4.0938 4.0938 0.000000 [LOSS_COMPARE] 9 4.4019 4.4019 0.000000 [LOSS_COMPARE] 10 3.7451 3.7451 0.000000 .... [LOSS_COMPARE] 90 2.802 2.802 0.000000 [LOSS_COMPARE] 91 2.7207 2.7207 0.000000 [LOSS_COMPARE] 92 2.7454 2.7454 0.000000 [LOSS_COMPARE] 93 2.6992 2.6992 0.000000 [LOSS_COMPARE] 94 2.743 2.743 0.000000 [LOSS_COMPARE] 95 2.7534 2.7534 0.000000 [LOSS_COMPARE] 96 2.8403 2.8403 0.000000 [LOSS_COMPARE] 97 2.783 2.783 0.000000 [LOSS_COMPARE] 98 3.0892 3.0892 0.000000 [LOSS_COMPARE] 99 2.7905 2.7905 0.000000 [LOSS_COMPARE] 100 2.733 2.733 0.000000 [LOSS_COMPARE] [LOSS_COMPARE] Summary statistics: [LOSS_COMPARE] Average baseline loss: 3.1414940000000002 [LOSS_COMPARE] Average test loss: 3.1414940000000002 [LOSS_COMPARE] Average difference: 0.000000 [LOSS_COMPARE] [LOSS_COMPARE] Loss comparison complete. No results saved (no output folder specified). ``` **TODO** - This PR will invalidate torch.compile + CP due to pytorch/pytorch#170110. We will have to wait for Dynamo to fix the issue or refactor nn.Module core logic to avoid check hook_id. [ghstack-poisoned]
**Summary** 1. Refactored CP Dispatching: - New apply_cp() function uses PyTorch's _ContextParallel parallelization plan to dispatch attention call. - Enables CP dispatcher for SDPA attention type inside apply_cp() 2. New CP Data Sharding Approach: - Added a cp_shard() helper function that wraps PyTorch's _context_parallel_shard API - Uses _HeadTailLoadBalancer for SDPA attention load balancing - FlexAttention CP support deferred to a future PR - CP sharding now happens explicitly in post_dataloading_process() where inputs, labels, and positions are sharded - The new positions argument allows us to not shard the freqs_cis. Note that this PR require pytorch/pytorch#170200 **Test** ``` -> % python3 scripts/loss_compare.py . chienchin/loss_compare --baseline-options="--parallelism.context_parallel_degree=8" --test-options="--parallelism.context_parallel_degree=8" --steps=100 --assert-equal pick 5903566a Improve the loss_compare.sh logic [LOSS_COMPARE] [LOSS_COMPARE] Asserting losses are equal... [LOSS_COMPARE] Baseline log: /tmp/baseline_training.log [LOSS_COMPARE] Test log: /tmp/test_training.log [LOSS_COMPARE] Extracted 100 steps from baseline log [LOSS_COMPARE] Extracted 100 steps from test log test_losses_equal (__main__.assert_losses_equal.<locals>.LossEqualityTest.test_losses_equal) ... ok ---------------------------------------------------------------------- Ran 1 test in 0.000s OK [LOSS_COMPARE] All losses are equal. Assertion passed! [LOSS_COMPARE] ========================================== [LOSS_COMPARE] LOSS COMPARISON ANALYSIS [LOSS_COMPARE] ========================================== [LOSS_COMPARE] Step-by-step loss comparison: [LOSS_COMPARE] Step Baseline Loss Test Loss Difference [LOSS_COMPARE] ---- ------------- --------- ---------- [LOSS_COMPARE] 1 8.1309 8.1309 0.000000 [LOSS_COMPARE] 2 7.8268 7.8268 0.000000 [LOSS_COMPARE] 3 7.2284 7.2284 0.000000 [LOSS_COMPARE] 4 6.4669 6.4669 0.000000 [LOSS_COMPARE] 5 5.4017 5.4017 0.000000 [LOSS_COMPARE] 6 4.7656 4.7656 0.000000 [LOSS_COMPARE] 7 4.3587 4.3587 0.000000 [LOSS_COMPARE] 8 4.0938 4.0938 0.000000 [LOSS_COMPARE] 9 4.4019 4.4019 0.000000 [LOSS_COMPARE] 10 3.7451 3.7451 0.000000 .... [LOSS_COMPARE] 90 2.802 2.802 0.000000 [LOSS_COMPARE] 91 2.7207 2.7207 0.000000 [LOSS_COMPARE] 92 2.7454 2.7454 0.000000 [LOSS_COMPARE] 93 2.6992 2.6992 0.000000 [LOSS_COMPARE] 94 2.743 2.743 0.000000 [LOSS_COMPARE] 95 2.7534 2.7534 0.000000 [LOSS_COMPARE] 96 2.8403 2.8403 0.000000 [LOSS_COMPARE] 97 2.783 2.783 0.000000 [LOSS_COMPARE] 98 3.0892 3.0892 0.000000 [LOSS_COMPARE] 99 2.7905 2.7905 0.000000 [LOSS_COMPARE] 100 2.733 2.733 0.000000 [LOSS_COMPARE] [LOSS_COMPARE] Summary statistics: [LOSS_COMPARE] Average baseline loss: 3.1414940000000002 [LOSS_COMPARE] Average test loss: 3.1414940000000002 [LOSS_COMPARE] Average difference: 0.000000 [LOSS_COMPARE] [LOSS_COMPARE] Loss comparison complete. No results saved (no output folder specified). ``` **TODO** - This PR will invalidate torch.compile + CP due to pytorch/pytorch#170110. We will have to wait for Dynamo to fix the issue or refactor nn.Module core logic to avoid check hook_id. [ghstack-poisoned]
| def post_dataloading_process( | ||
| self, | ||
| input_dict: dict[str, torch.Tensor], | ||
| labels: torch.Tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel we should consolidate this into input_dict["labels"]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Our dataloader signature return two variables input_dict and labels. So unless we consolidate the dataloader type as well, I think it is more consistent to keep this API signature the same as other APIs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, I think we should also consolidate them in data loader, but maybe in a future PR.
| A tuple of (inputs, labels, extra_inputs, extra_kwargs) where: | ||
| - inputs: Main input tensor extracted from input_dict["input"]. | ||
| - labels: Target labels (potentially modified by CP sharding). | ||
| - extra_inputs: Dict of auxiliary input tensors (all keys except |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we can consolidate this into input_dict as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remember extra_inputs are being a seperate dict because PP doesn't forward them across stages. I admit it is quite annoying to have separate inputs. However, it is more clear to have two separate dicts, because one will be forwarded while another one will only consumed by the first stage. We should explore how to make PP input signature more user friendly. cc., @H-Huang
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we already have two extra_inputs and extra_kwargs, so we should be able to merge inputs at least to one of them?
doesn't need to be updated in this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PP treats extra_inputs and extra_kwargs differently. extra_inputs will not be forwarded to next stages while extra_kwargs will be forwarded to next stages.
We should aim to merge "inputs" and "extra_inputs" if possible.
torchtitan/components/validate.py
Outdated
| # extra_kwargs are. | ||
| extra_kwargs: dict[str, Any] = {} | ||
|
|
||
| attn_type = getattr(self.model_args, "attn_type", "sdpa") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems we introduce model_args to validator because of this line.
This condition is checked again in get_attention_masks. Can we remove it here? We can return None instead of throw in get_attention_masks when attn_type is sdpa.
The argument is that validator is not supposed to know ModelArgs details.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I use try/except to remove the need of model_args. When get_attention_masks. we intended to keep the return type simple. That's why exception is raised.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
try/catch doesn't sound right here -- e.g. if model_args.attn_type is flex but this method throws NotImplemented, then there's no way the computation can be correct.
If the return type is meant to be optional, it's type should be optional. In the long term we can try to think of a better way to encapsulate the logic between flex and sdpa.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right that the current code cannot differentiate Flex without get_attention_mask implementation. But imo, it is not an issue of try/except but we should throw a correct exception, which is TypeError (according to Python docs, TypeError is used for This exception may be raised by user code to indicate that an attempted operation on an object is not supported).
In any case, it depends on the definition of get_attention_masks. If it is meant to be called only by flex attention or varlen attention, then its return type is not optional and TypeError is a valid exception in such a case. On the other hand, if it is meant to be called even when the attention type is SDPA, then None makes sense.
Both implementations are legit IMO. But since this function is more like an adapter function, making return type of get_attention_masks optional seems to leak less information to the trainer, atm. If that is a more clear semantics, I can change get_attention_masks to return an optional type.
[ghstack-poisoned]
| def post_dataloading_process( | ||
| self, | ||
| input_dict: dict[str, torch.Tensor], | ||
| labels: torch.Tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, I think we should also consolidate them in data loader, but maybe in a future PR.
| # NOTE: when the fourth argument (positions) is not None, its input layout | ||
| # and desired input layout should be Replicate() | ||
| # and desired input layout is still None as we don't convert freqs_cis to | ||
| # a DTensor for llama3. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you remind me why the treatment for some model (e.g. llama3) is different from others (e.g. dsv3)? is it because some of freqs_cis are dtensors and others are not? shall we make them consistent all as DTensors? maybe leave a TODO for @wwwjn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it because some of freqs_cis are dtensors and others are not? shall we make them consistent all as DTensors?
#2149 will fix the inconsistency. I left a TODO here.
| cp_world_size = cp_mesh.size(0) | ||
| if attention_masks is not None: | ||
| raise ValueError( | ||
| "FlexAttention CP is not supported yet. Will come in the next PR." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on the loss curves I observed, we might want to make HeadTail load balancer also available for Flex.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll make it configurable in the FlexCP PR. That's the PR that PTRR is introduced.
torchtitan/components/validate.py
Outdated
| # extra_kwargs are. | ||
| extra_kwargs: dict[str, Any] = {} | ||
|
|
||
| attn_type = getattr(self.model_args, "attn_type", "sdpa") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
try/catch doesn't sound right here -- e.g. if model_args.attn_type is flex but this method throws NotImplemented, then there's no way the computation can be correct.
If the return type is meant to be optional, it's type should be optional. In the long term we can try to think of a better way to encapsulate the logic between flex and sdpa.
[ghstack-poisoned]
|
I changed the exception to TypeError anyway since NotImplementedError is not right for SDPA as mentioned in the above reply. But feel free to request the change to get_attention_masks if people think that API should return an optional mask. cc., @tianyu-l |
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, some nit comments
| A tuple of (inputs, labels, extra_inputs, extra_kwargs) where: | ||
| - inputs: Main input tensor extracted from input_dict["input"]. | ||
| - labels: Target labels (potentially modified by CP sharding). | ||
| - extra_inputs: Dict of auxiliary input tensors (all keys except |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we already have two extra_inputs and extra_kwargs, so we should be able to merge inputs at least to one of them?
doesn't need to be updated in this PR.
| # extra_kwargs are. | ||
| extra_kwargs: dict[str, Any] = {} | ||
|
|
||
| try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: here we have access to self.model_args so don't need to try-catch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prefer to keep the consistent style for the same logic.
|
|
||
| inputs = cast( | ||
| tuple[torch.Tensor, ...], | ||
| _context_parallel_shard( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Last time I mentioned the idea of doing reverse permute after all-gather KV so that with vs. without load balancing, forward numerics are bitwise identical. That would require change on _create_cp_block_mask.
What's your latest thought? I guess it's not absolutely necessary for torchtitan CP right now, but maybe good to keep in mind.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only concern is that are we going to make the reversing option as the only option or do we want to make it as configurable? It can be error-prone to make it as configurable as there may be more than one place that is order-aware (BlockMask creation and allgather) -- users need to ensure the consistent option. Or we can make it reverse anyway, but the performance indication is not clear to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should make it the only option, as long as perf is OK. This needs to be tested. Not urgent though.
|
Landing the PR, pytorch/pytorch#172519 is the candidate fix to the compile issue. |
|
Only note on pytorch/pytorch#172519 is that it will require patching |
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.14.0) (oldest at bottom): * #2145 * __->__ #2144 **Summary** 1. Refactored CP Dispatching: - New apply_cp() function uses PyTorch's _ContextParallel parallelization plan to dispatch attention call. - Enables CP dispatcher for SDPA attention type inside apply_cp() 2. New CP Data Sharding Approach: - Added a cp_shard() helper function that wraps PyTorch's _context_parallel_shard API - Uses _HeadTailLoadBalancer for SDPA attention load balancing - FlexAttention CP support deferred to a future PR - CP sharding now happens explicitly in post_dataloading_process() where inputs, labels, and positions are sharded - The new positions argument allows us to not shard the freqs_cis. Note that this PR require pytorch/pytorch#170200 **Test** ``` -> % python3 scripts/loss_compare.py . chienchin/loss_compare --baseline-options="--parallelism.context_parallel_degree=8" --test-options="--parallelism.context_parallel_degree=8" --steps=100 --assert-equal pick 5903566a Improve the loss_compare.sh logic [LOSS_COMPARE] [LOSS_COMPARE] Asserting losses are equal... [LOSS_COMPARE] Baseline log: /tmp/baseline_training.log [LOSS_COMPARE] Test log: /tmp/test_training.log [LOSS_COMPARE] Extracted 100 steps from baseline log [LOSS_COMPARE] Extracted 100 steps from test log test_losses_equal (__main__.assert_losses_equal.<locals>.LossEqualityTest.test_losses_equal) ... ok ---------------------------------------------------------------------- Ran 1 test in 0.000s OK [LOSS_COMPARE] All losses are equal. Assertion passed! [LOSS_COMPARE] ========================================== [LOSS_COMPARE] LOSS COMPARISON ANALYSIS [LOSS_COMPARE] ========================================== [LOSS_COMPARE] Step-by-step loss comparison: [LOSS_COMPARE] Step Baseline Loss Test Loss Difference [LOSS_COMPARE] ---- ------------- --------- ---------- [LOSS_COMPARE] 1 8.1309 8.1309 0.000000 [LOSS_COMPARE] 2 7.8268 7.8268 0.000000 [LOSS_COMPARE] 3 7.2284 7.2284 0.000000 [LOSS_COMPARE] 4 6.4669 6.4669 0.000000 [LOSS_COMPARE] 5 5.4017 5.4017 0.000000 [LOSS_COMPARE] 6 4.7656 4.7656 0.000000 [LOSS_COMPARE] 7 4.3587 4.3587 0.000000 [LOSS_COMPARE] 8 4.0938 4.0938 0.000000 [LOSS_COMPARE] 9 4.4019 4.4019 0.000000 [LOSS_COMPARE] 10 3.7451 3.7451 0.000000 .... [LOSS_COMPARE] 90 2.802 2.802 0.000000 [LOSS_COMPARE] 91 2.7207 2.7207 0.000000 [LOSS_COMPARE] 92 2.7454 2.7454 0.000000 [LOSS_COMPARE] 93 2.6992 2.6992 0.000000 [LOSS_COMPARE] 94 2.743 2.743 0.000000 [LOSS_COMPARE] 95 2.7534 2.7534 0.000000 [LOSS_COMPARE] 96 2.8403 2.8403 0.000000 [LOSS_COMPARE] 97 2.783 2.783 0.000000 [LOSS_COMPARE] 98 3.0892 3.0892 0.000000 [LOSS_COMPARE] 99 2.7905 2.7905 0.000000 [LOSS_COMPARE] 100 2.733 2.733 0.000000 [LOSS_COMPARE] [LOSS_COMPARE] Summary statistics: [LOSS_COMPARE] Average baseline loss: 3.1414940000000002 [LOSS_COMPARE] Average test loss: 3.1414940000000002 [LOSS_COMPARE] Average difference: 0.000000 [LOSS_COMPARE] [LOSS_COMPARE] Loss comparison complete. No results saved (no output folder specified). ``` **TODO** - This PR will invalidate torch.compile + CP due to pytorch/pytorch#170110. We will have to wait for Dynamo to fix the issue or refactor nn.Module core logic to avoid check hook_id.
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.14.0) (oldest at bottom): * #2145 * __->__ #2144 **Summary** 1. Refactored CP Dispatching: - New apply_cp() function uses PyTorch's _ContextParallel parallelization plan to dispatch attention call. - Enables CP dispatcher for SDPA attention type inside apply_cp() 2. New CP Data Sharding Approach: - Added a cp_shard() helper function that wraps PyTorch's _context_parallel_shard API - Uses _HeadTailLoadBalancer for SDPA attention load balancing - FlexAttention CP support deferred to a future PR - CP sharding now happens explicitly in post_dataloading_process() where inputs, labels, and positions are sharded - The new positions argument allows us to not shard the freqs_cis. Note that this PR require pytorch/pytorch#170200 **Test** ``` -> % python3 scripts/loss_compare.py . chienchin/loss_compare --baseline-options="--parallelism.context_parallel_degree=8" --test-options="--parallelism.context_parallel_degree=8" --steps=100 --assert-equal pick 5903566a Improve the loss_compare.sh logic [LOSS_COMPARE] [LOSS_COMPARE] Asserting losses are equal... [LOSS_COMPARE] Baseline log: /tmp/baseline_training.log [LOSS_COMPARE] Test log: /tmp/test_training.log [LOSS_COMPARE] Extracted 100 steps from baseline log [LOSS_COMPARE] Extracted 100 steps from test log test_losses_equal (__main__.assert_losses_equal.<locals>.LossEqualityTest.test_losses_equal) ... ok ---------------------------------------------------------------------- Ran 1 test in 0.000s OK [LOSS_COMPARE] All losses are equal. Assertion passed! [LOSS_COMPARE] ========================================== [LOSS_COMPARE] LOSS COMPARISON ANALYSIS [LOSS_COMPARE] ========================================== [LOSS_COMPARE] Step-by-step loss comparison: [LOSS_COMPARE] Step Baseline Loss Test Loss Difference [LOSS_COMPARE] ---- ------------- --------- ---------- [LOSS_COMPARE] 1 8.1309 8.1309 0.000000 [LOSS_COMPARE] 2 7.8268 7.8268 0.000000 [LOSS_COMPARE] 3 7.2284 7.2284 0.000000 [LOSS_COMPARE] 4 6.4669 6.4669 0.000000 [LOSS_COMPARE] 5 5.4017 5.4017 0.000000 [LOSS_COMPARE] 6 4.7656 4.7656 0.000000 [LOSS_COMPARE] 7 4.3587 4.3587 0.000000 [LOSS_COMPARE] 8 4.0938 4.0938 0.000000 [LOSS_COMPARE] 9 4.4019 4.4019 0.000000 [LOSS_COMPARE] 10 3.7451 3.7451 0.000000 .... [LOSS_COMPARE] 90 2.802 2.802 0.000000 [LOSS_COMPARE] 91 2.7207 2.7207 0.000000 [LOSS_COMPARE] 92 2.7454 2.7454 0.000000 [LOSS_COMPARE] 93 2.6992 2.6992 0.000000 [LOSS_COMPARE] 94 2.743 2.743 0.000000 [LOSS_COMPARE] 95 2.7534 2.7534 0.000000 [LOSS_COMPARE] 96 2.8403 2.8403 0.000000 [LOSS_COMPARE] 97 2.783 2.783 0.000000 [LOSS_COMPARE] 98 3.0892 3.0892 0.000000 [LOSS_COMPARE] 99 2.7905 2.7905 0.000000 [LOSS_COMPARE] 100 2.733 2.733 0.000000 [LOSS_COMPARE] [LOSS_COMPARE] Summary statistics: [LOSS_COMPARE] Average baseline loss: 3.1414940000000002 [LOSS_COMPARE] Average test loss: 3.1414940000000002 [LOSS_COMPARE] Average difference: 0.000000 [LOSS_COMPARE] [LOSS_COMPARE] Loss comparison complete. No results saved (no output folder specified). ``` **TODO** - This PR will invalidate torch.compile + CP due to pytorch/pytorch#170110. We will have to wait for Dynamo to fix the issue or refactor nn.Module core logic to avoid check hook_id.
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.14.0) (oldest at bottom): * pytorch#2145 * __->__ pytorch#2144 **Summary** 1. Refactored CP Dispatching: - New apply_cp() function uses PyTorch's _ContextParallel parallelization plan to dispatch attention call. - Enables CP dispatcher for SDPA attention type inside apply_cp() 2. New CP Data Sharding Approach: - Added a cp_shard() helper function that wraps PyTorch's _context_parallel_shard API - Uses _HeadTailLoadBalancer for SDPA attention load balancing - FlexAttention CP support deferred to a future PR - CP sharding now happens explicitly in post_dataloading_process() where inputs, labels, and positions are sharded - The new positions argument allows us to not shard the freqs_cis. Note that this PR require pytorch/pytorch#170200 **Test** ``` -> % python3 scripts/loss_compare.py . chienchin/loss_compare --baseline-options="--parallelism.context_parallel_degree=8" --test-options="--parallelism.context_parallel_degree=8" --steps=100 --assert-equal pick 5903566a Improve the loss_compare.sh logic [LOSS_COMPARE] [LOSS_COMPARE] Asserting losses are equal... [LOSS_COMPARE] Baseline log: /tmp/baseline_training.log [LOSS_COMPARE] Test log: /tmp/test_training.log [LOSS_COMPARE] Extracted 100 steps from baseline log [LOSS_COMPARE] Extracted 100 steps from test log test_losses_equal (__main__.assert_losses_equal.<locals>.LossEqualityTest.test_losses_equal) ... ok ---------------------------------------------------------------------- Ran 1 test in 0.000s OK [LOSS_COMPARE] All losses are equal. Assertion passed! [LOSS_COMPARE] ========================================== [LOSS_COMPARE] LOSS COMPARISON ANALYSIS [LOSS_COMPARE] ========================================== [LOSS_COMPARE] Step-by-step loss comparison: [LOSS_COMPARE] Step Baseline Loss Test Loss Difference [LOSS_COMPARE] ---- ------------- --------- ---------- [LOSS_COMPARE] 1 8.1309 8.1309 0.000000 [LOSS_COMPARE] 2 7.8268 7.8268 0.000000 [LOSS_COMPARE] 3 7.2284 7.2284 0.000000 [LOSS_COMPARE] 4 6.4669 6.4669 0.000000 [LOSS_COMPARE] 5 5.4017 5.4017 0.000000 [LOSS_COMPARE] 6 4.7656 4.7656 0.000000 [LOSS_COMPARE] 7 4.3587 4.3587 0.000000 [LOSS_COMPARE] 8 4.0938 4.0938 0.000000 [LOSS_COMPARE] 9 4.4019 4.4019 0.000000 [LOSS_COMPARE] 10 3.7451 3.7451 0.000000 .... [LOSS_COMPARE] 90 2.802 2.802 0.000000 [LOSS_COMPARE] 91 2.7207 2.7207 0.000000 [LOSS_COMPARE] 92 2.7454 2.7454 0.000000 [LOSS_COMPARE] 93 2.6992 2.6992 0.000000 [LOSS_COMPARE] 94 2.743 2.743 0.000000 [LOSS_COMPARE] 95 2.7534 2.7534 0.000000 [LOSS_COMPARE] 96 2.8403 2.8403 0.000000 [LOSS_COMPARE] 97 2.783 2.783 0.000000 [LOSS_COMPARE] 98 3.0892 3.0892 0.000000 [LOSS_COMPARE] 99 2.7905 2.7905 0.000000 [LOSS_COMPARE] 100 2.733 2.733 0.000000 [LOSS_COMPARE] [LOSS_COMPARE] Summary statistics: [LOSS_COMPARE] Average baseline loss: 3.1414940000000002 [LOSS_COMPARE] Average test loss: 3.1414940000000002 [LOSS_COMPARE] Average difference: 0.000000 [LOSS_COMPARE] [LOSS_COMPARE] Loss comparison complete. No results saved (no output folder specified). ``` **TODO** - This PR will invalidate torch.compile + CP due to pytorch/pytorch#170110. We will have to wait for Dynamo to fix the issue or refactor nn.Module core logic to avoid check hook_id.
Stack from ghstack (oldest at bottom):
Summary
Note that this PR require pytorch/pytorch#170200
Test
TODO