diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 861187e0..82cc5c2f 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -10,14 +10,14 @@ jobs: runs-on: ubuntu-latest steps: - name: Release - uses: patrick-kidger/action_update_python_project@v6 + uses: patrick-kidger/action_update_python_project@v8 with: python-version: "3.11" test-script: | cp -r ${{ github.workspace }}/test ./test cp ${{ github.workspace }}/pyproject.toml ./pyproject.toml - python -m pip install -r ./test/requirements.txt - python -m test + uv sync --extra tests --no-install-project --inexact + uv run --no-sync pytest pypi-token: ${{ secrets.pypi_token }} github-user: patrick-kidger github-token: ${{ github.token }} diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index b209bb3d..0137d757 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -23,13 +23,16 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install -r ./test/requirements.txt - + python -m pip install '.[dev,docs,tests]' - name: Checks with pre-commit - uses: pre-commit/action@v3.0.1 + run: | + pre-commit run --all-files - name: Test with pytest run: | - python -m pip install . python -m test + + - name: Check that documentation can be built. + run: | + mkdocs build diff --git a/.gitignore b/.gitignore index daf54d8c..176aa2af 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ site/ .pymon .idea/ .venv/ +uv.lock diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7f8df8ce..29b072c4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,15 +8,15 @@ repos: files: ^pyproject\.toml$ additional_dependencies: ["toml-sort==0.23.1"] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.2.2 + rev: v0.13.0 hooks: - id: ruff-format # formatter - types_or: [ python, pyi, jupyter ] + types_or: [ python, pyi, jupyter, toml ] - id: ruff # linter - types_or: [ python, pyi, jupyter ] + types_or: [ python, pyi, jupyter, toml ] args: [ --fix ] - repo: https://github.com/RobertCraigie/pyright-python - rev: v1.1.350 + rev: v1.1.405 hooks: - id: pyright additional_dependencies: [equinox, jax, jaxtyping, optax, optimistix, lineax, pytest, typeguard==2.13.3, typing_extensions, wadler_lindig] diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 1c9b3ced..78f188a3 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -8,23 +8,15 @@ Contributions (pull requests) are very welcome! Here's how to get started. First fork the library on GitHub. -Then clone and install the library in development mode: +Then clone and install the library: ```bash git clone https://github.com/your-username-here/diffrax.git cd diffrax -pip install -e . +pip install -e '.[dev]' +pre-commit install # `pre-commit` is installed by `pip` on the previous line ``` -Then install the pre-commit hook: - -```bash -pip install pre-commit -pre-commit install -``` - -These hooks use ruff to lint and format the code, and pyright to type-check it. - --- **If you're making changes to the code:** @@ -34,8 +26,8 @@ Now make your changes. Make sure to include additional tests if necessary. Next verify the tests all pass: ```bash -pip install -r test/requirements.txt -pytest +pip install -e '.[tests]' +pytest # `pytest` is installed by `pip` on the previous line. ``` Then push your changes back to your fork of the repository: diff --git a/diffrax/_adjoint.py b/diffrax/_adjoint.py index 69a59b7a..7bc081b9 100644 --- a/diffrax/_adjoint.py +++ b/diffrax/_adjoint.py @@ -362,8 +362,7 @@ def loop( if is_unsafe_sde(terms): kind = "lax" msg = ( - "Cannot reverse-mode autodifferentiate when using " - "`UnsafeBrownianPath`." + "Cannot reverse-mode autodifferentiate when using `UnsafeBrownianPath`." ) elif max_steps is None: kind = "lax" diff --git a/diffrax/_brownian/path.py b/diffrax/_brownian/path.py index f97eebf1..61f49644 100644 --- a/diffrax/_brownian/path.py +++ b/diffrax/_brownian/path.py @@ -62,9 +62,9 @@ class UnsafeBrownianPath(AbstractBrownianPath): """ shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True) - levy_area: type[ - BrownianIncrement | SpaceTimeLevyArea | SpaceTimeTimeLevyArea - ] = eqx.field(static=True) + levy_area: type[BrownianIncrement | SpaceTimeLevyArea | SpaceTimeTimeLevyArea] = ( + eqx.field(static=True) + ) key: PRNGKeyArray def __init__( diff --git a/diffrax/_brownian/tree.py b/diffrax/_brownian/tree.py index fd0ede84..8a430668 100644 --- a/diffrax/_brownian/tree.py +++ b/diffrax/_brownian/tree.py @@ -235,9 +235,9 @@ class VirtualBrownianTree(AbstractBrownianPath): t1: RealScalarLike tol: RealScalarLike shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True) - levy_area: type[ - BrownianIncrement | SpaceTimeLevyArea | SpaceTimeTimeLevyArea - ] = eqx.field(static=True) + levy_area: type[BrownianIncrement | SpaceTimeLevyArea | SpaceTimeTimeLevyArea] = ( + eqx.field(static=True) + ) key: PyTree[PRNGKeyArray] _spline: _Spline = eqx.field(static=True) diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index 1855d4ae..6fc38ce3 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -3,6 +3,7 @@ from collections.abc import Callable from typing import ( # noqa: UP035 Any, + cast, get_args, get_origin, Tuple, @@ -1164,7 +1165,10 @@ def _wrap(term): def _get_tols(x): outs = [] for attr in ("rtol", "atol", "norm"): - if getattr(solver.root_finder, attr) is use_stepsize_tol: + if ( + getattr(cast(AbstractImplicitSolver, solver).root_finder, attr) + is use_stepsize_tol + ): outs.append(getattr(x, attr)) return tuple(outs) diff --git a/diffrax/_solution.py b/diffrax/_solution.py index 3abe2725..392c447e 100644 --- a/diffrax/_solution.py +++ b/diffrax/_solution.py @@ -10,7 +10,7 @@ from ._path import AbstractPath -class RESULTS(optx.RESULTS): # pyright: ignore +class RESULTS(optx.RESULTS): # pyright: ignore[reportGeneralTypeIssues] successful = "" max_steps_reached = ( "The maximum number of solver steps was reached. Try increasing `max_steps`." @@ -121,8 +121,8 @@ class Solution(AbstractPath): # the structure of `subs`. # SaveAt(fn=...) means that `ys` will then follow with arbitrary sub-dependent # PyTree structures. - ts: PyTree[Real[Array, " ?times"], " S"] | None - ys: PyTree[Shaped[Array, "?times ?*shape"], "S ..."] | None + ts: PyTree[Real[Array, " ?times"], " S"] | None # pyright: ignore[reportUndefinedVariable] + ys: PyTree[Shaped[Array, "?times ?*shape"], "S ..."] | None # pyright: ignore interpolation: DenseInterpolation | None stats: dict[str, Any] result: RESULTS @@ -133,7 +133,7 @@ class Solution(AbstractPath): def evaluate( self, t0: RealScalarLike, t1: RealScalarLike | None = None, left: bool = True - ) -> PyTree[Shaped[Array, "?*shape"], " Y"]: + ) -> PyTree[Shaped[Array, "?*shape"], " Y"]: # pyright: ignore[reportUndefinedVariable] """If dense output was saved, then evaluate the solution at any point in the region of integration `self.t0` to `self.t1`. @@ -153,7 +153,7 @@ def evaluate( def derivative( self, t: RealScalarLike, left: bool = True - ) -> PyTree[Shaped[Array, "?*shape"], " Y"]: + ) -> PyTree[Shaped[Array, "?*shape"], " Y"]: # pyright: ignore[reportUndefinedVariable] r"""If dense output was saved, then calculate an **approximation** to the derivative of the solution at any point in the region of integration `self.t0` to `self.t1`. diff --git a/diffrax/_solver/dopri5.py b/diffrax/_solver/dopri5.py index 4a3cedfe..325f717f 100644 --- a/diffrax/_solver/dopri5.py +++ b/diffrax/_solver/dopri5.py @@ -91,9 +91,9 @@ class Dopri5(AbstractERK): """ tableau: ClassVar[ButcherTableau] = _dopri5_tableau - interpolation_cls: ClassVar[ - Callable[..., _Dopri5Interpolation] - ] = _Dopri5Interpolation + interpolation_cls: ClassVar[Callable[..., _Dopri5Interpolation]] = ( + _Dopri5Interpolation + ) def order(self, terms): del terms diff --git a/diffrax/_solver/dopri8.py b/diffrax/_solver/dopri8.py index 4801eccc..958d8819 100644 --- a/diffrax/_solver/dopri8.py +++ b/diffrax/_solver/dopri8.py @@ -340,9 +340,9 @@ class Dopri8(AbstractERK): """ tableau: ClassVar[ButcherTableau] = _dopri8_tableau - interpolation_cls: ClassVar[ - Callable[..., _Dopri8Interpolation] - ] = _Dopri8Interpolation + interpolation_cls: ClassVar[Callable[..., _Dopri8Interpolation]] = ( + _Dopri8Interpolation + ) def order(self, terms): del terms diff --git a/diffrax/_solver/euler.py b/diffrax/_solver/euler.py index 7ed11381..b1a323f7 100644 --- a/diffrax/_solver/euler.py +++ b/diffrax/_solver/euler.py @@ -24,9 +24,9 @@ class Euler(AbstractItoSolver): """ term_structure: ClassVar = AbstractTerm - interpolation_cls: ClassVar[ - Callable[..., LocalLinearInterpolation] - ] = LocalLinearInterpolation + interpolation_cls: ClassVar[Callable[..., LocalLinearInterpolation]] = ( + LocalLinearInterpolation + ) def order(self, terms): return 1 diff --git a/diffrax/_solver/euler_heun.py b/diffrax/_solver/euler_heun.py index 4940cfa5..97fbaee2 100644 --- a/diffrax/_solver/euler_heun.py +++ b/diffrax/_solver/euler_heun.py @@ -29,9 +29,9 @@ class EulerHeun(AbstractStratonovichSolver): term_structure: ClassVar = MultiTerm[ tuple[AbstractTerm[Any, RealScalarLike], AbstractTerm] ] - interpolation_cls: ClassVar[ - Callable[..., LocalLinearInterpolation] - ] = LocalLinearInterpolation + interpolation_cls: ClassVar[Callable[..., LocalLinearInterpolation]] = ( + LocalLinearInterpolation + ) def order(self, terms): return 1 diff --git a/diffrax/_solver/foster_langevin_srk.py b/diffrax/_solver/foster_langevin_srk.py index 717477d5..ea1026fa 100644 --- a/diffrax/_solver/foster_langevin_srk.py +++ b/diffrax/_solver/foster_langevin_srk.py @@ -231,9 +231,9 @@ def _choose(tay_leaf, direct_leaf): if inner is sentinel: inner = jtu.tree_structure(out) else: - assert ( - jtu.tree_structure(out) == inner - ), f"Expected {inner}, got {jtu.tree_structure(out)}" + assert jtu.tree_structure(out) == inner, ( + f"Expected {inner}, got {jtu.tree_structure(out)}" + ) return out diff --git a/diffrax/_solver/implicit_euler.py b/diffrax/_solver/implicit_euler.py index 064209da..68477c78 100644 --- a/diffrax/_solver/implicit_euler.py +++ b/diffrax/_solver/implicit_euler.py @@ -35,9 +35,9 @@ class ImplicitEuler(AbstractImplicitSolver, AbstractAdaptiveSolver): # # We don't use it as this seems to be quite a bad choice for low-order solvers: it # produces very oscillatory interpolations. - interpolation_cls: ClassVar[ - Callable[..., LocalLinearInterpolation] - ] = LocalLinearInterpolation + interpolation_cls: ClassVar[Callable[..., LocalLinearInterpolation]] = ( + LocalLinearInterpolation + ) root_finder: optx.AbstractRootFinder = with_stepsize_controller_tols(optx.Chord)() root_find_max_steps: int = 10 diff --git a/diffrax/_solver/kencarp3.py b/diffrax/_solver/kencarp3.py index f15c7c03..cda04489 100644 --- a/diffrax/_solver/kencarp3.py +++ b/diffrax/_solver/kencarp3.py @@ -163,9 +163,9 @@ class KenCarp3(AbstractRungeKutta, AbstractImplicitSolver): _explicit_tableau, _implicit_tableau ) calculate_jacobian: ClassVar[CalculateJacobian] = CalculateJacobian.second_stage - interpolation_cls: ClassVar[ - Callable[..., _KenCarp3Interpolation] - ] = _KenCarp3Interpolation + interpolation_cls: ClassVar[Callable[..., _KenCarp3Interpolation]] = ( + _KenCarp3Interpolation + ) root_finder: optx.AbstractRootFinder = with_stepsize_controller_tols(VeryChord)() root_find_max_steps: int = 10 diff --git a/diffrax/_solver/kencarp4.py b/diffrax/_solver/kencarp4.py index 4764900f..752cf38c 100644 --- a/diffrax/_solver/kencarp4.py +++ b/diffrax/_solver/kencarp4.py @@ -166,9 +166,9 @@ class KenCarp4(AbstractRungeKutta, AbstractImplicitSolver): _explicit_tableau, _implicit_tableau ) calculate_jacobian: ClassVar[CalculateJacobian] = CalculateJacobian.second_stage - interpolation_cls: ClassVar[ - Callable[..., _KenCarp4Interpolation] - ] = _KenCarp4Interpolation + interpolation_cls: ClassVar[Callable[..., _KenCarp4Interpolation]] = ( + _KenCarp4Interpolation + ) root_finder: optx.AbstractRootFinder = with_stepsize_controller_tols(VeryChord)() root_find_max_steps: int = 10 diff --git a/diffrax/_solver/kencarp5.py b/diffrax/_solver/kencarp5.py index ba9af78c..b5b0f213 100644 --- a/diffrax/_solver/kencarp5.py +++ b/diffrax/_solver/kencarp5.py @@ -233,9 +233,9 @@ class KenCarp5(AbstractRungeKutta, AbstractImplicitSolver): _explicit_tableau, _implicit_tableau ) calculate_jacobian: ClassVar[CalculateJacobian] = CalculateJacobian.second_stage - interpolation_cls: ClassVar[ - Callable[..., _KenCarp5Interpolation] - ] = _KenCarp5Interpolation + interpolation_cls: ClassVar[Callable[..., _KenCarp5Interpolation]] = ( + _KenCarp5Interpolation + ) root_finder: optx.AbstractRootFinder = with_stepsize_controller_tols(VeryChord)() root_find_max_steps: int = 10 diff --git a/diffrax/_solver/leapfrog_midpoint.py b/diffrax/_solver/leapfrog_midpoint.py index 76487dfd..ddcaa12e 100644 --- a/diffrax/_solver/leapfrog_midpoint.py +++ b/diffrax/_solver/leapfrog_midpoint.py @@ -44,9 +44,9 @@ class LeapfrogMidpoint(AbstractSolver): """ term_structure: ClassVar = AbstractTerm - interpolation_cls: ClassVar[ - Callable[..., LocalLinearInterpolation] - ] = LocalLinearInterpolation + interpolation_cls: ClassVar[Callable[..., LocalLinearInterpolation]] = ( + LocalLinearInterpolation + ) def order(self, terms): return 2 diff --git a/diffrax/_solver/milstein.py b/diffrax/_solver/milstein.py index 69893716..897bc9c1 100644 --- a/diffrax/_solver/milstein.py +++ b/diffrax/_solver/milstein.py @@ -44,9 +44,9 @@ class StratonovichMilstein(AbstractStratonovichSolver): term_structure: ClassVar = MultiTerm[ tuple[AbstractTerm[Any, RealScalarLike], AbstractTerm] ] - interpolation_cls: ClassVar[ - Callable[..., LocalLinearInterpolation] - ] = LocalLinearInterpolation + interpolation_cls: ClassVar[Callable[..., LocalLinearInterpolation]] = ( + LocalLinearInterpolation + ) def order(self, terms): raise ValueError("`StratonovichMilstein` should not be used to solve ODEs.") @@ -123,9 +123,9 @@ class ItoMilstein(AbstractItoSolver): term_structure: ClassVar = MultiTerm[ tuple[AbstractTerm[Any, RealScalarLike], AbstractTerm] ] - interpolation_cls: ClassVar[ - Callable[..., LocalLinearInterpolation] - ] = LocalLinearInterpolation + interpolation_cls: ClassVar[Callable[..., LocalLinearInterpolation]] = ( + LocalLinearInterpolation + ) def order(self, terms): raise ValueError("`ItoMilstein` should not be used to solve ODEs.") diff --git a/diffrax/_solver/reversible_heun.py b/diffrax/_solver/reversible_heun.py index 9288d00d..91617d4f 100644 --- a/diffrax/_solver/reversible_heun.py +++ b/diffrax/_solver/reversible_heun.py @@ -36,9 +36,9 @@ class ReversibleHeun(AbstractAdaptiveSolver, AbstractStratonovichSolver): """ term_structure: ClassVar = AbstractTerm - interpolation_cls: ClassVar[ - Callable[..., LocalLinearInterpolation] - ] = LocalLinearInterpolation # TODO use something better than this? + interpolation_cls: ClassVar[Callable[..., LocalLinearInterpolation]] = ( + LocalLinearInterpolation # TODO use something better than this? + ) def order(self, terms): return 2 diff --git a/diffrax/_solver/runge_kutta.py b/diffrax/_solver/runge_kutta.py index f46d2045..9473ab44 100644 --- a/diffrax/_solver/runge_kutta.py +++ b/diffrax/_solver/runge_kutta.py @@ -358,6 +358,11 @@ class AbstractRungeKutta(AbstractAdaptiveSolver[_SolverState]): tableau: AbstractClassVar[ButcherTableau | MultiButcherTableau] calculate_jacobian: AbstractClassVar[CalculateJacobian] + if TYPE_CHECKING: + # Pretend that we're implicit + root_finder: ClassVar[optx.AbstractRootFinder] + root_find_max_steps: ClassVar[int] + def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) if hasattr(cls, "tableau"): # Abstract subclasses may not have a tableau @@ -804,7 +809,7 @@ def embed_c(tab): ) implicit_predictor = np.zeros( (num_stages, num_stages), - dtype=np.result_type(*implicit_tableau.a_predictor), + dtype=np.result_type(*cast(tuple, implicit_tableau.a_predictor)), ) for i, a_predictor_i in enumerate(implicit_tableau.a_predictor): # pyright: ignore implicit_predictor[i + 1, : i + 1] = a_predictor_i diff --git a/diffrax/_solver/semi_implicit_euler.py b/diffrax/_solver/semi_implicit_euler.py index 9e8a92ed..34122fe3 100644 --- a/diffrax/_solver/semi_implicit_euler.py +++ b/diffrax/_solver/semi_implicit_euler.py @@ -14,8 +14,8 @@ _ErrorEstimate: TypeAlias = None _SolverState: TypeAlias = None -Ya: TypeAlias = PyTree[Float[ArrayLike, "?*y"], " Y"] -Yb: TypeAlias = PyTree[Float[ArrayLike, "?*y"], " Y"] +Ya: TypeAlias = PyTree[Float[ArrayLike, "?*y"], " Y"] # pyright: ignore[reportUndefinedVariable] +Yb: TypeAlias = PyTree[Float[ArrayLike, "?*y"], " Y"] # pyright: ignore[reportUndefinedVariable] class SemiImplicitEuler(AbstractSolver): @@ -26,9 +26,9 @@ class SemiImplicitEuler(AbstractSolver): """ term_structure: ClassVar = (AbstractTerm, AbstractTerm) - interpolation_cls: ClassVar[ - Callable[..., LocalLinearInterpolation] - ] = LocalLinearInterpolation + interpolation_cls: ClassVar[Callable[..., LocalLinearInterpolation]] = ( + LocalLinearInterpolation + ) def order(self, terms): return 1 diff --git a/diffrax/_solver/srk.py b/diffrax/_solver/srk.py index 1af9579f..ddb03095 100644 --- a/diffrax/_solver/srk.py +++ b/diffrax/_solver/srk.py @@ -54,8 +54,7 @@ class AbstractStochasticCoeffs(eqx.Module): b_error: eqx.AbstractVar[Float[np.ndarray, " s"] | None] @abc.abstractmethod - def check(self) -> int: - ... + def check(self) -> int: ... class AdditiveCoeffs(AbstractStochasticCoeffs): diff --git a/diffrax/_solver/tsit5.py b/diffrax/_solver/tsit5.py index 3060088a..7dc7a14f 100644 --- a/diffrax/_solver/tsit5.py +++ b/diffrax/_solver/tsit5.py @@ -181,9 +181,9 @@ class Tsit5(AbstractERK): """ tableau: ClassVar[ButcherTableau] = _tsit5_tableau - interpolation_cls: ClassVar[ - Callable[..., _Tsit5Interpolation] - ] = _Tsit5Interpolation + interpolation_cls: ClassVar[Callable[..., _Tsit5Interpolation]] = ( + _Tsit5Interpolation + ) def order(self, terms): return 5 diff --git a/diffrax/_step_size_controller/clip.py b/diffrax/_step_size_controller/clip.py index c8a5cb76..0a642d6e 100644 --- a/diffrax/_step_size_controller/clip.py +++ b/diffrax/_step_size_controller/clip.py @@ -220,8 +220,7 @@ def __init__( self.jump_ts = _none_or_sorted_array(jump_ts) if (store_rejected_steps is not None) and (store_rejected_steps <= 0): raise ValueError( - "`store_rejected_steps must either be `None`" - " or a non-negative integer." + "`store_rejected_steps must either be `None` or a non-negative integer." ) self.store_rejected_steps = store_rejected_steps self.callback_on_reject = _callback_on_reject diff --git a/diffrax/_step_size_controller/pid.py b/diffrax/_step_size_controller/pid.py index 7fb034fb..1092f184 100644 --- a/diffrax/_step_size_controller/pid.py +++ b/diffrax/_step_size_controller/pid.py @@ -89,7 +89,7 @@ def intermediate(carry): # PIDController(... step_ts=s, jump_ts=j) this should return a # ClipStepSizeController(PIDController(...), s, j). class _MetaPID(type(eqx.Module)): - def __call__(cls, *args, **kwargs): # pyright: ignore[reportSelfClsParameterName] + def __call__(cls, *args, **kwargs): step_ts = kwargs.pop("step_ts", None) jump_ts = kwargs.pop("jump_ts", None) if step_ts is not None or jump_ts is not None: diff --git a/diffrax/_term.py b/diffrax/_term.py index 41f7af09..9b0c0314 100644 --- a/diffrax/_term.py +++ b/diffrax/_term.py @@ -3,7 +3,7 @@ import typing import warnings from collections.abc import Callable -from typing import cast, Generic, TypeAlias, TypeVar +from typing import Any, cast, Generic, TypeAlias, TypeVar import equinox as eqx import jax @@ -12,7 +12,7 @@ import lineax as lx import numpy as np from equinox.internal import ω -from jaxtyping import Array, ArrayLike, PyTree, PyTreeDef, Shaped +from jaxtyping import Array, ArrayLike, PyTree, Shaped from ._brownian import AbstractBrownianPath from ._custom_types import ( @@ -835,7 +835,7 @@ def _fn(_control): jac = make_jac(_fn)(control) assert vf_prod_tree is not sentinel - vf_prod_tree = cast(PyTreeDef, vf_prod_tree) + vf_prod_tree = cast(Any, vf_prod_tree) if jtu.tree_structure(None) in (vf_prod_tree, control_tree): # An unusual/not-useful edge case to handle. raise NotImplementedError( @@ -868,7 +868,7 @@ def _get_vf_tree(_, tree): jtu.tree_map(_get_vf_tree, control, vf) assert vf_prod_tree is not sentinel - vf_prod_tree = cast(PyTreeDef, vf_prod_tree) + vf_prod_tree = cast(Any, vf_prod_tree) vf = jtu.tree_transpose(control_tree, vf_prod_tree, vf) diff --git a/diffrax/_typing.py b/diffrax/_typing.py index 74627847..90a0b7b8 100644 --- a/diffrax/_typing.py +++ b/diffrax/_typing.py @@ -184,7 +184,7 @@ def _get_args_of_impl( if len(params) == 0: error_cls = cls else: - error_cls = cls[params] + error_cls = cls[params] # pyright: ignore[reportIndexIssue] raise TypeError( f"{error_cls} inherits from {base_cls} in multiple incompatible ways." ) diff --git a/docs/devdocs/srk_example.ipynb b/docs/devdocs/srk_example.ipynb index 39364def..319beee2 100644 --- a/docs/devdocs/srk_example.ipynb +++ b/docs/devdocs/srk_example.ipynb @@ -55,11 +55,6 @@ "source": [ "%env JAX_PLATFORM_NAME=cuda\n", "\n", - "from test.helpers import (\n", - " get_mlp_sde,\n", - " get_time_sde,\n", - " simple_sde_order,\n", - ")\n", "from warnings import simplefilter\n", "\n", "import diffrax\n", @@ -76,6 +71,11 @@ " SRA1,\n", ")\n", "from jax import config\n", + "from test.helpers import (\n", + " get_mlp_sde,\n", + " get_time_sde,\n", + " simple_sde_order,\n", + ")\n", "\n", "\n", "simplefilter(\"ignore\", category=FutureWarning)\n", diff --git a/pyproject.toml b/pyproject.toml index e7326e29..c7ec6d27 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,13 +25,15 @@ keywords = ["jax", "dynamical-systems", "differential-equations", "deep-learning license = {file = "LICENSE"} name = "diffrax" readme = "README.md" -requires-python = ">=3.10,<4.0" +requires-python = ">=3.10" urls = {repository = "https://github.com/patrick-kidger/diffrax"} version = "0.7.0" [project.optional-dependencies] +dev = ["pre-commit"] docs = [ "hippogriffe==0.2.2", + "griffe==1.7.3", "mkdocs==1.6.1", "mkdocs-include-exclude-files==0.1.0", "mkdocs-ipynb==0.1.1", @@ -40,6 +42,14 @@ docs = [ "mkdocstrings-python==1.16.8", "pymdown-extensions==10.14.3" ] +tests = [ + "beartype", + "jaxlib", + "optax", + "pytest", + "scipy", + "tqdm" +] [tool.hatch.build] include = ["diffrax/*"] @@ -60,10 +70,14 @@ src = [] [tool.ruff.lint] fixable = ["I001", "F401", "UP"] -ignore = ["E402", "E721", "E731", "E741", "F722", "UP038"] -ignore-init-module-imports = true +ignore = ["E402", "E721", "E731", "E741", "F722"] select = ["E", "F", "I001", "UP"] +[tool.ruff.lint.flake8-import-conventions.extend-aliases] +"collections" = "co" +"functools" = "ft" +"itertools" = "it" + [tool.ruff.lint.isort] combine-as-imports = true extra-standard-library = ["typing_extensions"] diff --git a/test/helpers.py b/test/helpers.py index b6311065..97b0f074 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -95,8 +95,8 @@ def tree_allclose(x, y, *, rtol=1e-5, atol=1e-8, equal_nan=False): def path_l2_dist( - ys1: PyTree[Shaped[Array, "repeats times ?*channels"], " T"], - ys2: PyTree[Shaped[Array, "repeats times ?*channels"], " T"], + ys1: PyTree[Shaped[Array, "repeats times ?*channels"], " T"], # pyright: ignore[reportUndefinedVariable] + ys2: PyTree[Shaped[Array, "repeats times ?*channels"], " T"], # pyright: ignore[reportUndefinedVariable] ): # first compute the square of the difference and sum over # all but the first two axes (which represent the number of samples diff --git a/test/requirements.txt b/test/requirements.txt deleted file mode 100644 index 9de88eb6..00000000 --- a/test/requirements.txt +++ /dev/null @@ -1,6 +0,0 @@ -beartype -jaxlib -optax -pytest -scipy -tqdm diff --git a/test/test_brownian.py b/test/test_brownian.py index 361c761d..d33bbda8 100644 --- a/test/test_brownian.py +++ b/test/test_brownian.py @@ -173,6 +173,7 @@ def _eval(key): else: w = values + assert isinstance(w, jax.Array) assert w.shape == (num_samples,) ref_dist = stats.norm(loc=0, scale=math.sqrt(dt)) _, pval = stats.kstest(w, ref_dist.cdf) diff --git a/test/test_saveat_solution.py b/test/test_saveat_solution.py index 3db4e7ba..7b8b9dfc 100644 --- a/test/test_saveat_solution.py +++ b/test/test_saveat_solution.py @@ -1,5 +1,6 @@ import contextlib import math +from typing import cast import diffrax import equinox as eqx @@ -119,6 +120,7 @@ def test_saveat_solution(): assert sol.ts.shape == (4096,) # pyright: ignore assert sol.ys.shape == (4096, 1) # pyright: ignore _ts = jnp.where(sol.ts == jnp.inf, jnp.nan, sol.ts) + _ts = cast(jax.Array, _ts) with jax.numpy_rank_promotion("allow"): _ys = _y0 * jnp.exp(-0.5 * (_ts - _t0))[:, None] _ys = jnp.where(jnp.isnan(_ys), jnp.inf, _ys) @@ -140,6 +142,7 @@ def test_saveat_solution(): assert sol.ts.shape == (n,) # pyright: ignore assert sol.ys.shape == (n, 1) # pyright: ignore _ts = jnp.where(sol.ts == jnp.inf, jnp.nan, sol.ts) + _ts = cast(jax.Array, _ts) with jax.numpy_rank_promotion("allow"): _ys = _y0 * jnp.exp(-0.5 * (_ts - _t0))[:, None] _ys = jnp.where(jnp.isnan(_ys), jnp.inf, _ys) @@ -161,6 +164,7 @@ def test_saveat_solution(): assert sol.ts.shape == (n,) # pyright: ignore assert sol.ys.shape == (n, 1) # pyright: ignore _ts = jnp.where(sol.ts == jnp.inf, jnp.nan, sol.ts) + _ts = cast(jax.Array, _ts) with jax.numpy_rank_promotion("allow"): _ys = _y0 * jnp.exp(-0.5 * (_ts - _t0))[:, None] _ys = jnp.where(jnp.isnan(_ys), jnp.inf, _ys) diff --git a/test/test_sde1.py b/test/test_sde1.py index b50d014f..ad7318e6 100644 --- a/test/test_sde1.py +++ b/test/test_sde1.py @@ -157,7 +157,8 @@ def test_sde_strong_limit( ts_coarse = jnp.linspace(t0, t1, 2**level_coarse + 1, endpoint=True) contr_fine = diffrax.StepTo(ts=ts_fine) contr_coarse = diffrax.StepTo(ts=ts_coarse) - save_ts = jnp.linspace(t0, t1, 2**5 + 1, endpoint=True) + save_ts = ts_coarse[:: 2 ** (level_coarse - 5)] + assert len(save_ts) == 2**5 + 1 assert len(jnp.intersect1d(ts_fine, save_ts)) == len(save_ts) assert len(jnp.intersect1d(ts_coarse, save_ts)) == len(save_ts) saveat = diffrax.SaveAt(ts=save_ts) diff --git a/test/test_solver.py b/test/test_solver.py index aa618712..a022f644 100644 --- a/test/test_solver.py +++ b/test/test_solver.py @@ -58,9 +58,9 @@ class _DoubleDopri5(diffrax.AbstractRungeKutta): tableau: ClassVar[diffrax.MultiButcherTableau] = diffrax.MultiButcherTableau( diffrax.Dopri5.tableau, diffrax.Dopri5.tableau ) - calculate_jacobian: ClassVar[ - diffrax.CalculateJacobian - ] = diffrax.CalculateJacobian.never + calculate_jacobian: ClassVar[diffrax.CalculateJacobian] = ( + diffrax.CalculateJacobian.never + ) @staticmethod def interpolation_cls(**kwargs): diff --git a/test/test_underdamped_langevin.py b/test/test_underdamped_langevin.py index 246506bb..53140fba 100644 --- a/test/test_underdamped_langevin.py +++ b/test/test_underdamped_langevin.py @@ -100,9 +100,9 @@ def test_shape(solver, dtype): # check that the output has the correct pytree structure and shape def check_shape(y0_leaf, sol_leaf): - assert ( - sol_leaf.shape == (7,) + y0_leaf.shape - ), f"shape={sol_leaf.shape}, expected={(7,) + y0_leaf.shape}" + assert sol_leaf.shape == (7,) + y0_leaf.shape, ( + f"shape={sol_leaf.shape}, expected={(7,) + y0_leaf.shape}" + ) assert sol_leaf.dtype == dtype, f"dtype={sol_leaf.dtype}, expected={dtype}" jtu.tree_map(check_shape, sde.y0, sol.ys) @@ -193,9 +193,9 @@ def get_dt_and_controller(level): ref_solution=true_sol, ) - assert ( - -0.2 < order - theoretical_order < 0.25 - ), f"order={order}, theoretical_order={theoretical_order}" + assert -0.2 < order - theoretical_order < 0.25, ( + f"order={order}, theoretical_order={theoretical_order}" + ) @pytest.mark.parametrize("solver_cls", _only_uld_solvers_cls())