-
Notifications
You must be signed in to change notification settings - Fork 51
Clessig/jk/develop/1654 fix fsteps #1689
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: jk/develop/1654_fix_fsteps
Are you sure you want to change the base?
Clessig/jk/develop/1654 fix fsteps #1689
Conversation
…tput* to be more general. Cleaned up various details.
| assert all(step > 0 for step in forecast_cfg.num_steps), valid_forecast_steps_offset1 | ||
|
|
||
| # check forecast offset | ||
| if forecast_cfg.get("offset") is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We set a default forecast_offset=0 in the multi_stream_data_sampler - so we should still do the the checks below if forecast_cfg.get("offset") == None.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which line was this? Still cannot see it in the deleted code above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here: https://github.com/ecmwf/WeatherGenerator/blob/clessig/jk/develop/1654_fix_fsteps/src/weathergen/datasets/multi_stream_data_sampler.py#L107
So you could specify no offset in the config but still specify num_steps: 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I now see https://github.com/ecmwf/WeatherGenerator/blob/clessig/jk/develop/1654_fix_fsteps/src/weathergen/train/trainer.py#L126
So I think it should be: if forecast_cfg is not None:
| self.forecast_policy = None | ||
| self.time_step = np.timedelta64(0, "ms") | ||
|
|
||
| fsm = self.list_num_forecast_steps[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why [0]? list_num_forecast_steps is not necessarily sorted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand correctly, the forecast_steps (if several options are given) are drawn further down for each batch step. So maybe the dataset indices should be reduced by the maximum possible forecast length?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't rethink the logic here but just moved things to more appropriate places in the function. I think it should also be a separate PR to fix/check this.
…Generator into clessig/jk/develop/1654_fix_fsteps
Jubeku
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
output_idxs was not only renamed from forecast_idxs but also the definition changed so that it always contains all output indices, not only in the case for forecasting.
len(output_idxs) == batch.get_output_len() unless forecast.offset==1, then len(output_idxs) == batch.get_output_len() -1.
| register_tokens=z[:, : self.register_token_idx], | ||
| class_token=z[:, self.register_token_idx : self.class_token_idx], | ||
| patch_tokens=z[:, self.class_token_idx :], | ||
| register_tokens=z[:, self.register_token_idxs] if z is not None else z, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe [...] if z is not None else None is more readable?
| for step in range(batch.get_output_len()): | ||
| # apply forecasting engine | ||
| if self.forecast_engine: | ||
| tokens = self.forecast_engine(tokens, step) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You always want to call forecast_engine already for step 0 and overwrite the tokens at step 1 when your first output_idx is 1 (aka your first forecast step)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If output_idxs[0] = 1 and self.forecast_engine is not None then this is the correct behavior. If output_idxs[0] = 0 and self.forecast_engine is None we also obtain the correct behavior. We could enforce that output_idxs[0] = 1 and self.forecast_engine is None is an invalid config option.
|
|
||
| tokens = self.forecast_engine(tokens, fstep) | ||
| # safe latent prediction | ||
| tokens_pre_norm = self.latent_pre_norm(tokens) if step == 0 else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are tokens_pre_norm the tokens post the pre-model-LayerNorm?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tokens_pre_norm is what was previously called z. But the name is wrong, it should be tokens_post_norm, yes.
| output = ModelOutput(batch.get_output_len()) | ||
|
|
||
| tokens, posteriors = self.encoder(model_params, batch) | ||
| output.add_latent_prediction(0, "posteriors", posteriors) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will be overwritten when output_idxs contain 0. But maybe that's supposed to be?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, posteriors is only written here (latents is set in predict_latent()).
| # collect all targets, concatenating across batch dimension since this is also how it | ||
| # happens for predictions in the model | ||
| timestep_idxs = [0] if len(forecast_idxs) == 0 else forecast_idxs | ||
| timestep_idxs = [0] if len(output_idxs) == 0 else output_idxs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's not needed anymore if we always enforce len(output_idxs) > 0 (see assertion in line 102)
| assert all(step > 0 for step in forecast_cfg.num_steps), valid_forecast_steps_offset1 | ||
|
|
||
| # check forecast offset | ||
| if forecast_cfg.get("offset") is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I now see https://github.com/ecmwf/WeatherGenerator/blob/clessig/jk/develop/1654_fix_fsteps/src/weathergen/train/trainer.py#L126
So I think it should be: if forecast_cfg is not None:
| timestep_idxs = [0] if len(output_idxs) == 0 else output_idxs | ||
| for stream_name in stream_names: | ||
| # collect targets for all forecast steps | ||
| for t_idx in timestep_idxs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can directly loop over output_idxs.
|
Running inference I got: Command: |
|
...it's because |
Description
Cleanups and improvements to target branch.
Issue Number
Checklist before asking for review
./scripts/actions.sh lint./scripts/actions.sh unit-test./scripts/actions.sh integration-testlaunch-slurm.py --time 60