Skip to content

Conversation

@sophie-xhonneux
Copy link
Contributor

@sophie-xhonneux sophie-xhonneux commented Jan 14, 2026

Description

Make sure register and class tokens are used in query aggregation engine.

Issue Number

Closes #1608
Closes #1673

Checklist before asking for review

  • I have performed a self-review of my code
  • My changes comply with basic sanity checks:
    • I have fixed formatting issues with ./scripts/actions.sh lint
    • I have run unit tests with ./scripts/actions.sh unit-test
    • I have documented my code and I have updated the docstrings.
    • I have added unit tests, if relevant
  • I have tried my changes with data and code:
    • I have run the integration tests with ./scripts/actions.sh integration-test
    • (bigger changes) I have run a full training and I have written in the comment the run_id(s): launch-slurm.py --time 60
    • (bigger changes and experiments) I have shared a hegdedoc in the github issue with all the configurations and runs for this experiments
  • I have informed and aligned with people impacted by my change:
    • for config changes: the MatterMost channels and/or a design doc
    • for changes of dependencies: the MatterMost software development channel

@github-actions github-actions bot added the model Related to model training or definition (not generic infra) label Jan 14, 2026
@sophie-xhonneux sophie-xhonneux requested review from clessig and removed request for shmh40 January 14, 2026 16:58

# streams_directory: "./config/streams/era5_1deg/"
streams_directory: "./config/streams/era5_nppatms_synop/"
streams_directory: "./config/streams/era5_1deg/"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Restore this file

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some changes make sense, and shouldn't be restored, e.g. training_mode including masking

# currently fixed to 1.0 (due to limitations with flex_attention and triton)
forecast_att_dense_rate: 1.0

sslpred_num_blocks: 12
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move to JEPA loss terms, it also looks like it this has already been done.

# get target_aux calculators for different loss terms
self.target_and_aux_calculators = self.get_target_aux_calculators(self.training_cfg)
self.validate_with_ema_cfg = self.get_target_aux_calculators(self.validation_cfg)
# self.validate_with_ema_cfg = self.get_target_aux_calculators(self.validation_cfg)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-enable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually think this breaks things and should be removed

target_aux.update_state_pre_backward(self.cf.general.istep, batch, self.model)
for _, target_aux in self.target_and_aux_calculators.items()
]
[
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why has this been removed?

target_aux.update_state_post_opt_step(step, batch, self.model)
for _, target_aux in self.target_and_aux_calculators.items()
]
[
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why has this been removed?

.repeat(rs, 1)
)
cell_lens_r = cell_lens.unsqueeze(0).reshape(rs, self.num_healpix_cells)
mask = torch.cat([mask_reg_class_tokens, cell_lens_r.to(torch.bool)], dim=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sophie-xhonneux to check, here we have applied aggregation engine to the unmasked tokens and the register + class tokens, then we are just creating this mask for the reg and class tokens (all 1s), concat with the appropriate mask for the normal tokens and then fill tokens_global in the corresponding positions with the output of the aggregation engine unmasked?

if tokens_c.shape[0] == 0:
# Check if this chunk is empty
if l0 == l1 or toks.shape[0] == 0:
continue
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sophie-xhonneux do you think the fsdp problem we had before might also be fixed here by the PR on the FSDP package? Worth trying with just a copy and paste of the posted PR https://github.com/pytorch/pytorch/pull/170667/files ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One can just edit in the local pyenv and try:

.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fully_shard.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed we should

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

model Related to model training or definition (not generic infra)

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

[Bug] Tensor Shape and Indexing Mismatches in encoder.py when rs > 1 Add the register tokens before the aggregation engine

5 participants