33import json
44from dataclasses import dataclass
55from pathlib import Path
6+ import random
67
7- from datasets import load_dataset
8+ from datasets import load_dataset , concatenate_datasets
89from open_data_scientist .codeagent import ReActDataScienceAgent
910
1011
@@ -150,31 +151,49 @@ def write_jsonl(data: list[dict], filepath: Path) -> None:
150151
151152
152153def main (
153- test_first_only = False ,
154154 submit = False ,
155155 data_dir = None ,
156156 which_split = "dev" ,
157157 skip_hard = False ,
158+ reduced_test = False ,
158159):
160+ if skip_hard and reduced_test :
161+ raise ValueError ("Cannot use both --skip-hard and --reduced-test at the same time" )
162+
159163 # Load the dataset
160164 ds = load_dataset ("adyen/DABstep" , "tasks" )
161165
162166 dataset = ds [which_split ]
163167
164168 # Store hard tasks before filtering if we're skipping and submitting
165169 skipped_tasks = []
166- if skip_hard and submit :
167- skipped_tasks = [task for task in dataset if task .get ("level" ) == "hard" ]
168-
169170 if skip_hard :
171+ skipped_tasks = [task for task in dataset if task .get ("level" ) == "hard" ]
170172 dataset = dataset .filter (lambda example : example .get ("level" ) != "hard" )
171-
172- if test_first_only :
173- dataset = dataset .select ([0 , 1 , 2 ])
173+ elif reduced_test :
174+ dataset = dataset .shuffle (seed = 42 )
175+ easy_tasks = dataset .filter (lambda x : x ["level" ] == "easy" )
176+ hard_tasks = dataset .filter (lambda x : x ["level" ] == "hard" )
177+
178+ # Sample 20 tasks from each difficulty level
179+ sampled_easy = easy_tasks .select (range (20 ))
180+ sampled_hard = hard_tasks .select (range (20 ))
181+
182+ sampled_ids = set ()
183+ for task in sampled_easy :
184+ sampled_ids .add (task ["task_id" ])
185+ for task in sampled_hard :
186+ sampled_ids .add (task ["task_id" ])
187+
188+ skipped_tasks = [task for task in dataset if task ["task_id" ] not in sampled_ids ]
189+ dataset = concatenate_datasets ([sampled_easy , sampled_hard ])
190+ dataset = dataset .shuffle (seed = 42 )
191+ else :
192+ print ("Running all tasks" )
174193
175194 number_of_examples = len (dataset )
176195 results = []
177- with concurrent .futures .ThreadPoolExecutor (max_workers = 10 ) as executor :
196+ with concurrent .futures .ThreadPoolExecutor (max_workers = 3 ) as executor :
178197 future_to_task = {
179198 executor .submit (process_task , task , submit , data_dir ): task
180199 for task in dataset
@@ -222,9 +241,6 @@ def main(
222241
223242if __name__ == "__main__" :
224243 parser = argparse .ArgumentParser (description = "Run DABstep evaluation" )
225- parser .add_argument (
226- "--test-first-only" , action = "store_true" , help = "Test only the first example"
227- )
228244 parser .add_argument (
229245 "--submit" , action = "store_true" , help = "Submit the results to the leaderboard"
230246 )
@@ -237,6 +253,9 @@ def main(
237253 parser .add_argument (
238254 "--skip-hard" , action = "store_true" , help = "Skip examples with level=hard"
239255 )
256+ parser .add_argument (
257+ "--reduced-test" , action = "store_true" , help = "Sample 20 easy and 20 hard tasks"
258+ )
240259 parser .add_argument (
241260 "--data-dir" ,
242261 default = None ,
@@ -245,9 +264,9 @@ def main(
245264 args = parser .parse_args ()
246265
247266 main (
248- test_first_only = args .test_first_only ,
249267 submit = args .submit ,
250268 data_dir = args .data_dir ,
251269 which_split = args .which_split ,
252270 skip_hard = args .skip_hard ,
271+ reduced_test = args .reduced_test ,
253272 )
0 commit comments