1919from numpyro .ops .provenance import eval_provenance
2020
2121
22+ def _call_bind (primitive , fn , * args ):
23+ try :
24+ return primitive .bind (fn , * args )
25+ except TypeError :
26+ return primitive .bind (* args , subfuns = (fn ,))
27+
28+
2229@pytest .mark .parametrize (
2330 "f, inputs, expected_output" ,
2431 [
@@ -84,7 +91,7 @@ def identity(x):
8491 else {}
8592 )
8693 fn , out_tree = flatten_fun_nokwargs (lu .wrap_init (id_fn , ** id_info ), in_tree )
87- out = call_p . bind ( fn , * args )
94+ out = _call_bind ( call_p , fn , * args )
8895 return jax .tree .unflatten (out_tree (), out )
8996
9097 assert eval_provenance (identity , x = {"v" : 2 }) == {"v" : frozenset ({"x" })}
@@ -99,7 +106,7 @@ def identity(x):
99106 else {}
100107 )
101108 fn , out_tree = flatten_fun_nokwargs (lu .wrap_init (id_fn , ** id_info ), in_tree )
102- out = closed_call_p . bind ( fn , * args )
109+ out = _call_bind ( closed_call_p , fn , * args )
103110 return jax .tree .unflatten (out_tree (), out )
104111
105112 assert eval_provenance (identity , x = {"v" : 2 }) == {"v" : frozenset ({"x" })}
0 commit comments