Skip to content

Commit 61234fd

Browse files
authored
Merge pull request #3 from juaml/enh/add_hexsha_check
refactor: use hexsha for dataset check logic
2 parents 9a706a4 + 07524e8 commit 61234fd

File tree

5 files changed

+87
-21
lines changed

5 files changed

+87
-21
lines changed

changelog.d/3.added.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add `hexsha` parameter for get operation via API and CLI and for `check_dataset`

changelog.d/3.changed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Check commit hash instead of checking out state if dataset is installed

junifer_data/_cli.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,18 +62,27 @@ def cli() -> None: # pragma: no cover
6262
metavar="<tag>",
6363
help="Tag to checkout",
6464
)
65+
@click.option(
66+
"-s",
67+
"--hexsha",
68+
default=None,
69+
type=str,
70+
metavar="<hexsha>",
71+
help="Commit hash to verify",
72+
)
6573
@click.option("-v", "--verbose", count=True, type=int)
6674
def get(
6775
file_path: click.Path,
6876
dataset_path: click.Path,
6977
tag: str,
78+
hexsha: str,
7079
verbose: int,
7180
) -> None:
7281
"""Get FILE_PATH.
7382
7483
FILE_PATH should be relative to <dataset>/<tag>, if provided.
7584
If not provided, <dataset> defaults to "$HOME/junifer_data/<tag>" and <tag>
76-
defaults to "main".
85+
defaults to "main". If <hexsha> is provided, commit hash is verified.
7786
7887
"""
7988
_set_log_config(verbose)
@@ -82,6 +91,7 @@ def get(
8291
file_path=file_path,
8392
dataset_path=dataset_path,
8493
tag=tag,
94+
hexsha=hexsha,
8595
)
8696
except RuntimeError as err:
8797
click.echo(f"Failure: {err}", err=True)

junifer_data/_functions.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def get(
2424
file_path: Path,
2525
dataset_path: Optional[Path] = None,
2626
tag: Optional[str] = None,
27+
hexsha: Optional[str] = None,
2728
) -> Path:
2829
"""Fetch ``file_path`` from junifer-data dataset.
2930
@@ -38,6 +39,8 @@ def get(
3839
tag : str or None, optional
3940
Tag to checkout; for example, for ``v1.0.0``, pass ``"1.0.0"``.
4041
If None, ``"main"`` is checked out (default None).
42+
hexsha: str or None, optional
43+
Commit hash to verify. If None, no verification will be performed.
4144
4245
Returns
4346
-------
@@ -48,10 +51,13 @@ def get(
4851
------
4952
RuntimeError
5053
If there is a problem fetching the file.
54+
ValueError
55+
If `hexsha` is provided but does not match the checked out tag.
56+
If `hexsha` is provided for the main tag.
5157
5258
"""
5359
# Get dataset
54-
dataset = check_dataset(data_dir=dataset_path, tag=tag)
60+
dataset = check_dataset(data_dir=dataset_path, tag=tag, hexsha=hexsha)
5561
# Fetch file
5662
try:
5763
got = dataset.get(file_path, result_renderer="disabled")

junifer_data/_utils.py

Lines changed: 67 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@
1818
logger = logging.getLogger(__name__)
1919

2020

