Skip to content

fix test because new jax 0.9.2 release#2155

Merged
Qazalbash merged 2 commits intopyro-ppl:masterfrom
juanitorduz:fix-jax-0.9.2
Mar 20, 2026
Merged

fix test because new jax 0.9.2 release#2155
Qazalbash merged 2 commits intopyro-ppl:masterfrom
juanitorduz:fix-jax-0.9.2

Conversation

@juanitorduz
Copy link
Copy Markdown
Collaborator

@juanitorduz juanitorduz commented Mar 19, 2026

Closes #2154

@Qazalbash
Copy link
Copy Markdown
Collaborator

Is this change working for older JAX?

@juanitorduz
Copy link
Copy Markdown
Collaborator Author

juanitorduz commented Mar 20, 2026

Is this change working for older JAX?

Thanks for the feedback! With the latest commit b010649 it does work (I have tested locally)

The fix is a one-line swap in _call_bind -- try the old positional API (bind(fn, *args)) first, which works on JAX <= 0.9.0. On JAX 0.9.2 it cleanly raises TypeError, triggering the fallback to the new subfuns API (bind(*args, subfuns=(fn,))). All 10 tests pass on JAX 0.9.0.

@juanitorduz juanitorduz requested a review from fehiepsi March 20, 2026 07:41
@juanitorduz juanitorduz marked this pull request as ready for review March 20, 2026 07:41
@juanitorduz juanitorduz requested a review from Qazalbash March 20, 2026 07:42
Copy link
Copy Markdown
Collaborator

@Qazalbash Qazalbash left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Qazalbash Qazalbash merged commit d1670c9 into pyro-ppl:master Mar 20, 2026
9 checks passed
def _call_bind(primitive, fn, *args):
try:
return primitive.bind(fn, *args)
except TypeError:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The typeerror might come from the fn, not from the bind method. Maybe check for jax version instead?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Tests failing with jax==0.9.2 in test/ops/test_provenance.py

3 participants