Skip to content

Conversation

@fegin
Copy link
Contributor

@fegin fegin commented Dec 11, 2025

Stack from ghstack (oldest at bottom):

Summary

  1. Refactored CP Dispatching:
  • New apply_cp() function uses PyTorch's _ContextParallel parallelization plan to dispatch attention call.
  • Enables CP dispatcher for SDPA attention type inside apply_cp()
  1. New CP Data Sharding Approach:
  • Added a cp_shard() helper function that wraps PyTorch's _context_parallel_shard API
  • Uses _HeadTailLoadBalancer for SDPA attention load balancing
  • FlexAttention CP support deferred to a future PR
  • CP sharding now happens explicitly in post_dataloading_process() where inputs, labels, and positions are sharded
  • The new positions argument allows us to not shard the freqs_cis.

Note that this PR require pytorch/pytorch#170200

Test

-> % python3 scripts/loss_compare.py . chienchin/loss_compare --baseline-options="--parallelism.context_parallel_degree=8" --test-options="--parallelism.context_parallel_degree=8" --steps=100 --assert-equal

pick 5903566a Improve the loss_compare.sh logic

[LOSS_COMPARE]
[LOSS_COMPARE] Asserting losses are equal...
[LOSS_COMPARE] Baseline log: /tmp/baseline_training.log
[LOSS_COMPARE] Test log: /tmp/test_training.log
[LOSS_COMPARE] Extracted 100 steps from baseline log
[LOSS_COMPARE] Extracted 100 steps from test log
test_losses_equal
(__main__.assert_losses_equal.<locals>.LossEqualityTest.test_losses_equal)
... ok

----------------------------------------------------------------------
Ran 1 test in 0.000s

OK
[LOSS_COMPARE] All losses are equal. Assertion passed!
[LOSS_COMPARE] ==========================================
[LOSS_COMPARE] LOSS COMPARISON ANALYSIS
[LOSS_COMPARE] ==========================================

[LOSS_COMPARE] Step-by-step loss comparison:
[LOSS_COMPARE] Step    Baseline Loss    Test Loss   Difference
[LOSS_COMPARE] ----    -------------    ---------   ----------
[LOSS_COMPARE] 1       8.1309           8.1309           0.000000
[LOSS_COMPARE] 2       7.8268           7.8268           0.000000
[LOSS_COMPARE] 3       7.2284           7.2284           0.000000
[LOSS_COMPARE] 4       6.4669           6.4669           0.000000
[LOSS_COMPARE] 5       5.4017           5.4017           0.000000
[LOSS_COMPARE] 6       4.7656           4.7656           0.000000
[LOSS_COMPARE] 7       4.3587           4.3587           0.000000
[LOSS_COMPARE] 8       4.0938           4.0938           0.000000
[LOSS_COMPARE] 9       4.4019           4.4019           0.000000
[LOSS_COMPARE] 10      3.7451           3.7451           0.000000
....
[LOSS_COMPARE] 90      2.802            2.802            0.000000
[LOSS_COMPARE] 91      2.7207           2.7207           0.000000
[LOSS_COMPARE] 92      2.7454           2.7454           0.000000
[LOSS_COMPARE] 93      2.6992           2.6992           0.000000
[LOSS_COMPARE] 94      2.743            2.743            0.000000
[LOSS_COMPARE] 95      2.7534           2.7534           0.000000
[LOSS_COMPARE] 96      2.8403           2.8403           0.000000
[LOSS_COMPARE] 97      2.783            2.783            0.000000
[LOSS_COMPARE] 98      3.0892           3.0892           0.000000
[LOSS_COMPARE] 99      2.7905           2.7905           0.000000
[LOSS_COMPARE] 100     2.733            2.733            0.000000
[LOSS_COMPARE]
[LOSS_COMPARE] Summary statistics:
[LOSS_COMPARE] Average baseline loss:  3.1414940000000002
[LOSS_COMPARE] Average test loss: 3.1414940000000002
[LOSS_COMPARE] Average difference:     0.000000
[LOSS_COMPARE]
[LOSS_COMPARE] Loss comparison complete. No results saved (no output
folder specified).

TODO

[ghstack-poisoned]
[ghstack-poisoned]
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Dec 11, 2025
@fegin fegin requested a review from acisseJZhong December 11, 2025 21:09
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
fegin added a commit that referenced this pull request Dec 15, 2025
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0)
(oldest at bottom):
* #2145
* #2144
* __->__ #2143

1. Accept one "." (meaning the current commit) case to simplify the
command line.
2. Ignore the untracked files.
…PyTorch CP APIs"

**Summary**

1. Refactored CP Dispatching:
  - New apply_cp() function uses PyTorch's _ContextParallel parallelization plan to dispatch attention call.
  - Enables CP dispatcher for SDPA attention type inside apply_cp()
2. New CP Data Sharding Approach:
  - Added a cp_shard() helper function that wraps PyTorch's _context_parallel_shard API
  - Uses _HeadTailLoadBalancer for SDPA attention load balancing
  - FlexAttention CP support deferred to a future PR
  - CP sharding now happens explicitly in post_dataloading_process() where inputs, labels, and positions are sharded
  - The new positions argument allows us to not shard the freqs_cis.

