Skip to content

refactor(models, training): shard_shapes#964

Draft
japols wants to merge 5 commits intomainfrom
refactor/shard-shapes
Draft

refactor(models, training): shard_shapes#964
japols wants to merge 5 commits intomainfrom
refactor/shard-shapes

Conversation

@japols
Copy link
Copy Markdown
Member

@japols japols commented Mar 10, 2026

Description

This PR simplifies the sharding metadata used to track tensor shapes across ranks.

Previously, we stored the full tensor shape for each rank (e.g. list[list[int]]). This required layers to manually track reshapes and dimension changes whenever tensors were transformed, which made the sharding logic fragile and tightly coupled to tensor layouts.

This refactor introduces ShardSizes = Union[list[int], None], representing the per-rank shard sizes along only the sharded dimension. Layers now propagate this information through a bundled GraphShardInfo / BipartiteGraphInfo which tracks shard metadata for both nodes and edges. The full shape expansion only happens at the level of the communication primitive where shapes are assumed equal for non-sharded dimensions across ranks.

Also refactor the all-to-all primitives for head/channel <-> grid sharding into a single common all-to-all primitive.

Additional notes

For now I've tested sharding for a global model across combinations of:

  • ensemble/single
  • keep_batch_sharded True/False
  • transformer/graphtransformer
  • head/edge sharding

please feel free to also test your favourite use case.

As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/

By opening this pull request, I affirm that all authors agree to the Contributor License Agreement.

@japols japols requested a review from ssmmnn11 March 10, 2026 14:45
@japols japols self-assigned this Mar 10, 2026
@github-project-automation github-project-automation bot moved this to To be triaged in Anemoi-dev Mar 10, 2026
@cathalobrien
Copy link
Copy Markdown
Contributor

nice, the benchmark tests pass

@japols japols requested a review from cathalobrien March 10, 2026 16:21
@japols japols marked this pull request as draft March 11, 2026 09:53
@mchantry mchantry added ATS Approval Not Needed No approval needed by ATS and removed ATS Approval Needed Approval needed by ATS labels Mar 11, 2026
@japols
Copy link
Copy Markdown
Member Author

japols commented Mar 17, 2026

waiting for #931 to be merged

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

Status: To be triaged

Development

Successfully merging this pull request may close these issues.

4 participants