Skip to content

Conversation

@clessig
Copy link
Collaborator

@clessig clessig commented Jan 24, 2026

Description

Cleanups and improvements to target branch.

Issue Number

Checklist before asking for review

  • I have performed a self-review of my code
  • My changes comply with basic sanity checks:
    • I have fixed formatting issues with ./scripts/actions.sh lint
    • I have run unit tests with ./scripts/actions.sh unit-test
    • I have documented my code and I have updated the docstrings.
    • I have added unit tests, if relevant
  • I have tried my changes with data and code:
    • I have run the integration tests with ./scripts/actions.sh integration-test
    • (bigger changes) I have run a full training and I have written in the comment the run_id(s): launch-slurm.py --time 60
    • (bigger changes and experiments) I have shared a hegdedoc in the github issue with all the configurations and runs for this experiments
  • I have informed and aligned with people impacted by my change:
    • for config changes: the MatterMost channels and/or a design doc
    • for changes of dependencies: the MatterMost software development channel

assert all(step > 0 for step in forecast_cfg.num_steps), valid_forecast_steps_offset1

# check forecast offset
if forecast_cfg.get("offset") is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We set a default forecast_offset=0 in the multi_stream_data_sampler - so we should still do the the checks below if forecast_cfg.get("offset") == None.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which line was this? Still cannot see it in the deleted code above.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.forecast_policy = None
self.time_step = np.timedelta64(0, "ms")

fsm = self.list_num_forecast_steps[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why [0]? list_num_forecast_steps is not necessarily sorted.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, the forecast_steps (if several options are given) are drawn further down for each batch step. So maybe the dataset indices should be reduced by the maximum possible forecast length?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't rethink the logic here but just moved things to more appropriate places in the function. I think it should also be a separate PR to fix/check this.

@github-project-automation github-project-automation bot moved this to In Progress in WeatherGen-dev Jan 26, 2026
Copy link
Contributor

@Jubeku Jubeku left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

output_idxs was not only renamed from forecast_idxs but also the definition changed so that it always contains all output indices, not only in the case for forecasting.
len(output_idxs) == batch.get_output_len() unless forecast.offset==1, then len(output_idxs) == batch.get_output_len() -1.

register_tokens=z[:, : self.register_token_idx],
class_token=z[:, self.register_token_idx : self.class_token_idx],
patch_tokens=z[:, self.class_token_idx :],
register_tokens=z[:, self.register_token_idxs] if z is not None else z,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe [...] if z is not None else None is more readable?

for step in range(batch.get_output_len()):
# apply forecasting engine
if self.forecast_engine:
tokens = self.forecast_engine(tokens, step)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You always want to call forecast_engine already for step 0 and overwrite the tokens at step 1 when your first output_idx is 1 (aka your first forecast step)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If output_idxs[0] = 1 and self.forecast_engine is not None then this is the correct behavior. If output_idxs[0] = 0 and self.forecast_engine is None we also obtain the correct behavior. We could enforce that output_idxs[0] = 1 and self.forecast_engine is None is an invalid config option.


tokens = self.forecast_engine(tokens, fstep)
# safe latent prediction
tokens_pre_norm = self.latent_pre_norm(tokens) if step == 0 else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are tokens_pre_norm the tokens post the pre-model-LayerNorm?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tokens_pre_norm is what was previously called z. But the name is wrong, it should be tokens_post_norm, yes.

output = ModelOutput(batch.get_output_len())

tokens, posteriors = self.encoder(model_params, batch)
output.add_latent_prediction(0, "posteriors", posteriors)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be overwritten when output_idxs contain 0. But maybe that's supposed to be?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, posteriors is only written here (latents is set in predict_latent()).

# collect all targets, concatenating across batch dimension since this is also how it
# happens for predictions in the model
timestep_idxs = [0] if len(forecast_idxs) == 0 else forecast_idxs
timestep_idxs = [0] if len(output_idxs) == 0 else output_idxs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not needed anymore if we always enforce len(output_idxs) > 0 (see assertion in line 102)

assert all(step > 0 for step in forecast_cfg.num_steps), valid_forecast_steps_offset1

# check forecast offset
if forecast_cfg.get("offset") is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

timestep_idxs = [0] if len(output_idxs) == 0 else output_idxs
for stream_name in stream_names:
# collect targets for all forecast steps
for t_idx in timestep_idxs:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can directly loop over output_idxs.

@Jubeku
Copy link
Contributor

Jubeku commented Jan 26, 2026

Running inference I got:

num_forecast_steps at mini_epoch=0 : 6
  0%|                                                                                                                                           | 0/16 [00:20<?, ?it/s]
Traceback (most recent call last):
  File "/users/jkuehner/CODE/WeatherGenerator/src/weathergen/run_train.py", line 78, in inference_from_args
    trainer.inference(cf, devices, args.from_run_id, args.mini_epoch)
  File "/users/jkuehner/CODE/WeatherGenerator/src/weathergen/train/trainer.py", line 213, in inference
    self.validate(0, self.test_cfg, self.batch_size_test_per_gpu)
  File "/users/jkuehner/CODE/WeatherGenerator/src/weathergen/train/trainer.py", line 535, in validate
    preds = self.model(
            ^^^^^^^^^^^
  File "/users/jkuehner/CODE/WeatherGenerator/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/users/jkuehner/CODE/WeatherGenerator/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/users/jkuehner/CODE/WeatherGenerator/src/weathergen/model/model.py", line 590, in forward
    output = self.predict_decoders(model_params, step, tokens, batch, output)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/users/jkuehner/CODE/WeatherGenerator/src/weathergen/model/model.py", line 649, in predict_decoders
    tokens = tokens[:, (self.aux_token_idxs[-1] + 1) :]
                        ~~~~~~~~~~~~~~~~~~~^^^^
IndexError: list index out of range
[6] > /users/jkuehner/CODE/WeatherGenerator/src/weathergen/model/model.py(649)predict_decoders()
-> tokens = tokens[:, (self.aux_token_idxs[-1] + 1) :]

Command:
uv run inference --from_run_id=eu7fp1nm --samples=16 --options training_config.forecast.num_steps=6 zarr_store=zip

@Jubeku
Copy link
Contributor

Jubeku commented Jan 26, 2026

...it's because self.aux_token_idxs = [] is empty, so self.aux_token_idxs[-1] throws an error

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: In Progress

Development

Successfully merging this pull request may close these issues.

3 participants