Note that this PR require pytorch/pytorch#170200

**Test**
```
-> % python3 scripts/loss_compare.py . chienchin/loss_compare --baseline-options="--parallelism.context_parallel_degree=8" --test-options="--parallelism.context_parallel_degree=8" --steps=100 --assert-equal

pick 5903566a Improve the loss_compare.sh logic

[LOSS_COMPARE]
[LOSS_COMPARE] Asserting losses are equal...
[LOSS_COMPARE] Baseline log: /tmp/baseline_training.log
[LOSS_COMPARE] Test log: /tmp/test_training.log
[LOSS_COMPARE] Extracted 100 steps from baseline log
[LOSS_COMPARE] Extracted 100 steps from test log
test_losses_equal
(__main__.assert_losses_equal.<locals>.LossEqualityTest.test_losses_equal)
... ok

----------------------------------------------------------------------
Ran 1 test in 0.000s

OK
[LOSS_COMPARE] All losses are equal. Assertion passed!
[LOSS_COMPARE] ==========================================
[LOSS_COMPARE] LOSS COMPARISON ANALYSIS
[LOSS_COMPARE] ==========================================

[LOSS_COMPARE] Step-by-step loss comparison:
[LOSS_COMPARE] Step    Baseline Loss    Test Loss   Difference
[LOSS_COMPARE] ----    -------------    ---------   ----------
[LOSS_COMPARE] 1       8.1309           8.1309           0.000000
[LOSS_COMPARE] 2       7.8268           7.8268           0.000000
[LOSS_COMPARE] 3       7.2284           7.2284           0.000000
[LOSS_COMPARE] 4       6.4669           6.4669           0.000000
[LOSS_COMPARE] 5       5.4017           5.4017           0.000000
[LOSS_COMPARE] 6       4.7656           4.7656           0.000000
[LOSS_COMPARE] 7       4.3587           4.3587           0.000000
[LOSS_COMPARE] 8       4.0938           4.0938           0.000000
[LOSS_COMPARE] 9       4.4019           4.4019           0.000000
[LOSS_COMPARE] 10      3.7451           3.7451           0.000000
....
[LOSS_COMPARE] 90      2.802            2.802            0.000000
[LOSS_COMPARE] 91      2.7207           2.7207           0.000000
[LOSS_COMPARE] 92      2.7454           2.7454           0.000000
[LOSS_COMPARE] 93      2.6992           2.6992           0.000000
[LOSS_COMPARE] 94      2.743            2.743            0.000000
[LOSS_COMPARE] 95      2.7534           2.7534           0.000000
[LOSS_COMPARE] 96      2.8403           2.8403           0.000000
[LOSS_COMPARE] 97      2.783            2.783            0.000000
[LOSS_COMPARE] 98      3.0892           3.0892           0.000000
[LOSS_COMPARE] 99      2.7905           2.7905           0.000000
[LOSS_COMPARE] 100     2.733            2.733            0.000000
[LOSS_COMPARE]
[LOSS_COMPARE] Summary statistics:
[LOSS_COMPARE] Average baseline loss:  3.1414940000000002
[LOSS_COMPARE] Average test loss: 3.1414940000000002
[LOSS_COMPARE] Average difference:     0.000000
[LOSS_COMPARE]
[LOSS_COMPARE] Loss comparison complete. No results saved (no output
folder specified).
```

**TODO**
  - This PR will invalidate torch.compile + CP due to pytorch/pytorch#170110. We will have to wait for Dynamo to fix the issue or refactor nn.Module core logic to avoid check hook_id.

[ghstack-poisoned]
**Summary**

1. Refactored CP Dispatching:
  - New apply_cp() function uses PyTorch's _ContextParallel parallelization plan to dispatch attention call.
  - Enables CP dispatcher for SDPA attention type inside apply_cp()
2. New CP Data Sharding Approach:
  - Added a cp_shard() helper function that wraps PyTorch's _context_parallel_shard API
  - Uses _HeadTailLoadBalancer for SDPA attention load balancing
  - FlexAttention CP support deferred to a future PR
  - CP sharding now happens explicitly in post_dataloading_process() where inputs, labels, and positions are sharded
  - The new positions argument allows us to not shard the freqs_cis.

Note that this PR require pytorch/pytorch#170200

