Skip to content

Commit cea7a15

Browse files
velochyricardoV94
authored andcommitted
Added paths_inv_hessian_diag to sample_stats to make it easier to use for nuts init
1 parent a46b99d commit cea7a15

File tree

4 files changed

+167
-33
lines changed

4 files changed

+167
-33
lines changed

pymc_extras/inference/pathfinder/idata.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,13 @@ def pathfinder_result_to_xarray(
145145
if result.elbo_argmax is not None:
146146
data_vars["elbo_argmax"] = xr.DataArray(_extract_scalar(result.elbo_argmax))
147147

148+
if n_params is not None and result.inv_hessian_diag is not None:
149+
data_vars["inv_hessian_diag"] = xr.DataArray(
150+
result.inv_hessian_diag[0],
151+
dims=["param"],
152+
coords={"param": coords["param"]},
153+
)
154+
148155
data_vars["lbfgs_status_code"] = xr.DataArray(result.lbfgs_status.value)
149156
data_vars["lbfgs_status_name"] = xr.DataArray(result.lbfgs_status.name)
150157
data_vars["path_status_code"] = xr.DataArray(result.path_status.value)
@@ -340,6 +347,13 @@ def _add_path_scalar(name: str, data):
340347
_add_path_scalar("logQ_mean", np.mean(result.logQ, axis=1))
341348
_add_path_scalar("logQ_max", np.max(result.logQ, axis=1))
342349

350+
if n_params is not None and result.inv_hessian_diag is not None:
351+
data_vars["paths_inv_hessian_diag"] = xr.DataArray(
352+
result.inv_hessian_diag,
353+
dims=["path", "param"],
354+
coords={"path": coords["path"], "param": coords["param"]},
355+
)
356+
343357
if n_params is not None and result.samples is not None and result.samples.ndim >= 3:
344358
final_samples = result.samples[:, -1, :] # (S, N)
345359
data_vars["paths_final_sample"] = xr.DataArray(

pymc_extras/inference/pathfinder/pathfinder.py

Lines changed: 64 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ def _bfgs_sample_pt(
399399
u: TensorVariable,
400400
J: int,
401401
N: int,
402-
) -> tuple[TensorVariable, TensorVariable]:
402+
) -> tuple[TensorVariable, TensorVariable, TensorVariable]:
403403
"""Symbolic L-BFGS inverse-Hessian sample.
404404
405405
The dense vs sparse path is selected at graph-construction time from
@@ -420,6 +420,7 @@ def _bfgs_sample_pt(
420420
-------
421421
phi : (M, N) samples
422422
logQ : (M,) log-density under the approximation
423+
inv_hessian_diag : (N,) diagonal of sampler covariance used for `phi`
423424
"""
424425
J2 = 2 * J
425426

@@ -447,6 +448,7 @@ def _bfgs_sample_pt(
447448
logdet = 2.0 * pt.sum(pt.log(pt.abs(pt.diag(Lchol))))
448449
mu = x - H_inv @ g
449450
phi = (mu[:, None] + Lchol @ u.T).T # (M, N)
451+
inv_hessian_diag = pt.diag(H_inv)
450452
else:
451453
# Sparse path: economy QR avoids O(N²) matrices (large-N regime)
452454
Q, R = pt.linalg.qr(beta * inv_sqrt_alpha[:, None], mode="reduced") # Q:(N,2J), R:(2J,2J)
@@ -457,9 +459,12 @@ def _bfgs_sample_pt(
457459
mu = x - ((alpha * g)[:, None] + beta @ (gamma @ btg))[:, 0] # (N,)
458460
QtU = Q.T @ u.T # (2J, M)
459461
phi = (mu[:, None] + sqrt_alpha[:, None] * (Q @ ((Lchol - pt.eye(J2)) @ QtU) + u.T)).T
462+
# diag(Q (L L^T - I) Q^T) = row_norm_sq(Q L) - row_norm_sq(Q)
463+
ql = Q @ Lchol
464+
inv_hessian_diag = alpha * (1.0 + pt.sum(ql * ql, axis=1) - pt.sum(Q * Q, axis=1))
460465

461466
logQ = -0.5 * (logdet + pt.sum(u * u, axis=-1) + N * np.log(2.0 * np.pi))
462-
return phi, logQ
467+
return phi, logQ, inv_hessian_diag
463468

464469

465470
def make_pathfinder_sample_fn(
@@ -469,6 +474,7 @@ def make_pathfinder_sample_fn(
469474
jacobian: bool,
470475
compile_kwargs: dict,
471476
vectorize: bool = False,
477+
return_inv_hessian_diag: bool = False,
472478
) -> Function:
473479
"""Compile a single PyTensor function covering bfgs sample + batched logP evaluation.
474480
@@ -487,11 +493,13 @@ def make_pathfinder_sample_fn(
487493
compile_kwargs : dict
488494
vectorize : bool, optional
489495
If True, use vectorize_graph instead of pytensor.map. Default False.
496+
return_inv_hessian_diag : bool, optional
497+
If True, include inverse-Hessian diagonal as 4th output.
490498
491499
Returns
492500
-------
493501
fn : Function
494-
Compiled: (x, g, alpha, s_win, z_win, u) → (phi, logQ, logP)
502+
Compiled: (x, g, alpha, s_win, z_win, u) → (phi, logQ, logP[, inv_hessian_diag])
495503
where s_win, z_win are (N, J), u is (M, N), and M is a dynamic dimension.
496504
"""
497505
(logP_single,), single_input = pm.pytensorf.join_nonshared_inputs(
@@ -507,7 +515,9 @@ def make_pathfinder_sample_fn(
507515
z_win_sym = pt.matrix("z_win", dtype="float64") # (N, J)
508516
u_sym = pt.matrix("u", dtype="float64") # (M, N) — M is dynamic
509517

510-
phi_sym, logQ_sym = _bfgs_sample_pt(x_sym, g_sym, alpha_sym, s_win_sym, z_win_sym, u_sym, J, N)
518+
phi_sym, logQ_sym, inv_hessian_diag_sym = _bfgs_sample_pt(
519+
x_sym, g_sym, alpha_sym, s_win_sym, z_win_sym, u_sym, J, N
520+
)
511521

512522
if vectorize:
513523
batched_logP_sym = vectorize_graph(logP_single, replace={single_input: phi_sym})
@@ -518,6 +528,10 @@ def make_pathfinder_sample_fn(
518528
return_updates=False,
519529
)
520530

531+
outputs = [phi_sym, logQ_sym, batched_logP_sym]
532+
if return_inv_hessian_diag:
533+
outputs.append(inv_hessian_diag_sym)
534+
521535
fn = pytensor.function(
522536
[
523537
pytensor.In(x_sym, borrow=True),
@@ -527,7 +541,7 @@ def make_pathfinder_sample_fn(
527541
pytensor.In(z_win_sym, borrow=True),
528542
pytensor.In(u_sym, borrow=True),
529543
],
530-
[phi_sym, logQ_sym, batched_logP_sym],
544+
outputs,
531545
**compile_kwargs,
532546
)
533547
fn.trust_input = True
@@ -557,7 +571,7 @@ def make_elbo_from_state_fn(
557571
S_sym = pt.matrix("S", dtype="float64")
558572
Z_sym = pt.matrix("Z", dtype="float64")
559573
u_sym = pt.matrix("u", dtype="float64")
560-
phi_sym, logQ_sym = _bfgs_sample_pt(x_sym, g_sym, alpha_sym, S_sym, Z_sym, u_sym, J, N)
574+
phi_sym, logQ_sym, _ = _bfgs_sample_pt(x_sym, g_sym, alpha_sym, S_sym, Z_sym, u_sym, J, N)
561575
if vectorize:
562576
batched_logP_sym = vectorize_graph(logP_single, replace={single_input: phi_sym})
563577
else:
@@ -747,7 +761,8 @@ def __call__(self, x: NDArray) -> float | None:
747761
# Sample + logP in a single compiled call. Pass s_win/z_win as inputs.
748762
u = self._rng.standard_normal((self.num_elbo_draws, self._N))
749763
try:
750-
_, logQ, logP = self.sample_logp_fn(x, g, alpha, self.s_win, self.z_win, u)
764+
sample_out = self.sample_logp_fn(x, g, alpha, self.s_win, self.z_win, u)
765+
_, logQ, logP = sample_out[:3]
751766
logP = np.asarray(logP)
752767
logQ = np.asarray(logQ)
753768
finite = np.isfinite(logP)
@@ -880,7 +895,21 @@ def make_single_pathfinder_fn(
880895
N = x_base.shape[0]
881896

882897
sample_logp_func = make_pathfinder_sample_fn(
883-
model, N, maxcor, jacobian=jacobian, compile_kwargs=compile_kwargs, vectorize=vectorize
898+
model,
899+
N,
900+
maxcor,
901+
jacobian=jacobian,
902+
compile_kwargs=compile_kwargs,
903+
vectorize=vectorize,
904+
)
905+
final_sample_logp_func = make_pathfinder_sample_fn(
906+
model,
907+
N,
908+
maxcor,
909+
jacobian=jacobian,
910+
compile_kwargs=compile_kwargs,
911+
vectorize=vectorize,
912+
return_inv_hessian_diag=True,
884913
)
885914

886915
def _check_lbfgs_status(status):
@@ -889,7 +918,9 @@ def _check_lbfgs_status(status):
889918
elif status == LBFGSStatus.LBFGS_FAILED:
890919
raise LBFGSException()
891920

892-
def _make_result(psi, logP_psi, logQ_psi, lbfgs_niter, elbo_argmax, lbfgs_status):
921+
def _make_result(
922+
psi, logP_psi, logQ_psi, lbfgs_niter, elbo_argmax, inv_hessian_diag, lbfgs_status
923+
):
893924
if np.all(~np.isfinite(logQ_psi)):
894925
raise PathInvalidLogQ()
895926
path_status = PathStatus.ELBO_ARGMAX_AT_ZERO if elbo_argmax == 0 else PathStatus.SUCCESS
@@ -899,6 +930,7 @@ def _make_result(psi, logP_psi, logQ_psi, lbfgs_niter, elbo_argmax, lbfgs_status
899930
logQ=logQ_psi,
900931
lbfgs_niter=lbfgs_niter,
901932
elbo_argmax=elbo_argmax,
933+
inv_hessian_diag=inv_hessian_diag,
902934
lbfgs_status=lbfgs_status,
903935
path_status=path_status,
904936
)
@@ -916,6 +948,7 @@ def single_pathfinder_fn(
916948
local_lbfgs = LBFGS(cached_fn, maxcor, maxiter, ftol, gtol, maxls, epsilon)
917949

918950
local_sample_logp = sample_logp_func.copy(share_memory=False)
951+
local_final_sample_logp = final_sample_logp_func.copy(share_memory=False)
919952

920953
lbfgs_status = LBFGSStatus.LBFGS_FAILED # default before LBFGS runs
921954
try:
@@ -965,17 +998,19 @@ def single_pathfinder_fn(
965998
{"status": "sampling", "current_elbo": None, "step_size": None}
966999
)
9671000

968-
phi_final, logQ_psi_flat, logP_psi_flat = local_sample_logp(
1001+
sample_out = local_final_sample_logp(
9691002
best_state["x"],
9701003
best_state["g"],
9711004
best_state["alpha"],
9721005
best_state["s_win"],
9731006
best_state["z_win"],
9741007
u_final,
9751008
)
1009+
phi_final, logQ_psi_flat, logP_psi_flat, inv_hessian_diag = sample_out
9761010
phi_final = np.asarray(phi_final)
9771011
logQ_psi_flat = np.asarray(logQ_psi_flat)
9781012
logP_psi_flat = np.asarray(logP_psi_flat)
1013+
inv_hessian_diag = np.asarray(inv_hessian_diag)[None]
9791014
# Add batch dim L=1 to match downstream expectations
9801015
psi = phi_final[None] # (1, M, N)
9811016
logP_psi = logP_psi_flat[None] # (1, M)
@@ -1012,7 +1047,15 @@ def single_pathfinder_fn(
10121047
path_status=PathStatus.SINGLE_STEP,
10131048
)
10141049

1015-
result = _make_result(psi, logP_psi, logQ_psi, lbfgs_niter, elbo_argmax, lbfgs_status)
1050+
result = _make_result(
1051+
psi,
1052+
logP_psi,
1053+
logQ_psi,
1054+
lbfgs_niter,
1055+
elbo_argmax,
1056+
inv_hessian_diag,
1057+
lbfgs_status,
1058+
)
10161059
if progress_callback is not None:
10171060
status_str = (
10181061
"elbo@0" if result.path_status == PathStatus.ELBO_ARGMAX_AT_ZERO else "ok"
@@ -1260,6 +1303,7 @@ class PathfinderResult:
12601303
logQ: NDArray | None = None
12611304
lbfgs_niter: NDArray | None = None
12621305
elbo_argmax: NDArray | None = None
1306+
inv_hessian_diag: NDArray | None = None
12631307
lbfgs_status: LBFGSStatus = LBFGSStatus.LBFGS_FAILED
12641308
path_status: PathStatus = PathStatus.PATH_FAILED
12651309

@@ -1310,6 +1354,7 @@ class MultiPathfinderResult:
13101354
logQ: NDArray | None = None
13111355
lbfgs_niter: NDArray | None = None
13121356
elbo_argmax: NDArray | None = None
1357+
inv_hessian_diag: NDArray | None = None
13131358
lbfgs_status: Counter = field(default_factory=Counter)
13141359
path_status: Counter = field(default_factory=Counter)
13151360
importance_sampling: str | None = "psis"
@@ -1331,7 +1376,14 @@ class MultiPathfinderResult:
13311376
def from_path_results(cls, path_results: list[PathfinderResult]) -> "MultiPathfinderResult":
13321377
"""aggregate successful pathfinder results and count the occurrences of each status in PathStatus and LBFGSStatus"""
13331378

1334-
NUMERIC_ATTRIBUTES = ["samples", "logP", "logQ", "lbfgs_niter", "elbo_argmax"]
1379+
NUMERIC_ATTRIBUTES = [
1380+
"samples",
1381+
"logP",
1382+
"logQ",
1383+
"lbfgs_niter",
1384+
"elbo_argmax",
1385+
"inv_hessian_diag",
1386+
]
13351387

13361388
success_results = []
13371389
mpr = cls()

tests/pathfinder/test_idata.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class MockPathfinderResult:
2121
logQ: np.ndarray = None
2222
lbfgs_niter: np.ndarray = None
2323
elbo_argmax: np.ndarray = None
24+
inv_hessian_diag: np.ndarray = None
2425
lbfgs_status: LBFGSStatus = LBFGSStatus.CONVERGED
2526
path_status: PathStatus = PathStatus.SUCCESS
2627

@@ -34,6 +35,7 @@ class MockMultiPathfinderResult:
3435
logQ: np.ndarray = None
3536
lbfgs_niter: np.ndarray = None
3637
elbo_argmax: np.ndarray = None
38+
inv_hessian_diag: np.ndarray = None
3739
lbfgs_status: Counter = None
3840
path_status: Counter = None
3941
importance_sampling: str = "psis"
@@ -232,6 +234,7 @@ def test_multi_result_conversion(self):
232234
logQ=np.random.normal(-11, 1, (3, 100)),
233235
lbfgs_niter=np.array([50, 45, 55]),
234236
elbo_argmax=np.array([25, 30, 20]),
237+
inv_hessian_diag=np.abs(np.random.normal(1.0, 0.1, (3, 2))),
235238
lbfgs_status=Counter({LBFGSStatus.CONVERGED: 3}),
236239
path_status=Counter({PathStatus.SUCCESS: 3}),
237240
num_paths=3,
@@ -261,6 +264,7 @@ def test_multi_result_conversion(self):
261264
# Check per-path data (paths_ prefix)
262265
assert "paths_lbfgs_niter" in ds.data_vars
263266
assert "paths_elbo_argmax" in ds.data_vars
267+
assert "paths_inv_hessian_diag" in ds.data_vars
264268
assert "paths_logP_mean" in ds.data_vars
265269
assert "paths_logQ_mean" in ds.data_vars
266270
assert "paths_final_sample" in ds.data_vars
@@ -269,6 +273,7 @@ def test_multi_result_conversion(self):
269273
assert "path" in ds.dims
270274
assert ds.sizes["path"] == 3
271275
assert ds["paths_lbfgs_niter"].shape == (3,)
276+
assert ds["paths_inv_hessian_diag"].shape == (3, 2)
272277

273278
# Check config data (config_ prefix)
274279
assert "config_num_draws" in ds.data_vars

0 commit comments

Comments
 (0)