@@ -399,7 +399,7 @@ def _bfgs_sample_pt(
399399 u : TensorVariable ,
400400 J : int ,
401401 N : int ,
402- ) -> tuple [TensorVariable , TensorVariable ]:
402+ ) -> tuple [TensorVariable , TensorVariable , TensorVariable ]:
403403 """Symbolic L-BFGS inverse-Hessian sample.
404404
405405 The dense vs sparse path is selected at graph-construction time from
@@ -420,6 +420,7 @@ def _bfgs_sample_pt(
420420 -------
421421 phi : (M, N) samples
422422 logQ : (M,) log-density under the approximation
423+ inv_hessian_diag : (N,) diagonal of sampler covariance used for `phi`
423424 """
424425 J2 = 2 * J
425426
@@ -447,6 +448,7 @@ def _bfgs_sample_pt(
447448 logdet = 2.0 * pt .sum (pt .log (pt .abs (pt .diag (Lchol ))))
448449 mu = x - H_inv @ g
449450 phi = (mu [:, None ] + Lchol @ u .T ).T # (M, N)
451+ inv_hessian_diag = pt .diag (H_inv )
450452 else :
451453 # Sparse path: economy QR avoids O(N²) matrices (large-N regime)
452454 Q , R = pt .linalg .qr (beta * inv_sqrt_alpha [:, None ], mode = "reduced" ) # Q:(N,2J), R:(2J,2J)
@@ -457,9 +459,12 @@ def _bfgs_sample_pt(
457459 mu = x - ((alpha * g )[:, None ] + beta @ (gamma @ btg ))[:, 0 ] # (N,)
458460 QtU = Q .T @ u .T # (2J, M)
459461 phi = (mu [:, None ] + sqrt_alpha [:, None ] * (Q @ ((Lchol - pt .eye (J2 )) @ QtU ) + u .T )).T
462+ # diag(Q (L L^T - I) Q^T) = row_norm_sq(Q L) - row_norm_sq(Q)
463+ ql = Q @ Lchol
464+ inv_hessian_diag = alpha * (1.0 + pt .sum (ql * ql , axis = 1 ) - pt .sum (Q * Q , axis = 1 ))
460465
461466 logQ = - 0.5 * (logdet + pt .sum (u * u , axis = - 1 ) + N * np .log (2.0 * np .pi ))
462- return phi , logQ
467+ return phi , logQ , inv_hessian_diag
463468
464469
465470def make_pathfinder_sample_fn (
@@ -469,6 +474,7 @@ def make_pathfinder_sample_fn(
469474 jacobian : bool ,
470475 compile_kwargs : dict ,
471476 vectorize : bool = False ,
477+ return_inv_hessian_diag : bool = False ,
472478) -> Function :
473479 """Compile a single PyTensor function covering bfgs sample + batched logP evaluation.
474480
@@ -487,11 +493,13 @@ def make_pathfinder_sample_fn(
487493 compile_kwargs : dict
488494 vectorize : bool, optional
489495 If True, use vectorize_graph instead of pytensor.map. Default False.
496+ return_inv_hessian_diag : bool, optional
497+ If True, include inverse-Hessian diagonal as 4th output.
490498
491499 Returns
492500 -------
493501 fn : Function
494- Compiled: (x, g, alpha, s_win, z_win, u) → (phi, logQ, logP)
502+ Compiled: (x, g, alpha, s_win, z_win, u) → (phi, logQ, logP[, inv_hessian_diag] )
495503 where s_win, z_win are (N, J), u is (M, N), and M is a dynamic dimension.
496504 """
497505 (logP_single ,), single_input = pm .pytensorf .join_nonshared_inputs (
@@ -507,7 +515,9 @@ def make_pathfinder_sample_fn(
507515 z_win_sym = pt .matrix ("z_win" , dtype = "float64" ) # (N, J)
508516 u_sym = pt .matrix ("u" , dtype = "float64" ) # (M, N) — M is dynamic
509517
510- phi_sym , logQ_sym = _bfgs_sample_pt (x_sym , g_sym , alpha_sym , s_win_sym , z_win_sym , u_sym , J , N )
518+ phi_sym , logQ_sym , inv_hessian_diag_sym = _bfgs_sample_pt (
519+ x_sym , g_sym , alpha_sym , s_win_sym , z_win_sym , u_sym , J , N
520+ )
511521
512522 if vectorize :
513523 batched_logP_sym = vectorize_graph (logP_single , replace = {single_input : phi_sym })
@@ -518,6 +528,10 @@ def make_pathfinder_sample_fn(
518528 return_updates = False ,
519529 )
520530
531+ outputs = [phi_sym , logQ_sym , batched_logP_sym ]
532+ if return_inv_hessian_diag :
533+ outputs .append (inv_hessian_diag_sym )
534+
521535 fn = pytensor .function (
522536 [
523537 pytensor .In (x_sym , borrow = True ),
@@ -527,7 +541,7 @@ def make_pathfinder_sample_fn(
527541 pytensor .In (z_win_sym , borrow = True ),
528542 pytensor .In (u_sym , borrow = True ),
529543 ],
530- [ phi_sym , logQ_sym , batched_logP_sym ] ,
544+ outputs ,
531545 ** compile_kwargs ,
532546 )
533547 fn .trust_input = True
@@ -557,7 +571,7 @@ def make_elbo_from_state_fn(
557571 S_sym = pt .matrix ("S" , dtype = "float64" )
558572 Z_sym = pt .matrix ("Z" , dtype = "float64" )
559573 u_sym = pt .matrix ("u" , dtype = "float64" )
560- phi_sym , logQ_sym = _bfgs_sample_pt (x_sym , g_sym , alpha_sym , S_sym , Z_sym , u_sym , J , N )
574+ phi_sym , logQ_sym , _ = _bfgs_sample_pt (x_sym , g_sym , alpha_sym , S_sym , Z_sym , u_sym , J , N )
561575 if vectorize :
562576 batched_logP_sym = vectorize_graph (logP_single , replace = {single_input : phi_sym })
563577 else :
@@ -747,7 +761,8 @@ def __call__(self, x: NDArray) -> float | None:
747761 # Sample + logP in a single compiled call. Pass s_win/z_win as inputs.
748762 u = self ._rng .standard_normal ((self .num_elbo_draws , self ._N ))
749763 try :
750- _ , logQ , logP = self .sample_logp_fn (x , g , alpha , self .s_win , self .z_win , u )
764+ sample_out = self .sample_logp_fn (x , g , alpha , self .s_win , self .z_win , u )
765+ _ , logQ , logP = sample_out [:3 ]
751766 logP = np .asarray (logP )
752767 logQ = np .asarray (logQ )
753768 finite = np .isfinite (logP )
@@ -880,7 +895,21 @@ def make_single_pathfinder_fn(
880895 N = x_base .shape [0 ]
881896
882897 sample_logp_func = make_pathfinder_sample_fn (
883- model , N , maxcor , jacobian = jacobian , compile_kwargs = compile_kwargs , vectorize = vectorize
898+ model ,
899+ N ,
900+ maxcor ,
901+ jacobian = jacobian ,
902+ compile_kwargs = compile_kwargs ,
903+ vectorize = vectorize ,
904+ )
905+ final_sample_logp_func = make_pathfinder_sample_fn (
906+ model ,
907+ N ,
908+ maxcor ,
909+ jacobian = jacobian ,
910+ compile_kwargs = compile_kwargs ,
911+ vectorize = vectorize ,
912+ return_inv_hessian_diag = True ,
884913 )
885914
886915 def _check_lbfgs_status (status ):
@@ -889,7 +918,9 @@ def _check_lbfgs_status(status):
889918 elif status == LBFGSStatus .LBFGS_FAILED :
890919 raise LBFGSException ()
891920
892- def _make_result (psi , logP_psi , logQ_psi , lbfgs_niter , elbo_argmax , lbfgs_status ):
921+ def _make_result (
922+ psi , logP_psi , logQ_psi , lbfgs_niter , elbo_argmax , inv_hessian_diag , lbfgs_status
923+ ):
893924 if np .all (~ np .isfinite (logQ_psi )):
894925 raise PathInvalidLogQ ()
895926 path_status = PathStatus .ELBO_ARGMAX_AT_ZERO if elbo_argmax == 0 else PathStatus .SUCCESS
@@ -899,6 +930,7 @@ def _make_result(psi, logP_psi, logQ_psi, lbfgs_niter, elbo_argmax, lbfgs_status
899930 logQ = logQ_psi ,
900931 lbfgs_niter = lbfgs_niter ,
901932 elbo_argmax = elbo_argmax ,
933+ inv_hessian_diag = inv_hessian_diag ,
902934 lbfgs_status = lbfgs_status ,
903935 path_status = path_status ,
904936 )
@@ -916,6 +948,7 @@ def single_pathfinder_fn(
916948 local_lbfgs = LBFGS (cached_fn , maxcor , maxiter , ftol , gtol , maxls , epsilon )
917949
918950 local_sample_logp = sample_logp_func .copy (share_memory = False )
951+ local_final_sample_logp = final_sample_logp_func .copy (share_memory = False )
919952
920953 lbfgs_status = LBFGSStatus .LBFGS_FAILED # default before LBFGS runs
921954 try :
@@ -965,17 +998,19 @@ def single_pathfinder_fn(
965998 {"status" : "sampling" , "current_elbo" : None , "step_size" : None }
966999 )
9671000
968- phi_final , logQ_psi_flat , logP_psi_flat = local_sample_logp (
1001+ sample_out = local_final_sample_logp (
9691002 best_state ["x" ],
9701003 best_state ["g" ],
9711004 best_state ["alpha" ],
9721005 best_state ["s_win" ],
9731006 best_state ["z_win" ],
9741007 u_final ,
9751008 )
1009+ phi_final , logQ_psi_flat , logP_psi_flat , inv_hessian_diag = sample_out
9761010 phi_final = np .asarray (phi_final )
9771011 logQ_psi_flat = np .asarray (logQ_psi_flat )
9781012 logP_psi_flat = np .asarray (logP_psi_flat )
1013+ inv_hessian_diag = np .asarray (inv_hessian_diag )[None ]
9791014 # Add batch dim L=1 to match downstream expectations
9801015 psi = phi_final [None ] # (1, M, N)
9811016 logP_psi = logP_psi_flat [None ] # (1, M)
@@ -1012,7 +1047,15 @@ def single_pathfinder_fn(
10121047 path_status = PathStatus .SINGLE_STEP ,
10131048 )
10141049
1015- result = _make_result (psi , logP_psi , logQ_psi , lbfgs_niter , elbo_argmax , lbfgs_status )
1050+ result = _make_result (
1051+ psi ,
1052+ logP_psi ,
1053+ logQ_psi ,
1054+ lbfgs_niter ,
1055+ elbo_argmax ,
1056+ inv_hessian_diag ,
1057+ lbfgs_status ,
1058+ )
10161059 if progress_callback is not None :
10171060 status_str = (
10181061 "elbo@0" if result .path_status == PathStatus .ELBO_ARGMAX_AT_ZERO else "ok"
@@ -1260,6 +1303,7 @@ class PathfinderResult:
12601303 logQ : NDArray | None = None
12611304 lbfgs_niter : NDArray | None = None
12621305 elbo_argmax : NDArray | None = None
1306+ inv_hessian_diag : NDArray | None = None
12631307 lbfgs_status : LBFGSStatus = LBFGSStatus .LBFGS_FAILED
12641308 path_status : PathStatus = PathStatus .PATH_FAILED
12651309
@@ -1310,6 +1354,7 @@ class MultiPathfinderResult:
13101354 logQ : NDArray | None = None
13111355 lbfgs_niter : NDArray | None = None
13121356 elbo_argmax : NDArray | None = None
1357+ inv_hessian_diag : NDArray | None = None
13131358 lbfgs_status : Counter = field (default_factory = Counter )
13141359 path_status : Counter = field (default_factory = Counter )
13151360 importance_sampling : str | None = "psis"
@@ -1331,7 +1376,14 @@ class MultiPathfinderResult:
13311376 def from_path_results (cls , path_results : list [PathfinderResult ]) -> "MultiPathfinderResult" :
13321377 """aggregate successful pathfinder results and count the occurrences of each status in PathStatus and LBFGSStatus"""
13331378
1334- NUMERIC_ATTRIBUTES = ["samples" , "logP" , "logQ" , "lbfgs_niter" , "elbo_argmax" ]
1379+ NUMERIC_ATTRIBUTES = [
1380+ "samples" ,
1381+ "logP" ,
1382+ "logQ" ,
1383+ "lbfgs_niter" ,
1384+ "elbo_argmax" ,
1385+ "inv_hessian_diag" ,
1386+ ]
13351387
13361388 success_results = []
13371389 mpr = cls ()
0 commit comments