**Test**
```
-> % python3 scripts/loss_compare.py . chienchin/loss_compare --baseline-options="--parallelism.context_parallel_degree=8" --test-options="--parallelism.context_parallel_degree=8" --steps=100 --assert-equal

pick 5903566a Improve the loss_compare.sh logic

[LOSS_COMPARE]
[LOSS_COMPARE] Asserting losses are equal...
[LOSS_COMPARE] Baseline log: /tmp/baseline_training.log
[LOSS_COMPARE] Test log: /tmp/test_training.log
[LOSS_COMPARE] Extracted 100 steps from baseline log
[LOSS_COMPARE] Extracted 100 steps from test log
test_losses_equal
(__main__.assert_losses_equal.<locals>.LossEqualityTest.test_losses_equal)
... ok

----------------------------------------------------------------------
Ran 1 test in 0.000s

OK
[LOSS_COMPARE] All losses are equal. Assertion passed!
[LOSS_COMPARE] ==========================================
[LOSS_COMPARE] LOSS COMPARISON ANALYSIS
[LOSS_COMPARE] ==========================================

[LOSS_COMPARE] Step-by-step loss comparison:
[LOSS_COMPARE] Step    Baseline Loss    Test Loss   Difference
[LOSS_COMPARE] ----    -------------    ---------   ----------
[LOSS_COMPARE] 1       8.1309           8.1309           0.000000
[LOSS_COMPARE] 2       7.8268           7.8268           0.000000
[LOSS_COMPARE] 3       7.2284           7.2284           0.000000
[LOSS_COMPARE] 4       6.4669           6.4669           0.000000
[LOSS_COMPARE] 5       5.4017           5.4017           0.000000
[LOSS_COMPARE] 6       4.7656           4.7656           0.000000
[LOSS_COMPARE] 7       4.3587           4.3587           0.000000
[LOSS_COMPARE] 8       4.0938           4.0938           0.000000
[LOSS_COMPARE] 9       4.4019           4.4019           0.000000
[LOSS_COMPARE] 10      3.7451           3.7451           0.000000
....
[LOSS_COMPARE] 90      2.802            2.802            0.000000
[LOSS_COMPARE] 91      2.7207           2.7207           0.000000
[LOSS_COMPARE] 92      2.7454           2.7454           0.000000
[LOSS_COMPARE] 93      2.6992           2.6992           0.000000
[LOSS_COMPARE] 94      2.743            2.743            0.000000
[LOSS_COMPARE] 95      2.7534           2.7534           0.000000
[LOSS_COMPARE] 96      2.8403           2.8403           0.000000
[LOSS_COMPARE] 97      2.783            2.783            0.000000
[LOSS_COMPARE] 98      3.0892           3.0892           0.000000
[LOSS_COMPARE] 99      2.7905           2.7905           0.000000
[LOSS_COMPARE] 100     2.733            2.733            0.000000
[LOSS_COMPARE]
[LOSS_COMPARE] Summary statistics:
[LOSS_COMPARE] Average baseline loss:  3.1414940000000002
[LOSS_COMPARE] Average test loss: 3.1414940000000002
[LOSS_COMPARE] Average difference:     0.000000
[LOSS_COMPARE]
[LOSS_COMPARE] Loss comparison complete. No results saved (no output
folder specified).
```

**TODO**
  - This PR will invalidate torch.compile + CP due to pytorch/pytorch#170110. We will have to wait for Dynamo to fix the issue or refactor nn.Module core logic to avoid check hook_id.

[ghstack-poisoned]
…PyTorch CP APIs"

**Summary**

1. Refactored CP Dispatching:
  - New apply_cp() function uses PyTorch's _ContextParallel parallelization plan to dispatch attention call.
  - Enables CP dispatcher for SDPA attention type inside apply_cp()
2. New CP Data Sharding Approach:
  - Added a cp_shard() helper function that wraps PyTorch's _context_parallel_shard API
  - Uses _HeadTailLoadBalancer for SDPA attention load balancing
  - FlexAttention CP support deferred to a future PR
  - CP sharding now happens explicitly in post_dataloading_process() where inputs, labels, and positions are sharded
  - The new positions argument allows us to not shard the freqs_cis.

Note that this PR require pytorch/pytorch#170200

**Test**
```
-> % python3 scripts/loss_compare.py . chienchin/loss_compare --baseline-options="--parallelism.context_parallel_degree=8" --test-options="--parallelism.context_parallel_degree=8" --steps=100 --assert-equal

pick 5903566a Improve the loss_compare.sh logic

[LOSS_COMPARE]
[LOSS_COMPARE] Asserting losses are equal...
[LOSS_COMPARE] Baseline log: /tmp/baseline_training.log
[LOSS_COMPARE] Test log: /tmp/test_training.log
[LOSS_COMPARE] Extracted 100 steps from baseline log
[LOSS_COMPARE] Extracted 100 steps from test log
test_losses_equal
(__main__.assert_losses_equal.<locals>.LossEqualityTest.test_losses_equal)
... ok

----------------------------------------------------------------------
Ran 1 test in 0.000s

OK
[LOSS_COMPARE] All losses are equal. Assertion passed!
[LOSS_COMPARE] ==========================================
[LOSS_COMPARE] LOSS COMPARISON ANALYSIS
[LOSS_COMPARE] ==========================================

[LOSS_COMPARE] Step-by-step loss comparison:
[LOSS_COMPARE] Step    Baseline Loss    Test Loss   Difference
[LOSS_COMPARE] ----    -------------    ---------   ----------
[LOSS_COMPARE] 1       8.1309           8.1309           0.000000
[LOSS_COMPARE] 2       7.8268           7.8268           0.000000
[LOSS_COMPARE] 3       7.2284           7.2284           0.000000
[LOSS_COMPARE] 4       6.4669           6.4669           0.000000
[LOSS_COMPARE] 5       5.4017           5.4017           0.000000
[LOSS_COMPARE] 6       4.7656           4.7656           0.000000
[LOSS_COMPARE] 7       4.3587           4.3587           0.000000
[LOSS_COMPARE] 8       4.0938           4.0938           0.000000
[LOSS_COMPARE] 9       4.4019           4.4019           0.000000
[LOSS_COMPARE] 10      3.7451           3.7451           0.000000
....
[LOSS_COMPARE] 90      2.802            2.802            0.000000
[LOSS_COMPARE] 91      2.7207           2.7207           0.000000
[LOSS_COMPARE] 92      2.7454           2.7454           0.000000
[LOSS_COMPARE] 93      2.6992           2.6992           0.000000
[LOSS_COMPARE] 94      2.743            2.743            0.000000
[LOSS_COMPARE] 95      2.7534           2.7534           0.000000
[LOSS_COMPARE] 96      2.8403           2.8403           0.000000
[LOSS_COMPARE] 97      2.783            2.783            0.000000
[LOSS_COMPARE] 98      3.0892           3.0892           0.000000
[LOSS_COMPARE] 99      2.7905           2.7905           0.000000
[LOSS_COMPARE] 100     2.733            2.733            0.000000
[LOSS_COMPARE]
[LOSS_COMPARE] Summary statistics:
[LOSS_COMPARE] Average baseline loss:  3.1414940000000002
[LOSS_COMPARE] Average test loss: 3.1414940000000002
[LOSS_COMPARE] Average difference:     0.000000
[LOSS_COMPARE]
[LOSS_COMPARE] Loss comparison complete. No results saved (no output
folder specified).
```

