Skip to content

Commit b2c9224

Browse files
authored
Fix a bunch of range/tuple/array operations (#946)
Generate `std.range`/`std.tuple`, follownig the approach of #939. Then, fix calling functions in QB when range[_T_anypoint] or array[_T_anytype] used, by specifically checking ParametricType subclasses against GenericAlias[T].
1 parent 0194d19 commit b2c9224

File tree

6 files changed

+94
-30
lines changed

6 files changed

+94
-30
lines changed

gel/_internal/_codegen/_models/_pydantic.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1506,14 +1506,6 @@ def get_type(
15061506
else ImportTime.runtime
15071507
)
15081508

1509-
if (
1510-
import_time is ImportTime.typecheck_runtime
1511-
or import_time is ImportTime.late_runtime
1512-
):
1513-
foreign_import_time = ImportTime.runtime
1514-
else:
1515-
foreign_import_time = import_time
1516-
15171509
if reflection.is_array_type(stype):
15181510
arr = self.get_object(
15191511
SchemaPath('std', 'array'),
@@ -1529,8 +1521,9 @@ def get_type(
15291521
return f"{arr}[{elem_type}]"
15301522

15311523
elif reflection.is_tuple_type(stype):
1532-
tup = self.import_name(
1533-
BASE_IMPL, "Tuple", import_time=foreign_import_time
1524+
tup = self.get_object(
1525+
SchemaPath('std', 'tuple'),
1526+
aspect=ModuleAspect.SHAPES,
15341527
)
15351528
elem_types = [
15361529
self.get_type(
@@ -1545,8 +1538,9 @@ def get_type(
15451538
return f"{tup}[{', '.join(elem_types)}]"
15461539

15471540
elif reflection.is_range_type(stype):
1548-
rang = self.import_name(
1549-
BASE_IMPL, "Range", import_time=foreign_import_time
1541+
rang = self.get_object(
1542+
SchemaPath('std', 'range'),
1543+
aspect=ModuleAspect.SHAPES,
15501544
)
15511545
elem_type = self.get_type(
15521546
stype.get_element_type(self._types),
@@ -1558,11 +1552,13 @@ def get_type(
15581552
return f"{rang}[{elem_type}]"
15591553

15601554
elif reflection.is_multi_range_type(stype):
1561-
rang_el = self.import_name(
1562-
BASE_IMPL, "Range", import_time=foreign_import_time
1555+
rang_el = self.get_object(
1556+
SchemaPath('std', 'range'),
1557+
aspect=ModuleAspect.SHAPES,
15631558
)
1564-
rang = self.import_name(
1565-
BASE_IMPL, "MultiRange", import_time=foreign_import_time
1559+
rang = self.get_object(
1560+
SchemaPath('std', 'multirange'),
1561+
aspect=ModuleAspect.SHAPES,
15661562
)
15671563
elem_type = self.get_type(
15681564
stype.get_element_type(self._types),
@@ -2420,12 +2416,13 @@ def write_generic_types(
24202416
tname = gt.name
24212417
tmeta = f"__{tname}_meta__"
24222418
with self._class_def(tmeta, meta_bases):
2423-
un_ops = self._write_prefix_operator_methods(ptype)
2424-
bin_ops = self._write_infix_operator_methods(ptype)
2425-
if gt.name == "anytype":
2426-
self.write(f"__hash__ = {type_}.__hash__")
2427-
elif not un_ops and not bin_ops:
2428-
self.write("pass")
2419+
self._write_prefix_operator_methods(ptype)
2420+
self._write_infix_operator_methods(ptype)
2421+
# We have custom __eq__ functions for codegen
2422+
# purposes... but when applied to two normal
2423+
# types they still behave as identity, so keep the
2424+
# normal __hash__.
2425+
self.write(f"__hash__ = {type_}.__hash__")
24292426

24302427
class_kwargs["metaclass"] = tmeta
24312428

gel/_internal/_typing_dispatch.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,57 @@
3636
from gel._internal import _namespace
3737
from gel._internal import _typing_eval
3838
from gel._internal import _typing_inspect
39+
from gel._internal import _typing_parametric
3940
from gel._internal._utils import type_repr
4041

4142
_P = ParamSpec("_P")
4243
_R_co = TypeVar("_R_co", covariant=True)
4344

4445

46+
def _resolve_to_bound(tp: Any, fn: Any) -> Any:
47+
if isinstance(tp, TypeVar):
48+
tp = tp.__bound__
49+
ns = _namespace.module_ns_of(fn)
50+
tp = _typing_eval.resolve_type(tp, globals=ns)
51+
52+
return tp
53+
54+
55+
def _issubclass(lhs: Any, tp: Any, fn: Any) -> bool:
56+
# NB: Much more limited than _isinstance below.
57+
58+
# The only special case here is handling subtyping on
59+
# our ParametricTypes. ParametricType creates bona fide
60+
# subclasses when indexed with concrete types, but GenericAlias
61+
# when indexed with type variables.
62+
#
63+
# This handles the case where the RHS of issubclass is a
64+
# GenericAlias over one of our ParametricTypes by comparing the
65+
# types for equality and then checking that the concrete types are
66+
# subtypes of the variable bounds.
67+
# This lets us handle cases like:
68+
# std.array[Object] <: std.array[_T_anytype].
69+
if _typing_inspect.is_generic_alias(tp):
70+
origin = typing.get_origin(tp)
71+
args = typing.get_args(tp)
72+
if issubclass(origin, _typing_parametric.ParametricType):
73+
if (
74+
not issubclass(lhs, _typing_parametric.ParametricType)
75+
or lhs.__parametric_origin__ is not origin
76+
or lhs.__parametric_type_args__ is None
77+
):
78+
return False
79+
80+
targs = lhs.__parametric_type_args__[origin]
81+
return all(
82+
_issubclass(l, _resolve_to_bound(r, fn), fn)
83+
for l, r in zip(targs, args, strict=True)
84+
)
85+
86+
# In other cases,
87+
return issubclass(lhs, tp) # pyright: ignore [reportArgumentType]
88+
89+
4590
def _isinstance(obj: Any, tp: Any, fn: Any) -> bool:
4691
# Handle Any type - matches everything
4792
if tp is Any:
@@ -62,17 +107,16 @@ def _isinstance(obj: Any, tp: Any, fn: Any) -> bool:
62107
origin = typing.get_origin(tp)
63108
args = typing.get_args(tp)
64109
if origin is type:
65-
atype = args[0]
66-
if isinstance(atype, TypeVar):
67-
atype = atype.__bound__
68-
ns = _namespace.module_ns_of(fn)
69-
atype = _typing_eval.resolve_type(atype, globals=ns)
110+
atype = _resolve_to_bound(args[0], fn)
70111

71112
if isinstance(obj, type):
72113
return issubclass(obj, atype)
114+
# NB: This is to handle the case where obj is something
115+
# like a qb BaseAlias, where it has some fictitious
116+
# associated type that isn't really its runtime type.
73117
elif (mroent := getattr(obj, "__mro_entries__", None)) is not None:
74118
genalias_mro = mroent((obj,))
75-
return any(issubclass(c, atype) for c in genalias_mro)
119+
return any(_issubclass(c, atype, fn) for c in genalias_mro)
76120
else:
77121
return False
78122

@@ -141,6 +185,10 @@ def _isinstance(obj: Any, tp: Any, fn: Any) -> bool:
141185
else:
142186
# For other generic types, fall back to checking the origin
143187
return isinstance(obj, origin)
188+
189+
elif isinstance(tp, TypeVar):
190+
return _isinstance(obj, _resolve_to_bound(tp, fn), fn)
191+
144192
else:
145193
raise TypeError(f"_isinstance() argument 2 is {tp!r}")
146194

gel/_internal/ruff.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,5 @@ extend-ignore = [
5959
"UP045", # non-pep604-annotation-optional
6060
"PLR5501", # collapsible-else-if
6161
"FURB103", # don't use open and write
62+
"E741", # ambiguous variable names
6263
]

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ extend-ignore = [
222222
"E252", # missing-whitespace-around-parameter-equals
223223
"F541", # f-string-missing-placeholders
224224
"Q000", # prefer double quotes
225+
"E741", # ambiguous variable names
225226
]
226227

227228
[tool.ruff.format]

tests/test_model_generator.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6501,9 +6501,17 @@ def test_modelgen_operators_string_contains_and_patterns(self):
65016501
for user in users_with_char:
65026502
self.assertTrue(search_char in user.name.lower())
65036503

6504-
@tb.xfail
6504+
def test_modelgen_operators_range_eq(self):
6505+
from models.orm import default
6506+
6507+
res = self.client.query(
6508+
default.RangeTest.filter(
6509+
lambda u: u.int_range == u.int_range
6510+
)
6511+
)
6512+
self.assertEqual(len(res), 1)
6513+
65056514
def test_modelgen_operators_range_contains(self):
6506-
"""Test string containment and pattern matching operators"""
65076515
from models.orm import default, std
65086516

65096517
res = self.client.query(

tests/test_qb.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1343,6 +1343,15 @@ def test_qb_poly_07(self):
13431343
self.assertEqual(c.kind, kind)
13441344
self.assertIsInstance(c, default.Chocolate)
13451345

1346+
def test_qb_array_agg_01(self):
1347+
from models.orm import default, std
1348+
1349+
agg = std.array_agg(default.User)
1350+
unpack = std.array_unpack(agg)
1351+
1352+
res = self.client.query(unpack)
1353+
self.assertEqual(len(res), 6)
1354+
13461355

13471356
class TestQueryBuilderModify(tb.ModelTestCase):
13481357
"""This test suite is for data manipulation using QB."""

0 commit comments

Comments
 (0)