Skip to content

Conversation

@patrick-kidger
Copy link
Owner

Fixes #688.

Tagging @lockwo @johannahaffner WDYT?

(Whilst I'm here I've removed the 'compilation is taking a long time' which is usually much less of an issue these days, ever since we landed the scan-over-stages refactor. I also don't like that it recommends noinline these days, which is pretty deep magic!)


### Compilation is taking a long time.

- Set `dt0=<not None>`, e.g. `diffeqsolve(..., dt0=0.01)`. In contrast `dt0=None` will determine the initial step size automatically, but will increase compilation time.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the first two of these points worth keeping anywhere?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They feel niche enough (i.e. not meeting the 'F' in 'FAQ') that I was feeling inclined to cut them.

If we ever find ourselves with a longer list of tips-and-tricks to tackle compilation time then I'd be happy to give them a home there, however.

import timeit

@jax.jit
@eqx.debug.assert_max_traces(max_traces=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this have any overhead? I.e. should it be used to debug, then excluded from the final timing analysis (which is this code block)?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, good catch. Only overhead at compile time, I think! So we'll be fine for the use-case here.

execution_time = min(timeit.repeat(lambda: jax.block_until_ready(run(x)), number=1, repeat=20))
```

2. Use the same ODE solver in both implementations to get an apples-to-apples comparison. It's not surprising that different solvers give different performance characteristics. (And if one implementation does not provide a solver that the other does, then no comparison can be made.)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"same ODE solver" -> "same solver"?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So here (and for your comment below) I'm intentionally only focusing on ODE solvers.

The reason for this is that I already know that we could do more to improve performance of SDE solvers, and I've just never really found the time to figure that one out. (And on this note you've been very gracious around not merging the stateful-controls PR on this topic!)

execution_time = min(timeit.repeat(lambda: jax.block_until_ready(run(x)), number=1, repeat=20))
```

2. Use the same ODE solver in both implementations to get an apples-to-apples comparison. It's not surprising that different solvers give different performance characteristics. (And if one implementation does not provide a solver that the other does, then no comparison can be made.)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess timing has been mostly ODE focused, but for SDEs is there anything of note (e.g. in how brownian noise is computed between implementations)?

@johannahaffner
Copy link
Contributor

No notes, this looks great! Very useful to be able to point people to this.

@patrick-kidger patrick-kidger merged commit 2fd3ef3 into dev Sep 18, 2025
1 of 2 checks passed
@patrick-kidger
Copy link
Owner Author

Thank you both for the comments!

@patrick-kidger patrick-kidger deleted the bookmark-faq branch September 18, 2025 22:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants