diff --git a/modelforge/potential/bayesian_models.py b/modelforge/potential/bayesian_models.py index 2a6f285b8..e0d3db7c5 100644 --- a/modelforge/potential/bayesian_models.py +++ b/modelforge/potential/bayesian_models.py @@ -2,6 +2,19 @@ import pyro from pyro.nn.module import to_pyro_module_ +import functools + +def rsetattr(obj, attr, val): + pre, _, post = attr.rpartition('.') + return setattr(rgetattr(obj, pre) if pre else obj, post, val) + +# using wonder's beautiful simplification: https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-objects/31174427?noredirect=1#comment86638618_31174427 + +def rgetattr(obj, attr, *args): + def _getattr(obj, attr): + return getattr(obj, attr, *args) + return functools.reduce(_getattr, [obj] + attr.split('.')) + def init_log_sigma(model, value): """Initializes the log_sigma parameters of a model @@ -14,17 +27,21 @@ def init_log_sigma(model, value): The value to initialize the log_sigma parameters to """ - log_sigma_params = { - name + "_log_sigma": pyro.nn.PyroParam( - torch.ones(param.shape) * value, + params = { + name: pyro.nn.PyroSample( + pyro.distributions.Normal( + torch.zeros(param.shape), + torch.ones(param.shape) * value, + ) ) for name, param in model.named_parameters() } - for name, param in log_sigma_params.items(): - setattr(model, name, param) + for name, param in model.named_parameters(): + rsetattr(model, name, params[name]) + -class BayesianAutoNormalPotential(torch.nn.Module): +class BayesianAutoNormalPotential(pyro.nn.PyroModule): """A Bayesian model with a normal prior and likelihood. Parameters @@ -39,19 +56,21 @@ class BayesianAutoNormalPotential(torch.nn.Module): provide the prior; if `y` is provided, provide the likelihood. """ def __init__( - self, + self, base_model, *args, **kwargs, ): super().__init__() + to_pyro_module_(base_model) + self.base_model = base_model log_sigma = kwargs.pop("log_sigma", 0.0) - init_log_sigma(self, log_sigma) + init_log_sigma(self.base_model, log_sigma) - def model(self, *args, **kwargs): + def forward(self, *args, **kwargs): """The model function. If no `y` argument is provided, provide the prior; if `y` is provided, provide the likelihood. """ y = kwargs.pop("y", None) - y_hat = self(*args, **kwargs) + y_hat = self.base_model(*args, **kwargs).E pyro.sample( "obs", pyro.distributions.Delta(y_hat), diff --git a/modelforge/tests/ase.toml b/modelforge/tests/ase.toml new file mode 100644 index 000000000..bbe625e8a --- /dev/null +++ b/modelforge/tests/ase.toml @@ -0,0 +1,5 @@ +1 = -1313.4668615546 +6 = -99366.70745535441 +7 = -143309.9379722722 +8 = -197082.0671774158 +9 = -261811.54555874597 diff --git a/modelforge/tests/test_bayesian_model.py b/modelforge/tests/test_bayesian_model.py new file mode 100644 index 000000000..88c37669a --- /dev/null +++ b/modelforge/tests/test_bayesian_model.py @@ -0,0 +1,26 @@ +import pytest +import pyro +from modelforge.potential import SchNet +from modelforge.potential.bayesian_models import BayesianAutoNormalPotential +from .helper_functions import SIMPLIFIED_INPUT_DATA + +@pytest.mark.parametrize("input_data", SIMPLIFIED_INPUT_DATA) +def test_bayesian_model(input_data): + # initialize a vanilla SchNet model + schnet = SchNet() + + # make a Bayesian model from the SchNet + schnet = BayesianAutoNormalPotential(schnet, log_sigma=1e-2).forward + guide = pyro.infer.autoguide.AutoDiagonalNormal(schnet) + assert guide is not None + + # run SVI using the Bayesian model + svi = pyro.infer.SVI( + model=schnet, + guide=guide, + optim=pyro.optim.Adam({"lr": 1e-3}), + loss=pyro.infer.Trace_ELBO(), + ) + + # calculate VI loss + svi.step(input_data, y=0.0)