From ed415359fa8536d2a617a4953014d941dc86d838 Mon Sep 17 00:00:00 2001 From: Seb Croft Date: Thu, 19 Feb 2026 09:59:09 +0000 Subject: [PATCH 1/2] Added unpack_args to UnivariateFilter kalman_step function so ordering of time-varying statespace matrices is correct --- pymc_extras/statespace/filters/kalman_filter.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pymc_extras/statespace/filters/kalman_filter.py b/pymc_extras/statespace/filters/kalman_filter.py index ddf55c77f..163684f58 100644 --- a/pymc_extras/statespace/filters/kalman_filter.py +++ b/pymc_extras/statespace/filters/kalman_filter.py @@ -778,7 +778,10 @@ def _univariate_inner_filter_step(self, y, Z_row, d_row, sigma_H, nan_flag, a, P return a_filtered, P_filtered, pt.atleast_1d(y_hat), pt.atleast_2d(F), ll_inner - def kalman_step(self, y, a, P, c, d, T, Z, R, H, Q): + def kalman_step(self, *args): + + y, a, P, c, d, T, Z, R, H, Q = self.unpack_args(args) + nan_mask = pt.isnan(y) W = pt.set_subtensor(pt.eye(y.shape[0])[nan_mask, nan_mask], 0.0) From 10b46e2256e5a5e459478ea41b01696cf4679d8c Mon Sep 17 00:00:00 2001 From: Seb Croft Date: Sun, 22 Feb 2026 19:51:37 +0000 Subject: [PATCH 2/2] added missing value handling for log likelihood and statespace matrices in UnivariateFilter kalman_step --- pymc_extras/statespace/filters/kalman_filter.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/pymc_extras/statespace/filters/kalman_filter.py b/pymc_extras/statespace/filters/kalman_filter.py index 163684f58..0985c5255 100644 --- a/pymc_extras/statespace/filters/kalman_filter.py +++ b/pymc_extras/statespace/filters/kalman_filter.py @@ -779,15 +779,10 @@ def _univariate_inner_filter_step(self, y, Z_row, d_row, sigma_H, nan_flag, a, P return a_filtered, P_filtered, pt.atleast_1d(y_hat), pt.atleast_2d(F), ll_inner def kalman_step(self, *args): - y, a, P, c, d, T, Z, R, H, Q = self.unpack_args(args) - nan_mask = pt.isnan(y) - - W = pt.set_subtensor(pt.eye(y.shape[0])[nan_mask, nan_mask], 0.0) - Z_masked = W.dot(Z) - H_masked = W.dot(H) - y_masked = pt.set_subtensor(y[nan_mask], 0.0) + nan_mask = pt.or_(pt.isnan(y), pt.eq(y, self.missing_fill_value)) + y_masked, Z_masked, H_masked, all_nan_flag = self.handle_missing_values(y, Z, H) result = pytensor.scan( self._univariate_inner_filter_step, @@ -808,6 +803,10 @@ def kalman_step(self, *args): P_filtered = stabilize(0.5 * (P_filtered + P_filtered.mT), self.cov_jitter) a_hat, P_hat = self.predict(a=a_filtered, P=P_filtered, c=c, T=T, R=R, Q=Q) - ll = -0.5 * ((pt.neq(ll_inner, 0).sum()) * MVN_CONST + ll_inner.sum()) + ll = pt.switch( + all_nan_flag, + 0.0, + -0.5 * ((pt.neq(ll_inner, 0).sum()) * MVN_CONST + ll_inner.sum()), + ) return a_filtered, a_hat, obs_mu, P_filtered, P_hat, obs_cov, ll