Skip to content

Commit d642ba7

Browse files
authored
test python 3.14 (#2157)
1 parent 13f16a7 commit d642ba7

File tree

4 files changed

+24
-13
lines changed

4 files changed

+24
-13
lines changed

.github/workflows/ci.yml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ jobs:
2525
runs-on: ubuntu-latest
2626
strategy:
2727
matrix:
28-
python-version: ["3.11", "3.13"]
28+
python-version: ["3.11", "3.14"]
2929

3030
steps:
3131
- uses: actions/checkout@v6
@@ -60,7 +60,7 @@ jobs:
6060
needs: [lint, prek]
6161
strategy:
6262
matrix:
63-
python-version: ["3.11", "3.13"]
63+
python-version: ["3.11", "3.14"]
6464
env:
6565
UV_PYTHON: ${{ matrix.python-version }}
6666

@@ -88,7 +88,7 @@ jobs:
8888
run: |
8989
JAX_ENABLE_X64=1 uv run pytest -vs test/test_distributions.py -k "powerLaw or Dagum"
9090
- name: Test tracer leak
91-
if: matrix.python-version == '3.13'
91+
if: matrix.python-version == '3.14'
9292
env:
9393
JAX_CHECK_TRACER_LEAKS: 1
9494
run: |
@@ -99,7 +99,7 @@ jobs:
9999
test/infer/test_mcmc.py::test_reuse_mcmc_run
100100
uv run pytest -vs test/test_distributions.py::test_mean_var -k Gompertz
101101
- name: Coveralls
102-
if: github.repository == 'pyro-ppl/numpyro' && matrix.python-version == '3.13'
102+
if: github.repository == 'pyro-ppl/numpyro' && matrix.python-version == '3.14'
103103
uses: coverallsapp/github-action@v2
104104
with:
105105
github-token: ${{ secrets.GITHUB_TOKEN }}
@@ -112,7 +112,7 @@ jobs:
112112
needs: [lint, prek]
113113
strategy:
114114
matrix:
115-
python-version: ["3.11", "3.13"]
115+
python-version: ["3.11", "3.14"]
116116
env:
117117
UV_PYTHON: ${{ matrix.python-version }}
118118

@@ -155,7 +155,7 @@ jobs:
155155
run: |
156156
JAX_ENABLE_X64=1 uv run pytest -vs test/contrib/test_nested_sampling.py
157157
- name: Coveralls
158-
if: github.repository == 'pyro-ppl/numpyro' && matrix.python-version == '3.13'
158+
if: github.repository == 'pyro-ppl/numpyro' && matrix.python-version == '3.14'
159159
uses: coverallsapp/github-action@v2
160160
with:
161161
github-token: ${{ secrets.GITHUB_TOKEN }}
@@ -168,7 +168,7 @@ jobs:
168168
needs: [lint, prek]
169169
strategy:
170170
matrix:
171-
python-version: ["3.13"]
171+
python-version: ["3.14"]
172172
env:
173173
UV_PYTHON: ${{ matrix.python-version }}
174174

@@ -194,7 +194,7 @@ jobs:
194194
run: |
195195
CI=1 XLA_FLAGS="--xla_force_host_platform_device_count=2" uv run pytest -vs -k test_example
196196
- name: Coveralls
197-
if: github.repository == 'pyro-ppl/numpyro' && matrix.python-version == '3.13'
197+
if: github.repository == 'pyro-ppl/numpyro' && matrix.python-version == '3.14'
198198
uses: coverallsapp/github-action@v2
199199
with:
200200
github-token: ${{ secrets.GITHUB_TOKEN }}

numpyro/contrib/nested_sampling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from jaxns.public import DefaultNestedSampler
2222
from jaxns.utils import NestedSamplerResults
2323

24-
except ImportError as e:
24+
except (ImportError, AttributeError) as e:
2525
raise ImportError(
2626
f"{e} \n "
2727
f"To use this module, please install `jaxns>2.5` package. It can be"

numpyro/examples/datasets.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,15 +172,21 @@ def _download_with_retries(url: str, out_path: str) -> None:
172172
request = Request(url, headers={"User-Agent": "numpyro-datasets"})
173173
with urlopen(request) as response, open(out_path, "wb") as f:
174174
shutil.copyfileobj(response, f)
175+
if isinstance(last_exc, HTTPError):
176+
last_exc.close()
175177
return
176178
except (HTTPError, URLError) as exc:
177179
retryable = _is_retryable_error(exc)
178180
delay = _download_backoff_delay(exc, attempt) if retryable else None
179181
detached_exc = _detached_download_error(exc)
180182
if isinstance(exc, HTTPError):
181183
exc.close()
184+
if isinstance(last_exc, HTTPError):
185+
last_exc.close()
182186
last_exc = detached_exc
183187
if not retryable:
188+
if isinstance(detached_exc, HTTPError):
189+
detached_exc.close()
184190
raise detached_exc
185191
if attempt < _DOWNLOAD_MAX_RETRIES - 1:
186192
print(
@@ -189,6 +195,8 @@ def _download_with_retries(url: str, out_path: str) -> None:
189195
)
190196
)
191197
time.sleep(delay)
198+
if isinstance(last_exc, HTTPError):
199+
last_exc.close()
192200
raise last_exc
193201

194202

@@ -221,6 +229,8 @@ def _download(dset: dset) -> None:
221229
print("Download complete.")
222230
break
223231
else:
232+
if isinstance(last_exc, HTTPError):
233+
last_exc.close()
224234
raise last_exc
225235

226236

test/contrib/test_nested_sampling.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,23 @@
1111
import jax.numpy as jnp
1212

1313
import numpyro
14+
import numpyro.distributions as dist
15+
from numpyro.distributions.transforms import AffineTransform, ExpTransform
1416

17+
_jaxns_available = True
1518
try:
1619
if os.environ.get("JAX_ENABLE_X64"):
1720
from numpyro.contrib.nested_sampling import NestedSampler, UniformReparam
1821

1922
except ImportError:
20-
pytestmark = pytest.mark.skip(reason="jaxns is not installed")
23+
_jaxns_available = False
2124

22-
import numpyro.distributions as dist
23-
from numpyro.distributions.transforms import AffineTransform, ExpTransform
2425

2526
pytestmark = [
2627
pytest.mark.filterwarnings("ignore:jax.tree_.+ is deprecated:FutureWarning"),
2728
pytest.mark.filterwarnings("ignore:JAX x64"),
2829
pytest.mark.skipif(
29-
not os.environ.get("JAX_ENABLE_X64"),
30+
not os.environ.get("JAX_ENABLE_X64") or not _jaxns_available,
3031
reason="test suite for jaxns requires double precision",
3132
),
3233
]

0 commit comments

Comments
 (0)