Skip to content

Commit b8ba67d

Browse files
Panditjkelling
authored andcommitted
refactoring such that ML model related stuff is
out of the openpmd-streaming-continual-learning.py and put into a separate directory
1 parent 8277da1 commit b8ba67d

5 files changed

Lines changed: 288 additions & 286 deletions

File tree

scripts/job_hemera.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,5 @@ nvidia-smi
3232

3333
BATCH_SIZE="${BATCH_SIZE:-4}"
3434
mpiexec bash $INSITUML/scripts/ompi_CUDA_VISIBLE_DEVICES_wrapper.sh python $INSITUML/tools/openpmd-streaming-continual-learning.py --io_config $INSITUML/share/configs/io_config_hemera.py --model_config $INSITUML/share/configs/model_config.py
35+
# mpiexec bash $INSITUML/scripts/ompi_CUDA_VISIBLE_DEVICES_wrapper.sh python $INSITUML/tools/openpmd-streaming-continual-learning.py --io_config $INSITUML/share/configs/io_config_hemera.py --model_config $INSITUML/share/configs/model_config.py
3536
# mpiexec python $INSITUML/tools/openpmd-streaming-continual-learning.py --io_config $INSITUML/share/configs/io_config_hemera.py --model_config $INSITUML/share/configs/model_config.py --type_streamer offline

tools/models/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
"""
2+
Model-related modules for openpmd-learning.
3+
4+
This package contains:
5+
- Model architectures
6+
- Model initialization and loading functions
7+
- Configuration utilities
8+
"""
9+
10+
from models.architectures import ModelFinal
11+
from models.model_factory import load_objects, get_VAE_encoder_kwargs, get_VAE_decoder_kwargs
12+
13+
__all__ = [
14+
'ModelFinal',
15+
'load_objects',
16+
'get_VAE_encoder_kwargs',
17+
'get_VAE_decoder_kwargs',
18+
]

tools/models/architectures.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import torch.nn as nn
2+
from inSituML.loss_functions import EarthMoversLoss
3+
from inSituML.ks_models import INNModel
4+
5+
6+
class ModelFinal(nn.Module):
7+
def __init__(
8+
self,
9+
base_network,
10+
inner_model,
11+
loss_function_IM=None,
12+
weight_AE=1.0,
13+
weight_IM=1.0,
14+
):
15+
super().__init__()
16+
17+
self.base_network = base_network
18+
self.inner_model = inner_model
19+
self.loss_function_IM = loss_function_IM
20+
self.weight_AE = weight_AE
21+
self.weight_IM = weight_IM
22+
23+
def forward(self, x, y):
24+
25+
loss_AE, loss_ae_reconst, kl_loss, _, encoded = self.base_network(
26+
x
27+
)
28+
29+
# Check if the inner model is an instance of INNModel
30+
if isinstance(self.inner_model, INNModel):
31+
# Use the compute_losses function of INNModel
32+
(loss_IM, l_fit, l_latent, l_rev) = (
33+
self.inner_model.compute_losses(encoded, y)
34+
)
35+
total_loss = (
36+
loss_AE * self.weight_AE + loss_IM * self.weight_IM
37+
)
38+
39+
losses = {
40+
"total_loss": total_loss,
41+
"loss_AE": loss_AE * self.weight_AE,
42+
"loss_IM": loss_IM * self.weight_IM,
43+
"loss_ae_reconst": loss_ae_reconst,
44+
"kl_loss": kl_loss,
45+
"l_fit": l_fit,
46+
"l_latent": l_latent,
47+
"l_rev": l_rev,
48+
}
49+
50+
return losses
51+
else:
52+
# For other types of models, such as MAF
53+
loss_IM = self.inner_model(inputs=encoded, context=y)
54+
total_loss = (
55+
loss_AE * self.weight_AE + loss_IM * self.weight_IM
56+
)
57+
58+
losses = {
59+
"total_loss": total_loss,
60+
"loss_AE": loss_AE * self.weight_AE,
61+
"loss_IM": loss_IM * self.weight_IM,
62+
"loss_ae_reconst": loss_ae_reconst,
63+
"kl_loss": kl_loss,
64+
}
65+
66+
return losses
67+
68+
def reconstruct(self, x, y, num_samples=1):
69+
70+
if isinstance(self.inner_model, INNModel):
71+
lat_z_pred = self.inner_model(x, y, rev=True)
72+
y = self.base_network.decoder(lat_z_pred)
73+
else:
74+
lat_z_pred = self.inner_model.sample_pointcloud(
75+
num_samples=num_samples, cond=y
76+
)
77+
y = self.base_network.decoder(lat_z_pred)
78+
79+
return y, lat_z_pred

