feat(training, models): graph config for skipped noise injector and multiscale loss#855
feat(training, models): graph config for skipped noise injector and multiscale loss#855
Conversation
…d_and_multiscale # Conflicts: # training/src/anemoi/training/config/graph/encoder_decoder_only.yaml # training/src/anemoi/training/config/graph/hierarchical_2level.yaml # training/src/anemoi/training/config/graph/hierarchical_3level.yaml # training/src/anemoi/training/config/graph/limited_area.yaml # training/src/anemoi/training/config/graph/multi_scale.yaml # training/src/anemoi/training/config/graph/stretched_grid.yaml # training/src/anemoi/training/train/train.py
…/graph_for_skipped_and_multiscale # Conflicts: # training/tests/unit/losses/test_combined_loss.py
…d_and_multiscale # Conflicts: # training/src/anemoi/training/schemas/training.py
…d_and_multiscale # Conflicts: # training/tests/integration/test_training_cycle.py
| self, | ||
| *, | ||
| model_config: DotDict, | ||
| model_config: Any, |
There was a problem hiding this comment.
why is the typing changing?
| "Whether to use autocast for the noise projection matrix operations." | ||
|
|
||
| @model_validator(mode="after") | ||
| def validate_noise_projection(self) -> "NoiseConditioningSchema": |
There was a problem hiding this comment.
you could move this to
so that the checks is then done both when config_validation is on and off| else: | ||
| model_config_local = model_config | ||
|
|
||
| self.input_times = list(model_config_local.training.explicit_times.input) |
There was a problem hiding this comment.
is this related to this PR, or left over/not updated with main? Those seems unrelated so probably best not change this as part of the PR (same for the self.latent_skip)
| Graph definition | ||
| """ | ||
| super().__init__() | ||
| if type(model_config) is dict and not OmegaConf.is_config(model_config): |
There was a problem hiding this comment.
why are we doing this here (and also on the interpolator)?
| from anemoi.models.utils.projection_helpers import residual_projection_truncation_node_name | ||
| from anemoi.models.utils.projection_helpers import uses_fused_dataset_graph | ||
|
|
||
|
|
There was a problem hiding this comment.
Why do we need all of these utility functions? How easy is to maintain this?
There was a problem hiding this comment.
After trying to look at this more, I believe this should be something in anemoi-graphs and projection helpers should also be part of anemoi-graphs. And then in graph_config we could just do one name resolution after the graph is build so that those are then used by the model or the losses when needed. If then the graph_factory returns:
graph_data, projection_data = graph_factory.build()
and then that object can be consumed by the losses and model with something like
model = instantiate(config.model, graph=graph_data, projection_data=projection_data)
loss = get_loss_function(config.training_loss, ..., graph_data=graph_data,
loss_matrices_graph=projection_data.multiscale_loss_matrices_graph)
Could something like this be considered?
With that then the GraphFactory would live in anemoi-graphs and also the projection helpers that would be significantly simplified since the resolve it's done once. And similarly to the GraphCreator object we could have ProjectionCreator that then it's called in the factory to resolve the projection_metadata
| LOGGER = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class TrainerGraphDataFactory: |
There was a problem hiding this comment.
is this something that should live in training? I'd would say this seems to me a very graph specific functionality that would be better placed in anemoi-graphs and just simply take the part of the training config related to the dataset names that it needs rather than whole config
| from anemoi.models.layers.graph_provider import ProjectionGraphProvider | ||
| from anemoi.models.layers.sparse_projector import SparseProjector | ||
| from anemoi.models.utils.projection_helpers import ( | ||
| multiscale_loss_matrices_graph as derive_multiscale_loss_matrices_graph, |
There was a problem hiding this comment.
why changing the name here?
| @@ -0,0 +1,2 @@ | |||
| --- | |||
There was a problem hiding this comment.
why here we have this and the truncation/none.yaml is just an empty file?
| --- | ||
| defaults: | ||
| - _self_ | ||
|
|
There was a problem hiding this comment.
could we simplify the definition of these configs by doing something like?
num_scales: 4
base_num_nearest_neighbours: 16
base_sigma: 0.00471
scale_factor: 2 # optional, defaults to 2
edge_weight_attribute: gauss_weight
gaussian_norm: l1
and then just a helper that builds this?
def _expand_geometric_smoothers(projection_cfg: Any) -> dict[str, Any] | None:
"""Build an explicit smoothers dict from a compact geometric progression spec."""
num_scales = projection_cfg.get("num_scales")
if num_scales is None:
return None
base_neighbours = projection_cfg["base_num_nearest_neighbours"]
base_sigma = projection_cfg["base_sigma"]
scale_factor = projection_cfg.get("scale_factor", 2)
edge_weight_attribute = projection_cfg.get("edge_weight_attribute", 'gauss_weight')
gaussian_norm = projection_cfg.get("gaussian_norm", "l1")
smoothers = {}
for i in range(num_scales):
factor = scale_factor**i
smoothers[f"smooth_{factor}x"] = {
"edge_weight_attribute": edge_weight_attribute,
"gaussian_norm": gaussian_norm,
"num_nearest_neighbours": base_neighbours * factor,
"sigma": round(base_sigma * factor, 5),
}
return smoothers
If we still want to provide support for other cases that don't have this pattern we could still allow for a dictionary of specific smoothers layer? But do you thing we'd need that case?
| assert len(v) == len(info.data["loss_matrices"]), "weights must have same length as loss_matrices" | ||
| return v | ||
| @model_validator(mode="after") | ||
| def validate_matrix_source(self) -> Self: |
There was a problem hiding this comment.
Similarly this could be a function that gets called in the SchemaCommonMixin so it's checked in cases with config validation on and off
|
|
||
| smoothing_matrices: list[ProjectionGraphProvider | None] = [] | ||
| for entry in loss_matrices_graph: | ||
| if entry is None or entry is False or entry == "None": |
There was a problem hiding this comment.
could this be simplified with the suggestions of moving some of the validation checks to the SchemaMixin?
| NESTED_LOSS_CLASS_NAMES = { | ||
| "MultiscaleLossWrapper", | ||
| } | ||
| GRAPH_DATA_WRAPPER_CLASS_NAMES = { |
There was a problem hiding this comment.
What about an alternative solution here, where we could add a specific needs_graph_data: bool = True to the losses and then when instantiating the loss do:
if "per_scale_loss" in loss_config:
per_scale_loss_config = loss_config.pop("per_scale_loss")
per_scale_loss = get_loss_function(
OmegaConf.create(per_scale_loss_config),
scalers,
data_indices,
graph_data=graph_data,
**kwargs,
)
loss_config["per_scale_loss"] = per_scale_loss
target_cls = hydra.utils.get_class(loss_config["_target_"])
if getattr(target_cls, "needs_graph_data", False) and graph_data is not None:
kwargs["graph_data"] = graph_data
loss_function = instantiate(loss_config, _recursive_=False, **kwargs)
| from omegaconf import OmegaConf | ||
|
|
||
| DEFAULT_DATASET_NAME = "data" | ||
| from anemoi.models.utils.projection_helpers import DEFAULT_DATASET_NAME |
There was a problem hiding this comment.
I don't think this default name should be defined by the projection_helpers. Why did you choose this? I might be missing something.
Description
This PR adds new configuration support in the training and models components to enable graph-based settings for skipped noise injection and multiscale loss integration.
Main changed
Introduces graph configuration options to control:
Adds associated tests covering:
Implements support for noise projection in graph-based models
Includes additional fixes and refactors to improve graph config handling
Why?
These changes enhance flexibility in defining how graph-structured models handle noise injection and multiscale loss terms during training, leveraging existing functionality in anemoi-graphs to use sparse matrices generation. Facilitating reproducibility via graph recipes and reducing dependency on truncation files graph not properly tracked.
In addition, it will enable more involved interpolation workflows in the future.
📚 Documentation preview 📚: https://anemoi-training--855.org.readthedocs.build/en/855/
📚 Documentation preview 📚: https://anemoi-graphs--855.org.readthedocs.build/en/855/
📚 Documentation preview 📚: https://anemoi-models--855.org.readthedocs.build/en/855/