Skip to content

Commit cf11699

Browse files
committed
remove metadata and create it dynamically
1 parent 55a57b0 commit cf11699

File tree

2 files changed

+109
-813
lines changed

2 files changed

+109
-813
lines changed

browsergym/experiments/src/browsergym/experiments/benchmark/metadata/utils.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1+
import csv
2+
import importlib.resources
13
import io
4+
import json
5+
import os
26
import pkgutil
37
from collections import defaultdict
48
from copy import deepcopy
@@ -9,7 +13,112 @@
913
from browsergym.experiments.loop import EnvArgs
1014

1115

16+
def make_webarena_verified_metadata_if_not_exists():
17+
"""
18+
Checks if the webarena_verified.csv file exists. If not, it creates it.
19+
"""
20+
if os.path.exists(os.path.join(os.path.dirname(__file__), "webarena_verified.csv")):
21+
return
22+
23+
# Load the json file from the webarena-verified library
24+
data = json.loads(
25+
importlib.resources.files("webarena_verified")
26+
.joinpath("assets/dataset/webarena-verified.json")
27+
.read_text()
28+
)
29+
# Create a mapping from task_id to intent_template_id and revision for efficient lookup. This is used to find the dependency task name.
30+
task_id_to_template_id = {task["task_id"]: task["intent_template_id"] for task in data}
31+
task_id_to_revision = {task["task_id"]: task["revision"] for task in data}
32+
33+
# Read the original webarena.csv and create a mapping from task_id to original task info
34+
original_csv_path = os.path.join(os.path.dirname(__file__), "webarena.csv")
35+
original_tasks = {}
36+
with open(original_csv_path, "r") as f:
37+
reader = csv.DictReader(f)
38+
for row in reader:
39+
task_id = int(row["task_id"])
40+
original_tasks[task_id] = {
41+
"requires_reset": row["requires_reset"],
42+
"sites": row["sites"],
43+
"eval_types": row["eval_types"],
44+
"browsergym_split": row["browsergym_split"],
45+
"depends_on": row["depends_on"],
46+
}
47+
48+
# Create CSV data
49+
csv_data = []
50+
for task in data:
51+
intent_template_id = task["intent_template_id"]
52+
task_id = task["task_id"]
53+
revision = task["revision"]
54+
55+
# Extract eval_types
56+
new_eval_types = []
57+
for evaluator_config in task.get("eval", []):
58+
new_eval_types.append(evaluator_config["evaluator"])
59+
assert len(new_eval_types) > 0, f"Task {task_id} has no evaluators"
60+
new_eval_types_str = " ".join(new_eval_types)
61+
62+
# Extract new task sites
63+
sites = task.get("sites", [])
64+
sites_str = " ".join(sites) if sites else ""
65+
66+
# Get original task data for comparison and dependency copying
67+
original_task = original_tasks.get(task_id, {})
68+
69+
# Assert that new task sites matches the original task sites
70+
assert sites_str == original_task.get(
71+
"sites", ""
72+
), f"Task {task_id}: sites mismatch - JSON: {sites_str}, CSV: {original_task.get("sites", "")}"
73+
74+
# Construct the dependency task name
75+
if original_dependency := original_task.get("depends_on"):
76+
dependency_task_id = int(original_dependency.split(".")[-1])
77+
dependency_template_id = task_id_to_template_id[dependency_task_id]
78+
dependency_revision = task_id_to_revision[dependency_task_id]
79+
dependency_task_name = f"webarena_verified.{dependency_template_id}.{dependency_task_id}.{dependency_revision}"
80+
else:
81+
dependency_task_name = ""
82+
83+
# Create metadata row
84+
row = {
85+
"task_name": f"webarena_verified.{intent_template_id}.{task_id}.{revision}",
86+
"requires_reset": str(
87+
original_task.get("requires_reset", False)
88+
), # copy original requires_reset
89+
"sites": sites_str,
90+
"eval_types": new_eval_types_str,
91+
"task_id": str(task_id),
92+
"browsergym_split": original_task.get(
93+
"browsergym_split", "train"
94+
), # copy original browsergym_split
95+
"depends_on": dependency_task_name,
96+
}
97+
csv_data.append(row)
98+
99+
# Write CSV file
100+
output_path = os.path.join(os.path.dirname(__file__), "webarena_verified.csv")
101+
with open(output_path, "w", newline="") as f:
102+
fieldnames = [
103+
"task_name",
104+
"requires_reset",
105+
"sites",
106+
"eval_types",
107+
"task_id",
108+
"browsergym_split",
109+
"depends_on",
110+
]
111+
writer = csv.DictWriter(f, fieldnames=fieldnames)
112+
writer.writeheader()
113+
writer.writerows(csv_data)
114+
115+
print(f"Created {output_path} with {len(csv_data)} tasks")
116+
117+
12118
def task_metadata(benchmark_name: str):
119+
if benchmark_name == "webarena_verified":
120+
make_webarena_verified_metadata_if_not_exists()
121+
13122
return task_metadata_from_csv(
14123
io.StringIO(pkgutil.get_data(__name__, f"{benchmark_name}.csv").decode("utf-8"))
15124
)

0 commit comments

Comments
 (0)