diff --git a/pixi.lock b/pixi.lock index 5017102b..db43c5c9 100644 --- a/pixi.lock +++ b/pixi.lock @@ -5296,7 +5296,7 @@ packages: - pypi: ./ name: array-api-extra version: 0.10.0.dev0 - sha256: aa2a2cd7d3add680efbe7c7bd783710ac6ea512d16d7dc55ef53f6b538b10aea + sha256: b7a2669b3a14d47901d142f6d4d4c40d2f8c41a38e518e26ab2a7161dbbe267e requires_dist: - array-api-compat>=1.12.0,<2 requires_python: '>=3.11' diff --git a/pyproject.toml b/pyproject.toml index 715ef1e7..e9c4169f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -153,15 +153,13 @@ dask-core = ">=2025.12.0" # No distributed, tornado, etc. sparse = ">=0.17.0" [tool.pixi.feature.backends.target.linux-64.dependencies] -# On CPU use >=0.7.0 -# On GPU, use 0.6.0 (0.6.2 and 0.7.0 both segfault); see jaxlib pin below. -jax = ">=0.6.0" +jax = ">=0.7.2" [tool.pixi.feature.backends.target.osx-64.dependencies] -jax = ">=0.6.0" +jax = ">=0.7.2" [tool.pixi.feature.backends.target.osx-arm64.dependencies] -jax = ">=0.6.0" +jax = ">=0.7.2" [tool.pixi.feature.backends.target.win-64.dependencies] # jax = "*" # unavailable @@ -175,23 +173,17 @@ jax = ">=0.6.0" [tool.pixi.feature.cuda-backends] system-requirements = { cuda = "12" } -[tool.pixi.feature.cuda-backends.target.linux-64.dependencies] +[tool.pixi.feature.cuda-backends.target.linux.dependencies] cupy = ">=13.6.0" -# JAX 0.6.2 and 0.7.0 segfault on CUDA -jaxlib = { version = ">=0.6.0,!=0.6.2,!=0.7.0", build = "cuda12*" } +jaxlib = { version = ">=0.7.2", build = "cuda12*" } pytorch = { version = ">=2.9.1", build = "cuda12*" } -[tool.pixi.feature.cuda-backends.target.osx-64.dependencies] +[tool.pixi.feature.cuda-backends.target.osx.dependencies] # cupy = "*" # unavailable # jaxlib = { version = "*", build = "cuda12*" } # unavailable # pytorch = { version = "*", build = "cuda12*" } # unavailable -[tool.pixi.feature.cuda-backends.target.osx-arm64.dependencies] -# cupy = "*" # unavailable -# jaxlib = { version = "*", build = "cuda12*" } # unavailable -# pytorch = { version = "*", build = "cuda12*" } # unavailable - -[tool.pixi.feature.cuda-backends.target.win-64.dependencies] +[tool.pixi.feature.cuda-backends.target.win.dependencies] cupy = ">=13.6.0" # jaxlib = { version = "*", build = "cuda12*" } # unavailable pytorch = { version = ">=2.9.1", build = "cuda12*" } diff --git a/tests/test_funcs.py b/tests/test_funcs.py index ea3ca267..db64fdb6 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -521,6 +521,7 @@ def test_complex(self, xp: ModuleType): expect = xp.asarray([[1.0, -1.0j], [1.0j, 1.0]], dtype=xp.complex128) xp_assert_close(actual, expect) + @pytest.mark.xfail_xp_backend(Backend.JAX_GPU, reason="jax#32296") @pytest.mark.xfail_xp_backend(Backend.JAX, reason="jax#32296") @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="sparse#877") def test_empty(self, xp: ModuleType): @@ -989,14 +990,14 @@ def test_device(self, xp: ModuleType, device: Device, equal_nan: bool): assert get_device(res) == device def test_array_on_device_with_scalar(self, xp: ModuleType, device: Device): - a = xp.asarray([0.01, 0.5, 0.8, 0.9, 1.00001], device=device) + a = xp.asarray([0.01, 0.5, 0.8, 0.9, 1.00001], device=device, dtype=xp.float64) b = 1 res = isclose(a, b) assert get_device(res) == device xp_assert_equal(res, xp.asarray([False, False, False, False, True])) a = 0.1 - b = xp.asarray([0.01, 0.5, 0.8, 0.9, 0.100001], device=device) + b = xp.asarray([0.01, 0.5, 0.8, 0.9, 0.100001], device=device, dtype=xp.float64) res = isclose(a, b) assert get_device(res) == device xp_assert_equal(res, xp.asarray([False, False, False, False, True]))