Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
f8a1e0e
StateSpace Module Hurricane Case Study WIP
Dekermanjian Dec 30, 2024
d4e4c45
attempt to fix plot rendering on ReviewNB
Dekermanjian Dec 30, 2024
ff5de09
attempt to fix plot rendering on ReviewNB try jupyterlab renderer
Dekermanjian Dec 30, 2024
386b850
attempt to fix plot rendering on ReviewNB try notebook renderer
Dekermanjian Dec 30, 2024
8ed4c1b
1. Made changes in accordance to reviewer comments
Dekermanjian Jan 19, 2025
611e8b4
1. added forecast to exogenous variables section
Dekermanjian Jan 22, 2025
42bc4bb
1. updated the exogenous covariates section
Dekermanjian Feb 6, 2025
470bb3c
1. Added text to spline section
Dekermanjian Feb 9, 2025
efa1361
1. fixed typos, latex errors, and improved readability of text
Dekermanjian Apr 7, 2025
c22bf40
updated notebook removing exogenous forecasts hack
Dekermanjian Jun 14, 2025
38d2b84
made some figures static to reduce file size
Dekermanjian Jun 14, 2025
b1227df
1. updated docstring indentation
Dekermanjian Jun 14, 2025
4c22c42
attempt to fix latex rendering
Dekermanjian Jun 15, 2025
c350e67
missed spot in cleaning up latex
Dekermanjian Jun 15, 2025
fd148e1
updated equation number to attempt to fix latex rendering
Dekermanjian Jun 15, 2025
661c3f1
removed equation numbering
Dekermanjian Jun 15, 2025
8b52b73
fixed broken reference in closing remarks
Dekermanjian Jun 15, 2025
ab70d9e
updated headers. Only one H1 header in notebook
Dekermanjian Jun 15, 2025
4d8196c
updated authoring dates and fixed missed h3 header
Dekermanjian Jun 15, 2025
e6c5ab9
made change to utilize conditional posteriors for predictive state me…
Dekermanjian Jun 15, 2025
2fb9a2b
Revert "made change to utilize conditional posteriors for predictive …
Dekermanjian Jun 15, 2025
c595df8
try mimetype plotly renderer
Dekermanjian Jun 15, 2025
55a94c4
try sphinx_gallery plotly renderer
Dekermanjian Jun 15, 2025
8617d48
try modifying javascript to render correctly
Dekermanjian Jun 15, 2025
093a2c3
try disabling plotly mathjax
Dekermanjian Jun 15, 2025
d931ae1
try iframe renderer
Dekermanjian Jun 15, 2025
3425565
changed plotly to static rendering
Dekermanjian Jun 15, 2025
818df11
fixed inconsistency where transition matrix was sometimes referred to…
Dekermanjian Jun 16, 2025
e526bd5
try remove-cell tag to change thumbnail
Dekermanjian Jun 16, 2025
fa68d2c
updated the way in which we sample filter_outputs using the sample_fi…
Dekermanjian Aug 31, 2025
a2b40c7
updated to use new arviz
Dekermanjian Feb 14, 2026
3416b03
added arviz-variat style to arviz plots
Dekermanjian Feb 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
569 changes: 370 additions & 199 deletions examples/case_studies/ssm_hurricane_tracking.ipynb

Large diffs are not rendered by default.

72 changes: 48 additions & 24 deletions examples/case_studies/ssm_hurricane_tracking.myst.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ jupytext:
format_name: myst
format_version: 0.13
kernelspec:
display_name: pymc_examples_dev
display_name: pymc-examples
language: python
name: python3
myst:
Expand Down Expand Up @@ -91,8 +91,8 @@ import warnings

warnings.filterwarnings("ignore", message="The RandomType SharedVariables", category=UserWarning)

