-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
127 lines (111 loc) · 4.49 KB
/
main.py
File metadata and controls
127 lines (111 loc) · 4.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from dataclasses import dataclass
from typing import Optional
from minisweagent.agents.default import (
DefaultAgent,
NonTerminatingException,
AgentConfig,
Submitted,
)
import subprocess
from minisweagent.models import get_model
from minisweagent.environments.local import LocalEnvironment
debug = False
@dataclass
class ValidatingAgentConfig(AgentConfig):
exec_command: Optional[str] = None
class ValidatingAgent(DefaultAgent):
def __init__(self, *args, model_name: str, **kwargs):
super().__init__(
*args,
**kwargs,
config_class=ValidatingAgentConfig,
model=get_model(input_model_name=model_name),
env=LocalEnvironment(),
)
# More informative logging of the agent's message log when --debug is set
def add_message(self, role: str, content: str, **kwargs):
super().add_message(role, content, **kwargs)
if debug:
assistant_steps = sum(1 for m in self.messages if m.get("role") == "assistant")
tag = f"step {assistant_steps:02d}" if role in ("assistant", "user") else "setup"
cost_str = ""
try:
if role == "assistant":
cost_val = getattr(self.model, "cost", 0.0)
n_calls_val = getattr(self.model, "n_calls", 0)
cost_str = f" (calls={n_calls_val}, cost={cost_val:.4f})"
except Exception:
pass
print(f"[{tag}] {role}{cost_str}:")
print(content if isinstance(content, str) else str(content))
print()
def _summarize_for_log(self, text: str, limit: int = 800) -> str:
if not isinstance(text, str):
return str(text)
t = text.rstrip()
if len(t) <= limit:
return t
return t[:limit] + f"\n... [truncated {len(t) - limit} chars]"
def has_finished(self, output: dict[str, str]):
"""Only validate when the agent signals completion via sentinel line."""
lines = output.get("output", "").lstrip().splitlines(keepends=True)
if not lines:
return
first_line = lines[0].strip()
if first_line not in [
"MINI_SWE_AGENT_FINAL_OUTPUT",
"COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT",
]:
# Not a completion signal; continue stepping without validation
return
# Agent signaled completion; run validation if configured
if self.config.exec_command:
result = subprocess.run(
self.config.exec_command,
shell=True,
capture_output=True,
text=True,
)
if debug:
print(f"validation result: {result.stdout}{result.stderr}")
if result.returncode != 0:
raise NonTerminatingException(
"validation failed\nSTDOUT:\n"
+ (result.stdout or "")
+ "\nSTDERR:\n"
+ (result.stderr or "")
)
# Validation passed (or not configured) — submit final output (everything after the sentinel)
raise Submitted("".join(lines[1:]))
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, required=True)
parser.add_argument("--exec", type=str, required=True)
parser.add_argument("--debug", action="store_true")
parser.add_argument("--model", type=str, default="claude-sonnet-4-20250514")
args = parser.parse_args()
agent = ValidatingAgent(exec_command=args.exec, model_name=args.model)
global debug
debug = args.debug
status, message = agent.run(args.task)
if debug:
try:
cost_val = getattr(agent.model, "cost", 0.0)
n_calls_val = getattr(agent.model, "n_calls", 0)
print(f"finished with status={status}, steps={n_calls_val}, cost={cost_val:.4f}")
except Exception:
print(f"finished with status={status}")
# Show the full transcript
if agent.messages:
print("full transcript:")
for idx, m in enumerate(agent.messages, start=1):
role = m.get("role", "?")
content = m.get("content", "")
if not isinstance(content, str):
content = str(content)
print(f"----- message {idx} ({role}) -----")
print(content)
print("----- end message -----\n")
if __name__ == "__main__":
main()