21-
def check_dataset(
21+
def check_dataset( # noqa: C901
2222
data_dir: Union[str, Path, None] = None,
2323
tag: Optional[str] = None,
24+
hexsha: Optional[str] = None,
2425
) -> dl.Dataset:
2526
"""Get or install junifer-data dataset.
2627
@@ -33,6 +34,8 @@ def check_dataset(
3334
tag: str or None, optional
3435
Tag to checkout; for example, for ``v1.0.0``, pass ``"1.0.0"``.
3536
If None, ``"main"`` is checked out (default None).
37+
hexsha: str or None, optional
38+
Commit hash to verify. If None, no verification will be performed.
3639
3740
Returns
3841
-------
@@ -43,6 +46,10 @@ def check_dataset(
4346
------
4447
RuntimeError
4548
If there is a problem checking the dataset.
49+
ValueError
50+
If ``hexsha`` is provided for the main tag or
51+
if unknown tag is provided or
52+
if ``hexsha`` is provided but does not match the checked out ``tag``.
4653
4754
"""
4855
# Check tag
@@ -51,6 +58,10 @@ def check_dataset(
5158
else:
5259
tag = "main"
5360

61+
# Avoid hexsha check for main
62+
if tag == "main" and hexsha is not None:
63+
raise ValueError("Cannot verify hexsha for main tag.")
64+
5465
# Set dataset location
5566
if data_dir is not None:
5667
data_dir = Path(data_dir) / tag
@@ -62,6 +73,52 @@ def check_dataset(
6273
if dl.Dataset(data_dir).is_installed():
6374
logger.debug(f"Found existing junifer-data at: {data_dir.resolve()}")
6475
dataset = dl.Dataset(data_dir)
76+
# Check if dataset is dirty
77+
if dataset.repo.dirty:
78+
raise RuntimeError(
79+
f"Found dirty junifer-data at: {data_dir.resolve()} . "
80+
"You can clean or delete the directory."
81+
)
82+
if tag == "main":
83+
# Main tag, use the latest commit
84+
try:
85+
dataset.update()
86+
except CommandError as e:
87+
raise RuntimeError(
88+
f"Failed to update junifer-data: {e}"
89+
) from e
90+
else:
91+
logger.debug("Successfully updated junifer-data")
92+
else:
93+
# Get commit hash for the tag
94+
tag_hexsha = [
95+
x["hexsha"]
96+
for x in dataset.repo.get_tags()
97+
if x["name"] == tag
98+
]
99+
# Check for incorrect tags
100+
if not tag_hexsha:
101+
raise ValueError(f"Unknown tag: {tag} for junifer-data.")
102+
103+
# Get commit hash for HEAD
104+
head_hexsha = dataset.repo.get_hexsha()
105+
# Get tag hexsha
106+
tag_hexsha = tag_hexsha[0]
107+
# Check that it matches the expected hexsha from the tag info
108+
if head_hexsha != tag_hexsha:
109+
raise ValueError(
110+
f"Wrong commit checked out for tag: {tag}. "
111+
f"Expected: {tag_hexsha}, got: {head_hexsha}."
112+
)
113+
114+
# Do hexsha verification for other tags if provided
115+
# Check that the hexsha matches the expected hexsha from the user
116+
# head_hexsha is now verified and can be used for checking
117+
if hexsha is not None and head_hexsha != hexsha:
118+
raise ValueError(
119+
f"Commit verification failed for tag: {tag}. "
120+
f"Expected: {head_hexsha}, got: {hexsha}."
121+
)
65122
else:
66123
logger.debug(f"Cloning junifer-data to: {data_dir.resolve()}")
67124
# Clone dataset
@@ -79,23 +136,14 @@ def check_dataset(
79136
logger.debug(
80137
f"Successfully cloned junifer-data to: {data_dir.resolve()}"
81138
)
82-
83-
# Update dataset to stay up-to-date
84-
try:
85-
dataset.update()
86-
except CommandError as e:
87-
raise RuntimeError(f"Failed to update junifer-data: {e}") from e
88-
else:
89-
logger.debug("Successfully updated junifer-data")
90-
91-
# Checkout correct state
92-
try:
93-
dataset.recall_state(tag)
94-
except CommandError as e:
95-
raise RuntimeError(
96-
f"Failed to checkout state of junifer-data: {e}"
97-
) from e
98-
else:
99-
logger.debug("Successfully checked out state of junifer-data")
139+
# Checkout correct state
140+
try:
141+
dataset.recall_state(tag)
142+
except CommandError as e:
143+
raise RuntimeError(
144+
f"Failed to checkout state of junifer-data: {e}"
145+
) from e
146+
else:
147+
logger.debug("Successfully checked out state of junifer-data")
100148

101149
return dataset

0 commit comments

Comments
 (0)