import arviz as az
import arviz.labels as azl
import arviz.preview as az
import numpy as np
import pymc as pm
import pytensor.tensor as pt
Expand All @@ -119,6 +119,8 @@ from pymc_extras.statespace.utils.constants import (

# make all plotly figures static
pio.renderers.default = "svg"
# set arviz style
az.style.use("arviz-variat")
```

## Helper Functions
Expand Down Expand Up @@ -313,9 +315,7 @@ def generate_period_forecasts(
longitude_cppc = az.extract(forecasts["forecast_observed"].sel(observed_state="x"))
latitude_cppc = az.extract(forecasts["forecast_observed"].sel(observed_state="y"))
cppc_var = forecasts["forecast_observed"].var(("chain", "draw"))
cppc_covs = xr.cov(
latitude_cppc["forecast_observed"], longitude_cppc["forecast_observed"], dim="sample"
)
cppc_covs = xr.cov(latitude_cppc, longitude_cppc, dim="sample")
covs_list = []
for i in range(cppc_covs.shape[0]):
covs_list.append(
Expand Down Expand Up @@ -964,23 +964,32 @@ with pm.Model(coords=n_ssm.coords) as newtonian:

n_ssm.build_statespace_graph(
data=fiona_df.select("longitude", "latitude").to_numpy(),
save_kalman_filter_outputs_in_idata=True,
)
newtonian_idata = pm.sample(
nuts_sampler="nutpie", nuts_sampler_kwargs={"backend": "jax", "gradient_backend": "jax"}
)
```

```{code-cell} ipython3
az.summary(newtonian_idata, var_names="acceleration_innovations", kind="stats")
az.summary(newtonian_idata, var_names="acceleration_innovations", kind="stats", round_to=4)
```

```{code-cell} ipython3
n_ssm_filter_outputs = n_ssm.sample_filter_outputs(
newtonian_idata, filter_output_names=["predicted_covariances", "predicted_observed_states"]
)
```

```{code-cell} ipython3
predicted_covs = newtonian_idata.posterior["predicted_covariance"].mean(("chain", "draw"))
predicted_covs = n_ssm_filter_outputs.posterior_predictive["predicted_covariances"].mean(
("chain", "draw")
)
```

```{code-cell} ipython3
post_mean = newtonian_idata.posterior["predicted_observed_state"].mean(("chain", "draw"))
post_mean = n_ssm_filter_outputs.posterior_predictive["predicted_observed_states"].mean(
("chain", "draw")
)
```

Not bad for a model with only one parameter. We can see that the forecast gets wonky in the middle where the trajectory of the Hurricane changes directions over short time periods. Again, it is important to keep in mind that what we are plotting are the one-step/period ahead forecast. In our case, our periods are six hours apart. Unfortunately, a 6-hour ahead hurricane forecast is not very practical. Let's see what we get when we generate a 4-period (24-hour) ahead forecast.
Expand Down Expand Up @@ -1343,10 +1352,7 @@ with pm.Model(coords=exog_ssm.coords) as exogenous:

acceleration_innovations = pm.Gamma("acceleration_innovations", 0.1, 5, shape=(1,))

exog_ssm.build_statespace_graph(
data=fiona_df.select("longitude", "latitude").to_numpy(),
save_kalman_filter_outputs_in_idata=True,
)
exog_ssm.build_statespace_graph(data=fiona_df.select("longitude", "latitude").to_numpy())
exogenous_idata = pm.sample(
nuts_sampler="nutpie", nuts_sampler_kwargs={"backend": "jax", "gradient_backend": "jax"}
)
Expand All @@ -1355,7 +1361,7 @@ with pm.Model(coords=exog_ssm.coords) as exogenous:
Typically, the surface wind speed and the central pressure of a hurricane carry little information on the path the hurricane will take. The path of a hurricane is, generally, influenced by surrounding atmospheric conditions like pressure gradients. Knowing this, it makes sense to see that many of our beta parameters are close to zero, indicating little to no influence on the hurricanes' path.

```{code-cell} ipython3
az.plot_trace(exogenous_idata, var_names="acceleration_innovations");
az.plot_trace_dist(exogenous_idata, var_names="acceleration_innovations");
```

```{code-cell} ipython3
Expand All @@ -1367,8 +1373,18 @@ az.plot_forest(
### Make in-sample forecasts with new exogenous model

```{code-cell} ipython3
predicted_covs = exogenous_idata.posterior["predicted_covariance"].mean(("chain", "draw"))
post_mean = exogenous_idata.posterior["predicted_observed_state"].mean(("chain", "draw"))
exog_ssm_filter_outputs = exog_ssm.sample_filter_outputs(
exogenous_idata, filter_output_names=["predicted_covariances", "predicted_observed_states"]
)
```

```{code-cell} ipython3
predicted_covs = exog_ssm_filter_outputs.posterior_predictive["predicted_covariances"].mean(
("chain", "draw")
)
post_mean = exog_ssm_filter_outputs.posterior_predictive["predicted_observed_states"].mean(
("chain", "draw")
)
```

Our one-period ahead forecasts seem to be slightly worse than our Newtonian model. You will notice that at the end of the forecast we see that our trajectory is erroneously more north rather than north-east. Since the exogenous variables we added to the model don't carry additional information with respect to the hurricane's trajectory, this results are expected.
Expand Down Expand Up @@ -1805,10 +1821,7 @@ with pm.Model(coords=spline_ssm.coords) as spline_model:

acceleration_innovations = pm.Gamma("acceleration_innovations", 0.1, 5, shape=(1,))

spline_ssm.build_statespace_graph(
data=fiona_df.select("longitude", "latitude").to_numpy(),
save_kalman_filter_outputs_in_idata=True,
)
spline_ssm.build_statespace_graph(data=fiona_df.select("longitude", "latitude").to_numpy())
spline_idata = pm.sample(
nuts_sampler="nutpie", nuts_sampler_kwargs={"backend": "jax", "gradient_backend": "jax"}
)
Expand All @@ -1817,11 +1830,11 @@ with pm.Model(coords=spline_ssm.coords) as spline_model:
Most of our spline parameters are around zero, with a handful of exceptions. Let's take a look at how these effect our forecasts.

```{code-cell} ipython3
az.plot_trace(spline_idata, var_names="acceleration_innovations");
az.plot_trace_dist(spline_idata, var_names="acceleration_innovations");
```

```{code-cell} ipython3
az.plot_trace(spline_idata, var_names=["beta_exog"], compact=True, figsize=(20, 8));
az.plot_trace_dist(spline_idata, var_names=["beta_exog"], compact=True);
```

### Make in-sample forecasts with new spline model
Expand All @@ -1831,8 +1844,18 @@ az.plot_trace(spline_idata, var_names=["beta_exog"], compact=True, figsize=(20,
Our one-period ahead forecasts, look better than the ones we generated from the Exogenous covariates model, but worse than the original model that purely follows Newtonian kinematics.

```{code-cell} ipython3
predicted_covs = spline_idata.posterior["predicted_covariance"].mean(("chain", "draw"))
post_mean = spline_idata.posterior["predicted_observed_state"].mean(("chain", "draw"))
spline_ssm_filter_outputs = spline_ssm.sample_filter_outputs(
spline_idata, filter_output_names=["predicted_covariances", "predicted_observed_states"]
)
```

```{code-cell} ipython3
predicted_covs = spline_ssm_filter_outputs.posterior_predictive["predicted_covariances"].mean(
("chain", "draw")
)
post_mean = spline_ssm_filter_outputs.posterior_predictive["predicted_observed_states"].mean(
("chain", "draw")
)
```

```{code-cell} ipython3
Expand Down Expand Up @@ -1941,6 +1964,7 @@ fig.show(width=1000, renderer="png", config={"displayModeBar": False})
```

## Authors
* Updated by Jonathan Dekermanjian in August, 2025 to use the `sample_filter_outputs` method
* Authored by Jonathan Dekermanjian in June, 2025

+++
Expand Down