Skip to content

Commit 510ee42

Browse files
author
Chojan Shang
committed
fix: prevent request processing from blocking permission or cancel handling
Signed-off-by: Chojan Shang <[email protected]>
1 parent e8b1f76 commit 510ee42

File tree

3 files changed

+160
-3
lines changed

3 files changed

+160
-3
lines changed

src/acp/core.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def __init__(
124124
self._reader = reader
125125
self._next_request_id = 0
126126
self._pending: dict[int, _Pending] = {}
127+
self._inflight: set[asyncio.Task[Any]] = set()
127128
self._write_lock = asyncio.Lock()
128129
self._recv_task = asyncio.create_task(self._receive_loop())
129130

@@ -132,6 +133,13 @@ async def close(self) -> None:
132133
self._recv_task.cancel()
133134
with contextlib.suppress(asyncio.CancelledError):
134135
await self._recv_task
136+
if self._inflight:
137+
tasks = list(self._inflight)
138+
for task in tasks:
139+
task.cancel()
140+
for task in tasks:
141+
with contextlib.suppress(asyncio.CancelledError):
142+
await task
135143
# Do not close writer here; lifecycle owned by caller
136144

137145
# --- IO loops ----------------------------------------------------------------
@@ -158,12 +166,28 @@ async def _process_message(self, message: dict) -> None:
158166
has_id = "id" in message
159167

160168
if method is not None and has_id:
161-
await self._handle_request(message)
162-
elif method is not None and not has_id:
169+
self._schedule(self._handle_request(message))
170+
return
171+
if method is not None and not has_id:
163172
await self._handle_notification(message)
164-
elif has_id:
173+
return
174+
if has_id:
165175
await self._handle_response(message)
166176

177+
def _schedule(self, coro: Awaitable[Any]) -> None:
178+
task = asyncio.create_task(coro)
179+
self._inflight.add(task)
180+
task.add_done_callback(self._task_done)
181+
182+
def _task_done(self, task: asyncio.Task[Any]) -> None:
183+
self._inflight.discard(task)
184+
if task.cancelled():
185+
return
186+
try:
187+
task.result()
188+
except Exception:
189+
logging.exception("Unhandled error in JSON-RPC request handler")
190+
167191
async def _handle_request(self, message: dict) -> None:
168192
"""Handle JSON-RPC request."""
169193
payload = {"jsonrpc": "2.0", "id": message["id"]}

tests/test_cancel_prompt_flow.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import asyncio
2+
3+
import pytest
4+
5+
from acp import (
6+
AgentSideConnection,
7+
CancelNotification,
8+
ClientSideConnection,
9+
PromptRequest,
10+
PromptResponse,
11+
)
12+
from acp.schema import TextContentBlock
13+
from tests.test_rpc import TestAgent, TestClient, _Server
14+
15+
16+
class LongRunningAgent(TestAgent):
17+
"""Agent variant whose prompt waits for a cancel notification."""
18+
19+
def __init__(self) -> None:
20+
super().__init__()
21+
self.prompt_started = asyncio.Event()
22+
self.cancel_received = asyncio.Event()
23+
24+
async def prompt(self, params: PromptRequest) -> PromptResponse:
25+
self.prompts.append(params)
26+
self.prompt_started.set()
27+
try:
28+
await asyncio.wait_for(self.cancel_received.wait(), timeout=1.0)
29+
except asyncio.TimeoutError as exc:
30+
msg = "Cancel notification did not arrive while prompt pending"
31+
raise AssertionError(msg) from exc
32+
return PromptResponse(stopReason="cancelled")
33+
34+
async def cancel(self, params: CancelNotification) -> None:
35+
await super().cancel(params)
36+
self.cancel_received.set()
37+
38+
39+
@pytest.mark.asyncio
40+
async def test_cancel_reaches_agent_during_prompt() -> None:
41+
async with _Server() as server:
42+
agent = LongRunningAgent()
43+
client = TestClient()
44+
agent_conn = ClientSideConnection(lambda _conn: client, server.client_writer, server.client_reader)
45+
_client_conn = AgentSideConnection(lambda _conn: agent, server.server_writer, server.server_reader)
46+
47+
prompt_request = PromptRequest(
48+
sessionId="sess-xyz",
49+
prompt=[TextContentBlock(type="text", text="hello")],
50+
)
51+
prompt_task = asyncio.create_task(agent_conn.prompt(prompt_request))
52+
53+
await agent.prompt_started.wait()
54+
assert not prompt_task.done(), "Prompt finished before cancel was sent"
55+
56+
await agent_conn.cancel(CancelNotification(sessionId="sess-xyz"))
57+
58+
await asyncio.wait_for(agent.cancel_received.wait(), timeout=1.0)
59+
60+
response = await asyncio.wait_for(prompt_task, timeout=1.0)
61+
assert response.stopReason == "cancelled"
62+
assert agent.cancellations == ["sess-xyz"]

tests/test_permission_flow.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import asyncio
2+
3+
import pytest
4+
5+
from acp import (
6+
AgentSideConnection,
7+
ClientSideConnection,
8+
PromptRequest,
9+
PromptResponse,
10+
RequestPermissionRequest,
11+
RequestPermissionResponse,
12+
)
13+
from acp.schema import PermissionOption, TextContentBlock, ToolCallUpdate
14+
from tests.test_rpc import TestAgent, TestClient, _Server
15+
16+
17+
class PermissionRequestAgent(TestAgent):
18+
"""Agent that asks the client for permission during a prompt."""
19+
20+
def __init__(self, conn: AgentSideConnection) -> None:
21+
super().__init__()
22+
self._conn = conn
23+
self.permission_responses: list[RequestPermissionResponse] = []
24+
25+
async def prompt(self, params: PromptRequest) -> PromptResponse:
26+
permission = await self._conn.requestPermission(
27+
RequestPermissionRequest(
28+
sessionId=params.sessionId,
29+
options=[
30+
PermissionOption(optionId="allow", name="Allow", kind="allow"),
31+
PermissionOption(optionId="deny", name="Deny", kind="deny"),
32+
],
33+
toolCall=ToolCallUpdate(toolCallId="call-1", title="Write File"),
34+
)
35+
)
36+
self.permission_responses.append(permission)
37+
return await super().prompt(params)
38+
39+
40+
@pytest.mark.asyncio
41+
async def test_agent_request_permission_roundtrip() -> None:
42+
async with _Server() as server:
43+
client = TestClient()
44+
client.queue_permission_selected("allow")
45+
46+
captured_agent: list[PermissionRequestAgent] = []
47+
48+
agent_conn = ClientSideConnection(lambda _conn: client, server.client_writer, server.client_reader)
49+
_agent_conn = AgentSideConnection(
50+
lambda conn: captured_agent.append(PermissionRequestAgent(conn)) or captured_agent[-1],
51+
server.server_writer,
52+
server.server_reader,
53+
)
54+
55+
response = await asyncio.wait_for(
56+
agent_conn.prompt(
57+
PromptRequest(
58+
sessionId="sess-perm",
59+
prompt=[TextContentBlock(type="text", text="needs approval")],
60+
)
61+
),
62+
timeout=1.0,
63+
)
64+
assert response.stopReason == "end_turn"
65+
66+
assert captured_agent, "Agent was not constructed"
67+
[agent] = captured_agent
68+
assert agent.permission_responses, "Agent did not receive permission response"
69+
permission_response = agent.permission_responses[0]
70+
assert permission_response.outcome.outcome == "selected"
71+
assert permission_response.outcome.optionId == "allow"

0 commit comments

Comments
 (0)