Skip to content

feat(training, models): graph config for skipped noise injector and multiscale loss#855

Open
ssmmnn11 wants to merge 45 commits intomainfrom
feat/graph_for_skipped_and_multiscale
Open

feat(training, models): graph config for skipped noise injector and multiscale loss#855
ssmmnn11 wants to merge 45 commits intomainfrom
feat/graph_for_skipped_and_multiscale

Conversation

@ssmmnn11
Copy link
Copy Markdown
Member

@ssmmnn11 ssmmnn11 commented Feb 3, 2026

Description
This PR adds new configuration support in the training and models components to enable graph-based settings for skipped noise injection and multiscale loss integration.

  • Main changed

  • Introduces graph configuration options to control:

    • Noise injector behavior when using graph-based models
    • Multiscale loss projections across graph representations
  • Adds associated tests covering:

    • Graph-based multiscale loss scales
    • Truncation and projection behavior
  • Implements support for noise projection in graph-based models

  • Includes additional fixes and refactors to improve graph config handling

  • Why?

These changes enhance flexibility in defining how graph-structured models handle noise injection and multiscale loss terms during training, leveraging existing functionality in anemoi-graphs to use sparse matrices generation. Facilitating reproducibility via graph recipes and reducing dependency on truncation files graph not properly tracked.
In addition, it will enable more involved interpolation workflows in the future.


📚 Documentation preview 📚: https://anemoi-training--855.org.readthedocs.build/en/855/


📚 Documentation preview 📚: https://anemoi-graphs--855.org.readthedocs.build/en/855/


📚 Documentation preview 📚: https://anemoi-models--855.org.readthedocs.build/en/855/

@ssmmnn11 ssmmnn11 self-assigned this Feb 3, 2026
@github-project-automation github-project-automation bot moved this to To be triaged in Anemoi-dev Feb 3, 2026
@github-actions github-actions bot added graphs and removed graphs labels Feb 3, 2026
@ssmmnn11 ssmmnn11 marked this pull request as draft February 3, 2026 09:40
@github-actions github-actions bot added the graphs label Feb 3, 2026
@ssmmnn11 ssmmnn11 changed the title feature(training, models): graph config for skipped, noise injector and multiscale loss feature(training, models): graph config for skipped noise injector and multiscale loss Feb 3, 2026
@ssmmnn11 ssmmnn11 changed the title feature(training, models): graph config for skipped noise injector and multiscale loss feat(training, models): graph config for skipped noise injector and multiscale loss Feb 3, 2026
@ssmmnn11 ssmmnn11 marked this pull request as ready for review February 6, 2026 11:20
ssmmnn11 added 12 commits March 12, 2026 16:36
…d_and_multiscale

# Conflicts:
#	training/src/anemoi/training/config/graph/encoder_decoder_only.yaml
#	training/src/anemoi/training/config/graph/hierarchical_2level.yaml
#	training/src/anemoi/training/config/graph/hierarchical_3level.yaml
#	training/src/anemoi/training/config/graph/limited_area.yaml
#	training/src/anemoi/training/config/graph/multi_scale.yaml
#	training/src/anemoi/training/config/graph/stretched_grid.yaml
#	training/src/anemoi/training/train/train.py
…/graph_for_skipped_and_multiscale

# Conflicts:
#	training/tests/unit/losses/test_combined_loss.py
…d_and_multiscale

# Conflicts:
#	training/src/anemoi/training/schemas/training.py
…d_and_multiscale

# Conflicts:
#	training/tests/integration/test_training_cycle.py
self,
*,
model_config: DotDict,
model_config: Any,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is the typing changing?

"Whether to use autocast for the noise projection matrix operations."

@model_validator(mode="after")
def validate_noise_projection(self) -> "NoiseConditioningSchema":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could move this to

def model_post_init(self, _: Any) -> None:
so that the checks is then done both when config_validation is on and off

else:
model_config_local = model_config

self.input_times = list(model_config_local.training.explicit_times.input)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this related to this PR, or left over/not updated with main? Those seems unrelated so probably best not change this as part of the PR (same for the self.latent_skip)

