55import random
66from typing import Tuple , Callable , List
77
8+ import torch
89import tqdm
910
11+ from pydgn .data .dataset import DatasetInterface
1012from pydgn .experiment .util import s2c
13+ from pydgn .model .interface import ModelInterface
1114from pydgn .static import *
1215
1316
@@ -349,6 +352,9 @@ def retrieve_experiments(model_selection_folder) -> List[dict]:
349352 """
350353 config_directory = os .path .join (model_selection_folder )
351354
355+ if not os .path .exists (config_directory ):
356+ raise FileNotFoundError (f"Directory not found: { config_directory } " )
357+
352358 folder_names = []
353359 for _ , dirs , _ in os .walk (config_directory ):
354360 for d in dirs :
@@ -448,3 +454,81 @@ def _finditem(obj, key):
448454 filtered_config_list .append (config )
449455
450456 return filtered_config_list
457+
458+
459+ def retrieve_best_configuration (model_selection_folder ) -> dict :
460+ """
461+ Once the experiments are done, retrieves the winning configuration from
462+ a specific model selection folder, and returns it as a dictionaries
463+
464+ :param model_selection_folder: path to the folder of a model selection,
465+ that is, your_results_path/..../MODEL_SELECTION/
466+ :return: a dictionary with info about the best configuration
467+ """
468+ config_directory = os .path .join (model_selection_folder )
469+
470+ if not os .path .exists (config_directory ):
471+ raise FileNotFoundError (f"Directory not found: { config_directory } " )
472+
473+ best_config = json .load (
474+ open (os .path .join (config_directory , "winner_config.json" ), "rb" )
475+ )
476+ return best_config
477+
478+
479+ def instantiate_dataset_from_config (config : dict ) -> DatasetInterface :
480+ """
481+ Instantiate a dataset from a configuration file.
482+ :param config (dict): the configuration file
483+ :return: an instance of DatasetInterface, i.e., the dataset
484+ """
485+ data_root = config [CONFIG ][DATA_ROOT ]
486+ dataset_name = config [CONFIG ][DATASET ]
487+ dataset_class = s2c (config [CONFIG ][DATASET_CLASS ])
488+
489+ return dataset_class (data_root , dataset_name )
490+
491+
492+ def instantiate_model_from_config (config : dict ,
493+ dataset : DatasetInterface ,
494+ config_type : str = "supervised_config" ) -> ModelInterface :
495+ """
496+ Instantiate a model from a configuration file.
497+ :param config (dict): the configuration file
498+ :param dataset (DatasetInterface): the dataset used in the experiment
499+ :param config_type (str): the type of model in ["supervised_config", "unsupervised_config"],
500+ as written on the YAML experiment configuration file. Defaults to "supervised_config"
501+ :return: an instance of ModelInterface, i.e., the model
502+ """
503+ config_ = config [CONFIG ][config_type ]
504+ readout_class = s2c (config_ ["readout" ])
505+
506+ model_class = s2c (config_ [MODEL ])
507+ model = model_class (dataset .dim_node_features ,
508+ dataset .dim_edge_features ,
509+ dataset .dim_target ,
510+ readout_class ,
511+ config = config_ )
512+
513+ return model
514+
515+
516+ def load_checkpoint (checkpoint_path : str , model : ModelInterface ,
517+ device : torch .device ):
518+ """
519+ Load a checkpoint from a checkpoint file into a model.
520+ :param checkpoint_path: the checkpoint file path
521+ :param model (ModelInterface): the model
522+ :param device (torch.device): the device, e.g, "cpu" or "cuda"
523+ """
524+ ckpt_dict = torch .load (
525+ checkpoint_path , map_location = "cpu" if device == "cpu" else None
526+ )
527+ model_state = ckpt_dict [MODEL_STATE ]
528+
529+ # Needed only when moving from cpu to cuda (due to changes in config
530+ # file). Move all parameters to cuda.
531+ for param in model_state .keys ():
532+ model_state [param ] = model_state [param ].to (device )
533+
534+ model .load_state_dict (model_state )
0 commit comments