Add an option to stop PyroModules from sharing parameters#3149
Conversation
fritzo
left a comment
There was a problem hiding this comment.
Nice rethinking towards more idiomatic PyTorch!
|
|
||
| def __enter__(self): | ||
| if not self.active and _is_module_local_param_enabled(): | ||
| self._param_ctx = pyro.get_param_store().scope(state=self.param_state) |
There was a problem hiding this comment.
Persisting self.param_state like this (and in _pyro_set_supermodule below) seems to be a reasonable solution for the behavior of vanilla pyro.param statements. Values of these parameters are now local to the outermost PyroModule in a nested PyroModule instance.
fritzo
left a comment
There was a problem hiding this comment.
Hey thanks for your patience in reviewing this subtle PR. The ELBOModule changes look clean. I'm still working through understanding the module_local_param changes...
pyro/primitives.py
Outdated
| poutine.enable_validation(poutine_validation_status) | ||
|
|
||
|
|
||
| def enable_module_local_param(is_enabled: bool = False) -> None: |
There was a problem hiding this comment.
It would be nice to make it super clear that users can now decide between (i) a global param store or (ii) local nn.Module style parameters. Like maybe
with pyro.param_storage("local"): ...
with pyro.param_storage("global"): ...or pyro.disable_param_store(True) or pyro.enable_param_store(False). Whatever we call it I think it would be good in the first docstring sentence to mention the phrase "param store" and the word "nn.Module".
pyro/nn/module.py
Outdated
| if _is_module_local_param_enabled(): | ||
| with pyro.get_param_store().scope( | ||
| state=self._pyro_context.param_state | ||
| ) as vanilla_param_state: |
There was a problem hiding this comment.
nit: Would another word for "vanilla" be "global" or "global-only" or "raw" or "nonmodule" or something? We might want to avoid "vanilla" because PyTorch users new to Pyro might think of "vanilla" as "an nn.Param attribute of an nn.Module".
| :class:`~pyro.nn.PyroModule` s. | ||
|
|
||
| Users are therefore strongly encouraged to use this interface in conjunction | ||
| with :func:`~pyro.enable_module_local_param` which will override the default |
There was a problem hiding this comment.
nit: override -> disable or avoid?
fritzo
left a comment
There was a problem hiding this comment.
LGTM after minor comment on .set() vs .context() in tests
Replaces #2996
This PR adds two small related features for easier Pyro-PyTorch integration:
__call__method for the basepyro.infer.elbo.ELBOthat bindsELBOinstances to specificnn.Modulemodel/guide pairs in aModulethat exposes their PyTorch parametersPyroModuleinstances from sharing parameter values with one another through the global Pyro parameter store, and a primitive and context manager for toggling it. One context where this is useful is for workflows that involve multiple models and autoguides with overlapping parameter names.An edge case I haven't handled here is the behavior under the new local parameter setting of regular
pyro.paramstatements (as opposed toPyroParam) within aPyroModulethat don't have their data associated with any underlyingnn.Module. I've raised an error rather than attempt to get this working, since I think it's usually aPyroModuleprogramming anti-pattern to mix global and local parameter states in this way.I am also hopeful that these changes will simplify the use of Pyro with the PyTorch JIT and other PyTorch compilers, but I have left testing this for future work, since I suspect it will require additional engineering that is out of scope for this PR.
Tasks:
pyro.settingsmodule from Clean up handling of global settings #3152pyro.paramstatement inside aPyroModuleTested:
PyroModuletests intests/nn/test_module.pypyro.param