Skip to content
Open
39 changes: 29 additions & 10 deletions modelforge/potential/bayesian_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,19 @@
import pyro
from pyro.nn.module import to_pyro_module_

import functools

def rsetattr(obj, attr, val):
pre, _, post = attr.rpartition('.')
return setattr(rgetattr(obj, pre) if pre else obj, post, val)

# using wonder's beautiful simplification: https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-objects/31174427?noredirect=1#comment86638618_31174427

def rgetattr(obj, attr, *args):
def _getattr(obj, attr):
return getattr(obj, attr, *args)
return functools.reduce(_getattr, [obj] + attr.split('.'))

def init_log_sigma(model, value):
"""Initializes the log_sigma parameters of a model

Expand All @@ -14,17 +27,21 @@ def init_log_sigma(model, value):
The value to initialize the log_sigma parameters to

"""
log_sigma_params = {
name + "_log_sigma": pyro.nn.PyroParam(
torch.ones(param.shape) * value,
params = {
name: pyro.nn.PyroSample(
pyro.distributions.Normal(
torch.zeros(param.shape),
torch.ones(param.shape) * value,
)
)
for name, param in model.named_parameters()
}

for name, param in log_sigma_params.items():
setattr(model, name, param)
for name, param in model.named_parameters():
rsetattr(model, name, params[name])


class BayesianAutoNormalPotential(torch.nn.Module):
class BayesianAutoNormalPotential(pyro.nn.PyroModule):
"""A Bayesian model with a normal prior and likelihood.

Parameters
Expand All @@ -39,19 +56,21 @@ class BayesianAutoNormalPotential(torch.nn.Module):
provide the prior; if `y` is provided, provide the likelihood.
"""
def __init__(
self,
self, base_model,
*args, **kwargs,
):
super().__init__()
to_pyro_module_(base_model)
self.base_model = base_model
log_sigma = kwargs.pop("log_sigma", 0.0)
init_log_sigma(self, log_sigma)
init_log_sigma(self.base_model, log_sigma)

def model(self, *args, **kwargs):
def forward(self, *args, **kwargs):
"""The model function. If no `y` argument is provided,
provide the prior; if `y` is provided, provide the likelihood.
"""
y = kwargs.pop("y", None)
y_hat = self(*args, **kwargs)
y_hat = self.base_model(*args, **kwargs).E
pyro.sample(
"obs",
pyro.distributions.Delta(y_hat),
Expand Down
5 changes: 5 additions & 0 deletions modelforge/tests/ase.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
1 = -1313.4668615546
6 = -99366.70745535441
7 = -143309.9379722722
8 = -197082.0671774158
9 = -261811.54555874597
26 changes: 26 additions & 0 deletions modelforge/tests/test_bayesian_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import pytest
import pyro
from modelforge.potential import SchNet
from modelforge.potential.bayesian_models import BayesianAutoNormalPotential
from .helper_functions import SIMPLIFIED_INPUT_DATA

@pytest.mark.parametrize("input_data", SIMPLIFIED_INPUT_DATA)
def test_bayesian_model(input_data):
# initialize a vanilla SchNet model
schnet = SchNet()

# make a Bayesian model from the SchNet
schnet = BayesianAutoNormalPotential(schnet, log_sigma=1e-2).forward
guide = pyro.infer.autoguide.AutoDiagonalNormal(schnet)
assert guide is not None

# run SVI using the Bayesian model
svi = pyro.infer.SVI(
model=schnet,
guide=guide,
optim=pyro.optim.Adam({"lr": 1e-3}),
loss=pyro.infer.Trace_ELBO(),
)

# calculate VI loss
svi.step(input_data, y=0.0)