diff --git a/diffrax/_step_size_controller/pid.py b/diffrax/_step_size_controller/pid.py index 1092f184..c0ccfbec 100644 --- a/diffrax/_step_size_controller/pid.py +++ b/diffrax/_step_size_controller/pid.py @@ -81,8 +81,8 @@ def intermediate(carry): return jnp.minimum(100 * h0, h1) -# _PidState = (prev_inv_scaled_error, prev_prev_inv_scaled_error) -_PidState = tuple[RealScalarLike, RealScalarLike] +# _PidState = (prev_inv_scaled_error, prev_prev_inv_scaled_error, at_dtmin) +_PidState = tuple[RealScalarLike, RealScalarLike, BoolScalarLike] # We use a metaclass for backwards compatibility. When a user calls @@ -388,6 +388,7 @@ def init( return t1, ( jnp.array(1.0, dtype=real_dtype), jnp.array(1.0, dtype=real_dtype), + False, ) def adapt_step_size( @@ -469,6 +470,7 @@ def adapt_step_size( ( prev_inv_scaled_error, prev_prev_inv_scaled_error, + at_dtmin, ) = controller_state error_order = self._get_error_order(error_order) prev_dt = t1 - t0 @@ -489,9 +491,9 @@ def _scale(_y0, _y1_candidate, _y_error): scaled_error = self.norm(jtu.tree_map(_scale, y0, y1_candidate, y_error)) keep_step = scaled_error < 1 - # Automatically keep the step if we're at dtmin. + # Automatically keep the step if it was at dtmin. if self.dtmin is not None: - keep_step = keep_step | (prev_dt <= self.dtmin) + keep_step = keep_step | at_dtmin # Make sure it's not a Python scalar and thus getting a ZeroDivisionError. inv_scaled_error = 1 / jnp.asarray(scaled_error) inv_scaled_error = lax.stop_gradient( @@ -545,6 +547,11 @@ def _scale(_y0, _y1_candidate, _y_error): if self.dtmin is not None: if not self.force_dtmin: result = RESULTS.where(dt < self.dtmin, RESULTS.dt_min_reached, result) + # if we are already at dtmin and dt is unchanged (factor == 1), + # reset dt to dtmin to avoid accumulating float precision errors + dt = jnp.where(at_dtmin & (factor == 1), self.dtmin, dt) + # this flags the next loop to accept step + at_dtmin = dt <= self.dtmin dt = jnp.maximum(dt, self.dtmin) next_t0 = jnp.where(keep_step, t1, t0) @@ -554,7 +561,7 @@ def _scale(_y0, _y1_candidate, _y_error): prev_inv_scaled_error = jnp.where( keep_step, prev_inv_scaled_error, prev_prev_inv_scaled_error ) - controller_state = inv_scaled_error, prev_inv_scaled_error + controller_state = inv_scaled_error, prev_inv_scaled_error, at_dtmin # made_jump is handled by ClipStepSizeController, so we automatically set it to # False return keep_step, next_t0, next_t1, False, controller_state, result