-
-
Notifications
You must be signed in to change notification settings - Fork 1k
Add an option to stop PyroModules from sharing parameters #3149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
06a9122
fb5d282
ac21b8d
c789c3e
9835ee9
2cb054c
755ae4a
5f7cd94
c2eb114
a3a848c
81b1191
2c03fa0
401e9c6
446e8a5
84e3cde
d7c4777
8a42dbf
41a157c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,13 +5,26 @@ | |
| import warnings | ||
| from abc import ABCMeta, abstractmethod | ||
|
|
||
| import torch | ||
|
|
||
| import pyro | ||
| import pyro.poutine as poutine | ||
| from pyro.infer.util import is_validation_enabled | ||
| from pyro.poutine.util import prune_subsample_sites | ||
| from pyro.util import check_site_shape | ||
|
|
||
|
|
||
| class ELBOModule(torch.nn.Module): | ||
| def __init__(self, model: torch.nn.Module, guide: torch.nn.Module, elbo: "ELBO"): | ||
| super().__init__() | ||
| self.model = model | ||
| self.guide = guide | ||
| self.elbo = elbo | ||
|
|
||
| def forward(self, *args, **kwargs): | ||
| return self.elbo.differentiable_loss(self.model, self.guide, *args, **kwargs) | ||
|
|
||
|
|
||
| class ELBO(object, metaclass=ABCMeta): | ||
| """ | ||
| :class:`ELBO` is the top-level interface for stochastic variational | ||
|
|
@@ -23,6 +36,40 @@ class ELBO(object, metaclass=ABCMeta): | |
| :class:`~pyro.infer.tracegraph_elbo.TraceGraph_ELBO`, or | ||
| :class:`~pyro.infer.traceenum_elbo.TraceEnum_ELBO`. | ||
|
|
||
| .. note:: Derived classes now provide a more idiomatic PyTorch interface via | ||
| :meth:`__call__` for (model, guide) pairs that are :class:`~torch.nn.Module` s, | ||
| which is useful for integrating Pyro's variational inference tooling with | ||
| standard PyTorch interfaces like :class:`~torch.optim.Optimizer` s | ||
| and the large ecosystem of libraries like PyTorch Lightning | ||
| and the PyTorch JIT that work with these interfaces:: | ||
|
|
||
| model = Model() | ||
| guide = pyro.infer.autoguide.AutoNormal(model) | ||
|
|
||
| elbo_ = pyro.infer.Trace_ELBO(num_particles=10) | ||
|
|
||
| # Fix the model/guide pair | ||
| elbo = elbo_(model, guide) | ||
|
|
||
| # perform any data-dependent initialization | ||
| elbo(data) | ||
|
|
||
| optim = torch.optim.Adam(elbo.parameters(), lr=0.001) | ||
|
|
||
| for _ in range(100): | ||
| optim.zero_grad() | ||
| loss = elbo(data) | ||
| loss.backward() | ||
| optim.step() | ||
|
|
||
| Note that Pyro's global parameter store may cause this new interface to | ||
| behave unexpectedly relative to standard PyTorch when working with | ||
| :class:`~pyro.nn.PyroModule` s. | ||
|
|
||
| Users are therefore strongly encouraged to use this interface in conjunction | ||
| with :func:`~pyro.enable_module_local_param` which will override the default | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: override -> disable or avoid? |
||
| implicit sharing of parameters across :class:`~pyro.nn.PyroModule` instances. | ||
|
|
||
| :param num_particles: The number of particles/samples used to form the ELBO | ||
| (gradient) estimators. | ||
| :param int max_plate_nesting: Optional bound on max number of nested | ||
|
|
@@ -86,6 +133,13 @@ def __init__( | |
| self.jit_options = jit_options | ||
| self.tail_adaptive_beta = tail_adaptive_beta | ||
|
|
||
| def __call__(self, model: torch.nn.Module, guide: torch.nn.Module) -> ELBOModule: | ||
| """ | ||
| Given a model and guide, returns a :class:`~torch.nn.Module` which | ||
| computes the ELBO loss when called with arguments to the model and guide. | ||
| """ | ||
| return ELBOModule(model, guide, self) | ||
|
|
||
| def _guess_max_plate_nesting(self, model, guide, args, kwargs): | ||
| """ | ||
| Guesses max_plate_nesting by running the (model,guide) pair once | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,17 @@ | |
| from pyro.ops.provenance import detach_provenance | ||
| from pyro.poutine.runtime import _PYRO_PARAM_STORE | ||
|
|
||
| _MODULE_LOCAL_PARAMS: bool = False | ||
|
|
||
|
|
||
| @pyro.settings.register("module_local_params", __name__, "_MODULE_LOCAL_PARAMS") | ||
| def _validate_module_local_params(value: bool) -> None: | ||
| assert isinstance(value, bool) | ||
|
|
||
|
|
||
| def _is_module_local_param_enabled() -> bool: | ||
| return pyro.settings.get("module_local_params") | ||
|
|
||
|
|
||
| class PyroParam(namedtuple("PyroParam", ("init_value", "constraint", "event_dim"))): | ||
| """ | ||
|
|
@@ -178,15 +189,23 @@ def __init__(self): | |
| self.active = 0 | ||
| self.cache = {} | ||
| self.used = False | ||
| if _is_module_local_param_enabled(): | ||
| self.param_state = {"params": {}, "constraints": {}} | ||
|
|
||
| def __enter__(self): | ||
| if not self.active and _is_module_local_param_enabled(): | ||
| self._param_ctx = pyro.get_param_store().scope(state=self.param_state) | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Persisting |
||
| self.param_state = self._param_ctx.__enter__() | ||
| self.active += 1 | ||
| self.used = True | ||
|
|
||
| def __exit__(self, type, value, traceback): | ||
| self.active -= 1 | ||
| if not self.active: | ||
| self.cache.clear() | ||
| if _is_module_local_param_enabled(): | ||
| self._param_ctx.__exit__(type, value, traceback) | ||
| del self._param_ctx | ||
|
|
||
| def get(self, name): | ||
| if self.active: | ||
|
|
@@ -409,6 +428,8 @@ def named_pyro_params(self, prefix="", recurse=True): | |
| yield elem | ||
|
|
||
| def _pyro_set_supermodule(self, name, context): | ||
| if _is_module_local_param_enabled() and pyro.settings.get("validate_poutine"): | ||
| self._check_module_local_param_usage() | ||
| self._pyro_name = name | ||
| self._pyro_context = context | ||
| for key, value in self._modules.items(): | ||
|
|
@@ -424,7 +445,26 @@ def _pyro_get_fullname(self, name): | |
|
|
||
| def __call__(self, *args, **kwargs): | ||
| with self._pyro_context: | ||
| return super().__call__(*args, **kwargs) | ||
| result = super().__call__(*args, **kwargs) | ||
| if ( | ||
| pyro.settings.get("validate_poutine") | ||
| and not self._pyro_context.active | ||
| and _is_module_local_param_enabled() | ||
| ): | ||
| self._check_module_local_param_usage() | ||
| return result | ||
|
|
||
| def _check_module_local_param_usage(self) -> None: | ||
| self_nn_params = set(id(p) for p in self.parameters()) | ||
| self_pyro_params = set( | ||
| id(p if not hasattr(p, "unconstrained") else p.unconstrained()) | ||
| for p in self._pyro_context.param_state["params"].values() | ||
| ) | ||
| if not self_pyro_params <= self_nn_params: | ||
| raise NotImplementedError( | ||
| "Support for global pyro.param statements in PyroModules " | ||
| "with local param mode enabled is not yet implemented." | ||
| ) | ||
|
|
||
| def __getattr__(self, name): | ||
| # PyroParams trigger pyro.param statements. | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.