**TODO**
  - This PR will invalidate torch.compile + CP due to pytorch/pytorch#170110. We will have to wait for Dynamo to fix the issue or refactor nn.Module core logic to avoid check hook_id.

[ghstack-poisoned]
**Summary**

1. Refactored CP Dispatching:
  - New apply_cp() function uses PyTorch's _ContextParallel parallelization plan to dispatch attention call.
  - Enables CP dispatcher for SDPA attention type inside apply_cp()
2. New CP Data Sharding Approach:
  - Added a cp_shard() helper function that wraps PyTorch's _context_parallel_shard API
  - Uses _HeadTailLoadBalancer for SDPA attention load balancing
  - FlexAttention CP support deferred to a future PR
  - CP sharding now happens explicitly in post_dataloading_process() where inputs, labels, and positions are sharded
  - The new positions argument allows us to not shard the freqs_cis.

Note that this PR require pytorch/pytorch#170200

**Test**
```
-> % python3 scripts/loss_compare.py . chienchin/loss_compare --baseline-options="--parallelism.context_parallel_degree=8" --test-options="--parallelism.context_parallel_degree=8" --steps=100 --assert-equal

pick 5903566a Improve the loss_compare.sh logic

[LOSS_COMPARE]
[LOSS_COMPARE] Asserting losses are equal...
[LOSS_COMPARE] Baseline log: /tmp/baseline_training.log
[LOSS_COMPARE] Test log: /tmp/test_training.log
[LOSS_COMPARE] Extracted 100 steps from baseline log
[LOSS_COMPARE] Extracted 100 steps from test log
test_losses_equal
(__main__.assert_losses_equal.<locals>.LossEqualityTest.test_losses_equal)
... ok

----------------------------------------------------------------------
Ran 1 test in 0.000s

OK
[LOSS_COMPARE] All losses are equal. Assertion passed!
[LOSS_COMPARE] ==========================================
[LOSS_COMPARE] LOSS COMPARISON ANALYSIS
[LOSS_COMPARE] ==========================================

[LOSS_COMPARE] Step-by-step loss comparison:
[LOSS_COMPARE] Step    Baseline Loss    Test Loss   Difference
[LOSS_COMPARE] ----    -------------    ---------   ----------
[LOSS_COMPARE] 1       8.1309           8.1309           0.000000
[LOSS_COMPARE] 2       7.8268           7.8268           0.000000
[LOSS_COMPARE] 3       7.2284           7.2284           0.000000
[LOSS_COMPARE] 4       6.4669           6.4669           0.000000
[LOSS_COMPARE] 5       5.4017           5.4017           0.000000
[LOSS_COMPARE] 6       4.7656           4.7656           0.000000
[LOSS_COMPARE] 7       4.3587           4.3587           0.000000
[LOSS_COMPARE] 8       4.0938           4.0938           0.000000
[LOSS_COMPARE] 9       4.4019           4.4019           0.000000
[LOSS_COMPARE] 10      3.7451           3.7451           0.000000
....
[LOSS_COMPARE] 90      2.802            2.802            0.000000
[LOSS_COMPARE] 91      2.7207           2.7207           0.000000
[LOSS_COMPARE] 92      2.7454           2.7454           0.000000
[LOSS_COMPARE] 93      2.6992           2.6992           0.000000
[LOSS_COMPARE] 94      2.743            2.743            0.000000
[LOSS_COMPARE] 95      2.7534           2.7534           0.000000
[LOSS_COMPARE] 96      2.8403           2.8403           0.000000
[LOSS_COMPARE] 97      2.783            2.783            0.000000
[LOSS_COMPARE] 98      3.0892           3.0892           0.000000
[LOSS_COMPARE] 99      2.7905           2.7905           0.000000
[LOSS_COMPARE] 100     2.733            2.733            0.000000
[LOSS_COMPARE]
[LOSS_COMPARE] Summary statistics:
[LOSS_COMPARE] Average baseline loss:  3.1414940000000002
[LOSS_COMPARE] Average test loss: 3.1414940000000002
[LOSS_COMPARE] Average difference:     0.000000
[LOSS_COMPARE]
[LOSS_COMPARE] Loss comparison complete. No results saved (no output
folder specified).
```

