Skip to content

Commit d1670c9

Browse files
authored
fix test because new jax 0.9.2 release (#2155)
1 parent af37655 commit d1670c9

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

test/ops/test_provenance.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@
1919
from numpyro.ops.provenance import eval_provenance
2020

2121

22+
def _call_bind(primitive, fn, *args):
23+
try:
24+
return primitive.bind(fn, *args)
25+
except TypeError:
26+
return primitive.bind(*args, subfuns=(fn,))
27+
28+
2229
@pytest.mark.parametrize(
2330
"f, inputs, expected_output",
2431
[
@@ -84,7 +91,7 @@ def identity(x):
8491
else {}
8592
)
8693
fn, out_tree = flatten_fun_nokwargs(lu.wrap_init(id_fn, **id_info), in_tree)
87-
out = call_p.bind(fn, *args)
94+
out = _call_bind(call_p, fn, *args)
8895
return jax.tree.unflatten(out_tree(), out)
8996

9097
assert eval_provenance(identity, x={"v": 2}) == {"v": frozenset({"x"})}
@@ -99,7 +106,7 @@ def identity(x):
99106
else {}
100107
)
101108
fn, out_tree = flatten_fun_nokwargs(lu.wrap_init(id_fn, **id_info), in_tree)
102-
out = closed_call_p.bind(fn, *args)
109+
out = _call_bind(closed_call_p, fn, *args)
103110
return jax.tree.unflatten(out_tree(), out)
104111

105112
assert eval_provenance(identity, x={"v": 2}) == {"v": frozenset({"x"})}

0 commit comments

Comments
 (0)