Skip to content

Conversation

@Jubeku
Copy link
Contributor

@Jubeku Jubeku commented Jan 19, 2026

Description

This PR makes sure that:

  • In forecast mode we get len(pred) = len(forecast_steps) + forecast_offset.

  • batch.get_forecast_steps() corresponds to forecast.num_steps.

  • The loops over forecast steps in the model forward and in the loss calculator go from forecast_offset to forecast_offset + num_forecast_steps .

  • The tokens go first through the forecasting engine before going through the decoder.

Issue Number

Closes #1654

Checklist before asking for review

  • I have performed a self-review of my code
  • My changes comply with basic sanity checks:
    • I have fixed formatting issues with ./scripts/actions.sh lint
    • I have run unit tests with ./scripts/actions.sh unit-test
    • I have documented my code and I have updated the docstrings.
    • I have added unit tests, if relevant
  • I have tried my changes with data and code:
    • I have run the integration tests with ./scripts/actions.sh integration-test
    • (bigger changes) I have run a full training and I have written in the comment the run_id(s): launch-slurm.py --time 60
    • (bigger changes and experiments) I have shared a hegdedoc in the github issue with all the configurations and runs for this experiments
  • I have informed and aligned with people impacted by my change:
    • for config changes: the MatterMost channels and/or a design doc
    • for changes of dependencies: the MatterMost software development channel

forecast_dt = sdata.get_forecast_steps()
return forecast_dt
forecast_steps = sdata.get_forecast_steps()
return forecast_steps
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am still a bit concerned about the for loops here in which the variable is overwritten at each iteration. I assume both forecast_steps and output_len are constant within a batch across all streams?! If yes we can probably use it to handle it better.

Copy link
Contributor

@MatKbauer MatKbauer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good so far. I'd like to make one major suggestion: Can we rename forecast_steps to forecast_idxs (or forecast_idcs) to be very explicit in our naming? I would even go down to changing fstep to fidx then.

Most of my comments in the review relate to this suggested change in naming convention.

forecast_dt = sdata.get_forecast_steps()
return forecast_dt
forecast_steps = sdata.get_forecast_steps()
return forecast_steps
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still don't quite understand why we have to loop over the streams here and overwrite output_len and forecast_steps, respectively. In multi-stream setup, output_len and forecast_steps will only be returned for the last stream. So why looping over streams instead of returning the last, if that is what we want?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In line with comments below, I'd advocate for get_num_forecast_steps() instead of get_forecast_steps()

dt = self._get_output_length(forecast_dt)
stream_data = StreamData(
base_idx, num_steps_input, dt, forecast_dt, self.forecast_offset, self.num_healpix_cells
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename

  • num_steps_input to num_input_steps,
  • dt to num_forecast_steps, and
  • forecast_dt to forecast_steps?

Get number of forecast steps
"""
return self.forecast_steps
return self.forecast_idxs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename function from get_forecast_steps() to get_forecast_idxs and adapt function header (side note, I'd prefer to use idcs over idxs, but we have idxs at various places in the code)


self.input_steps = input_steps
self.forecast_steps = forecast_steps
self.output_steps = output_steps
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use self.num_output_steps and self.num_input_steps to be consistent with num_forecast_steps in the multi_stream_data_sampler?

"""
Get length of output
"""
return self.output_steps
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return self.num_output_steps

self.latent = [{} for _ in range(len_target)]
self.aux_outputs = {}

def add_physical_target(self, fstep: int, stream_name: StreamName, pred: torch.Tensor) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fidx

latent: list[dict[str, torch.Tensor | LatentState]]
aux_outputs: dict[str, torch.Tensor]

def __init__(self, len_target: int, forecast_steps: list) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forecast_idxs

# collect all targets, concatenating across batch dimension since this is also how it
# happens for predictions in the model
targets = {}
fstep_idxs = [0] if len(forecast_steps) == 0 else forecast_steps
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about using fidxs here and fidx in the forecast loop below?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also considered something like target_idxs to make clear this not only holds for forecasting. But target_idx is also used later in the loss module to refer to the source-target correspondence so it invites risky confusions.

def compute(self, bidx, batch, model_params, model) -> TargetAuxOutput:
# TODO: properly retrieve/define these
stream_names = [k for k, _ in batch.samples[0].streams_data.items()]
forecast_steps = batch.get_forecast_steps()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forecast_idxs = batch.get_forecast_idxs()

for fstep in range(window_offset_prediction, forecast_steps):
# TODO why does this stop at forecast_steps? Maybe explains #1657
# for fstep in range(forecast_offset, forecast_steps + 1):
for fstep in range(forecast_offset, forecast_steps):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Crucial to keep in mind. Maybe have to extend this to +1 here or iterate over forecast_idxs

self.rng = np.random.default_rng(self.data_loader_rng_seed)

fsm = (
self.forecast_steps[min(self.mini_epoch, len(self.forecast_steps) - 1)]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@clessig, why is len(self.forecast_steps) used here?

elif self.forecast_policy == "random" or self.forecast_policy == "sequential_random":
# randint high=one-past
self.perms_forecast_dt = self.rng.integers(
low=self.forecast_steps.min(), high=fsm + 1, size=len_dt_samples, dtype=np.int64
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is here high=fsm+1?

@Jubeku Jubeku marked this pull request as ready for review January 22, 2026 16:39
Copy link
Contributor

@MatKbauer MatKbauer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works well for me, when testing forecasting with offset: 0 or 1 and different num_steps. I launched a 2-node forecast pre-training (2 forecast steps) and would like to quickly verify 8 forecast steps fine-tuning subsequently.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

Clean-up forecast data-flow

4 participants