Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Release
uses: patrick-kidger/action_update_python_project@v6
uses: patrick-kidger/action_update_python_project@v8
with:
python-version: "3.11"
test-script: |
cp -r ${{ github.workspace }}/test ./test
cp ${{ github.workspace }}/pyproject.toml ./pyproject.toml
python -m pip install -r ./test/requirements.txt
python -m test
uv sync --extra tests --no-install-project --inexact
uv run --no-sync pytest
pypi-token: ${{ secrets.pypi_token }}
github-user: patrick-kidger
github-token: ${{ github.token }}
11 changes: 7 additions & 4 deletions .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,16 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install -r ./test/requirements.txt

python -m pip install '.[dev,docs,tests]'

- name: Checks with pre-commit
uses: pre-commit/[email protected]
run: |
pre-commit run --all-files

- name: Test with pytest
run: |
python -m pip install .
python -m test

- name: Check that documentation can be built.
run: |
mkdocs build
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ site/
.pymon
.idea/
.venv/
uv.lock
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ repos:
files: ^pyproject\.toml$
additional_dependencies: ["toml-sort==0.23.1"]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.2
rev: v0.13.0
hooks:
- id: ruff-format # formatter
types_or: [ python, pyi, jupyter ]
types_or: [ python, pyi, jupyter, toml ]
- id: ruff # linter
types_or: [ python, pyi, jupyter ]
types_or: [ python, pyi, jupyter, toml ]
args: [ --fix ]
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.350
rev: v1.1.405
hooks:
- id: pyright
additional_dependencies: [equinox, jax, jaxtyping, optax, optimistix, lineax, pytest, typeguard==2.13.3, typing_extensions, wadler_lindig]
18 changes: 5 additions & 13 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,15 @@ Contributions (pull requests) are very welcome! Here's how to get started.

First fork the library on GitHub.

Then clone and install the library in development mode:
Then clone and install the library:

```bash
git clone https://github.com/your-username-here/diffrax.git
cd diffrax
pip install -e .
pip install -e '.[dev]'
pre-commit install # `pre-commit` is installed by `pip` on the previous line
```

Then install the pre-commit hook:

```bash
pip install pre-commit
pre-commit install
```

These hooks use ruff to lint and format the code, and pyright to type-check it.

---

**If you're making changes to the code:**
Expand All @@ -34,8 +26,8 @@ Now make your changes. Make sure to include additional tests if necessary.
Next verify the tests all pass:

```bash
pip install -r test/requirements.txt
pytest
pip install -e '.[tests]'
pytest # `pytest` is installed by `pip` on the previous line.
```

