-
Notifications
You must be signed in to change notification settings - Fork 89
feat: automatic data nodes creation based on dataset entries #968
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f235998
4a3e142
9dba030
129bdd4
a910dec
9ede2cf
3652dcc
152c4b1
5d7ec3d
203937d
9b570f2
d79a2e7
1388b00
bb11369
0e24869
09e37bc
ae01d1e
abdabac
63608a0
e9c96fc
6177a1e
0daf8e3
6cac9c9
85a3769
d48f097
ac31369
adbbdf4
ea11dbf
c97a569
a8ea455
54602fb
3b2d9e7
0476fad
41bd9c9
ba8fef1
550fbcb
db30844
0c564f6
03c66ac
d2fa4ae
b05122c
f7eb9bf
94066a7
87e873d
6c639e8
5eb08bb
9296eab
6f6de32
f4e3b70
b7bfa2c
cc9249a
7a92c6c
9154334
a7e1198
036af51
b931c23
caa35df
de80e6e
93a024b
f11a389
ffe8c5a
900c452
6dfedd6
923834d
5421f00
54cd734
7bb524c
a878757
fd103ae
4de5106
ede7639
d46a85f
2062bfc
7932b26
04b06b6
8fba12f
ce99a94
02f7c85
eed77ac
2919052
219b41b
ca7dbb7
d2452b3
af48de8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| # (C) Copyright 2025 Anemoi contributors. | ||
| # | ||
| # This software is licensed under the terms of the Apache Licence Version 2.0 | ||
| # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
| # | ||
| # In applying this licence, ECMWF does not waive the privileges and immunities | ||
| # granted to it by virtue of its status as an intergovernmental organisation | ||
| # nor does it submit to any jurisdiction. | ||
| from .compile import NodesAxis | ||
| from .compile import concat_edges | ||
| from .compile import get_distributed_device | ||
| from .compile import get_edge_attributes | ||
| from .compile import get_grid_reference_distance | ||
| from .compile import get_nearest_neighbour | ||
| from .compile import haversine_distance | ||
|
|
||
| __all__ = [ | ||
| "get_distributed_device", | ||
| "get_nearest_neighbour", | ||
| "get_grid_reference_distance", | ||
| "concat_edges", | ||
| "haversine_distance", | ||
| "NodesAxis", | ||
| "get_edge_attributes", | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,84 @@ | ||
| # (C) Copyright 2026- Anemoi contributors. | ||
| # | ||
| # This software is licensed under the terms of the Apache Licence Version 2.0 | ||
| # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
| # | ||
| # In applying this licence, ECMWF does not waive the privileges and immunities | ||
| # granted to it by virtue of its status as an intergovernmental organisation | ||
| # nor does it submit to any jurisdiction. | ||
|
|
||
| import logging | ||
|
|
||
| from omegaconf import DictConfig | ||
| from omegaconf import OmegaConf | ||
|
|
||
| LOGGER = logging.getLogger(__name__) | ||
|
|
||
| DEFAULT_DATASET_NAME = "data" | ||
| DEFAULT_NODE_ATTR = { | ||
| "area_weight": { | ||
| "_target_": "anemoi.graphs.nodes.attributes.SphericalAreaWeights", | ||
| "norm": "unit-max", | ||
| "fill_value": 0, | ||
| } | ||
| } | ||
|
|
||
|
|
||
| def get_multiple_datasets_config(config: DictConfig, default_dataset_name: str = DEFAULT_DATASET_NAME) -> dict: | ||
| """Get multiple datasets configuration for old configs. | ||
| Use /'data/' as the default dataset name. | ||
| """ | ||
| if "datasets" in config: | ||
| if isinstance(config, dict): | ||
| return config["datasets"] | ||
| return config.datasets | ||
|
|
||
| return OmegaConf.create({default_dataset_name: config}) | ||
|
|
||
|
|
||
| def integrate_data_nodes_in_config(config: DictConfig): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure how it's handled elsewhere in the code, but I would suggest noting here that |
||
|
|
||
| # Introduce Data Nodes in graph config | ||
| train_configs = get_multiple_datasets_config(config.dataloader.training) | ||
|
|
||
| val_configs = {} | ||
| if hasattr(config.dataloader, "validation"): | ||
| val_configs = get_multiple_datasets_config(config.dataloader.validation) | ||
|
|
||
| test_configs = {} | ||
| if hasattr(config.dataloader, "test"): | ||
| test_configs = get_multiple_datasets_config(config.dataloader.test) | ||
|
|
||
| dataset_configs = { | ||
| **train_configs, | ||
| **val_configs, | ||
| **test_configs, | ||
| } | ||
|
|
||
| for dataset_name, dataset_config in dataset_configs.items(): | ||
| if dataset_name not in config.graph.nodes: | ||
| LOGGER.info("Creating graph node entry for dataset '%s'", dataset_name) | ||
| dataset_reader_config = dataset_config.dataset_config | ||
| if isinstance(dataset_reader_config, (DictConfig, dict)): | ||
| if "dataset" not in dataset_reader_config: | ||
| msg = f"Dataset '{dataset_name}' is missing 'dataset' key." | ||
| raise ValueError(msg) | ||
| dataset_source = dataset_reader_config["dataset"] | ||
| else: | ||
| dataset_source = dataset_reader_config | ||
|
|
||
| if dataset_source is None: | ||
| msg = f"Dataset source is None for dataset '{dataset_name}'. Check dataloader.dataset_config.dataset." | ||
| raise ValueError(msg) | ||
|
|
||
| # Add dataset nodes from dataloader into graph recepe | ||
| config.graph.nodes[dataset_name] = { | ||
| "node_builder": {"_target_": "anemoi.graphs.nodes.AnemoiDatasetNodes", "dataset": dataset_source}, | ||
| "attributes": ( | ||
| config.graph.attributes.nodes if hasattr(config.graph, "attributes") else DEFAULT_NODE_ATTR | ||
| ), | ||
| } | ||
| else: | ||
| LOGGER.info("Graph node entry for dataset '%s' is already specified in the config.", dataset_name) | ||
|
|
||
| return config | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -29,6 +29,7 @@ | |||||
| from pytorch_lightning.utilities.rank_zero import rank_zero_only | ||||||
| from torch_geometric.data import HeteroData | ||||||
|
|
||||||
| from anemoi.graphs.utils.config import integrate_data_nodes_in_config | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Is it possible to re-use the existent copy of this function?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ideally this was done to remove dependency of graphs from models! |
||||||
| from anemoi.models.utils.compile import mark_for_compilation | ||||||
| from anemoi.training.data.datamodule import AnemoiDatasetsDataModule | ||||||
| from anemoi.training.diagnostics.callbacks import get_callbacks | ||||||
|
|
@@ -166,6 +167,9 @@ def graph_data(self) -> HeteroData: | |||||
| else: | ||||||
| graph_filename = None | ||||||
|
|
||||||
| # Introduce Data Nodes in graph config | ||||||
| self.config = integrate_data_nodes_in_config(self.config) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you manipulating the global Did you consider adding a flag to the configuration that explicitly enables the automatic generation of
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very good point @MeraX, will look into this! |
||||||
|
|
||||||
| # Create new graph | ||||||
| from anemoi.graphs.create import GraphCreator | ||||||
|
|
||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -162,6 +162,7 @@ def global_config( | |
| "node_weights", | ||
| "output_steps", | ||
| ] | ||
|
|
||
| return cfg, url, model_architecture | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR also introduces yet another place where defaults are stored (
DEFAULT_NODE_ATTR). Are these defaults actually necessary or would it be possible to set these defaults simply as the default python argument values ofanemoi.graphs.nodes.AnemoiDatasetNodes?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would agree, these global variables seem unecessary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At the time being "area_weight" is a necessary attribute for any data nodes, this isn't enforced in the code, and test will fail if not provided.
I could try and find another workaround by enforcing creation somewhere else, but not sure where. Suggestions?