diff --git a/pymc_extras/statespace/filters/kalman_filter.py b/pymc_extras/statespace/filters/kalman_filter.py index ddf55c77f..0985c5255 100644 --- a/pymc_extras/statespace/filters/kalman_filter.py +++ b/pymc_extras/statespace/filters/kalman_filter.py @@ -778,13 +778,11 @@ def _univariate_inner_filter_step(self, y, Z_row, d_row, sigma_H, nan_flag, a, P return a_filtered, P_filtered, pt.atleast_1d(y_hat), pt.atleast_2d(F), ll_inner - def kalman_step(self, y, a, P, c, d, T, Z, R, H, Q): - nan_mask = pt.isnan(y) + def kalman_step(self, *args): + y, a, P, c, d, T, Z, R, H, Q = self.unpack_args(args) - W = pt.set_subtensor(pt.eye(y.shape[0])[nan_mask, nan_mask], 0.0) - Z_masked = W.dot(Z) - H_masked = W.dot(H) - y_masked = pt.set_subtensor(y[nan_mask], 0.0) + nan_mask = pt.or_(pt.isnan(y), pt.eq(y, self.missing_fill_value)) + y_masked, Z_masked, H_masked, all_nan_flag = self.handle_missing_values(y, Z, H) result = pytensor.scan( self._univariate_inner_filter_step, @@ -805,6 +803,10 @@ def kalman_step(self, y, a, P, c, d, T, Z, R, H, Q): P_filtered = stabilize(0.5 * (P_filtered + P_filtered.mT), self.cov_jitter) a_hat, P_hat = self.predict(a=a_filtered, P=P_filtered, c=c, T=T, R=R, Q=Q) - ll = -0.5 * ((pt.neq(ll_inner, 0).sum()) * MVN_CONST + ll_inner.sum()) + ll = pt.switch( + all_nan_flag, + 0.0, + -0.5 * ((pt.neq(ll_inner, 0).sum()) * MVN_CONST + ll_inner.sum()), + ) return a_filtered, a_hat, obs_mu, P_filtered, P_hat, obs_cov, ll