Skip to content

Commit 074a7c9

Browse files
committed
Finish async iterables in the non incremental delivery case
Replicates graphql/graphql-js@d9fc656
1 parent 6531cf5 commit 074a7c9

File tree

2 files changed

+105
-68
lines changed

2 files changed

+105
-68
lines changed

src/graphql/execution/execute.py

Lines changed: 74 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -941,82 +941,88 @@ async def complete_async_iterator_value(
941941
awaitable_indices: list[int] = []
942942
append_awaitable = awaitable_indices.append
943943
stream_usage = self.get_stream_usage(field_group, path)
944+
try:
945+
early_return = async_iterator.aclose # type: ignore[attr-defined]
946+
except AttributeError:
947+
early_return = None
944948
index = 0
945-
while True:
946-
if stream_usage and index >= stream_usage.initial_count:
947-
stream_item_queue = self.build_async_stream_item_queue(
948-
index,
949-
path,
950-
async_iterator,
951-
stream_usage.field_group,
952-
info,
953-
item_type,
954-
)
955-
956-
try:
957-
early_return = async_iterator.aclose() # type: ignore
958-
except AttributeError:
959-
early_return = None
960-
stream_record: StreamRecord
961-
962-
if early_return is None:
963-
stream_record = StreamRecord(
964-
stream_item_queue, path, stream_usage.label
965-
)
966-
else:
967-
stream_record = CancellableStreamRecord(
968-
early_return,
969-
stream_item_queue,
949+
try:
950+
while True:
951+
if stream_usage and index >= stream_usage.initial_count:
952+
stream_item_queue = self.build_async_stream_item_queue(
953+
index,
970954
path,
971-
stream_usage.label,
955+
async_iterator,
956+
stream_usage.field_group,
957+
info,
958+
item_type,
972959
)
973-
if self.cancellable_streams is None: # pragma: no branch
974-
self.cancellable_streams = set()
975-
self.cancellable_streams.add(stream_record)
976-
self._canceled_iterators.add(async_iterator)
977960

978-
add_increment(stream_record)
979-
break
961+
stream_record: StreamRecord
980962

981-
item_path = path.add_key(index, None)
982-
try:
983-
item = await anext(async_iterator)
984-
except StopAsyncIteration:
985-
break
986-
except Exception as raw_error:
987-
raise located_error(
988-
raw_error, to_nodes(field_group), path.as_list()
989-
) from raw_error
963+
if early_return is None:
964+
stream_record = StreamRecord(
965+
stream_item_queue, path, stream_usage.label
966+
)
967+
else:
968+
stream_record = CancellableStreamRecord(
969+
early_return(),
970+
stream_item_queue,
971+
path,
972+
stream_usage.label,
973+
)
974+
if self.cancellable_streams is None: # pragma: no branch
975+
self.cancellable_streams = set()
976+
self.cancellable_streams.add(stream_record)
977+
self._canceled_iterators.add(async_iterator)
990978

991-
if is_awaitable(item):
992-
append_completed(
993-
complete_awaitable_list_item_value(
994-
item,
995-
graphql_wrapped_result,
996-
item_type,
997-
field_group,
998-
info,
999-
item_path,
1000-
incremental_context,
1001-
defer_map,
1002-
)
1003-
)
1004-
append_awaitable(index)
979+
add_increment(stream_record)
980+
break
1005981

1006-
elif complete_list_item_value(
1007-
item,
1008-
completed_results,
1009-
graphql_wrapped_result,
1010-
item_type,
1011-
field_group,
1012-
info,
1013-
item_path,
1014-
incremental_context,
1015-
defer_map,
1016-
):
1017-
append_awaitable(index)
982+
item_path = path.add_key(index, None)
983+
try:
984+
item = await anext(async_iterator)
985+
except StopAsyncIteration:
986+
break
987+
except Exception as raw_error:
988+
raise located_error(
989+
raw_error, to_nodes(field_group), path.as_list()
990+
) from raw_error
991+
992+
if is_awaitable(item):
993+
append_completed(
994+
complete_awaitable_list_item_value(
995+
item,
996+
graphql_wrapped_result,
997+
item_type,
998+
field_group,
999+
info,
1000+
item_path,
1001+
incremental_context,
1002+
defer_map,
1003+
)
1004+
)
1005+
append_awaitable(index)
10181006

1019-
index += 1
1007+
elif complete_list_item_value(
1008+
item,
1009+
completed_results,
1010+
graphql_wrapped_result,
1011+
item_type,
1012+
field_group,
1013+
info,
1014+
item_path,
1015+
incremental_context,
1016+
defer_map,
1017+
):
1018+
append_awaitable(index)
1019+
1020+
index += 1
1021+
except Exception:
1022+
if early_return is not None: # pragma: no branch
1023+
with suppress_exceptions:
1024+
await early_return()
1025+
raise
10201026

10211027
if not awaitable_indices:
10221028
return graphql_wrapped_result

tests/execution/test_lists.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,37 @@ async def list_field():
356356
assert await _complete(list_field(), "[Int!]") == ({"listField": None}, errors)
357357
assert await _complete(list_field(), "[Int!]!") == (None, errors)
358358

359+
async def calls_aclose_when_non_null_list_item_errors():
360+
values = (1, None, 2)
361+
362+
class ListField:
363+
def __init__(self) -> None:
364+
self.index = 0
365+
self.closed = False
366+
367+
def __aiter__(self):
368+
return self
369+
370+
async def __anext__(self):
371+
value = values[self.index]
372+
self.index += 1
373+
return value
374+
375+
async def aclose(self) -> None:
376+
self.closed = True
377+
378+
list_field = ListField()
379+
errors = [
380+
{
381+
"message": "Cannot return null for non-nullable field Query.listField.",
382+
"locations": [(1, 3)],
383+
"path": ["listField", 1],
384+
}
385+
]
386+
387+
assert await _complete(list_field, "[Int!]") == ({"listField": None}, errors)
388+
assert list_field.closed
389+
359390

360391
def describe_execute_handles_list_nullability():
361392
async def _complete(list_field: Any, as_type: str) -> ExecutionResult:

0 commit comments

Comments
 (0)