tools/models/model_factory.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
import os, sys
2+
import torch
3+
import torch.optim as optim
4+
from inSituML.utilities import MMD_multiscale, fit, load_checkpoint
5+
from inSituML.ks_models import INNModel
6+
7+
from inSituML.args_transform import MAPPING_TO_LOSS
8+
from inSituML.encoder_decoder import Encoder
9+
from inSituML.encoder_decoder import Conv3DDecoder
10+
from inSituML.loss_functions import EarthMoversLoss
11+
from inSituML.networks import VAE
12+
from models.architectures import ModelFinal
13+
14+
def get_world_size():
15+
"""Get the world size for distributed training."""
16+
world_size = None
17+
if "WORLD_SIZE" in os.environ:
18+
world_size = int(os.environ["WORLD_SIZE"])
19+
elif "SLURM_NTASKS" in os.environ:
20+
print(
21+
(
22+
"[WW] WORLD_SIZE not defined in env, "
23+
+ "falling back to SLURM_NTASKS."
24+
),
25+
file=sys.stderr,
26+
)
27+
world_size = int(os.environ["SLURM_NTASKS"])
28+
else:
29+
raise RuntimeError("cannot determine WORLD_SIZE")
30+
return world_size
31+
32+
def get_VAE_encoder_kwargs(io_config, model_config):
33+
"""Create encoder kwargs dictionary from configs"""
34+
return {
35+
"ae_config": "non_deterministic",
36+
"z_dim": model_config.latent_space_dims,
37+
"input_dim": io_config.ps_dims,
38+
"conv_layer_config": [16, 32, 64, 128, 256, 608],
39+
"conv_add_bn": False,
40+
"fc_layer_config": [544],
41+
}
42+
43+
def get_VAE_decoder_kwargs(io_config, model_config):
44+
"""Create decoder kwargs dictionary from configs"""
45+
return {
46+
"z_dim": model_config.latent_space_dims,
47+
"input_dim": io_config.ps_dims,
48+
"initial_conv3d_size": [16, 4, 4, 4],
49+
"add_batch_normalisation": False,
50+
"fc_layer_config": [1024],
51+
}
52+
53+
def load_objects(rank, io_config, model_config, world_size):
54+
"""Load and initialize model objects, optimizer, and scheduler."""
55+
56+
# Get configuration
57+
config = model_config.config
58+
59+
# Get model parameters
60+
VAE_encoder_kwargs = get_VAE_encoder_kwargs(io_config, model_config)
61+
VAE_decoder_kwargs = get_VAE_decoder_kwargs(io_config, model_config)
62+
63+
# Initialize loss function
64+
loss_fn_for_VAE = MAPPING_TO_LOSS[
65+
model_config.config["loss_function"]
66+
](**model_config.config["loss_kwargs"])
67+
68+
# Initialize VAE
69+
VAE_obj = VAE(
70+
encoder=Encoder,
71+
encoder_kwargs=VAE_encoder_kwargs,
72+
decoder=Conv3DDecoder,
73+
z_dim=model_config.latent_space_dims,
74+
decoder_kwargs=VAE_decoder_kwargs,
75+
loss_function=loss_fn_for_VAE,
76+
property_="momentum_force",
77+
particles_to_sample=io_config.number_of_particles,
78+
ae_config="non_deterministic",
79+
use_encoding_in_decoder=False,
80+
weight_kl=model_config.config["lambd_kl"],
81+
device=rank,
82+
)
83+
84+
# Initialize inner model
85+
inner_model = INNModel(
86+
ndim_tot=config["ndim_tot"],
87+
ndim_x=config["ndim_x"],
88+
ndim_y=config["ndim_y"],
89+
ndim_z=config["ndim_z"],
90+
loss_fit=fit,
91+
loss_latent=MMD_multiscale,
92+
loss_backward=MMD_multiscale,
93+
lambd_predict=config["lambd_predict"],
94+
lambd_latent=config["lambd_latent"],
95+
lambd_rev=config["lambd_rev"],
96+
zeros_noise_scale=config["zeros_noise_scale"],
97+
y_noise_scale=config["y_noise_scale"],
98+
hidden_size=config["hidden_size"],
99+
activation=config["activation"],
100+
num_coupling_layers=config["num_coupling_layers"],
101+
device=rank,
102+
)
103+
104+
# Initialize final model
105+
model = ModelFinal(
106+
VAE_obj,
107+
inner_model,
108+
EarthMoversLoss(),
109+
weight_AE=config["lambd_AE"],
110+
weight_IM=config["lambd_IM"],
111+
)
112+
113+
# Load a pre-trained model
114+
map_location = {"cuda:%d" % 0: "cuda:%d" % rank}
115+
if config["load_model"] is not None:
116+
original_state_dict = torch.load(
117+
config["load_model"], map_location=map_location
118+
)
119+
# updated_state_dict = {key.replace('VAE.', 'base_network.'):
120+
# value for key, value in original_state_dict.items()}
121+
model.load_state_dict(original_state_dict)
122+
print("Loaded pre-trained model successfully", flush=True)
123+
124+
elif config["load_model_checkpoint"] is not None:
125+
model, _, _, _, _, _ = load_checkpoint(
126+
config["load_model_checkpoint"],
127+
model,
128+
map_location=map_location,
129+
)
130+
print("Loaded model checkpoint successfully", flush=True)
131+
else:
132+
pass # run with random init
133+
134+
lr = config["lr"]
135+
bs_factor = (
136+
io_config.trainBatchBuffer_config["training_bs"] / 2 * world_size
137+
)
138+
lr = lr * config["lr_scaling"](bs_factor)
139+
print(
140+
"Scaling learning rate from {} to {} due to bs factor {}".format(
141+
config["lr"], lr, bs_factor
142+
),
143+
flush=True,
144+
)
145+
146+
optimizer = optim.Adam(
147+
[
148+
{
149+
"params": model.base_network.parameters(),
150+
"lr": lr * config["lrAEmult"],
151+
},
152+
{"params": model.inner_model.parameters()},
153+
], # model.parameters()
154+
lr=lr,
155+
betas=config["betas"],
156+
eps=config["eps"],
157+
weight_decay=config["weight_decay"],
158+
)
159+
if ("lr_annealingRate" not in config) or config[
160+
"lr_annealingRate"
161+
] is None:
162+
scheduler = None
163+
else:
164+
scheduler = torch.optim.lr_scheduler.StepLR(
165+
optimizer, step_size=500, gamma=config["lr_annealingRate"]
166+
)
167+
168+
return optimizer, scheduler, model

0 commit comments

Comments
 (0)