Then push your changes back to your fork of the repository:
Expand Down
3 changes: 1 addition & 2 deletions diffrax/_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,7 @@ def loop(
if is_unsafe_sde(terms):
kind = "lax"
msg = (
"Cannot reverse-mode autodifferentiate when using "
"`UnsafeBrownianPath`."
"Cannot reverse-mode autodifferentiate when using `UnsafeBrownianPath`."
)
elif max_steps is None:
kind = "lax"
Expand Down
6 changes: 3 additions & 3 deletions diffrax/_brownian/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ class UnsafeBrownianPath(AbstractBrownianPath):
"""

shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)
levy_area: type[
BrownianIncrement | SpaceTimeLevyArea | SpaceTimeTimeLevyArea
] = eqx.field(static=True)
levy_area: type[BrownianIncrement | SpaceTimeLevyArea | SpaceTimeTimeLevyArea] = (
eqx.field(static=True)
)
key: PRNGKeyArray

def __init__(
Expand Down
6 changes: 3 additions & 3 deletions diffrax/_brownian/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,9 @@ class VirtualBrownianTree(AbstractBrownianPath):
t1: RealScalarLike
tol: RealScalarLike
shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)
levy_area: type[
BrownianIncrement | SpaceTimeLevyArea | SpaceTimeTimeLevyArea
] = eqx.field(static=True)
levy_area: type[BrownianIncrement | SpaceTimeLevyArea | SpaceTimeTimeLevyArea] = (
eqx.field(static=True)
)
key: PyTree[PRNGKeyArray]
_spline: _Spline = eqx.field(static=True)

Expand Down
6 changes: 5 additions & 1 deletion diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import Callable
from typing import ( # noqa: UP035
Any,
cast,
get_args,
get_origin,
Tuple,
Expand Down Expand Up @@ -1164,7 +1165,10 @@ def _wrap(term):
def _get_tols(x):
outs = []
for attr in ("rtol", "atol", "norm"):
if getattr(solver.root_finder, attr) is use_stepsize_tol:
if (
getattr(cast(AbstractImplicitSolver, solver).root_finder, attr)
is use_stepsize_tol
):
outs.append(getattr(x, attr))
return tuple(outs)

Expand Down
10 changes: 5 additions & 5 deletions diffrax/_solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ._path import AbstractPath


class RESULTS(optx.RESULTS): # pyright: ignore
class RESULTS(optx.RESULTS): # pyright: ignore[reportGeneralTypeIssues]
successful = ""
max_steps_reached = (
"The maximum number of solver steps was reached. Try increasing `max_steps`."
Expand Down Expand Up @@ -121,8 +121,8 @@ class Solution(AbstractPath):
# the structure of `subs`.
# SaveAt(fn=...) means that `ys` will then follow with arbitrary sub-dependent
# PyTree structures.
ts: PyTree[Real[Array, " ?times"], " S"] | None
ys: PyTree[Shaped[Array, "?times ?*shape"], "S ..."] | None
ts: PyTree[Real[Array, " ?times"], " S"] | None # pyright: ignore[reportUndefinedVariable]
ys: PyTree[Shaped[Array, "?times ?*shape"], "S ..."] | None # pyright: ignore
interpolation: DenseInterpolation | None
stats: dict[str, Any]
result: RESULTS
Expand All @@ -133,7 +133,7 @@ class Solution(AbstractPath):

def evaluate(
self, t0: RealScalarLike, t1: RealScalarLike | None = None, left: bool = True
) -> PyTree[Shaped[Array, "?*shape"], " Y"]:
) -> PyTree[Shaped[Array, "?*shape"], " Y"]: # pyright: ignore[reportUndefinedVariable]
"""If dense output was saved, then evaluate the solution at any point in the
region of integration `self.t0` to `self.t1`.

Expand All @@ -153,7 +153,7 @@ def evaluate(

def derivative(
self, t: RealScalarLike, left: bool = True
) -> PyTree[Shaped[Array, "?*shape"], " Y"]:
) -> PyTree[Shaped[Array, "?*shape"], " Y"]: # pyright: ignore[reportUndefinedVariable]
r"""If dense output was saved, then calculate an **approximation** to the
derivative of the solution at any point in the region of integration `self.t0`
to `self.t1`.
Expand Down
6 changes: 3 additions & 3 deletions diffrax/_solver/dopri5.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ class Dopri5(AbstractERK):
"""

tableau: ClassVar[ButcherTableau] = _dopri5_tableau
interpolation_cls: ClassVar[
Callable[..., _Dopri5Interpolation]
] = _Dopri5Interpolation
interpolation_cls: ClassVar[Callable[..., _Dopri5Interpolation]] = (
_Dopri5Interpolation
)

def order(self, terms):
del terms
Expand Down
6 changes: 3 additions & 3 deletions diffrax/_solver/dopri8.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,9 +340,9 @@ class Dopri8(AbstractERK):
"""

tableau: ClassVar[ButcherTableau] = _dopri8_tableau
interpolation_cls: ClassVar[
Callable[..., _Dopri8Interpolation]
] = _Dopri8Interpolation
interpolation_cls: ClassVar[Callable[..., _Dopri8Interpolation]] = (
_Dopri8Interpolation
)

def order(self, terms):
del terms
Expand Down
6 changes: 3 additions & 3 deletions diffrax/_solver/euler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ class Euler(AbstractItoSolver):
"""

term_structure: ClassVar = AbstractTerm
interpolation_cls: ClassVar[
Callable[..., LocalLinearInterpolation]
] = LocalLinearInterpolation
interpolation_cls: ClassVar[Callable[..., LocalLinearInterpolation]] = (
LocalLinearInterpolation
)

def order(self, terms):
return 1
Expand Down
6 changes: 3 additions & 3 deletions diffrax/_solver/euler_heun.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ class EulerHeun(AbstractStratonovichSolver):
term_structure: ClassVar = MultiTerm[
tuple[AbstractTerm[Any, RealScalarLike], AbstractTerm]
]
interpolation_cls: ClassVar[
Callable[..., LocalLinearInterpolation]
] = LocalLinearInterpolation
interpolation_cls: ClassVar[Callable[..., LocalLinearInterpolation]] = (
LocalLinearInterpolation
)

def order(self, terms):
return 1
Expand Down
6 changes: 3 additions & 3 deletions diffrax/_solver/foster_langevin_srk.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,9 @@ def _choose(tay_leaf, direct_leaf):
if inner is sentinel:
inner = jtu.tree_structure(out)
else:
assert (
jtu.tree_structure(out) == inner
), f"Expected {inner}, got {jtu.tree_structure(out)}"
assert jtu.tree_structure(out) == inner, (
f"Expected {inner}, got {jtu.tree_structure(out)}"
)

return out

Expand Down
6 changes: 3 additions & 3 deletions diffrax/_solver/implicit_euler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ class ImplicitEuler(AbstractImplicitSolver, AbstractAdaptiveSolver):
#
# We don't use it as this seems to be quite a bad choice for low-order solvers: it
# produces very oscillatory interpolations.
interpolation_cls: ClassVar[
Callable[..., LocalLinearInterpolation]
] = LocalLinearInterpolation
interpolation_cls: ClassVar[Callable[..., LocalLinearInterpolation]] = (
LocalLinearInterpolation
)

root_finder: optx.AbstractRootFinder = with_stepsize_controller_tols(optx.Chord)()
root_find_max_steps: int = 10
Expand Down
6 changes: 3 additions & 3 deletions diffrax/_solver/kencarp3.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,9 @@ class KenCarp3(AbstractRungeKutta, AbstractImplicitSolver):
_explicit_tableau, _implicit_tableau
)
calculate_jacobian: ClassVar[CalculateJacobian] = CalculateJacobian.second_stage
interpolation_cls: ClassVar[
Callable[..., _KenCarp3Interpolation]
] = _KenCarp3Interpolation
interpolation_cls: ClassVar[Callable[..., _KenCarp3Interpolation]] = (
_KenCarp3Interpolation
)

