Skip to content

Commit 5a44f02

Browse files
velochyricardoV94
authored andcommitted
Some cleanup, including trimming test count
1 parent cea7a15 commit 5a44f02

File tree

2 files changed

+85
-150
lines changed

2 files changed

+85
-150
lines changed

pymc_extras/inference/pathfinder/pathfinder.py

Lines changed: 78 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from typing import Any, Literal, Self, TypeAlias
2828

2929
import arviz as az
30-
import filelock
3130
import numpy as np
3231
import pymc as pm
3332
import pytensor
@@ -548,54 +547,6 @@ def make_pathfinder_sample_fn(
548547
return fn
549548

550549

551-
def make_elbo_from_state_fn(
552-
model: Model,
553-
N: int,
554-
J: int,
555-
jacobian: bool,
556-
compile_kwargs: dict,
557-
vectorize: bool = False,
558-
) -> Function:
559-
"""Compiled (x, g, alpha, S, Z, u) → (phi, logQ, logP) for fixture/tests.
560-
561-
S, Z are explicit inputs (not shared), for recomputing ELBO from saved state.
562-
"""
563-
(logP_single,), single_input = pm.pytensorf.join_nonshared_inputs(
564-
model.initial_point(),
565-
[model.logp(jacobian=jacobian)],
566-
model.value_vars,
567-
)
568-
x_sym = pt.vector("x", dtype="float64")
569-
g_sym = pt.vector("g", dtype="float64")
570-
alpha_sym = pt.vector("alpha", dtype="float64")
571-
S_sym = pt.matrix("S", dtype="float64")
572-
Z_sym = pt.matrix("Z", dtype="float64")
573-
u_sym = pt.matrix("u", dtype="float64")
574-
phi_sym, logQ_sym, _ = _bfgs_sample_pt(x_sym, g_sym, alpha_sym, S_sym, Z_sym, u_sym, J, N)
575-
if vectorize:
576-
batched_logP_sym = vectorize_graph(logP_single, replace={single_input: phi_sym})
577-
else:
578-
batched_logP_sym = pytensor.map(
579-
fn=lambda x_i: clone_replace([logP_single], replace={single_input: x_i})[0],
580-
sequences=[phi_sym],
581-
return_updates=False,
582-
)
583-
fn = pytensor.function(
584-
[
585-
pytensor.In(x_sym, borrow=True),
586-
pytensor.In(g_sym, borrow=True),
587-
pytensor.In(alpha_sym, borrow=True),
588-
pytensor.In(S_sym, borrow=True),
589-
pytensor.In(Z_sym, borrow=True),
590-
pytensor.In(u_sym, borrow=True),
591-
],
592-
[phi_sym, logQ_sym, batched_logP_sym],
593-
**compile_kwargs,
594-
)
595-
fn.trust_input = True
596-
return fn
597-
598-
599550
class PathStatus(Enum):
600551
"""
601552
Statuses of a single-path pathfinder.
@@ -1583,6 +1534,80 @@ def _get_status_warning(mpr: MultiPathfinderResult) -> list[str]:
15831534
return warnings
15841535

15851536

1537+
def _make_multipath_progress(progressbar: bool) -> CustomProgress:
1538+
return CustomProgress(
1539+
TextColumn("{task.description}", table_column=Column("Path", min_width=7, no_wrap=True)),
1540+
TextColumn(
1541+
"{task.fields[status]}", table_column=Column("Status", min_width=10, no_wrap=True)
1542+
),
1543+
TextColumn(
1544+
"{task.fields[lbfgs_steps]}", table_column=Column("Steps", min_width=6, no_wrap=True)
1545+
),
1546+
TextColumn(
1547+
"{task.fields[steps_per_sec]}",
1548+
table_column=Column("Steps/s", min_width=8, no_wrap=True),
1549+
),
1550+
TextColumn(
1551+
"{task.fields[best_ind]}", table_column=Column("Best step", min_width=9, no_wrap=True)
1552+
),
1553+
TextColumn(
1554+
"{task.fields[best_elbo]}", table_column=Column("Best ELBO", min_width=12, no_wrap=True)
1555+
),
1556+
TextColumn(
1557+
"{task.fields[current_elbo]}",
1558+
table_column=Column("Cur ELBO", min_width=12, no_wrap=True),
1559+
),
1560+
TextColumn(
1561+
"{task.fields[step_size]}",
1562+
table_column=Column("Step size", min_width=10, no_wrap=True),
1563+
),
1564+
TimeElapsedColumn(table_column=Column("Elapsed", min_width=8, no_wrap=True)),
1565+
include_headers=True,
1566+
console=Console(theme=default_progress_theme),
1567+
disable=not progressbar,
1568+
)
1569+
1570+
1571+
def _make_progress_callback(progress: CustomProgress, task_id: int) -> Callable[[dict], None]:
1572+
def cb(info: dict) -> None:
1573+
fields: dict[str, Any] = {}
1574+
if "status" in info and info["status"] is not None:
1575+
fields["status"] = info["status"]
1576+
if "lbfgs_steps" in info:
1577+
fields["lbfgs_steps"] = info["lbfgs_steps"]
1578+
if "best_elbo" in info:
1579+
val = info["best_elbo"]
1580+
fields["best_elbo"] = (
1581+
f"{val:.3f}" if val is not None and np.isfinite(float(val)) else "—"
1582+
)
1583+
if "best_ind" in info:
1584+
val = info["best_ind"]
1585+
fields["best_ind"] = (
1586+
str(int(val)) if val is not None and np.isfinite(float(val)) else "—"
1587+
)
1588+
if "current_elbo" in info:
1589+
val = info["current_elbo"]
1590+
fields["current_elbo"] = (
1591+
f"{val:.3f}" if val is not None and np.isfinite(float(val)) else "—"
1592+
)
1593+
if "step_size" in info:
1594+
val = info["step_size"]
1595+
fields["step_size"] = (
1596+
f"{val:.2e}" if val is not None and np.isfinite(float(val)) else "—"
1597+
)
1598+
if "steps_per_sec" in info:
1599+
val = info["steps_per_sec"]
1600+
fields["steps_per_sec"] = (
1601+
f"{val:.1f}/s" if val is not None and np.isfinite(float(val)) else "—"
1602+
)
1603+
if fields:
1604+
progress.update(task_id, **fields)
1605+
if info.get("status") in ("ok", "elbo@0"):
1606+
progress.stop_task(task_id)
1607+
1608+
return cb
1609+
1610+
15861611
def multipath_pathfinder(
15871612
model: Model,
15881613
num_paths: int,
@@ -1706,42 +1731,7 @@ def multipath_pathfinder(
17061731
compute_start = time.time()
17071732
try:
17081733
# Per-path progress bar (one row per path, updated in real time)
1709-
progress = CustomProgress(
1710-
TextColumn(
1711-
"{task.description}", table_column=Column("Path", min_width=7, no_wrap=True)
1712-
),
1713-
TextColumn(
1714-
"{task.fields[status]}", table_column=Column("Status", min_width=10, no_wrap=True)
1715-
),
1716-
TextColumn(
1717-
"{task.fields[lbfgs_steps]}",
1718-
table_column=Column("Steps", min_width=6, no_wrap=True),
1719-
),
1720-
TextColumn(
1721-
"{task.fields[steps_per_sec]}",
1722-
table_column=Column("Steps/s", min_width=8, no_wrap=True),
1723-
),
1724-
TextColumn(
1725-
"{task.fields[best_ind]}",
1726-
table_column=Column("Best step", min_width=9, no_wrap=True),
1727-
),
1728-
TextColumn(
1729-
"{task.fields[best_elbo]}",
1730-
table_column=Column("Best ELBO", min_width=12, no_wrap=True),
1731-
),
1732-
TextColumn(
1733-
"{task.fields[current_elbo]}",
1734-
table_column=Column("Cur ELBO", min_width=12, no_wrap=True),
1735-
),
1736-
TextColumn(
1737-
"{task.fields[step_size]}",
1738-
table_column=Column("Step size", min_width=10, no_wrap=True),
1739-
),
1740-
TimeElapsedColumn(table_column=Column("Elapsed", min_width=8, no_wrap=True)),
1741-
include_headers=True,
1742-
console=Console(theme=default_progress_theme),
1743-
disable=not progressbar,
1744-
)
1734+
progress = _make_multipath_progress(progressbar)
17451735

17461736
# Create one task per path and build per-path progress callbacks
17471737
task_ids = []
@@ -1760,51 +1750,7 @@ def multipath_pathfinder(
17601750
total=None,
17611751
)
17621752
task_ids.append(tid)
1763-
1764-
def _make_cb(task_id: int) -> Callable:
1765-
def cb(info: dict) -> None:
1766-
fields: dict[str, Any] = {}
1767-
if "status" in info and info["status"] is not None:
1768-
fields["status"] = info["status"]
1769-
if "lbfgs_steps" in info:
1770-
fields["lbfgs_steps"] = info["lbfgs_steps"]
1771-
if "best_elbo" in info:
1772-
val = info["best_elbo"]
1773-
fields["best_elbo"] = (
1774-
f"{val:.3f}" if val is not None and np.isfinite(float(val)) else "—"
1775-
)
1776-
if "best_ind" in info:
1777-
val = info["best_ind"]
1778-
fields["best_ind"] = (
1779-
str(int(val))
1780-
if val is not None and np.isfinite(float(val))
1781-
else "—"
1782-
)
1783-
if "current_elbo" in info:
1784-
val = info["current_elbo"]
1785-
fields["current_elbo"] = (
1786-
f"{val:.3f}" if val is not None and np.isfinite(float(val)) else "—"
1787-
)
1788-
if "step_size" in info:
1789-
val = info["step_size"]
1790-
fields["step_size"] = (
1791-
f"{val:.2e}" if val is not None and np.isfinite(float(val)) else "—"
1792-
)
1793-
if "steps_per_sec" in info:
1794-
val = info["steps_per_sec"]
1795-
fields["steps_per_sec"] = (
1796-
f"{val:.1f}/s"
1797-
if val is not None and np.isfinite(float(val))
1798-
else "—"
1799-
)
1800-
if fields:
1801-
progress.update(task_id, **fields)
1802-
if info.get("status") in ("ok", "elbo@0"):
1803-
progress.stop_task(task_id)
1804-
1805-
return cb
1806-
1807-
path_callbacks.append(_make_cb(tid))
1753+
path_callbacks.append(_make_progress_callback(progress, tid))
18081754

18091755
# concurrent="process" gives true parallelism via separate worker processes
18101756
# (matching PyMC's approach). concurrent=None is serial (useful for debugging).
@@ -1824,18 +1770,6 @@ def cb(info: dict) -> None:
18241770
raise result
18251771
else:
18261772
results.append(result)
1827-
except filelock.Timeout:
1828-
logger.warning("Lock timeout. Retrying...")
1829-
num_attempts = 0
1830-
while num_attempts < 10:
1831-
try:
1832-
results.append(result)
1833-
logger.info("Lock acquired. Continuing...")
1834-
break
1835-
except filelock.Timeout:
1836-
num_attempts += 1
1837-
time.sleep(0.5)
1838-
logger.warning(f"Lock timeout. Retrying... ({num_attempts}/10)")
18391773
except Exception as e:
18401774
logger.warning("Unexpected error in a path: %s", str(e))
18411775
results.append(
@@ -2083,12 +2017,12 @@ def fit_pathfinder(
20832017
if version.parse(blackjax.__version__).major < 1:
20842018
raise ImportError("fit_pathfinder requires blackjax 1.0 or above")
20852019

2086-
jitter_seed, pathfinder_seed, sample_seed = _get_seeds_per_chain(random_seed, 3)
2020+
_, pathfinder_seed, sample_seed = _get_seeds_per_chain(random_seed, 3)
20872021
# TODO: extend initial points with jitter_scale to blackjax
20882022
# TODO: extend blackjax pathfinder to multiple paths
20892023
x0, _ = DictToArrayBijection.map(model.initial_point())
20902024
logp_func = get_jaxified_logp_of_ravel_inputs(model)
2091-
pathfinder_state, pathfinder_info = blackjax.vi.pathfinder.approximate(
2025+
pathfinder_state, _ = blackjax.vi.pathfinder.approximate(
20922026
rng_key=jax.random.key(pathfinder_seed),
20932027
logdensity_fn=logp_func,
20942028
initial_position=x0,

tests/pathfinder/test_model_equivalence.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
EPSILON = 1e-12
5959

6060
COMPILE_KWARGS = {"mode": Mode(linker=DEFAULT_LINKER)}
61+
ALL_MODEL_NAMES = list(MODEL_FACTORIES.keys())
6162

6263

6364
# ---------------------------------------------------------------------------
@@ -279,7 +280,7 @@ def test_inv_hessian_diag_matches_sample_scales():
279280
np.testing.assert_allclose(emp_var, diag_est, rtol=0.2, atol=1e-3)
280281

281282

282-
@pytest.mark.parametrize("model_name", ["ard_regression", "hd_gaussian"])
283+
@pytest.mark.parametrize("model_name", ["ard_regression"])
283284
@pytest.mark.parametrize("vectorize", [False, True])
284285
def test_rng_reproducibility(model_name, vectorize):
285286
"""Two streaming runs with the same seed produce bit-identical results."""
@@ -298,7 +299,7 @@ def test_rng_reproducibility(model_name, vectorize):
298299
assert r1.best_step_idx == r2.best_step_idx if hasattr(r1, "best_step_idx") else True
299300

300301

301-
@pytest.mark.parametrize("model_name", list(MODEL_FACTORIES.keys()))
302+
@pytest.mark.parametrize("model_name", ["ard_regression"])
302303
@pytest.mark.parametrize("vectorize", [False, True])
303304
def test_fixture_match(model_name, vectorize):
304305
"""Streaming path selects an ELBO argmax consistent with the fixture.
@@ -333,7 +334,7 @@ def test_fixture_match(model_name, vectorize):
333334
)
334335

