Add modified Bayesian regression tutorial with more direct PyTorch usage#2996
Open
Add modified Bayesian regression tutorial with more direct PyTorch usage#2996
Conversation
… than the global parameter store. Add backwards-compatible __call__ method to pyro.infer.ELBO that returns a Module bound to a specific model and guide, allowing direct use of the PyTorch JIT API. Fork Bayesian regression tutorial into a PyTorch API usage tutorial to illustrate a PyTorch-native programming style facilitated by these changes and PyroModule.
Member
Author
|
For context, here is a condensed training loop from the tutorial notebook that I was trying to enable: # new: keep PyroParams out of the global parameter store
pyro.enable_module_local_param(True)
class BayesianRegression(PyroModule):
...
# Create fresh copies of model, guide, elbo
model = BayesianRegression(3, 1)
guide = AutoDiagonalNormal(model)
elbo = Trace_ELBO(num_particles=10)
# new: bind elbo to (model, guide) pair
elbo = elbo(model, guide)
# Populate guide parameters
elbo(x_data, y_data);
# new: use torch.optim directly
optim = torch.optim.Adam(guide.parameters(), lr=0.03)
# Temporarily disable runtime validation and compile ELBO
with pyro.validation_enabled(False):
# new: use torch.jit.trace directly
elbo = torch.jit.trace(elbo, (x_data, y_data), check_trace=False, strict=False)
# optimize
for j in range(1500):
loss = elbo(x_data, y_data)
optim.zero_grad()
loss.backward()
optim.step()
# prediction
predict_fn = Predictive(model, guide=guide, num_samples=800)
# new: use torch.jit.trace directly
predict_fn = torch.jit.trace(predict_fn, (x_data,), check_trace=False, strict=False)
samples = predict_fn(x_data) |
This was referenced Oct 20, 2022
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR is an attempt at making a couple small changes to the Pyro and
PyroModuleAPI to makePyroModulemore compatible with vanilla PyTorch programming idioms. The API changes are simple, although the implementations insidePyroModuleare a bit hacky and may not yet be correct.Changes:
pyro.enable_module_local_param()forPyroModuleparameters to be stored locally, rather than the global parameter store. Currently implemented by associating a newParamStoreDictobject with eachPyroModuleinstance, which may not be ideal.__call__method topyro.infer.ELBOthat returns atorch.nn.Modulebound to a specific model and guide, allowing direct use of the PyTorch JIT API (e.g.torch.jit.trace)PyroModule