diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index 91ac0e8a..17461ba4 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -262,7 +262,7 @@ def _clip_to_end(tprev, tnext, t1, keep_step): # The tolerance means that we don't end up with too-small intervals for # dense output, which then gives numerically unstable answers due to floating # point errors. - if tnext.dtype is jnp.dtype("float64"): + if tnext.dtype == jnp.dtype("float64"): tol = 1e-10 else: tol = 1e-6 @@ -385,7 +385,7 @@ def body_fun_aux(state): error_order, state.controller_state, ) - assert jnp.result_type(keep_step) is jnp.dtype(bool) + assert jnp.result_type(keep_step) == jnp.dtype(bool) # # Do some book-keeping.