1818logger = logging .getLogger (__name__ )
1919
2020
21- def check_dataset (
21+ def check_dataset ( # noqa: C901
2222 data_dir : Union [str , Path , None ] = None ,
2323 tag : Optional [str ] = None ,
24+ hexsha : Optional [str ] = None ,
2425) -> dl .Dataset :
2526 """Get or install junifer-data dataset.
2627
@@ -33,6 +34,8 @@ def check_dataset(
3334 tag: str or None, optional
3435 Tag to checkout; for example, for ``v1.0.0``, pass ``"1.0.0"``.
3536 If None, ``"main"`` is checked out (default None).
37+ hexsha: str or None, optional
38+ Commit hash to verify. If None, no verification will be performed.
3639
3740 Returns
3841 -------
@@ -43,6 +46,10 @@ def check_dataset(
4346 ------
4447 RuntimeError
4548 If there is a problem checking the dataset.
49+ ValueError
50+ If ``hexsha`` is provided for the main tag or
51+ if unknown tag is provided or
52+ if ``hexsha`` is provided but does not match the checked out ``tag``.
4653
4754 """
4855 # Check tag
@@ -51,6 +58,10 @@ def check_dataset(
5158 else :
5259 tag = "main"
5360
61+ # Avoid hexsha check for main
62+ if tag == "main" and hexsha is not None :
63+ raise ValueError ("Cannot verify hexsha for main tag." )
64+
5465 # Set dataset location
5566 if data_dir is not None :
5667 data_dir = Path (data_dir ) / tag
@@ -62,6 +73,52 @@ def check_dataset(
6273 if dl .Dataset (data_dir ).is_installed ():
6374 logger .debug (f"Found existing junifer-data at: { data_dir .resolve ()} " )
6475 dataset = dl .Dataset (data_dir )
76+ # Check if dataset is dirty
77+ if dataset .repo .dirty :
78+ raise RuntimeError (
79+ f"Found dirty junifer-data at: { data_dir .resolve ()} . "
80+ "You can clean or delete the directory."
81+ )
82+ if tag == "main" :
83+ # Main tag, use the latest commit
84+ try :
85+ dataset .update ()
86+ except CommandError as e :
87+ raise RuntimeError (
88+ f"Failed to update junifer-data: { e } "
89+ ) from e
90+ else :
91+ logger .debug ("Successfully updated junifer-data" )
92+ else :
93+ # Get commit hash for the tag
94+ tag_hexsha = [
95+ x ["hexsha" ]
96+ for x in dataset .repo .get_tags ()
97+ if x ["name" ] == tag
98+ ]
99+ # Check for incorrect tags
100+ if not tag_hexsha :
101+ raise ValueError (f"Unknown tag: { tag } for junifer-data." )
102+
103+ # Get commit hash for HEAD
104+ head_hexsha = dataset .repo .get_hexsha ()
105+ # Get tag hexsha
106+ tag_hexsha = tag_hexsha [0 ]
107+ # Check that it matches the expected hexsha from the tag info
108+ if head_hexsha != tag_hexsha :
109+ raise ValueError (
110+ f"Wrong commit checked out for tag: { tag } . "
111+ f"Expected: { tag_hexsha } , got: { head_hexsha } ."
112+ )
113+
114+ # Do hexsha verification for other tags if provided
115+ # Check that the hexsha matches the expected hexsha from the user
116+ # head_hexsha is now verified and can be used for checking
117+ if hexsha is not None and head_hexsha != hexsha :
118+ raise ValueError (
119+ f"Commit verification failed for tag: { tag } . "
120+ f"Expected: { head_hexsha } , got: { hexsha } ."
121+ )
65122 else :
66123 logger .debug (f"Cloning junifer-data to: { data_dir .resolve ()} " )
67124 # Clone dataset
@@ -79,23 +136,14 @@ def check_dataset(
79136 logger .debug (
80137 f"Successfully cloned junifer-data to: { data_dir .resolve ()} "
81138 )
82-
83- # Update dataset to stay up-to-date
84- try :
85- dataset .update ()
86- except CommandError as e :
87- raise RuntimeError (f"Failed to update junifer-data: { e } " ) from e
88- else :
89- logger .debug ("Successfully updated junifer-data" )
90-
91- # Checkout correct state
92- try :
93- dataset .recall_state (tag )
94- except CommandError as e :
95- raise RuntimeError (
96- f"Failed to checkout state of junifer-data: { e } "
97- ) from e
98- else :
99- logger .debug ("Successfully checked out state of junifer-data" )
139+ # Checkout correct state
140+ try :
141+ dataset .recall_state (tag )
142+ except CommandError as e :
143+ raise RuntimeError (
144+ f"Failed to checkout state of junifer-data: { e } "
145+ ) from e
146+ else :
147+ logger .debug ("Successfully checked out state of junifer-data" )
100148
101149 return dataset
0 commit comments