**TODO**
  - This PR will invalidate torch.compile + CP due to pytorch/pytorch#170110. We will have to wait for Dynamo to fix the issue or refactor nn.Module core logic to avoid check hook_id.

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@fegin fegin requested a review from tianyu-l December 16, 2025 07:32
def post_dataloading_process(
self,
input_dict: dict[str, torch.Tensor],
labels: torch.Tensor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel we should consolidate this into input_dict["labels"]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our dataloader signature return two variables input_dict and labels. So unless we consolidate the dataloader type as well, I think it is more consistent to keep this API signature the same as other APIs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I think we should also consolidate them in data loader, but maybe in a future PR.

A tuple of (inputs, labels, extra_inputs, extra_kwargs) where:
- inputs: Main input tensor extracted from input_dict["input"].
- labels: Target labels (potentially modified by CP sharding).
- extra_inputs: Dict of auxiliary input tensors (all keys except
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can consolidate this into input_dict as well?

Copy link
Contributor Author

@fegin fegin Dec 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember extra_inputs are being a seperate dict because PP doesn't forward them across stages. I admit it is quite annoying to have separate inputs. However, it is more clear to have two separate dicts, because one will be forwarded while another one will only consumed by the first stage. We should explore how to make PP input signature more user friendly. cc., @H-Huang

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we already have two extra_inputs and extra_kwargs, so we should be able to merge inputs at least to one of them?

doesn't need to be updated in this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PP treats extra_inputs and extra_kwargs differently. extra_inputs will not be forwarded to next stages while extra_kwargs will be forwarded to next stages.

We should aim to merge "inputs" and "extra_inputs" if possible.

# extra_kwargs are.
extra_kwargs: dict[str, Any] = {}

attn_type = getattr(self.model_args, "attn_type", "sdpa")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems we introduce model_args to validator because of this line.

This condition is checked again in get_attention_masks. Can we remove it here? We can return None instead of throw in get_attention_masks when attn_type is sdpa.

The argument is that validator is not supposed to know ModelArgs details.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I use try/except to remove the need of model_args. When get_attention_masks. we intended to keep the return type simple. That's why exception is raised.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try/catch doesn't sound right here -- e.g. if model_args.attn_type is flex but this method throws NotImplemented, then there's no way the computation can be correct.

If the return type is meant to be optional, it's type should be optional. In the long term we can try to think of a better way to encapsulate the logic between flex and sdpa.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right that the current code cannot differentiate Flex without get_attention_mask implementation. But imo, it is not an issue of try/except but we should throw a correct exception, which is TypeError (according to Python docs, TypeError is used for This exception may be raised by user code to indicate that an attempted operation on an object is not supported).

In any case, it depends on the definition of get_attention_masks. If it is meant to be called only by flex attention or varlen attention, then its return type is not optional and TypeError is a valid exception in such a case. On the other hand, if it is meant to be called even when the attention type is SDPA, then None makes sense.

Both implementations are legit IMO. But since this function is more like an adapter function, making return type of get_attention_masks optional seems to leak less information to the trainer, atm. If that is a more clear semantics, I can change get_attention_masks to return an optional type.

[ghstack-poisoned]
[ghstack-poisoned]
@fegin fegin requested review from tianyu-l and wwwjn December 30, 2025 04:32
[ghstack-poisoned]
def post_dataloading_process(
self,
input_dict: dict[str, torch.Tensor],
labels: torch.Tensor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I think we should also consolidate them in data loader, but maybe in a future PR.

# NOTE: when the fourth argument (positions) is not None, its input layout
# and desired input layout should be Replicate()
# and desired input layout is still None as we don't convert freqs_cis to
# a DTensor for llama3.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you remind me why the treatment for some model (e.g. llama3) is different from others (e.g. dsv3)? is it because some of freqs_cis are dtensors and others are not? shall we make them consistent all as DTensors? maybe leave a TODO for @wwwjn

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it because some of freqs_cis are dtensors and others are not? shall we make them consistent all as DTensors?

#2149 will fix the inconsistency. I left a TODO here.

cp_world_size = cp_mesh.size(0)
if attention_masks is not None:
raise ValueError(
"FlexAttention CP is not supported yet. Will come in the next PR."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the loss curves I observed, we might want to make HeadTail load balancer also available for Flex.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll make it configurable in the FlexCP PR. That's the PR that PTRR is introduced.

# extra_kwargs are.
extra_kwargs: dict[str, Any] = {}

attn_type = getattr(self.model_args, "attn_type", "sdpa")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try/catch doesn't sound right here -- e.g. if model_args.attn_type is flex but this method throws NotImplemented, then there's no way the computation can be correct.

If the return type is meant to be optional, it's type should be optional. In the long term we can try to think of a better way to encapsulate the logic between flex and sdpa.

fegin added 2 commits January 6, 2026 15:56
[ghstack-poisoned]
[ghstack-poisoned]
@fegin
Copy link
Contributor Author

fegin commented Jan 7, 2026

I changed the exception to TypeError anyway since NotImplementedError is not right for SDPA as mentioned in the above reply. But feel free to request the change to get_attention_masks if people think that API should return an optional mask.

cc., @tianyu-l

@fegin fegin requested a review from tianyu-l January 7, 2026 00:04
[ghstack-poisoned]
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, some nit comments

A tuple of (inputs, labels, extra_inputs, extra_kwargs) where:
- inputs: Main input tensor extracted from input_dict["input"].
- labels: Target labels (potentially modified by CP sharding).
- extra_inputs: Dict of auxiliary input tensors (all keys except
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we already have two extra_inputs and extra_kwargs, so we should be able to merge inputs at least to one of them?

doesn't need to be updated in this PR.

# extra_kwargs are.
extra_kwargs: dict[str, Any] = {}

try:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: here we have access to self.model_args so don't need to try-catch

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer to keep the consistent style for the same logic.


inputs = cast(
tuple[torch.Tensor, ...],
_context_parallel_shard(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Last time I mentioned the idea of doing reverse permute after all-gather KV so that with vs. without load balancing, forward numerics are bitwise identical. That would require change on _create_cp_block_mask.

What's your latest thought? I guess it's not absolutely necessary for torchtitan CP right now, but maybe good to keep in mind.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only concern is that are we going to make the reversing option as the only option or do we want to make it as configurable? It can be error-prone to make it as configurable as there may be more than one place that is order-aware (BlockMask creation and allgather) -- users need to ensure the consistent option. Or we can make it reverse anyway, but the performance indication is not clear to me.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should make it the only option, as long as perf is OK. This needs to be tested. Not urgent though.

@fegin
Copy link
Contributor Author

fegin commented Jan 20, 2026

Landing the PR, pytorch/pytorch#172519 is the candidate fix to the compile issue.

@Lucaskabela
Copy link
Contributor

Lucaskabela commented Jan 20, 2026

Only note on pytorch/pytorch#172519 is that it will require patching skip_nnmodule_hook_guards or inline_inbuilt_nn_modules in the dynamo config to have correctness for the hooks if the hooks differ per layer

fegin added 6 commits January 20, 2026 11:06
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@fegin fegin changed the base branch from gh/fegin/53/base to main January 21, 2026 00:51
@fegin fegin merged commit 1e8f9ac into main Jan 21, 2026
14 checks passed
wwwjn pushed a commit that referenced this pull request Jan 25, 2026
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.14.0)
(oldest at bottom):
* #2145
* __->__ #2144

**Summary**

1. Refactored CP Dispatching:
- New apply_cp() function uses PyTorch's _ContextParallel
parallelization plan to dispatch attention call.
  - Enables CP dispatcher for SDPA attention type inside apply_cp()
2. New CP Data Sharding Approach:
- Added a cp_shard() helper function that wraps PyTorch's
_context_parallel_shard API
  - Uses _HeadTailLoadBalancer for SDPA attention load balancing
  - FlexAttention CP support deferred to a future PR
- CP sharding now happens explicitly in post_dataloading_process() where
inputs, labels, and positions are sharded
  - The new positions argument allows us to not shard the freqs_cis.

Note that this PR require pytorch/pytorch#170200

**Test**
```
-> % python3 scripts/loss_compare.py . chienchin/loss_compare --baseline-options="--parallelism.context_parallel_degree=8" --test-options="--parallelism.context_parallel_degree=8" --steps=100 --assert-equal

pick 5903566a Improve the loss_compare.sh logic

[LOSS_COMPARE]
[LOSS_COMPARE] Asserting losses are equal...
[LOSS_COMPARE] Baseline log: /tmp/baseline_training.log
[LOSS_COMPARE] Test log: /tmp/test_training.log
[LOSS_COMPARE] Extracted 100 steps from baseline log
[LOSS_COMPARE] Extracted 100 steps from test log
test_losses_equal
(__main__.assert_losses_equal.<locals>.LossEqualityTest.test_losses_equal)
... ok

----------------------------------------------------------------------
Ran 1 test in 0.000s

OK
[LOSS_COMPARE] All losses are equal. Assertion passed!
[LOSS_COMPARE] ==========================================
[LOSS_COMPARE] LOSS COMPARISON ANALYSIS
[LOSS_COMPARE] ==========================================

[LOSS_COMPARE] Step-by-step loss comparison:
[LOSS_COMPARE] Step    Baseline Loss    Test Loss   Difference
[LOSS_COMPARE] ----    -------------    ---------   ----------
[LOSS_COMPARE] 1       8.1309           8.1309           0.000000
[LOSS_COMPARE] 2       7.8268           7.8268           0.000000
[LOSS_COMPARE] 3       7.2284           7.2284           0.000000
[LOSS_COMPARE] 4       6.4669           6.4669           0.000000
[LOSS_COMPARE] 5       5.4017           5.4017           0.000000
[LOSS_COMPARE] 6       4.7656           4.7656           0.000000
[LOSS_COMPARE] 7       4.3587           4.3587           0.000000
[LOSS_COMPARE] 8       4.0938           4.0938           0.000000
[LOSS_COMPARE] 9       4.4019           4.4019           0.000000
[LOSS_COMPARE] 10      3.7451           3.7451           0.000000
....
[LOSS_COMPARE] 90      2.802            2.802            0.000000
[LOSS_COMPARE] 91      2.7207           2.7207           0.000000
[LOSS_COMPARE] 92      2.7454           2.7454           0.000000
[LOSS_COMPARE] 93      2.6992           2.6992           0.000000
[LOSS_COMPARE] 94      2.743            2.743            0.000000
[LOSS_COMPARE] 95      2.7534           2.7534           0.000000
[LOSS_COMPARE] 96      2.8403           2.8403           0.000000
[LOSS_COMPARE] 97      2.783            2.783            0.000000
[LOSS_COMPARE] 98      3.0892           3.0892           0.000000
[LOSS_COMPARE] 99      2.7905           2.7905           0.000000
[LOSS_COMPARE] 100     2.733            2.733            0.000000
[LOSS_COMPARE]
[LOSS_COMPARE] Summary statistics:
[LOSS_COMPARE] Average baseline loss:  3.1414940000000002
[LOSS_COMPARE] Average test loss: 3.1414940000000002
[LOSS_COMPARE] Average difference:     0.000000
[LOSS_COMPARE]
[LOSS_COMPARE] Loss comparison complete. No results saved (no output
folder specified).
```

**TODO**
- This PR will invalidate torch.compile + CP due to
pytorch/pytorch#170110. We will have to wait
for Dynamo to fix the issue or refactor nn.Module core logic to avoid
check hook_id.
wwwjn pushed a commit that referenced this pull request Jan 25, 2026
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.14.0)
(oldest at bottom):
* #2145
* __->__ #2144

**Summary**

1. Refactored CP Dispatching:
- New apply_cp() function uses PyTorch's _ContextParallel
parallelization plan to dispatch attention call.
  - Enables CP dispatcher for SDPA attention type inside apply_cp()
2. New CP Data Sharding Approach:
- Added a cp_shard() helper function that wraps PyTorch's
_context_parallel_shard API
  - Uses _HeadTailLoadBalancer for SDPA attention load balancing
  - FlexAttention CP support deferred to a future PR
- CP sharding now happens explicitly in post_dataloading_process() where
inputs, labels, and positions are sharded
  - The new positions argument allows us to not shard the freqs_cis.

Note that this PR require pytorch/pytorch#170200

**Test**
```
-> % python3 scripts/loss_compare.py . chienchin/loss_compare --baseline-options="--parallelism.context_parallel_degree=8" --test-options="--parallelism.context_parallel_degree=8" --steps=100 --assert-equal

pick 5903566a Improve the loss_compare.sh logic

[LOSS_COMPARE]
[LOSS_COMPARE] Asserting losses are equal...
[LOSS_COMPARE] Baseline log: /tmp/baseline_training.log
[LOSS_COMPARE] Test log: /tmp/test_training.log
[LOSS_COMPARE] Extracted 100 steps from baseline log
[LOSS_COMPARE] Extracted 100 steps from test log
test_losses_equal
(__main__.assert_losses_equal.<locals>.LossEqualityTest.test_losses_equal)
... ok

----------------------------------------------------------------------
Ran 1 test in 0.000s

OK
[LOSS_COMPARE] All losses are equal. Assertion passed!
[LOSS_COMPARE] ==========================================
[LOSS_COMPARE] LOSS COMPARISON ANALYSIS
[LOSS_COMPARE] ==========================================

[LOSS_COMPARE] Step-by-step loss comparison:
[LOSS_COMPARE] Step    Baseline Loss    Test Loss   Difference
[LOSS_COMPARE] ----    -------------    ---------   ----------
[LOSS_COMPARE] 1       8.1309           8.1309           0.000000
[LOSS_COMPARE] 2       7.8268           7.8268           0.000000
[LOSS_COMPARE] 3       7.2284           7.2284           0.000000
[LOSS_COMPARE] 4       6.4669           6.4669           0.000000
[LOSS_COMPARE] 5       5.4017           5.4017           0.000000
[LOSS_COMPARE] 6       4.7656           4.7656           0.000000
[LOSS_COMPARE] 7       4.3587           4.3587           0.000000
[LOSS_COMPARE] 8       4.0938           4.0938           0.000000
[LOSS_COMPARE] 9       4.4019           4.4019           0.000000
[LOSS_COMPARE] 10      3.7451           3.7451           0.000000
....
[LOSS_COMPARE] 90      2.802            2.802            0.000000
[LOSS_COMPARE] 91      2.7207           2.7207           0.000000
[LOSS_COMPARE] 92      2.7454           2.7454           0.000000
[LOSS_COMPARE] 93      2.6992           2.6992           0.000000
[LOSS_COMPARE] 94      2.743            2.743            0.000000
[LOSS_COMPARE] 95      2.7534           2.7534           0.000000
[LOSS_COMPARE] 96      2.8403           2.8403           0.000000
[LOSS_COMPARE] 97      2.783            2.783            0.000000
[LOSS_COMPARE] 98      3.0892           3.0892           0.000000
[LOSS_COMPARE] 99      2.7905           2.7905           0.000000
[LOSS_COMPARE] 100     2.733            2.733            0.000000
[LOSS_COMPARE]
[LOSS_COMPARE] Summary statistics:
[LOSS_COMPARE] Average baseline loss:  3.1414940000000002
[LOSS_COMPARE] Average test loss: 3.1414940000000002
[LOSS_COMPARE] Average difference:     0.000000
[LOSS_COMPARE]
[LOSS_COMPARE] Loss comparison complete. No results saved (no output
folder specified).
```

**TODO**
- This PR will invalidate torch.compile + CP due to
pytorch/pytorch#170110. We will have to wait
for Dynamo to fix the issue or refactor nn.Module core logic to avoid
check hook_id.
wwwjn pushed a commit to wwwjn/torchtitan that referenced this pull request Jan 30, 2026
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.14.0)
(oldest at bottom):
* pytorch#2145
* __->__ pytorch#2144

**Summary**

1. Refactored CP Dispatching:
- New apply_cp() function uses PyTorch's _ContextParallel
parallelization plan to dispatch attention call.
  - Enables CP dispatcher for SDPA attention type inside apply_cp()
2. New CP Data Sharding Approach:
- Added a cp_shard() helper function that wraps PyTorch's
_context_parallel_shard API
  - Uses _HeadTailLoadBalancer for SDPA attention load balancing
  - FlexAttention CP support deferred to a future PR
- CP sharding now happens explicitly in post_dataloading_process() where
inputs, labels, and positions are sharded
  - The new positions argument allows us to not shard the freqs_cis.

Note that this PR require pytorch/pytorch#170200

**Test**
```
-> % python3 scripts/loss_compare.py . chienchin/loss_compare --baseline-options="--parallelism.context_parallel_degree=8" --test-options="--parallelism.context_parallel_degree=8" --steps=100 --assert-equal

pick 5903566a Improve the loss_compare.sh logic

[LOSS_COMPARE]
[LOSS_COMPARE] Asserting losses are equal...
[LOSS_COMPARE] Baseline log: /tmp/baseline_training.log
[LOSS_COMPARE] Test log: /tmp/test_training.log
[LOSS_COMPARE] Extracted 100 steps from baseline log
[LOSS_COMPARE] Extracted 100 steps from test log
test_losses_equal
(__main__.assert_losses_equal.<locals>.LossEqualityTest.test_losses_equal)
... ok

----------------------------------------------------------------------
Ran 1 test in 0.000s

OK
[LOSS_COMPARE] All losses are equal. Assertion passed!
[LOSS_COMPARE] ==========================================
[LOSS_COMPARE] LOSS COMPARISON ANALYSIS
[LOSS_COMPARE] ==========================================

[LOSS_COMPARE] Step-by-step loss comparison:
[LOSS_COMPARE] Step    Baseline Loss    Test Loss   Difference
[LOSS_COMPARE] ----    -------------    ---------   ----------
[LOSS_COMPARE] 1       8.1309           8.1309           0.000000
[LOSS_COMPARE] 2       7.8268           7.8268           0.000000
[LOSS_COMPARE] 3       7.2284           7.2284           0.000000
[LOSS_COMPARE] 4       6.4669           6.4669           0.000000
[LOSS_COMPARE] 5       5.4017           5.4017           0.000000
[LOSS_COMPARE] 6       4.7656           4.7656           0.000000
[LOSS_COMPARE] 7       4.3587           4.3587           0.000000
[LOSS_COMPARE] 8       4.0938           4.0938           0.000000
[LOSS_COMPARE] 9       4.4019           4.4019           0.000000
[LOSS_COMPARE] 10      3.7451           3.7451           0.000000
....
[LOSS_COMPARE] 90      2.802            2.802            0.000000
[LOSS_COMPARE] 91      2.7207           2.7207           0.000000
[LOSS_COMPARE] 92      2.7454           2.7454           0.000000
[LOSS_COMPARE] 93      2.6992           2.6992           0.000000
[LOSS_COMPARE] 94      2.743            2.743            0.000000
[LOSS_COMPARE] 95      2.7534           2.7534           0.000000
[LOSS_COMPARE] 96      2.8403           2.8403           0.000000
[LOSS_COMPARE] 97      2.783            2.783            0.000000
[LOSS_COMPARE] 98      3.0892           3.0892           0.000000
[LOSS_COMPARE] 99      2.7905           2.7905           0.000000
[LOSS_COMPARE] 100     2.733            2.733            0.000000
[LOSS_COMPARE]
[LOSS_COMPARE] Summary statistics:
[LOSS_COMPARE] Average baseline loss:  3.1414940000000002
[LOSS_COMPARE] Average test loss: 3.1414940000000002
[LOSS_COMPARE] Average difference:     0.000000
[LOSS_COMPARE]
[LOSS_COMPARE] Loss comparison complete. No results saved (no output
folder specified).
```

**TODO**
- This PR will invalidate torch.compile + CP due to
pytorch/pytorch#170110. We will have to wait
for Dynamo to fix the issue or refactor nn.Module core logic to avoid
check hook_id.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants