-
Notifications
You must be signed in to change notification settings - Fork 51
Move register & class tokens to be added earlier #1610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Move register & class tokens to be added earlier #1610
Conversation
…ae_aggregation_engine. More checnking needed.
…iex/dev/include-reg-tokens-in-query-agg-engine
…herGenerator into sophiex/dev/include-reg-tokens-in-query-agg-engine
…/dev/include-reg-tokens-in-query-agg-engine
config/config_dinov2.yml
Outdated
|
|
||
| # streams_directory: "./config/streams/era5_1deg/" | ||
| streams_directory: "./config/streams/era5_nppatms_synop/" | ||
| streams_directory: "./config/streams/era5_1deg/" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Restore this file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some changes make sense, and shouldn't be restored, e.g. training_mode including masking
config/config_physical_jepa.yml
Outdated
| # currently fixed to 1.0 (due to limitations with flex_attention and triton) | ||
| forecast_att_dense_rate: 1.0 | ||
|
|
||
| sslpred_num_blocks: 12 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move to JEPA loss terms, it also looks like it this has already been done.
src/weathergen/train/trainer.py
Outdated
| # get target_aux calculators for different loss terms | ||
| self.target_and_aux_calculators = self.get_target_aux_calculators(self.training_cfg) | ||
| self.validate_with_ema_cfg = self.get_target_aux_calculators(self.validation_cfg) | ||
| # self.validate_with_ema_cfg = self.get_target_aux_calculators(self.validation_cfg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re-enable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually think this breaks things and should be removed
| target_aux.update_state_pre_backward(self.cf.general.istep, batch, self.model) | ||
| for _, target_aux in self.target_and_aux_calculators.items() | ||
| ] | ||
| [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why has this been removed?
| target_aux.update_state_post_opt_step(step, batch, self.model) | ||
| for _, target_aux in self.target_and_aux_calculators.items() | ||
| ] | ||
| [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why has this been removed?
…d_shard() to better support all use cases.
| .repeat(rs, 1) | ||
| ) | ||
| cell_lens_r = cell_lens.unsqueeze(0).reshape(rs, self.num_healpix_cells) | ||
| mask = torch.cat([mask_reg_class_tokens, cell_lens_r.to(torch.bool)], dim=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sophie-xhonneux to check, here we have applied aggregation engine to the unmasked tokens and the register + class tokens, then we are just creating this mask for the reg and class tokens (all 1s), concat with the appropriate mask for the normal tokens and then fill tokens_global in the corresponding positions with the output of the aggregation engine unmasked?
| if tokens_c.shape[0] == 0: | ||
| # Check if this chunk is empty | ||
| if l0 == l1 or toks.shape[0] == 0: | ||
| continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sophie-xhonneux do you think the fsdp problem we had before might also be fixed here by the PR on the FSDP package? Worth trying with just a copy and paste of the posted PR https://github.com/pytorch/pytorch/pull/170667/files ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One can just edit in the local pyenv and try:
.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fully_shard.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed we should
…github.com:ecmwf/WeatherGenerator into sophiex/dev/include-reg-tokens-in-query-agg-engine
I think it still hangs in multi-GPU mode even with just DDP :/
…ed and removing dependence on default_config.yml
…github.com:ecmwf/WeatherGenerator into sophiex/dev/include-reg-tokens-in-query-agg-engine
Description
Make sure register and class tokens are used in query aggregation engine.
Issue Number
Closes #1608
Closes #1673
Checklist before asking for review
./scripts/actions.sh lint./scripts/actions.sh unit-test./scripts/actions.sh integration-testlaunch-slurm.py --time 60