2727from typing import Any , Literal , Self , TypeAlias
2828
2929import arviz as az
30- import filelock
3130import numpy as np
3231import pymc as pm
3332import pytensor
@@ -548,54 +547,6 @@ def make_pathfinder_sample_fn(
548547 return fn
549548
550549
551- def make_elbo_from_state_fn (
552- model : Model ,
553- N : int ,
554- J : int ,
555- jacobian : bool ,
556- compile_kwargs : dict ,
557- vectorize : bool = False ,
558- ) -> Function :
559- """Compiled (x, g, alpha, S, Z, u) → (phi, logQ, logP) for fixture/tests.
560-
561- S, Z are explicit inputs (not shared), for recomputing ELBO from saved state.
562- """
563- (logP_single ,), single_input = pm .pytensorf .join_nonshared_inputs (
564- model .initial_point (),
565- [model .logp (jacobian = jacobian )],
566- model .value_vars ,
567- )
568- x_sym = pt .vector ("x" , dtype = "float64" )
569- g_sym = pt .vector ("g" , dtype = "float64" )
570- alpha_sym = pt .vector ("alpha" , dtype = "float64" )
571- S_sym = pt .matrix ("S" , dtype = "float64" )
572- Z_sym = pt .matrix ("Z" , dtype = "float64" )
573- u_sym = pt .matrix ("u" , dtype = "float64" )
574- phi_sym , logQ_sym , _ = _bfgs_sample_pt (x_sym , g_sym , alpha_sym , S_sym , Z_sym , u_sym , J , N )
575- if vectorize :
576- batched_logP_sym = vectorize_graph (logP_single , replace = {single_input : phi_sym })
577- else :
578- batched_logP_sym = pytensor .map (
579- fn = lambda x_i : clone_replace ([logP_single ], replace = {single_input : x_i })[0 ],
580- sequences = [phi_sym ],
581- return_updates = False ,
582- )
583- fn = pytensor .function (
584- [
585- pytensor .In (x_sym , borrow = True ),
586- pytensor .In (g_sym , borrow = True ),
587- pytensor .In (alpha_sym , borrow = True ),
588- pytensor .In (S_sym , borrow = True ),
589- pytensor .In (Z_sym , borrow = True ),
590- pytensor .In (u_sym , borrow = True ),
591- ],
592- [phi_sym , logQ_sym , batched_logP_sym ],
593- ** compile_kwargs ,
594- )
595- fn .trust_input = True
596- return fn
597-
598-
599550class PathStatus (Enum ):
600551 """
601552 Statuses of a single-path pathfinder.
@@ -1583,6 +1534,80 @@ def _get_status_warning(mpr: MultiPathfinderResult) -> list[str]:
15831534 return warnings
15841535
15851536
1537+ def _make_multipath_progress (progressbar : bool ) -> CustomProgress :
1538+ return CustomProgress (
1539+ TextColumn ("{task.description}" , table_column = Column ("Path" , min_width = 7 , no_wrap = True )),
1540+ TextColumn (
1541+ "{task.fields[status]}" , table_column = Column ("Status" , min_width = 10 , no_wrap = True )
1542+ ),
1543+ TextColumn (
1544+ "{task.fields[lbfgs_steps]}" , table_column = Column ("Steps" , min_width = 6 , no_wrap = True )
1545+ ),
1546+ TextColumn (
1547+ "{task.fields[steps_per_sec]}" ,
1548+ table_column = Column ("Steps/s" , min_width = 8 , no_wrap = True ),
1549+ ),
1550+ TextColumn (
1551+ "{task.fields[best_ind]}" , table_column = Column ("Best step" , min_width = 9 , no_wrap = True )
1552+ ),
1553+ TextColumn (
1554+ "{task.fields[best_elbo]}" , table_column = Column ("Best ELBO" , min_width = 12 , no_wrap = True )
1555+ ),
1556+ TextColumn (
1557+ "{task.fields[current_elbo]}" ,
1558+ table_column = Column ("Cur ELBO" , min_width = 12 , no_wrap = True ),
1559+ ),
1560+ TextColumn (
1561+ "{task.fields[step_size]}" ,
1562+ table_column = Column ("Step size" , min_width = 10 , no_wrap = True ),
1563+ ),
1564+ TimeElapsedColumn (table_column = Column ("Elapsed" , min_width = 8 , no_wrap = True )),
1565+ include_headers = True ,
1566+ console = Console (theme = default_progress_theme ),
1567+ disable = not progressbar ,
1568+ )
1569+
1570+
1571+ def _make_progress_callback (progress : CustomProgress , task_id : int ) -> Callable [[dict ], None ]:
1572+ def cb (info : dict ) -> None :
1573+ fields : dict [str , Any ] = {}
1574+ if "status" in info and info ["status" ] is not None :
1575+ fields ["status" ] = info ["status" ]
1576+ if "lbfgs_steps" in info :
1577+ fields ["lbfgs_steps" ] = info ["lbfgs_steps" ]
1578+ if "best_elbo" in info :
1579+ val = info ["best_elbo" ]
1580+ fields ["best_elbo" ] = (
1581+ f"{ val :.3f} " if val is not None and np .isfinite (float (val )) else "—"
1582+ )
1583+ if "best_ind" in info :
1584+ val = info ["best_ind" ]
1585+ fields ["best_ind" ] = (
1586+ str (int (val )) if val is not None and np .isfinite (float (val )) else "—"
1587+ )
1588+ if "current_elbo" in info :
1589+ val = info ["current_elbo" ]
1590+ fields ["current_elbo" ] = (
1591+ f"{ val :.3f} " if val is not None and np .isfinite (float (val )) else "—"
1592+ )
1593+ if "step_size" in info :
1594+ val = info ["step_size" ]
1595+ fields ["step_size" ] = (
1596+ f"{ val :.2e} " if val is not None and np .isfinite (float (val )) else "—"
1597+ )
1598+ if "steps_per_sec" in info :
1599+ val = info ["steps_per_sec" ]
1600+ fields ["steps_per_sec" ] = (
1601+ f"{ val :.1f} /s" if val is not None and np .isfinite (float (val )) else "—"
1602+ )
1603+ if fields :
1604+ progress .update (task_id , ** fields )
1605+ if info .get ("status" ) in ("ok" , "elbo@0" ):
1606+ progress .stop_task (task_id )
1607+
1608+ return cb
1609+
1610+
15861611def multipath_pathfinder (
15871612 model : Model ,
15881613 num_paths : int ,
@@ -1706,42 +1731,7 @@ def multipath_pathfinder(
17061731 compute_start = time .time ()
17071732 try :
17081733 # Per-path progress bar (one row per path, updated in real time)
1709- progress = CustomProgress (
1710- TextColumn (
1711- "{task.description}" , table_column = Column ("Path" , min_width = 7 , no_wrap = True )
1712- ),
1713- TextColumn (
1714- "{task.fields[status]}" , table_column = Column ("Status" , min_width = 10 , no_wrap = True )
1715- ),
1716- TextColumn (
1717- "{task.fields[lbfgs_steps]}" ,
1718- table_column = Column ("Steps" , min_width = 6 , no_wrap = True ),
1719- ),
1720- TextColumn (
1721- "{task.fields[steps_per_sec]}" ,
1722- table_column = Column ("Steps/s" , min_width = 8 , no_wrap = True ),
1723- ),
1724- TextColumn (
1725- "{task.fields[best_ind]}" ,
1726- table_column = Column ("Best step" , min_width = 9 , no_wrap = True ),
1727- ),
1728- TextColumn (
1729- "{task.fields[best_elbo]}" ,
1730- table_column = Column ("Best ELBO" , min_width = 12 , no_wrap = True ),
1731- ),
1732- TextColumn (
1733- "{task.fields[current_elbo]}" ,
1734- table_column = Column ("Cur ELBO" , min_width = 12 , no_wrap = True ),
1735- ),
1736- TextColumn (
1737- "{task.fields[step_size]}" ,
1738- table_column = Column ("Step size" , min_width = 10 , no_wrap = True ),
1739- ),
1740- TimeElapsedColumn (table_column = Column ("Elapsed" , min_width = 8 , no_wrap = True )),
1741- include_headers = True ,
1742- console = Console (theme = default_progress_theme ),
1743- disable = not progressbar ,
1744- )
1734+ progress = _make_multipath_progress (progressbar )
17451735
17461736 # Create one task per path and build per-path progress callbacks
17471737 task_ids = []
@@ -1760,51 +1750,7 @@ def multipath_pathfinder(
17601750 total = None ,
17611751 )
17621752 task_ids .append (tid )
1763-
1764- def _make_cb (task_id : int ) -> Callable :
1765- def cb (info : dict ) -> None :
1766- fields : dict [str , Any ] = {}
1767- if "status" in info and info ["status" ] is not None :
1768- fields ["status" ] = info ["status" ]
1769- if "lbfgs_steps" in info :
1770- fields ["lbfgs_steps" ] = info ["lbfgs_steps" ]
1771- if "best_elbo" in info :
1772- val = info ["best_elbo" ]
1773- fields ["best_elbo" ] = (
1774- f"{ val :.3f} " if val is not None and np .isfinite (float (val )) else "—"
1775- )
1776- if "best_ind" in info :
1777- val = info ["best_ind" ]
1778- fields ["best_ind" ] = (
1779- str (int (val ))
1780- if val is not None and np .isfinite (float (val ))
1781- else "—"
1782- )
1783- if "current_elbo" in info :
1784- val = info ["current_elbo" ]
1785- fields ["current_elbo" ] = (
1786- f"{ val :.3f} " if val is not None and np .isfinite (float (val )) else "—"
1787- )
1788- if "step_size" in info :
1789- val = info ["step_size" ]
1790- fields ["step_size" ] = (
1791- f"{ val :.2e} " if val is not None and np .isfinite (float (val )) else "—"
1792- )
1793- if "steps_per_sec" in info :
1794- val = info ["steps_per_sec" ]
1795- fields ["steps_per_sec" ] = (
1796- f"{ val :.1f} /s"
1797- if val is not None and np .isfinite (float (val ))
1798- else "—"
1799- )
1800- if fields :
1801- progress .update (task_id , ** fields )
1802- if info .get ("status" ) in ("ok" , "elbo@0" ):
1803- progress .stop_task (task_id )
1804-
1805- return cb
1806-
1807- path_callbacks .append (_make_cb (tid ))
1753+ path_callbacks .append (_make_progress_callback (progress , tid ))
18081754
18091755 # concurrent="process" gives true parallelism via separate worker processes
18101756 # (matching PyMC's approach). concurrent=None is serial (useful for debugging).
@@ -1824,18 +1770,6 @@ def cb(info: dict) -> None:
18241770 raise result
18251771 else :
18261772 results .append (result )
1827- except filelock .Timeout :
1828- logger .warning ("Lock timeout. Retrying..." )
1829- num_attempts = 0
1830- while num_attempts < 10 :
1831- try :
1832- results .append (result )
1833- logger .info ("Lock acquired. Continuing..." )
1834- break
1835- except filelock .Timeout :
1836- num_attempts += 1
1837- time .sleep (0.5 )
1838- logger .warning (f"Lock timeout. Retrying... ({ num_attempts } /10)" )
18391773 except Exception as e :
18401774 logger .warning ("Unexpected error in a path: %s" , str (e ))
18411775 results .append (
@@ -2083,12 +2017,12 @@ def fit_pathfinder(
20832017 if version .parse (blackjax .__version__ ).major < 1 :
20842018 raise ImportError ("fit_pathfinder requires blackjax 1.0 or above" )
20852019
2086- jitter_seed , pathfinder_seed , sample_seed = _get_seeds_per_chain (random_seed , 3 )
2020+ _ , pathfinder_seed , sample_seed = _get_seeds_per_chain (random_seed , 3 )
20872021 # TODO: extend initial points with jitter_scale to blackjax
20882022 # TODO: extend blackjax pathfinder to multiple paths
20892023 x0 , _ = DictToArrayBijection .map (model .initial_point ())
20902024 logp_func = get_jaxified_logp_of_ravel_inputs (model )
2091- pathfinder_state , pathfinder_info = blackjax .vi .pathfinder .approximate (
2025+ pathfinder_state , _ = blackjax .vi .pathfinder .approximate (
20922026 rng_key = jax .random .key (pathfinder_seed ),
20932027 logdensity_fn = logp_func ,
20942028 initial_position = x0 ,
0 commit comments