Skip to content

Commit e8677d5

Browse files
authored
Apply splat to schema sets and path aliases without explicit select. (#929)
Ensures a splat is applied in the following cases: ```py default.User default.User.filter(...) default.User.groups default.User.groups.filter(...) default.User.filter(...).groups default.User.filter(...).groups.filter(...) ``` Also ensures that splats are applied if such a case is used within another object's select. ```py default.Group.select( users=default.Post.author # splat applied here ) ```
1 parent 16492ad commit e8677d5

File tree

5 files changed

+554
-8
lines changed

5 files changed

+554
-8
lines changed

gel/_internal/_qb/_expressions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -661,7 +661,7 @@ def wrap(
661661
kwargs = {}
662662
if isinstance(expr, ShapeOp):
663663
kwargs["body_scope"] = expr.scope
664-
elif isinstance(expr, SchemaSet):
664+
elif isinstance(expr, (SchemaSet, Path)):
665665
if splat_cb is not None:
666666
shape = splat_cb()
667667
else:

gel/_internal/_qb/_generics.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
edgeql_qb_expr,
4646
is_exprmethod,
4747
)
48-
from ._reflection import GelTypeMetadata
48+
from ._reflection import GelObjectTypeMetadata, GelTypeMetadata
4949

5050
if TYPE_CHECKING:
5151
from collections.abc import Iterable
@@ -366,7 +366,10 @@ def __infix_op__(
366366

367367
def __edgeql__(self) -> tuple[type, tuple[str, dict[str, object]]]:
368368
type_ = self.__gel_origin__
369-
if issubclass(type_, GelTypeMetadata):
369+
if issubclass(type_, GelObjectTypeMetadata) and issubclass(
370+
type_.__gel_reflection__,
371+
GelTypeMetadata.__gel_reflection__,
372+
):
370373
splat_cb = functools.partial(get_object_type_splat, type_)
371374
else:
372375
splat_cb = None

gel/_internal/_qbmodel/_abstract/_expressions.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,23 @@ def select(
189189
shape_elements.append(shape_el)
190190
else:
191191
el_expr = _qb.edgeql_qb_expr(kwarg, var=prefix_alias)
192+
if isinstance(el_expr, (_qb.SchemaSet, _qb.Path)):
193+
# If the expression is a schema set or path without an explicit
194+
# select, apply a splat.
195+
if (
196+
el_type := (
197+
kwarg
198+
if isinstance(kwarg, type)
199+
else kwarg.__gel_origin__
200+
if isinstance(kwarg, _qb.BaseAlias)
201+
else None
202+
)
203+
) and issubclass(el_type, _qb.GelObjectTypeMetadata):
204+
el_expr = _qb.ShapeOp(
205+
iter_expr=el_expr,
206+
shape=_qb.get_object_type_splat(el_type),
207+
)
208+
192209
shape_el = _qb.ShapeElement(
193210
name=ptrname,
194211
expr=el_expr,

gel/_internal/_testbase/_models.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
TypeVar,
1111
TYPE_CHECKING,
1212
)
13-
from collections.abc import Awaitable
13+
from collections.abc import Awaitable, Collection
1414
from typing_extensions import Self
1515

1616
import argparse
@@ -76,6 +76,7 @@
7676
if TYPE_CHECKING:
7777
from collections.abc import Callable, Iterator, Mapping, Sequence
7878
import pydantic
79+
from gel._internal._qbmodel._pydantic._models import GelModel
7980

8081

8182
_unset = object()
@@ -547,6 +548,83 @@ def assertPydanticPickles(
547548
getattr(model2, "__gel_changed_fields__", ...),
548549
)
549550

551+
def _assertObjectsWithFields(
552+
self,
553+
models: Collection[GelModel],
554+
identifying_field: str,
555+
expected_obj_fields: list[tuple[type[GelModel], dict[str, Any]]],
556+
) -> None:
557+
"""Test that models match the expected object fields.
558+
Pairs models with their expected fields using the identifying field.
559+
"""
560+
self.assertEqual(len(models), len(expected_obj_fields))
561+
562+
# Get models per identifier
563+
for model in models:
564+
self._assertHasFields(model, {identifying_field})
565+
566+
object_by_identifier = {
567+
expected_fields[identifying_field]: next(
568+
iter(
569+
m
570+
for m in models
571+
if getattr(m, identifying_field)
572+
== expected_fields[identifying_field]
573+
),
574+
None,
575+
)
576+
for _, expected_fields in expected_obj_fields
577+
}
578+
579+
# Check that models match obj_fields one to one
580+
for identifier, obj in object_by_identifier.items():
581+
self.assertIsNotNone(
582+
obj, f"No model with identifier '{identifier}'"
583+
)
584+
self.assertEqual(
585+
len(object_by_identifier),
586+
len(expected_obj_fields),
587+
"Duplicate identifier'",
588+
)
589+
590+
# Check each model
591+
for expected_type, expected_fields in expected_obj_fields:
592+
identifier = expected_fields[identifying_field]
593+
obj = object_by_identifier[identifier]
594+
assert obj is not None
595+
self.assertIsInstance(obj, expected_type)
596+
self._assertHasFields(obj, expected_fields)
597+
598+
def _assertHasFields(
599+
self,
600+
model: GelModel,
601+
expected_fields: dict[str, Any] | set[str],
602+
) -> None:
603+
for field_name in expected_fields:
604+
self.assertTrue(
605+
field_name in model.__pydantic_fields_set__,
606+
f"Model is missing field '{field_name}'",
607+
)
608+
609+
if isinstance(expected_fields, dict):
610+
expected = expected_fields[field_name]
611+
actual = getattr(model, field_name)
612+
self.assertEqual(
613+
expected,
614+
actual,
615+
f"Field '{field_name}' value ({actual}) different from "
616+
f"expected ({expected})",
617+
)
618+
619+
def _assertNotHasFields(
620+
self, model: GelModel, expected_fields: set[str]
621+
) -> None:
622+
for field_name in expected_fields:
623+
self.assertTrue(
624+
field_name not in model.__pydantic_fields_set__,
625+
f"Model has unexpected field '{field_name}'",
626+
)
627+
550628

551629
class ModelTestCase(SyncQueryTestCase, BaseModelTestCase): # pyright: ignore[reportIncompatibleVariableOverride, reportIncompatibleMethodOverride]
552630
pass

0 commit comments

Comments
 (0)