1+ import os , sys
2+ import torch
3+ import torch .optim as optim
4+ from inSituML .utilities import MMD_multiscale , fit , load_checkpoint
5+ from inSituML .ks_models import INNModel
6+
7+ from inSituML .args_transform import MAPPING_TO_LOSS
8+ from inSituML .encoder_decoder import Encoder
9+ from inSituML .encoder_decoder import Conv3DDecoder
10+ from inSituML .loss_functions import EarthMoversLoss
11+ from inSituML .networks import VAE
12+ from models .architectures import ModelFinal
13+
14+ def get_world_size ():
15+ """Get the world size for distributed training."""
16+ world_size = None
17+ if "WORLD_SIZE" in os .environ :
18+ world_size = int (os .environ ["WORLD_SIZE" ])
19+ elif "SLURM_NTASKS" in os .environ :
20+ print (
21+ (
22+ "[WW] WORLD_SIZE not defined in env, "
23+ + "falling back to SLURM_NTASKS."
24+ ),
25+ file = sys .stderr ,
26+ )
27+ world_size = int (os .environ ["SLURM_NTASKS" ])
28+ else :
29+ raise RuntimeError ("cannot determine WORLD_SIZE" )
30+ return world_size
31+
32+ def get_VAE_encoder_kwargs (io_config , model_config ):
33+ """Create encoder kwargs dictionary from configs"""
34+ return {
35+ "ae_config" : "non_deterministic" ,
36+ "z_dim" : model_config .latent_space_dims ,
37+ "input_dim" : io_config .ps_dims ,
38+ "conv_layer_config" : [16 , 32 , 64 , 128 , 256 , 608 ],
39+ "conv_add_bn" : False ,
40+ "fc_layer_config" : [544 ],
41+ }
42+
43+ def get_VAE_decoder_kwargs (io_config , model_config ):
44+ """Create decoder kwargs dictionary from configs"""
45+ return {
46+ "z_dim" : model_config .latent_space_dims ,
47+ "input_dim" : io_config .ps_dims ,
48+ "initial_conv3d_size" : [16 , 4 , 4 , 4 ],
49+ "add_batch_normalisation" : False ,
50+ "fc_layer_config" : [1024 ],
51+ }
52+
53+ def load_objects (rank , io_config , model_config , world_size ):
54+ """Load and initialize model objects, optimizer, and scheduler."""
55+
56+ # Get configuration
57+ config = model_config .config
58+
59+ # Get model parameters
60+ VAE_encoder_kwargs = get_VAE_encoder_kwargs (io_config , model_config )
61+ VAE_decoder_kwargs = get_VAE_decoder_kwargs (io_config , model_config )
62+
63+ # Initialize loss function
64+ loss_fn_for_VAE = MAPPING_TO_LOSS [
65+ model_config .config ["loss_function" ]
66+ ](** model_config .config ["loss_kwargs" ])
67+
68+ # Initialize VAE
69+ VAE_obj = VAE (
70+ encoder = Encoder ,
71+ encoder_kwargs = VAE_encoder_kwargs ,
72+ decoder = Conv3DDecoder ,
73+ z_dim = model_config .latent_space_dims ,
74+ decoder_kwargs = VAE_decoder_kwargs ,
75+ loss_function = loss_fn_for_VAE ,
76+ property_ = "momentum_force" ,
77+ particles_to_sample = io_config .number_of_particles ,
78+ ae_config = "non_deterministic" ,
79+ use_encoding_in_decoder = False ,
80+ weight_kl = model_config .config ["lambd_kl" ],
81+ device = rank ,
82+ )
83+
84+ # Initialize inner model
85+ inner_model = INNModel (
86+ ndim_tot = config ["ndim_tot" ],
87+ ndim_x = config ["ndim_x" ],
88+ ndim_y = config ["ndim_y" ],
89+ ndim_z = config ["ndim_z" ],
90+ loss_fit = fit ,
91+ loss_latent = MMD_multiscale ,
92+ loss_backward = MMD_multiscale ,
93+ lambd_predict = config ["lambd_predict" ],
94+ lambd_latent = config ["lambd_latent" ],
95+ lambd_rev = config ["lambd_rev" ],
96+ zeros_noise_scale = config ["zeros_noise_scale" ],
97+ y_noise_scale = config ["y_noise_scale" ],
98+ hidden_size = config ["hidden_size" ],
99+ activation = config ["activation" ],
100+ num_coupling_layers = config ["num_coupling_layers" ],
101+ device = rank ,
102+ )
103+
104+ # Initialize final model
105+ model = ModelFinal (
106+ VAE_obj ,
107+ inner_model ,
108+ EarthMoversLoss (),
109+ weight_AE = config ["lambd_AE" ],
110+ weight_IM = config ["lambd_IM" ],
111+ )
112+
113+ # Load a pre-trained model
114+ map_location = {"cuda:%d" % 0 : "cuda:%d" % rank }
115+ if config ["load_model" ] is not None :
116+ original_state_dict = torch .load (
117+ config ["load_model" ], map_location = map_location
118+ )
119+ # updated_state_dict = {key.replace('VAE.', 'base_network.'):
120+ # value for key, value in original_state_dict.items()}
121+ model .load_state_dict (original_state_dict )
122+ print ("Loaded pre-trained model successfully" , flush = True )
123+
124+ elif config ["load_model_checkpoint" ] is not None :
125+ model , _ , _ , _ , _ , _ = load_checkpoint (
126+ config ["load_model_checkpoint" ],
127+ model ,
128+ map_location = map_location ,
129+ )
130+ print ("Loaded model checkpoint successfully" , flush = True )
131+ else :
132+ pass # run with random init
133+
134+ lr = config ["lr" ]
135+ bs_factor = (
136+ io_config .trainBatchBuffer_config ["training_bs" ] / 2 * world_size
137+ )
138+ lr = lr * config ["lr_scaling" ](bs_factor )
139+ print (
140+ "Scaling learning rate from {} to {} due to bs factor {}" .format (
141+ config ["lr" ], lr , bs_factor
142+ ),
143+ flush = True ,
144+ )
145+
146+ optimizer = optim .Adam (
147+ [
148+ {
149+ "params" : model .base_network .parameters (),
150+ "lr" : lr * config ["lrAEmult" ],
151+ },
152+ {"params" : model .inner_model .parameters ()},
153+ ], # model.parameters()
154+ lr = lr ,
155+ betas = config ["betas" ],
156+ eps = config ["eps" ],
157+ weight_decay = config ["weight_decay" ],
158+ )
159+ if ("lr_annealingRate" not in config ) or config [
160+ "lr_annealingRate"
161+ ] is None :
162+ scheduler = None
163+ else :
164+ scheduler = torch .optim .lr_scheduler .StepLR (
165+ optimizer , step_size = 500 , gamma = config ["lr_annealingRate" ]
166+ )
167+
168+ return optimizer , scheduler , model
0 commit comments