Graph definition
"""
super().__init__()
if type(model_config) is dict and not OmegaConf.is_config(model_config):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we doing this here (and also on the interpolator)?

from anemoi.models.utils.projection_helpers import residual_projection_truncation_node_name
from anemoi.models.utils.projection_helpers import uses_fused_dataset_graph


Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need all of these utility functions? How easy is to maintain this?

Copy link
Copy Markdown
Contributor

@anaprietonem anaprietonem Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After trying to look at this more, I believe this should be something in anemoi-graphs and projection helpers should also be part of anemoi-graphs. And then in graph_config we could just do one name resolution after the graph is build so that those are then used by the model or the losses when needed. If then the graph_factory returns:

graph_data, projection_data = graph_factory.build()

and then that object can be consumed by the losses and model with something like

model = instantiate(config.model, graph=graph_data, projection_data=projection_data)
loss = get_loss_function(config.training_loss, ..., graph_data=graph_data,
                         loss_matrices_graph=projection_data.multiscale_loss_matrices_graph)

Could something like this be considered?
With that then the GraphFactory would live in anemoi-graphs and also the projection helpers that would be significantly simplified since the resolve it's done once. And similarly to the GraphCreator object we could have ProjectionCreator that then it's called in the factory to resolve the projection_metadata

LOGGER = logging.getLogger(__name__)


class TrainerGraphDataFactory:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this something that should live in training? I'd would say this seems to me a very graph specific functionality that would be better placed in anemoi-graphs and just simply take the part of the training config related to the dataset names that it needs rather than whole config

from anemoi.models.layers.graph_provider import ProjectionGraphProvider
from anemoi.models.layers.sparse_projector import SparseProjector
from anemoi.models.utils.projection_helpers import (
multiscale_loss_matrices_graph as derive_multiscale_loss_matrices_graph,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why changing the name here?

@@ -0,0 +1,2 @@
---
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why here we have this and the truncation/none.yaml is just an empty file?

---
defaults:
- _self_

Copy link
Copy Markdown
Contributor

@anaprietonem anaprietonem Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we simplify the definition of these configs by doing something like?

num_scales: 4
base_num_nearest_neighbours: 16
base_sigma: 0.00471
scale_factor: 2          # optional, defaults to 2
edge_weight_attribute: gauss_weight
gaussian_norm: l1

and then just a helper that builds this?

def _expand_geometric_smoothers(projection_cfg: Any) -> dict[str, Any] | None:
    """Build an explicit smoothers dict from a compact geometric progression spec."""
    num_scales = projection_cfg.get("num_scales")
    if num_scales is None:
        return None

    base_neighbours = projection_cfg["base_num_nearest_neighbours"]
    base_sigma = projection_cfg["base_sigma"]
    scale_factor = projection_cfg.get("scale_factor", 2)
    edge_weight_attribute = projection_cfg.get("edge_weight_attribute", 'gauss_weight')
    gaussian_norm = projection_cfg.get("gaussian_norm", "l1")

    smoothers = {}
    for i in range(num_scales):
        factor = scale_factor**i
        smoothers[f"smooth_{factor}x"] = {
            "edge_weight_attribute": edge_weight_attribute,
            "gaussian_norm": gaussian_norm,
            "num_nearest_neighbours": base_neighbours * factor,
            "sigma": round(base_sigma * factor, 5),
        }
    return smoothers

If we still want to provide support for other cases that don't have this pattern we could still allow for a dictionary of specific smoothers layer? But do you thing we'd need that case?

assert len(v) == len(info.data["loss_matrices"]), "weights must have same length as loss_matrices"
return v
@model_validator(mode="after")
def validate_matrix_source(self) -> Self:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly this could be a function that gets called in the SchemaCommonMixin so it's checked in cases with config validation on and off


smoothing_matrices: list[ProjectionGraphProvider | None] = []
for entry in loss_matrices_graph:
if entry is None or entry is False or entry == "None":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could this be simplified with the suggestions of moving some of the validation checks to the SchemaMixin?

NESTED_LOSS_CLASS_NAMES = {
"MultiscaleLossWrapper",
}
GRAPH_DATA_WRAPPER_CLASS_NAMES = {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about an alternative solution here, where we could add a specific needs_graph_data: bool = True to the losses and then when instantiating the loss do:

    if "per_scale_loss" in loss_config:
        per_scale_loss_config = loss_config.pop("per_scale_loss")
        per_scale_loss = get_loss_function(
            OmegaConf.create(per_scale_loss_config),
            scalers,
            data_indices,
            graph_data=graph_data,
            **kwargs,
        )
        loss_config["per_scale_loss"] = per_scale_loss

    target_cls = hydra.utils.get_class(loss_config["_target_"])
    if getattr(target_cls, "needs_graph_data", False) and graph_data is not None:
        kwargs["graph_data"] = graph_data

    loss_function = instantiate(loss_config, _recursive_=False, **kwargs)

from omegaconf import OmegaConf

DEFAULT_DATASET_NAME = "data"
from anemoi.models.utils.projection_helpers import DEFAULT_DATASET_NAME
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this default name should be defined by the projection_helpers. Why did you choose this? I might be missing something.

@github-project-automation github-project-automation bot moved this from To be triaged to Under Review in Anemoi-dev Mar 26, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Status: Under Review

Development

Successfully merging this pull request may close these issues.

5 participants