diff --git a/pixi.lock b/pixi.lock index b95dee59..e9eac902 100644 --- a/pixi.lock +++ b/pixi.lock @@ -5296,7 +5296,7 @@ packages: - pypi: ./ name: array-api-extra version: 0.10.0.dev0 - sha256: b4608a433d1cc449ea43a8d07ba73364e7ef5b49e402dc424e852445bdc785c0 + sha256: b70a1b46fb858c21713d8ee451f8e8a7ffee05e75416ecdd388e41a9c6cd77f1 requires_dist: - array-api-compat>=1.12.0,<2 requires_python: '>=3.11' diff --git a/pyproject.toml b/pyproject.toml index e22b8d5e..ec070e3a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -155,13 +155,13 @@ sparse = ">=0.17.0" [tool.pixi.feature.backends.target.linux-64.dependencies] # On CPU use >=0.7.0 # On GPU, use 0.6.0 (0.6.2 and 0.7.0 both segfault); see jaxlib pin below. -jax = ">=0.6.0" +jax = ">=0.7.2" [tool.pixi.feature.backends.target.osx-64.dependencies] -jax = ">=0.6.0" +jax = ">=0.7.2" [tool.pixi.feature.backends.target.osx-arm64.dependencies] -jax = ">=0.6.0" +jax = ">=0.7.2" [tool.pixi.feature.backends.target.win-64.dependencies] # jax = "*" # unavailable @@ -178,7 +178,7 @@ system-requirements = { cuda = "12" } [tool.pixi.feature.cuda-backends.target.linux-64.dependencies] cupy = ">=13.6.0" # JAX 0.6.2 and 0.7.0 segfault on CUDA -jaxlib = { version = ">=0.6.0,!=0.6.2,!=0.7.0", build = "cuda12*" } +jaxlib = { version = ">=0.7.2,!=0.6.2,!=0.7.0", build = "cuda12*" } pytorch = { version = ">=2.7.1", build = "cuda12*" } [tool.pixi.feature.cuda-backends.target.osx-64.dependencies]