Skip to content

Commit 61c547b

Browse files
long horizon execution (#1119)
* ready for review * some fixes * addressing comments + fixing multi turn * fit the task in one file --------- Co-authored-by: Akshath Mangudi <[email protected]>
1 parent e7048c3 commit 61c547b

File tree

1 file changed

+188
-0
lines changed

1 file changed

+188
-0
lines changed
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
"""
2+
name:
3+
Long Horizon Execution
4+
5+
dataset:
6+
arvindh75/Long-Horizon-Execution
7+
8+
abstract:
9+
Evaluation benchmark for long-context execution capabilities of language models.
10+
Tests a model's ability to maintain state and perform cumulative operations over
11+
long sequences of inputs. Supports both single-turn (all inputs at once) and
12+
multi-turn (inputs provided incrementally) evaluation modes.
13+
The task requires models to:
14+
1. Maintain a dictionary mapping keys to values
15+
2. Process a sequence of keys
16+
3. Calculate cumulative sums after each key or group of keys
17+
4. Handle varying context sizes and turn complexities
18+
Single-turn evaluation (Section 3.3): Model outputs only the final cumulative sum
19+
after processing all keys, allowing any aggregation strategy.
20+
21+
Multi-turn evaluation: Model processes keys in batches of K per turn, maintaining
22+
conversation history and outputting cumulative sums incrementally. Evaluates
23+
fractional accuracy (correct turns / total turns).
24+
25+
languages:
26+
english
27+
28+
tags:
29+
long-context, state-tracking, arithmetic, execution
30+
31+
paper:
32+
https://arxiv.org/abs/2509.09677
33+
34+
starred:
35+
true
36+
"""
37+
38+
import functools
39+
import itertools
40+
import re
41+
42+
from inspect_ai.dataset import Sample
43+
from inspect_ai.model import ChatMessageUser
44+
from inspect_ai.scorer import Score, Target, accuracy, mean, scorer
45+
from inspect_ai.solver import Generate, TaskState, solver
46+
47+
from lighteval.metrics.metrics import Metrics
48+
from lighteval.tasks.lighteval_task import LightevalTaskConfig
49+
50+
51+
PROMPT_TEMPLATE_MULTI_FOLLOWUP = """
52+
I will now provide you with the next {k} keys to process:
53+
54+
{keys_str}
55+
""".strip()
56+
57+
PROMPT_TEMPLATE_MULTI_START = """
58+
I will provide you with a dictionary and then give you the first {k} keys to process.
59+
Your task is to keep a running total (starting from 0) by adding the values associated with the keys I provide.
60+
In each turn, I'll provide {k} keys (comma-separated).
61+
Respond with the current running sum, enclosed in <answer> tags.
62+
63+
Dictionary to maintain:
64+
{dict_str}
65+
66+
Ready to start!
67+
68+
{keys_str}
69+
""".strip()
70+
71+
72+
def record_to_sample(record, k=1, max_turns=5):
73+
input_keys, input_values = record["input"], record["values"]
74+
75+
dictionary = dict(zip(input_keys, input_values))
76+
dictionary_str = str(dictionary)
77+
78+
keys_per_turn = [input_keys[i : i + k] for i in range(0, len(input_keys), k)][:max_turns]
79+
values_per_turn = [input_values[i : i + k] for i in range(0, len(input_values), k)][:max_turns]
80+
81+
targets_per_turn = list(itertools.accumulate(sum(values) for values in values_per_turn))
82+
83+
initial_prompt = PROMPT_TEMPLATE_MULTI_START.format(dict_str=dictionary_str, keys_str=str(keys_per_turn[0]), k=k)
84+
85+
metadata = {
86+
"keys_per_turn": keys_per_turn,
87+
"values_per_turn": values_per_turn,
88+
"targets_per_turn": targets_per_turn,
89+
"k": k,
90+
"max_turns": max_turns,
91+
}
92+
93+
return Sample(
94+
input=initial_prompt,
95+
target=str(targets_per_turn[-1]), # last turn cumulative sum
96+
metadata=metadata,
97+
)
98+
99+
100+
@solver
101+
def solver():
102+
async def solve(state: TaskState, generate: Generate):
103+
keys_per_turn = state.metadata["keys_per_turn"]
104+
105+
all_turn_outputs = []
106+
107+
# Process first turn (already in messages as initial prompt)
108+
state = await generate(state)
109+
all_turn_outputs.append(state.output.completion)
110+
111+
# Process remaining turns
112+
for keys in keys_per_turn[1:]:
113+
keys_str = ", ".join(keys)
114+
followup_prompt = PROMPT_TEMPLATE_MULTI_FOLLOWUP.format(keys_str=keys_str, k=state.metadata["k"])
115+
state.messages.append(ChatMessageUser(content=followup_prompt))
116+
state = await generate(state)
117+
all_turn_outputs.append(state.output.completion)
118+
119+
state.metadata["all_turn_outputs"] = all_turn_outputs
120+
121+
return state
122+
123+
return solve
124+
125+
126+
@scorer(metrics={"horizon": [mean()], "turn_accuracy": [mean()], "all_correct": [accuracy()]})
127+
def scorer():
128+
answer_pattern = re.compile(r"<answer>(.*?)</answer>", re.DOTALL)
129+
130+
async def score(state: TaskState, target: Target):
131+
all_turn_outputs = state.metadata.get("all_turn_outputs", [])
132+
targets_per_turn = state.metadata.get("targets_per_turn", [])
133+
134+
parsed_outputs = []
135+
136+
for turn_output in all_turn_outputs:
137+
match = answer_pattern.search(turn_output)
138+
if match:
139+
content = match.group(1).strip()
140+
try:
141+
parsed_value = int(content)
142+
parsed_outputs.append(parsed_value)
143+
except ValueError:
144+
parsed_outputs.append(None)
145+
146+
turn_results = []
147+
for turn_output, target in zip(parsed_outputs, targets_per_turn):
148+
is_correct = (turn_output is not None) and (turn_output == target)
149+
turn_results.append({"output": turn_output, "target": target, "correct": is_correct})
150+
151+
turn_accuracy = sum(result["correct"] for result in turn_results) / len(turn_results)
152+
153+
# Horizon: first turn (0-indexed) where the model was not correct anymore
154+
# If all turns are correct, horizon is len(turn_results) (number of turns completed)
155+
horizon = len(turn_results)
156+
for turn_idx, result in enumerate(turn_results):
157+
if not result["correct"]:
158+
horizon = turn_idx
159+
break
160+
161+
return Score(
162+
value={
163+
"turn_accuracy": turn_accuracy,
164+
"horizon": horizon,
165+
"all_correct": all(result["correct"] for result in turn_results),
166+
},
167+
answer=str(turn_results),
168+
explanation=state.output.completion,
169+
)
170+
171+
return score
172+
173+
174+
long_horizon_execution_10 = LightevalTaskConfig(
175+
name="long_horizon_execution",
176+
prompt_function=lambda line, task_name: line,
177+
sample_fields=functools.partial(record_to_sample, k=10, max_turns=30),
178+
solver=[solver()],
179+
scorer=[scorer()],
180+
hf_repo="arvindh75/Long-Horizon-Execution",
181+
hf_subset="default",
182+
evaluation_splits=("test",),
183+
metrics=[Metrics.exact_match],
184+
)
185+
186+
TASKS_TABLE = [
187+
long_horizon_execution_10,
188+
]

0 commit comments

Comments
 (0)