Skip to content

Commit f19b3c3

Browse files
committed
Add cast functions.
1 parent bbe53d2 commit f19b3c3

File tree

3 files changed

+132
-8
lines changed

3 files changed

+132
-8
lines changed

gel/_internal/_codegen/_models/_pydantic.py

Lines changed: 88 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2356,37 +2356,39 @@ def write_generic_types(
23562356
t_anypt = self.declare_typevar("_T_anypoint", bound=anypoint)
23572357
self.write(f'_Tt = {typevartup}("_Tt")')
23582358

2359+
# Order the bases with more specific types first so that
2360+
# __gel_reflection__ is resolved correctly.
23592361
generics = {
23602362
SchemaPath("std", "anytype"): [
23612363
geltype,
23622364
],
23632365
SchemaPath("std", "anyobject"): [
2364-
"anytype",
23652366
gelmodel,
2367+
"anytype",
23662368
],
23672369
SchemaPath("std", "anytuple"): [
2368-
"anytype",
23692370
anytuple,
2371+
"anytype",
23702372
],
23712373
SchemaPath("std", "anynamedtuple"): [
2372-
"anytuple",
23732374
anynamedtuple,
2375+
"anytuple",
23742376
],
23752377
SchemaPath("std", "tuple"): [
2376-
"anytuple",
23772378
f"{tup}[{unpack}[_Tt]]",
2379+
"anytuple",
23782380
],
23792381
SchemaPath("std", "array"): [
2380-
"anytype",
23812382
f"{arr}[{t_anytype}]",
2383+
"anytype",
23822384
],
23832385
SchemaPath("std", "range"): [
2384-
"anytype",
23852386
f"{rang}[{t_anypt}]",
2387+
"anytype",
23862388
],
23872389
SchemaPath("std", "multirange"): [
2388-
"anytype",
23892390
f"{mrang}[{t_anypt}]",
2391+
"anytype",
23902392
],
23912393
}
23922394

@@ -2441,7 +2443,7 @@ def write_generic_types(
24412443
with self._class_def(tmeta, meta_bases):
24422444
un_ops = self._write_prefix_operator_methods(ptype)
24432445
bin_ops = self._write_infix_operator_methods(ptype)
2444-
if gt.name == "anytype":
2446+
if gt.name in {"anytype", "anytuple"}:
24452447
self.write(f"__hash__ = {type_}.__hash__")
24462448
elif not un_ops and not bin_ops:
24472449
self.write("pass")
@@ -2536,7 +2538,36 @@ def _write_enum_scalar_type(
25362538
self.write_description(stype)
25372539
for value in stype.enum_values:
25382540
self.write(f"{ident(value)} = {value!r}")
2541+
2542+
self.write()
25392543
self.write_type_reflection(stype)
2544+
2545+
# cast method
2546+
expr_compat = self.import_name(BASE_IMPL, "ExprCompatible")
2547+
2548+
type_ = self.import_name("builtins", "type")
2549+
self_ = self.import_name("typing_extensions", "Self")
2550+
type_self = f"{type_}[{self_}]"
2551+
2552+
aexpr = self.import_name(BASE_IMPL, "AnnotatedExpr")
2553+
cast_op = self.import_name(BASE_IMPL, "CastOp")
2554+
2555+
self.write()
2556+
with self._classmethod_def(
2557+
"cast",
2558+
[f"expr: {expr_compat}"],
2559+
type_self,
2560+
):
2561+
self.write(f"return {aexpr}( # type: ignore [return-value]")
2562+
with self.indented():
2563+
self.write("cls,")
2564+
self.write(f"{cast_op}(")
2565+
with self.indented():
2566+
self.write("expr=expr,")
2567+
self.write("type_=cls.__gel_reflection__.name,")
2568+
self.write(")")
2569+
self.write(")")
2570+
25402571
self.write_section_break()
25412572

25422573
def _write_scalar_type(
@@ -2627,6 +2658,8 @@ def _write_regular_scalar_type(
26272658

26282659
self.export("anyenum")
26292660

2661+
is_generic = type_name in GENERIC_TYPES
2662+
26302663
if not runtime_parents:
26312664
typecheck_parents = [self.get_type(self._types_by_name["anytype"])]
26322665
runtime_parents = typecheck_parents
@@ -2728,6 +2761,22 @@ def _write_regular_scalar_type(
27282761
):
27292762
self.write("...")
27302763

2764+
# cast method
2765+
if not is_generic:
2766+
expr_compat = self.import_name(BASE_IMPL, "ExprCompatible")
2767+
2768+
type_ = self.import_name("builtins", "type")
2769+
self_ = self.import_name("typing_extensions", "Self")
2770+
type_self = f"{type_}[{self_}]"
2771+
2772+
self.write()
2773+
with self._classmethod_def(
2774+
"cast",
2775+
[f"expr: {expr_compat}"],
2776+
type_self,
2777+
):
2778+
self.write("...")
2779+
27312780
self.write()
27322781

27332782
with self.not_type_checking():
@@ -2742,6 +2791,37 @@ def _write_regular_scalar_type(
27422791
self.write()
27432792
self.write_type_reflection(stype)
27442793

2794+
# cast method
2795+
if not is_generic:
2796+
expr_compat = self.import_name(BASE_IMPL, "ExprCompatible")
2797+
2798+
type_ = self.import_name("builtins", "type")
2799+
self_ = self.import_name("typing_extensions", "Self")
2800+
type_self = f"{type_}[{self_}]"
2801+
2802+
aexpr = self.import_name(BASE_IMPL, "AnnotatedExpr")
2803+
cast_op = self.import_name(BASE_IMPL, "CastOp")
2804+
2805+
self.write()
2806+
with self._classmethod_def(
2807+
"cast",
2808+
[f"expr: {expr_compat}"],
2809+
type_self,
2810+
):
2811+
self.write(
2812+
f"return {aexpr}( # type: ignore [return-value]"
2813+
)
2814+
with self.indented():
2815+
self.write("cls,")
2816+
self.write(f"{cast_op}(")
2817+
with self.indented():
2818+
self.write("expr=expr,")
2819+
self.write(
2820+
"type_=cls.__gel_reflection__.name,"
2821+
)
2822+
self.write(")")
2823+
self.write(")")
2824+
27452825
self.write_section_break()
27462826

27472827
def render_callable_return_type(

gel/_internal/_qbmodel/_abstract/_primitive.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,22 @@ def _reconstruct_from_pickle(
269269
def __gel_get_py_type__(cls) -> type:
270270
return list
271271

272+
def __edgeql_literal__(self) -> _qb.Literal:
273+
return _qb.Literal(
274+
type_=type(self).__gel_reflection__.name,
275+
val=self,
276+
)
277+
278+
@classmethod
279+
def cast(cls, expr: _qb.ExprCompatible) -> type[Array[_T]]:
280+
return _qb.AnnotatedExpr( # type: ignore [return-value]
281+
cls,
282+
_qb.CastOp(
283+
expr=expr,
284+
type_=cls.__gel_reflection__.name,
285+
),
286+
)
287+
272288

273289
_Ts = TypeVarTuple("_Ts")
274290

@@ -335,6 +351,22 @@ class __gel_reflection__(GelPrimitiveType.__gel_reflection__): # noqa: N801
335351
def __gel_get_py_type__(cls) -> type:
336352
return tuple
337353

354+
def __edgeql_literal__(self) -> _qb.Literal:
355+
return _qb.Literal(
356+
type_=type(self).__gel_reflection__.name,
357+
val=self,
358+
)
359+
360+
@classmethod
361+
def cast(cls, expr: _qb.ExprCompatible) -> type[Tuple[Unpack[_Ts]]]:
362+
return _qb.AnnotatedExpr( # type: ignore [return-value]
363+
cls,
364+
_qb.CastOp(
365+
expr=expr,
366+
type_=cls.__gel_reflection__.name,
367+
),
368+
)
369+
338370

339371
if TYPE_CHECKING:
340372

@@ -374,6 +406,16 @@ class __gel_reflection__(GelPrimitiveType.__gel_reflection__): # noqa: N801
374406

375407
return __gel_reflection__
376408

409+
@classmethod
410+
def cast(cls, expr: _qb.ExprCompatible) -> type[Range[_T]]:
411+
return _qb.AnnotatedExpr( # type: ignore [return-value]
412+
cls,
413+
_qb.CastOp(
414+
expr=expr,
415+
type_=cls.__gel_reflection__.name,
416+
),
417+
)
418+
377419

378420
if TYPE_CHECKING:
379421

gel/models/pydantic.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
OPERAND_IS_ALIAS,
2323
AnnotatedExpr,
2424
BaseAlias,
25+
CastOp,
2526
EmptyDirection,
2627
Direction,
2728
GelLinkMetadata,
@@ -116,6 +117,7 @@
116117
"ArrayMeta",
117118
"BaseAlias",
118119
"Cardinality",
120+
"CastOp",
119121
"ComputedLink",
120122
"ComputedLinkWithProps",
121123
"ComputedMultiLink",

0 commit comments

Comments
 (0)