root_finder: optx.AbstractRootFinder = with_stepsize_controller_tols(VeryChord)()
root_find_max_steps: int = 10
Expand Down
6 changes: 3 additions & 3 deletions diffrax/_solver/kencarp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,9 @@ class KenCarp4(AbstractRungeKutta, AbstractImplicitSolver):
_explicit_tableau, _implicit_tableau
)
calculate_jacobian: ClassVar[CalculateJacobian] = CalculateJacobian.second_stage
interpolation_cls: ClassVar[
Callable[..., _KenCarp4Interpolation]
] = _KenCarp4Interpolation
interpolation_cls: ClassVar[Callable[..., _KenCarp4Interpolation]] = (
_KenCarp4Interpolation
)

root_finder: optx.AbstractRootFinder = with_stepsize_controller_tols(VeryChord)()
root_find_max_steps: int = 10
Expand Down
6 changes: 3 additions & 3 deletions diffrax/_solver/kencarp5.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,9 @@ class KenCarp5(AbstractRungeKutta, AbstractImplicitSolver):
_explicit_tableau, _implicit_tableau
)
calculate_jacobian: ClassVar[CalculateJacobian] = CalculateJacobian.second_stage
interpolation_cls: ClassVar[
Callable[..., _KenCarp5Interpolation]
] = _KenCarp5Interpolation
interpolation_cls: ClassVar[Callable[..., _KenCarp5Interpolation]] = (
_KenCarp5Interpolation
)

