-
Notifications
You must be signed in to change notification settings - Fork 51
1654 Fixing forecast steps in model, loss, and data loader #1656
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
…ainer... and entire Rattenschwanz
src/weathergen/datasets/batch.py
Outdated
| forecast_dt = sdata.get_forecast_steps() | ||
| return forecast_dt | ||
| forecast_steps = sdata.get_forecast_steps() | ||
| return forecast_steps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am still a bit concerned about the for loops here in which the variable is overwritten at each iteration. I assume both forecast_steps and output_len are constant within a batch across all streams?! If yes we can probably use it to handle it better.
MatKbauer
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good so far. I'd like to make one major suggestion: Can we rename forecast_steps to forecast_idxs (or forecast_idcs) to be very explicit in our naming? I would even go down to changing fstep to fidx then.
Most of my comments in the review relate to this suggested change in naming convention.
src/weathergen/datasets/batch.py
Outdated
| forecast_dt = sdata.get_forecast_steps() | ||
| return forecast_dt | ||
| forecast_steps = sdata.get_forecast_steps() | ||
| return forecast_steps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still don't quite understand why we have to loop over the streams here and overwrite output_len and forecast_steps, respectively. In multi-stream setup, output_len and forecast_steps will only be returned for the last stream. So why looping over streams instead of returning the last, if that is what we want?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In line with comments below, I'd advocate for get_num_forecast_steps() instead of get_forecast_steps()
| dt = self._get_output_length(forecast_dt) | ||
| stream_data = StreamData( | ||
| base_idx, num_steps_input, dt, forecast_dt, self.forecast_offset, self.num_healpix_cells | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we rename
num_steps_inputtonum_input_steps,dttonum_forecast_steps, andforecast_dttoforecast_steps?
| Get number of forecast steps | ||
| """ | ||
| return self.forecast_steps | ||
| return self.forecast_idxs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rename function from get_forecast_steps() to get_forecast_idxs and adapt function header (side note, I'd prefer to use idcs over idxs, but we have idxs at various places in the code)
|
|
||
| self.input_steps = input_steps | ||
| self.forecast_steps = forecast_steps | ||
| self.output_steps = output_steps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use self.num_output_steps and self.num_input_steps to be consistent with num_forecast_steps in the multi_stream_data_sampler?
| """ | ||
| Get length of output | ||
| """ | ||
| return self.output_steps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return self.num_output_steps
| self.latent = [{} for _ in range(len_target)] | ||
| self.aux_outputs = {} | ||
|
|
||
| def add_physical_target(self, fstep: int, stream_name: StreamName, pred: torch.Tensor) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fidx
| latent: list[dict[str, torch.Tensor | LatentState]] | ||
| aux_outputs: dict[str, torch.Tensor] | ||
|
|
||
| def __init__(self, len_target: int, forecast_steps: list) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
forecast_idxs
| # collect all targets, concatenating across batch dimension since this is also how it | ||
| # happens for predictions in the model | ||
| targets = {} | ||
| fstep_idxs = [0] if len(forecast_steps) == 0 else forecast_steps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about using fidxs here and fidx in the forecast loop below?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also considered something like target_idxs to make clear this not only holds for forecasting. But target_idx is also used later in the loss module to refer to the source-target correspondence so it invites risky confusions.
| def compute(self, bidx, batch, model_params, model) -> TargetAuxOutput: | ||
| # TODO: properly retrieve/define these | ||
| stream_names = [k for k, _ in batch.samples[0].streams_data.items()] | ||
| forecast_steps = batch.get_forecast_steps() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
forecast_idxs = batch.get_forecast_idxs()
| for fstep in range(window_offset_prediction, forecast_steps): | ||
| # TODO why does this stop at forecast_steps? Maybe explains #1657 | ||
| # for fstep in range(forecast_offset, forecast_steps + 1): | ||
| for fstep in range(forecast_offset, forecast_steps): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Crucial to keep in mind. Maybe have to extend this to +1 here or iterate over forecast_idxs
| self.rng = np.random.default_rng(self.data_loader_rng_seed) | ||
|
|
||
| fsm = ( | ||
| self.forecast_steps[min(self.mini_epoch, len(self.forecast_steps) - 1)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@clessig, why is len(self.forecast_steps) used here?
| elif self.forecast_policy == "random" or self.forecast_policy == "sequential_random": | ||
| # randint high=one-past | ||
| self.perms_forecast_dt = self.rng.integers( | ||
| low=self.forecast_steps.min(), high=fsm + 1, size=len_dt_samples, dtype=np.int64 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is here high=fsm+1?
MatKbauer
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Works well for me, when testing forecasting with offset: 0 or 1 and different num_steps. I launched a 2-node forecast pre-training (2 forecast steps) and would like to quickly verify 8 forecast steps fine-tuning subsequently.
Description
This PR makes sure that:
In forecast mode we get
len(pred) = len(forecast_steps) + forecast_offset.batch.get_forecast_steps()corresponds toforecast.num_steps.The loops over forecast steps in the model forward and in the loss calculator go from
forecast_offsettoforecast_offset + num_forecast_steps.The tokens go first through the forecasting engine before going through the decoder.
Issue Number
Closes #1654
Checklist before asking for review
./scripts/actions.sh lint./scripts/actions.sh unit-test./scripts/actions.sh integration-testlaunch-slurm.py --time 60