Skip to content

Commit 45d98fb

Browse files
authored
Update more examples to use arviz 1.0 (#833)
* update more examples to use arviz 1.0 * remove variable added unintentionally * remove aeppl mention * rerun golf example
1 parent 270f33a commit 45d98fb

13 files changed

+2627
-1802
lines changed

examples/case_studies/GEV.ipynb

Lines changed: 492 additions & 95 deletions
Large diffs are not rendered by default.

examples/case_studies/GEV.myst.md

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ jupytext:
55
format_name: myst
66
format_version: 0.13
77
kernelspec:
8-
display_name: default
8+
display_name: eabm
99
language: python
1010
name: python3
1111
---
@@ -38,14 +38,13 @@ Note that this parametrization of the shape parameter $\xi$ is opposite in sign
3838
We will use the example of the Port Pirie annual maximum sea-level data used in {cite:t}`coles2001gev`, and compare with the frequentist results presented there.
3939

4040
```{code-cell} ipython3
41-
import arviz as az
41+
import arviz.preview as az
4242
import matplotlib.pyplot as plt
4343
import numpy as np
4444
import pymc as pm
4545
import pymc_extras.distributions as pmx
46-
import pytensor.tensor as pt
4746
48-
from arviz.plots import plot_utils as azpu
47+
az.style.use("arviz-variat")
4948
```
5049

5150
## Data
@@ -112,18 +111,13 @@ Let's get a feel for how well our selected priors cover the range of the data:
112111

113112
```{code-cell} ipython3
114113
idata = pm.sample_prior_predictive(samples=1000, model=model)
115-
az.plot_ppc(idata, group="prior", figsize=(12, 6))
116-
ax = plt.gca()
117-
ax.set_xlim([2, 6])
118-
ax.set_ylim([0, 2]);
114+
az.plot_ppc_dist(idata, group="prior_predictive", kind="ecdf")
119115
```
120116

121117
And we can look at the sampled values of the parameters, using the `plot_posterior` function, but passing in the `idata` object and specifying the `group` to be `"prior"`:
122118

123119
```{code-cell} ipython3
124-
az.plot_posterior(
125-
idata, group="prior", var_names=["μ", "σ", "ξ"], hdi_prob="hide", point_estimate=None
126-
);
120+
az.plot_dist(idata, group="prior", var_names=["μ", "σ", "ξ"]);
127121
```
128122

129123
## Inference
@@ -144,7 +138,7 @@ idata.extend(trace)
144138
```
145139

146140
```{code-cell} ipython3
147-
az.plot_trace(idata, var_names=["μ", "σ", "ξ"], figsize=(12, 12));
141+
az.plot_trace_dist(idata, var_names=["μ", "σ", "ξ"]);
148142
```
149143

150144
### Divergences
@@ -159,27 +153,32 @@ The trace exhibits divergences (usually). The HMC/NUTS sampler can have problems
159153
The 95% credible interval range of the parameter estimates is:
160154

161155
```{code-cell} ipython3
162-
az.hdi(idata, hdi_prob=0.95)
156+
az.hdi(idata, prob=0.95)
163157
```
164158

165159
And examine the prediction distribution, considering parameter variability (and without needing to assume normality):
166160

167161
```{code-cell} ipython3
168-
az.plot_posterior(idata, hdi_prob=0.95, var_names=["z_p"], round_to=4);
162+
az.plot_dist(
163+
idata,
164+
ci_prob=0.95,
165+
var_names=["z_p"],
166+
stats={"point_estimate": {"round_to": 2}},
167+
);
169168
```
170169

171170
And let's compare the prior and posterior predictions of $z_p$ to see how the data has influenced things:
172171

173172
```{code-cell} ipython3
174-
az.plot_dist_comparison(idata, var_names=["z_p"]);
173+
az.plot_prior_posterior(idata, var_names=["z_p"]);
175174
```
176175

177176
## Comparison
178177
To compare with the results given in {cite:t}`coles2001gev`, we approximate the maximum likelihood estimates (MLE) using the mode of the posterior distributions (the *maximum a posteriori* or MAP estimate). These are close when the prior is reasonably flat around the posterior estimate.
179178

180179
The MLE results given in {cite:t}`coles2001gev` are:
181180

182-
$$\left(\hat{\mu}, \hat{\sigma}, \hat{\xi} \right) = \left( 3.87, 0.198, -0.050 \right) $$
181+
$$\left(\hat{\mu}, \hat{\sigma}, \hat{\xi} \right) = \left( 3.87, 0.198, -0.050 \right)$$
183182

184183

185184
And the variance-covariance matrix of the estimates is:
@@ -189,13 +188,8 @@ $$ V = \left[ \begin{array} 0.000780 & 0.000197 & -0.00107 \\
189188
-0.00107 & -0.000778 & 0.00965
190189
\end{array} \right] $$
191190

192-
193-
Note that extracting the MLE estimates from our inference involves accessing some of the Arviz back end functions to bash the xarray into something examinable:
194-
195191
```{code-cell} ipython3
196-
_, vals = az.sel_utils.xarray_to_ndarray(idata["posterior"], var_names=["μ", "σ", "ξ"])
197-
mle = [azpu.calculate_point_estimate("mode", val) for val in vals]
198-
mle
192+
az.mode(idata, var_names=["μ", "σ", "ξ"])
199193
```
200194

201195
```{code-cell} ipython3
@@ -207,12 +201,17 @@ The results are a good match, but the benefit of doing this in a Bayesian settin
207201
Finally, we examine the pairs plots and see where any difficulties in inference lie using the divergences
208202

209203
```{code-cell} ipython3
210-
az.plot_pair(idata, var_names=["μ", "σ", "ξ"], kind="kde", marginals=True, divergences=True);
204+
az.plot_pair(
205+
idata,
206+
var_names=["μ", "σ", "ξ"],
207+
visuals={"divergence": True},
208+
);
211209
```
212210

213211
## Authors
214212

215213
* Authored by [Colin Caprani](https://github.com/ccaprani), October 2021
214+
* Updated by Osvaldo Martin, January 2026
216215

217216
+++
218217

examples/case_studies/factor_analysis.ipynb

Lines changed: 107 additions & 206 deletions
Large diffs are not rendered by default.

examples/case_studies/factor_analysis.myst.md

Lines changed: 31 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ jupytext:
66
format_name: myst
77
format_version: 0.13
88
kernelspec:
9-
display_name: Python 3 (ipykernel)
9+
display_name: eabm
1010
language: python
1111
name: python3
1212
myst:
@@ -33,7 +33,7 @@ Factor analysis is a widely used probabilistic model for identifying low-rank st
3333
:::
3434

3535
```{code-cell} ipython3
36-
import arviz as az
36+
import arviz.preview as az
3737
import numpy as np
3838
import pymc as pm
3939
import pytensor.tensor as pt
@@ -42,7 +42,6 @@ import seaborn as sns
4242
import xarray as xr
4343
4444
from matplotlib import pyplot as plt
45-
from matplotlib.lines import Line2D
4645
from numpy.random import default_rng
4746
from xarray_einstats import linalg
4847
from xarray_einstats.stats import XrContinuousRV
@@ -52,7 +51,7 @@ print(f"Running on PyMC v{pm.__version__}")
5251

5352
```{code-cell} ipython3
5453
%config InlineBackend.figure_format = 'retina'
55-
az.style.use("arviz-darkgrid")
54+
az.style.use("arviz-variat")
5655
5756
np.set_printoptions(precision=3, suppress=True)
5857
RANDOM_SEED = 31415
@@ -128,11 +127,13 @@ with pm.Model(coords=coords) as PPCA:
128127
At this point, there are already several warnings regarding failed convergence checks. We can see further problems in the trace plot below. This plot shows the path taken by each sampler chain for a single entry in the matrix $W$ as well as the average evaluated over samples for each chain.
129128

130129
```{code-cell} ipython3
131-
for i in trace.posterior.chain.values:
132-
samples = trace.posterior["W"].sel(chain=i, observed_columns=3, latent_columns=1)
133-
plt.plot(samples, label=f"Chain {i + 1}")
134-
plt.axhline(samples.mean(), color=f"C{i}")
135-
plt.legend(ncol=4, loc="upper center", fontsize=12, frameon=True), plt.xlabel("Sample");
130+
az.plot_trace(
131+
trace,
132+
var_names="W",
133+
coords={"observed_columns": 3, "latent_columns": 1},
134+
sample_dims=["draw"],
135+
figure_kwargs={"sharey": True},
136+
);
136137
```
137138

138139
Each chain appears to have a different sample mean and we can also see that there is a great deal of autocorrelation across chains, manifest as long-range trends over sampling iterations.
@@ -194,13 +195,7 @@ with pm.Model(coords=coords) as PPCA_identified:
194195
F = pm.Normal("F", dims=("latent_columns", "rows"))
195196
sigma = pm.HalfNormal("sigma", 1.0)
196197
X = pm.Normal("X", mu=W @ F, sigma=sigma, observed=Y, dims=("observed_columns", "rows"))
197-
trace = pm.sample(tune=2000, random_seed=rng) # target_accept=0.9
198-
199-
for i in range(4):
200-
samples = trace.posterior["W"].sel(chain=i, observed_columns=3, latent_columns=1)
201-
plt.plot(samples, label=f"Chain {i + 1}")
202-
203-
plt.legend(ncol=4, loc="lower center", fontsize=8), plt.xlabel("Sample");
198+
trace = pm.sample(tune=2000, random_seed=rng, target_accept=0.9)
204199
```
205200

206201
$W$ (and $F$!) now have entries with identical posterior distributions as compared between sampler chains, although it's apparent that some autocorrelation remains.
@@ -251,29 +246,28 @@ When we compare the posteriors calculated using MCMC and VI, we find that (for a
251246
```{code-cell} ipython3
252247
col_selection = dict(observed_columns=3, latent_columns=1)
253248
254-
ax = az.plot_kde(
255-
trace.posterior["W"].sel(**col_selection).values,
256-
label=f"MCMC posterior for the explicit model",
257-
plot_kwargs={"color": f"C{1}"},
258-
)
259-
260-
az.plot_kde(
261-
trace_amortized.posterior["W"].sel(**col_selection).values,
262-
label="MCMC posterior for amortized inference",
263-
plot_kwargs={"color": f"C{2}", "linestyle": "--"},
249+
dt = az.from_dict(
250+
{
251+
"posterior": {
252+
"MCMC_explicit": trace.posterior["W"].sel(**col_selection),
253+
"MCMC_amortized": trace_amortized.posterior["W"].sel(**col_selection),
254+
"FR-ADVI_amortized": trace_vi.posterior["W"].sel(**col_selection),
255+
}
256+
}
264257
)
265258
266-
267-
az.plot_kde(
268-
trace_vi.posterior["W"].sel(**col_selection).squeeze().values,
269-
label="FR-ADVI posterior for amortized inference",
270-
plot_kwargs={"alpha": 0},
271-
fill_kwargs={"alpha": 0.5, "color": f"C{0}"},
259+
pc = az.plot_dist(
260+
dt,
261+
cols=None,
262+
aes={"color": ["__variable__"]},
263+
visuals={
264+
"title": False,
265+
"point_estimate_text": False,
266+
"point_estimate": False,
267+
"credible_interval": False,
268+
},
272269
)
273-
274-
275-
ax.set_title(rf"PDFs of $W$ estimate at {col_selection}")
276-
ax.legend(loc="upper left", fontsize=10);
270+
pc.add_legend("__variable__")
277271
```
278272

279273
### Post-hoc identification of F
@@ -389,6 +383,7 @@ We find that our model does a decent job of capturing the variation in the origi
389383
* Updated by [Christopher Krapu](https://github.com/ckrapu) on April 4, 2021
390384
* Updated by Oriol Abril-Pla to use PyMC v4 and xarray-einstats on March, 2022
391385
* Updated by Erik Werner on Dec, 2023 ([pymc-examples#612](https://github.com/pymc-devs/pymc-examples/pull/612))
386+
* Updated by Osvaldo Martin on January, 2026
392387

393388
+++
394389

0 commit comments

Comments
 (0)