Skip to content

Commit 4d86b09

Browse files
committed
polish(incremental): simplify iterator
Replicates graphql/graphql-js@5cd5001
1 parent a7663f4 commit 4d86b09

File tree

2 files changed

+44
-47
lines changed

2 files changed

+44
-47
lines changed

src/graphql/execution/incremental_graph.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from __future__ import annotations
44

55
from asyncio import (
6-
CancelledError,
76
Future,
87
Task,
98
ensure_future,
@@ -25,7 +24,7 @@
2524
)
2625

2726
if TYPE_CHECKING:
28-
from collections.abc import AsyncGenerator, Awaitable, Generator, Iterable, Sequence
27+
from collections.abc import Awaitable, Generator, Iterable, Sequence
2928

3029
from ..error.graphql_error import GraphQLError
3130
from .types import (
@@ -48,7 +47,7 @@ class IncrementalGraph:
4847

4948
_root_nodes: dict[SubsequentResultRecord, None]
5049
_completed_queue: list[IncrementalDataRecordResult]
51-
_next_queue: list[Future[Iterable[IncrementalDataRecordResult]]]
50+
_next_queue: list[Future[Iterable[IncrementalDataRecordResult] | None]]
5251

5352
_tasks: set[Task[Any]]
5453

@@ -87,24 +86,31 @@ def add_completed_reconcilable_deferred_grouped_field_set(
8786
incremental_data_records, deferred_records
8887
)
8988

90-
async def completed_incremental_data(
89+
def current_completed_batch(
9190
self,
92-
) -> AsyncGenerator[Iterable[IncrementalDataRecordResult], None]:
93-
"""Asynchronously yield completed incremental data record results."""
91+
) -> Generator[IncrementalDataRecordResult, None, None]:
92+
"""Yield the current completed batch of incremental data record results."""
93+
queue = self._completed_queue
94+
while queue:
95+
yield queue.pop(0)
96+
if not self._root_nodes:
97+
self.abort()
98+
99+
def next_completed_batch(
100+
self,
101+
) -> Future[Iterable[IncrementalDataRecordResult] | None]:
102+
"""Return a future that resolves to the next completed batch."""
94103
loop = get_running_loop()
95-
while True:
96-
if self._completed_queue:
97-
first_result = self._completed_queue.pop(0)
98-
yield self._yield_current_completed_incremental_data(first_result)
99-
else:
100-
future: Future[Iterable[IncrementalDataRecordResult]] = (
101-
loop.create_future()
102-
)
103-
self._next_queue.append(future)
104-
try:
105-
yield await future
106-
except CancelledError:
107-
break # pragma: no cover
104+
future: Future[Iterable[IncrementalDataRecordResult] | None] = (
105+
loop.create_future()
106+
)
107+
self._next_queue.append(future)
108+
return future
109+
110+
def abort(self) -> None:
111+
"""Abort the incremental graph execution."""
112+
for resolve in self._next_queue:
113+
resolve.set_result(None)
108114

109115
def has_next(self) -> bool:
110116
"""Check if there are more results to process."""
@@ -332,11 +338,7 @@ def _yield_current_completed_incremental_data(
332338
) -> Generator[IncrementalDataRecordResult, None, None]:
333339
"""Yield the current completed incremental data."""
334340
yield first_result
335-
queue = self._completed_queue
336-
while queue:
337-
yield queue.pop(0)
338-
if not self._root_nodes:
339-
self.stop_incremental_data()
341+
yield from self.current_completed_batch()
340342

341343
def _enqueue(self, completed: IncrementalDataRecordResult) -> None:
342344
"""Enqueue completed incremental data record result."""

src/graphql/execution/incremental_publisher.py

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
)
2828

2929
if TYPE_CHECKING:
30-
from collections.abc import AsyncGenerator, Sequence
30+
from collections.abc import AsyncGenerator, Iterable, Sequence
3131

3232
from ..error import GraphQLError
3333
from .types import (
@@ -134,38 +134,33 @@ async def _subscribe(
134134
incremental_graph = self._incremental_graph
135135
check_has_next = incremental_graph.has_next
136136
handle_completed_incremental_data = self._handle_completed_incremental_data
137-
completed_incremental_data = incremental_graph.completed_incremental_data()
138-
# use the raw iterator rather than 'async for' so as not to end the iterator
139-
# when exiting the loop with the next value
140-
get_next_results = completed_incremental_data.__aiter__().__anext__
141-
is_done = False
142-
try:
143-
while not is_done:
144-
try:
145-
completed_results = await get_next_results()
146-
except StopAsyncIteration: # pragma: no cover
147-
break
148137

149-
context = SubsequentIncrementalExecutionResultContext([], [], [])
150-
for completed_result in completed_results:
151-
await handle_completed_incremental_data(completed_result, context)
138+
try:
139+
while True:
140+
batch: Iterable[IncrementalDataRecordResult] | None = (
141+
incremental_graph.current_completed_batch()
142+
)
152143

153-
if context.incremental or context.completed: # pragma: no branch
154-
has_next = check_has_next()
144+
while batch is not None:
145+
context = SubsequentIncrementalExecutionResultContext([], [], [])
146+
for completed_result in batch:
147+
await handle_completed_incremental_data(
148+
completed_result, context
149+
)
155150

156-
if not has_next:
157-
is_done = True
151+
if context.incremental or context.completed: # pragma: no branch
152+
has_next = check_has_next()
158153

159-
subsequent_incremental_execution_result = (
160-
SubsequentIncrementalExecutionResult(
154+
yield SubsequentIncrementalExecutionResult(
161155
has_next=has_next,
162156
pending=context.pending or None,
163157
incremental=context.incremental or None,
164158
completed=context.completed or None,
165159
)
166-
)
167160

168-
yield subsequent_incremental_execution_result
161+
if not has_next:
162+
return
163+
batch = await incremental_graph.next_completed_batch()
169164
finally:
170165
await self._stop_async_iterators()
171166

0 commit comments

Comments
 (0)