335336

336-
@pytest.mark.parametrize("model_name", list(MODEL_FACTORIES.keys()))
337+
@pytest.mark.parametrize("model_name", ALL_MODEL_NAMES)
337338
def test_elbo_all_finite(model_name):
338339
"""Reference ELBO values recorded from current code must all be finite."""
339340
*_, elbo_ref = _load_fixture(model_name)
@@ -393,7 +394,7 @@ def _check_statistical_equivalence(
393394
)
394395

395396

396-
@pytest.mark.parametrize("model_name", list(MODEL_FACTORIES.keys()))
397+
@pytest.mark.parametrize("model_name", ALL_MODEL_NAMES)
397398
@pytest.mark.parametrize("vectorize", [False, True])
398399
def test_elbo_statistical_equivalence(model_name, vectorize):
399400
"""Full-stack: LBFGSStreamingCallback computes alpha/s/z internally; ELBO matches reference."""
@@ -406,8 +407,8 @@ def test_elbo_statistical_equivalence(model_name, vectorize):
406407
_check_statistical_equivalence(model_name, elbo, elbo_ref)
407408

408409

409-
@pytest.mark.parametrize("model_name", list(MODEL_FACTORIES.keys()))
410-
@pytest.mark.parametrize("vectorize", [False, True])
410+
@pytest.mark.parametrize("model_name", ["ard_regression"])
411+
@pytest.mark.parametrize("vectorize", [False])
411412
def test_elbo_statistical_equivalence_different_seed(model_name, vectorize):
412413
"""Verify tolerances are meaningful by using a different ELBO seed."""
413414
x_full, _, elbo_ref = _load_fixture(model_name)

0 commit comments

Comments
 (0)