diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index bc319d40..1855d4ae 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -161,10 +161,7 @@ def _check(term_cls, term, term_contr_kwargs, yi): pass elif n_term_args == 2: vf_type_expected, control_type_expected = term_args - try: - vf_type = eqx.filter_eval_shape(term.vf, t, yi, args) - except Exception as e: - raise ValueError(f"Error while tracing {term}.vf: " + str(e)) + vf_type = eqx.filter_eval_shape(term.vf, t, yi, args) vf_type_compatible = eqx.filter_eval_shape( better_isinstance, vf_type, vf_type_expected ) @@ -173,10 +170,7 @@ def _check(term_cls, term, term_contr_kwargs, yi): contr = ft.partial(term.contr, **term_contr_kwargs) # Work around https://github.com/google/jax/issues/21825 - try: - control_type = eqx.filter_eval_shape(contr, t, t) - except Exception as e: - raise ValueError(f"Error while tracing {term}.contr: " + str(e)) + control_type = eqx.filter_eval_shape(contr, t, t) control_type_compatible = eqx.filter_eval_shape( better_isinstance, control_type, control_type_expected )