I've run into an issue where sampling a (large) straightforward model results in unexpectedly large memory usage compared to numpyro sampled via PyMC.
Base code:
import numpy as np
import pymc as pm
from scipy.special import expit
N_PERSONS = 60_000
N_ITEMS = 1_000
N_ITEMS_PER_PERSON = 20
RPG = np.random.default_rng(12_1_2025)
theta_known = RPG.normal(0, 1, N_PERSONS)
beta_known = RPG.normal(0, 1, N_ITEMS)
person_idx = np.repeat(np.arange(N_PERSONS), N_ITEMS_PER_PERSON)
item_idx = np.arange(N_ITEMS)
item_idx = np.array(
[
RPG.choice(item_idx, size=N_ITEMS_PER_PERSON, replace=False)
for i in np.arange(N_PERSONS)
]
).flatten()
eta = theta_known[person_idx] - beta_known[item_idx]
score = (RPG.random(N_PERSONS * N_ITEMS_PER_PERSON) < expit(eta)) + 0
# ==================================================================
# Baseline Rasch Model
# ==================================================================
def build_rasch(
person_id: np.ndarray,
item_id: np.ndarray,
Y: np.ndarray,
theta_fixed: np.ndarray,
coords: dict = None,
) -> pm.Model:
N_obs = len(Y)
N_persons = len(np.unique(person_id))
N_items = len(np.unique(item_id))
if coords is None:
coords = {
"obs": np.arange(N_obs),
"person": np.arange(N_persons),
"item": np.arange(N_items),
}
with pm.Model(coords=coords) as model:
person_idx = pm.Data("person_idx", person_id, dims="obs")
item_idx = pm.Data("item_idx", item_id, dims="obs")
Y_data = pm.Data("Y_data", Y, dims="obs")
theta = pm.Data("theta_fixed", theta_fixed, dims="person")
beta = pm.Normal("beta", mu=0, sigma=2, dims="item")
theta_obs = theta[person_idx]
beta_obs = beta[item_idx]
logit_p = pm.Deterministic("logit_p", theta_obs - beta_obs, dims="obs")
y_obs = pm.Bernoulli("y_obs", logit_p=logit_p, observed=Y_data, dims="obs")
return model
rasch_model = build_rasch(
person_id=person_idx,
item_id=item_idx,
Y=score,
theta_fixed=theta_known,
)
Using numpyro to sample:
with rasch_model:
t_rasch = pm.sample(
nuts_sampler="numpyro",
chains=4,
draws=300,
tune=500,
progress_bar=True,
nuts_sampler_kwargs={"chain_method": "vectorized"},
)
Yields the following stats:
Mem:
Before sampling: ~3.9GB
During sampling: ~5.7GB
After sampling: ~16gb
After restarting the notebook and clearing memory, running nutpie to sample:
with rasch_model:
t_rasch = pm.sample(
nuts_sampler="nutpie",
chains=4,
draws=300,
tune=500,
progress_bar=True,
nuts_sampler_kwargs=dict(backend="jax", gradient_backend="jax"),
)
Yields the following stats:
Mem:
Before sampling: ~3.9GB
During sampling: ~63GB
After sampling: ~34gb
Notes:
Sampling is conducted on GPU (RTX 5000 Ada) via WSL. Running nutpie version 0.16.3
Let me know if you're able to replicate the issue or need additional info.
Thanks,
I've run into an issue where sampling a (large) straightforward model results in unexpectedly large memory usage compared to numpyro sampled via PyMC.
Base code:
Using numpyro to sample:
Yields the following stats:
Mem:
Before sampling:
~3.9GBDuring sampling:
~5.7GBAfter sampling:
~16gbAfter restarting the notebook and clearing memory, running nutpie to sample:
Yields the following stats:
Mem:
Before sampling:
~3.9GBDuring sampling:
~63GBAfter sampling:
~34gbNotes:
Sampling is conducted on GPU (RTX 5000 Ada) via WSL. Running nutpie version
0.16.3Let me know if you're able to replicate the issue or need additional info.
Thanks,