root_finder: optx.AbstractRootFinder = with_stepsize_controller_tols(VeryChord)()
root_find_max_steps: int = 10
Expand Down
6 changes: 3 additions & 3 deletions diffrax/_solver/leapfrog_midpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ class LeapfrogMidpoint(AbstractSolver):
"""

term_structure: ClassVar = AbstractTerm
interpolation_cls: ClassVar[
Callable[..., LocalLinearInterpolation]
] = LocalLinearInterpolation
interpolation_cls: ClassVar[Callable[..., LocalLinearInterpolation]] = (
LocalLinearInterpolation
)

def order(self, terms):
return 2
Expand Down
12 changes: 6 additions & 6 deletions diffrax/_solver/milstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ class StratonovichMilstein(AbstractStratonovichSolver):
term_structure: ClassVar = MultiTerm[
tuple[AbstractTerm[Any, RealScalarLike], AbstractTerm]
]
interpolation_cls: ClassVar[
Callable[..., LocalLinearInterpolation]
] = LocalLinearInterpolation
interpolation_cls: ClassVar[Callable[..., LocalLinearInterpolation]] = (
LocalLinearInterpolation
)

def order(self, terms):
raise ValueError("`StratonovichMilstein` should not be used to solve ODEs.")
Expand Down Expand Up @@ -123,9 +123,9 @@ class ItoMilstein(AbstractItoSolver):
term_structure: ClassVar = MultiTerm[
tuple[AbstractTerm[Any, RealScalarLike], AbstractTerm]
]
interpolation_cls: ClassVar[
Callable[..., LocalLinearInterpolation]
] = LocalLinearInterpolation
interpolation_cls: ClassVar[Callable[..., LocalLinearInterpolation]] = (
LocalLinearInterpolation
)

def order(self, terms):
raise ValueError("`ItoMilstein` should not be used to solve ODEs.")
Expand Down
6 changes: 3 additions & 3 deletions diffrax/_solver/reversible_heun.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ class ReversibleHeun(AbstractAdaptiveSolver, AbstractStratonovichSolver):
"""

term_structure: ClassVar = AbstractTerm
interpolation_cls: ClassVar[
Callable[..., LocalLinearInterpolation]
] = LocalLinearInterpolation # TODO use something better than this?
interpolation_cls: ClassVar[Callable[..., LocalLinearInterpolation]] = (
LocalLinearInterpolation # TODO use something better than this?
)

def order(self, terms):
return 2
Expand Down
7 changes: 6 additions & 1 deletion diffrax/_solver/runge_kutta.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,11 @@ class AbstractRungeKutta(AbstractAdaptiveSolver[_SolverState]):
tableau: AbstractClassVar[ButcherTableau | MultiButcherTableau]
calculate_jacobian: AbstractClassVar[CalculateJacobian]

if TYPE_CHECKING:
# Pretend that we're implicit
root_finder: ClassVar[optx.AbstractRootFinder]
root_find_max_steps: ClassVar[int]

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if hasattr(cls, "tableau"): # Abstract subclasses may not have a tableau
Expand Down Expand Up @@ -804,7 +809,7 @@ def embed_c(tab):
)
implicit_predictor = np.zeros(
(num_stages, num_stages),
dtype=np.result_type(*implicit_tableau.a_predictor),
dtype=np.result_type(*cast(tuple, implicit_tableau.a_predictor)),
)
for i, a_predictor_i in enumerate(implicit_tableau.a_predictor): # pyright: ignore
implicit_predictor[i + 1, : i + 1] = a_predictor_i
Expand Down
Loading