Skip to content

Commit 565efa7

Browse files
committed
add n_tries parameter
1 parent de73f7d commit 565efa7

File tree

3 files changed

+14
-5
lines changed

3 files changed

+14
-5
lines changed

tests/test_20_open_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import xarray as xr
55

66

7-
@pytest.mark.parametrize("download", [False])
7+
@pytest.mark.parametrize("download", [True, False])
88
def test_open_dataset(tmp_path: Path, index_node: str, download: bool) -> None:
99
esgpull_path = tmp_path / "esgpull"
1010
selection = {

xarray_esgf/client.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class Client:
4444
selection: dict[str, str | list[str]]
4545
esgpull_path: str | Path | None = None
4646
index_node: str | None = None
47+
n_tries: int = 1
4748

4849
@cached_property
4950
def _client(self) -> Esgpull:
@@ -75,9 +76,15 @@ def missing_files(self) -> list[File]:
7576
]
7677

7778
def download(self) -> list[File]:
78-
downloaded, errors = asyncio.run(
79-
self._client.download(self.missing_files, use_db=False)
80-
)
79+
files = []
80+
for _ in range(self.n_tries):
81+
downloaded, errors = asyncio.run(
82+
self._client.download(self.missing_files, use_db=False)
83+
)
84+
files.extend(downloaded)
85+
if not errors:
86+
break
87+
8188
exceptions = []
8289
for error in errors:
8390
err = error.err
@@ -86,7 +93,7 @@ def download(self) -> list[File]:
8693
if exceptions:
8794
msg = "Download errors"
8895
raise ExceptionGroup(msg, exceptions)
89-
return downloaded
96+
return files
9097

9198
@use_new_combine_kwarg_defaults
9299
def open_dataset(

xarray_esgf/engine.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@ def open_dataset( # type: ignore[override]
1919
concat_dims: DATASET_ID_KEYS | Iterable[DATASET_ID_KEYS] | None = None,
2020
download: bool = False,
2121
show_progress: bool = True,
22+
n_tries: int = 1,
2223
) -> Dataset:
2324
client = Client(
2425
selection=filename_or_obj,
2526
esgpull_path=esgpull_path,
2627
index_node=index_node,
28+
n_tries=n_tries,
2729
)
2830
return client.open_dataset(
2931
concat_dims=concat_dims,

0 commit comments

Comments
 (0)