Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions crates/ty_python_semantic/resources/mdtest/typed_dict.md
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,82 @@ def _(ATD: type[MyTD]):
z = ATD(a="foo")
```

Constructor validation should also work when the call target is a union or intersection of
`type[...]` values:

```py
from typing import Union
from ty_extensions import Intersection

class CtorRequired(TypedDict):
a: int

class CtorOptional(TypedDict, total=False):
a: int

def _(ATD: Union[type[CtorRequired], type[CtorOptional]]):
ok = ATD(a=1)

# Both union variants reject the `str` argument for `a`.
# error: [invalid-argument-type]
# error: [invalid-argument-type]
bad = ATD(a="foo")

# 0-arg construction: valid for `CtorOptional` (all fields optional),
# but `CtorRequired` requires `a`.
# error: [no-matching-overload]
no_args = ATD()

# Dict-literal construction through a union.
ok_dict = ATD({"a": 1})

# error: [invalid-argument-type]
# error: [invalid-argument-type]
bad_dict = ATD({"a": "foo"})

def _(ATD: Intersection[type[CtorRequired], type[CtorRequired]]):
ok = ATD(a=1)

# error: [invalid-argument-type]
bad = ATD(a="foo")

def _(ATD: Intersection[type[CtorRequired], type[CtorOptional]]):
ok = ATD(a=1)

# Both intersection members check the argument independently.
# error: [invalid-argument-type]
# error: [invalid-argument-type]
bad = ATD(a="foo")

# Dict-literal construction through an intersection.
ok_dict = ATD({"a": 1})
```

TypedDict constructors also support the `dict(mapping, **kwargs)`-style merge form. Keyword
arguments should override the positional mapping when validating the final shape:

```py
class BaseKwargs(TypedDict, total=False):
name: str

class ChildKwargs(BaseKwargs, total=False):
count: int

class OverrideCountKwargs(TypedDict, total=False):
count: str

def _(base: BaseKwargs, override: OverrideCountKwargs):
ok = ChildKwargs(base, count=1)
overridden = ChildKwargs(override, count=1)
overridden_literal = ChildKwargs({"count": "wrong"}, count=1)

# error: [invalid-argument-type]
bad_value = ChildKwargs({"name": 1}, count=1)

# error: [invalid-argument-type]
bad_mapping = ChildKwargs(1, count=1)
```

All of these have an invalid type for the `name` field:

```py
Expand Down Expand Up @@ -1915,6 +1991,27 @@ def _(node: Node, person: Person):
_: Node = Person(name="Alice", parent=Node(name="Bob", parent=Person(name="Charlie", parent=None)))
```

TypedDict constructor calls should also use field type context when inferring nested recursive
values:

```py
from typing import Any, List, TypedDict, Union
from typing_extensions import NotRequired

class Comparison(TypedDict):
field: str
op: NotRequired[str]
value: Any

class Logical(TypedDict):
op: NotRequired[str]
conditions: List["Filter"]

Filter = Union[Comparison, Logical]

logical = Logical(conditions=[Comparison(field="a", value="b")])
```

## Function/assignment syntax

This is not yet supported. Make sure that we do not emit false positives for this syntax:
Expand Down
60 changes: 28 additions & 32 deletions crates/ty_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2181,7 +2181,12 @@ impl<'db> Type<'db> {
}

Type::GenericAlias(alias) if alias.is_typed_dict(db) => {
Some(alias.origin(db).typed_dict_member(db, None, name, policy))
Some(alias.origin(db).typed_dict_member(
db,
Some(alias.specialization(db)),
name,
policy,
))
}

Type::GenericAlias(alias) => {
Expand Down Expand Up @@ -4220,32 +4225,42 @@ impl<'db> Type<'db> {
})
}

let (class_literal, class_specialization) = class.class_literal_and_specialization(db);
let (class_literal, _) = class.class_literal_and_specialization(db);
let class_generic_context = class_literal.generic_context(db);

let self_type = match self {
Type::ClassLiteral(class) if class.generic_context(db).is_some() => {
Type::from(class.identity_specialization(db))
}
_ => self,
};

let Some(constructor_instance_ty) = self_type.to_instance(db) else {
let return_type = self.to_instance(db).unwrap_or(Type::unknown());
return Binding::single(
self,
Signature::new_generic(
class_generic_context,
Parameters::gradual_form(),
return_type,
),
)
.into();
};

// Keep bespoke constructor behavior for cases that don't map cleanly to `__new__`/`__init__`.
let fallback_bindings = || {
let return_type = self.to_instance(db).unwrap_or(Type::unknown());
Binding::single(
self,
Signature::new_generic(
class_generic_context,
Parameters::gradual_form(),
return_type,
constructor_instance_ty,
),
)
.into()
};

// Checking TypedDict construction happens in `infer_call_expression_impl`, so here we just
// return a permissive fallback binding. TODO maybe we should just synthesize bindings for
// a TypedDict constructor? That would handle unions/intersections correctly.
if class_literal.is_typed_dict(db)
|| class::CodeGeneratorKind::TypedDict.matches(db, class_literal, class_specialization)
{
return fallback_bindings();
}

// These cases are checked in `Type::known_class_literal_bindings`, but currently we only
// call that for `ClassLiteral` types, so we need a permissive fallback here. TODO Ideally
// that would be called from `constructor_bindings` for better consistency, but that causes
Expand Down Expand Up @@ -4278,21 +4293,6 @@ impl<'db> Type<'db> {
return fallback_bindings();
}

// If we are trying to construct a non-specialized generic class, we should use the
// constructor parameters to try to infer the class specialization. To do this, we need to
// tweak our member lookup logic a bit. Normally, when looking up a class or instance
// member, we first apply the class's default specialization, and apply that specialization
// to the type of the member. To infer a specialization from the argument types, we need to
// have the class's typevars still in the method signature when we attempt to call it. To
// do this, we instead use the _identity_ specialization, which maps each of the class's
// generic typevars to itself.
let self_type = match self {
Type::ClassLiteral(class) if class.generic_context(db).is_some() => {
Type::from(class.identity_specialization(db))
}
_ => self,
};

// As of now we do not model custom `__call__` on meta-classes, so the code below
// only deals with interplay between `__new__` and `__init__` methods.
// The logic is roughly as follows:
Expand Down Expand Up @@ -4320,10 +4320,6 @@ impl<'db> Type<'db> {
// constructor-call bindings.
let new_method = self_type.lookup_dunder_new(db);

let Some(constructor_instance_ty) = self_type.to_instance(db) else {
return fallback_bindings();
};

// Construct an instance type to look up `__init__`. We use `self_type` (possibly identity-
// specialized) so the instance retains inferable class typevars during constructor checking.
// TODO: we should use the actual return type of `__new__` to determine the instance type
Expand Down
77 changes: 71 additions & 6 deletions crates/ty_python_semantic/src/types/class/static_literal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ use crate::{
use_def_map,
},
types::{
ApplyTypeMappingVisitor, BoundTypeVarInstance, CallArguments, CallableType, ClassBase,
ClassLiteral, ClassType, DATACLASS_FLAGS, DataclassFlags, DataclassParams, GenericAlias,
GenericContext, KnownClass, KnownInstanceType, MaterializationKind, MemberLookupPolicy,
MetaclassCandidate, MetaclassTransformInfo, Parameter, Parameters, PropertyInstanceType,
Signature, SpecialFormType, StaticMroError, SubclassOfType, Truthiness, Type, TypeContext,
TypeMapping, TypeVarVariance, UnionBuilder, UnionType,
ApplyTypeMappingVisitor, BindingContext, BoundTypeVarInstance, CallArguments, CallableType,
ClassBase, ClassLiteral, ClassType, DATACLASS_FLAGS, DataclassFlags, DataclassParams,
GenericAlias, GenericContext, KnownClass, KnownInstanceType, MaterializationKind,
MemberLookupPolicy, MetaclassCandidate, MetaclassTransformInfo, Parameter, Parameters,
PropertyInstanceType, Signature, SpecialFormType, StaticMroError, SubclassOfType,
Truthiness, Type, TypeContext, TypeMapping, TypeVarVariance, UnionBuilder, UnionType,
call::{CallError, CallErrorKind},
callable::CallableTypeKind,
class::{
Expand Down Expand Up @@ -1543,6 +1543,71 @@ impl<'db> StaticClassLiteral<'db> {
Type::heterogeneous_tuple(db, slots)
})
}
(CodeGeneratorKind::TypedDict, "__new__") => {
let inherited_generic_context =
inherited_generic_context.or_else(|| self.inherited_generic_context(db));

let self_typevar = BoundTypeVarInstance::synthetic_self(
db,
instance_ty,
BindingContext::Synthetic,
);
let self_ty = Type::TypeVar(self_typevar);
let generic_context = GenericContext::from_typevar_instances(
db,
inherited_generic_context
.iter()
.flat_map(|ctx| ctx.variables(db))
.chain(std::iter::once(self_typevar)),
);

let make_cls_parameter = || {
Parameter::positional_or_keyword(Name::new_static("cls"))
.with_annotated_type(SubclassOfType::from(db, self_typevar))
};

let fields = self.fields(db, specialization, field_policy);

let keyword_signature = Signature::new_generic(
Some(generic_context),
Parameters::new(
db,
std::iter::once(make_cls_parameter()).chain(fields.iter().map(
|(field_name, field)| {
let parameter = Parameter::keyword_only(field_name.clone())
.with_annotated_type(field.declared_ty);
if field.is_required() {
parameter
} else {
parameter.with_default_type(field.declared_ty)
}
},
)),
),
self_ty,
);

let positional_signature = Signature::new_generic(
Some(generic_context),
Parameters::new(
db,
[
make_cls_parameter(),
Parameter::positional_only(Some(Name::new_static("mapping")))
.with_annotated_type(Type::typed_dict(
self.apply_optional_specialization(db, specialization),
)),
],
),
self_ty,
);

Some(Type::Callable(CallableType::new(
db,
CallableSignature::from_overloads([keyword_signature, positional_signature]),
CallableTypeKind::FunctionLike,
)))
}
(CodeGeneratorKind::TypedDict, "__setitem__") => {
let fields = self.fields(db, specialization, field_policy);

Expand Down
35 changes: 31 additions & 4 deletions crates/ty_python_semantic/src/types/infer/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ use crate::types::set_theoretic::RecursivelyDefined;
use crate::types::subclass_of::SubclassOfInner;
use crate::types::tuple::{Tuple, TupleLength, TupleSpecBuilder, TupleType};
use crate::types::type_alias::{ManualPEP695TypeAliasType, PEP695TypeAliasType};
use crate::types::typed_dict::{validate_typed_dict_constructor, validate_typed_dict_dict_literal};
use crate::types::typed_dict::{
TypedDictConstructorCallKind, typed_dict_constructor_call_kind,
validate_typed_dict_constructor, validate_typed_dict_dict_literal,
};
use crate::types::typevar::{BoundTypeVarIdentity, TypeVarConstraints, TypeVarIdentity};
use crate::types::{
CallDunderError, CallableBinding, CallableType, ClassType, DynamicType, EvaluationMode,
Expand Down Expand Up @@ -7057,6 +7060,19 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
.bindings(self.db())
.match_parameters(self.db(), &call_arguments);

let typed_dict_constructor = class.and_then(|class| {
class
.class_literal(self.db())
.is_typed_dict(self.db())
.then_some(TypedDictType::new(class))
});

let typed_dict_constructor_call_kind = typed_dict_constructor
.map(|_| typed_dict_constructor_call_kind(arguments))
.unwrap_or(TypedDictConstructorCallKind::Unsupported);
let typed_dict_constructor_shape_supported =
typed_dict_constructor_call_kind != TypedDictConstructorCallKind::Unsupported;

report_missing_implicit_constructor_call(
&self.context,
self.db(),
Expand All @@ -7074,12 +7090,17 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
);

// Validate `TypedDict` constructor calls after argument type inference.
if let Some(class) = class
&& class.class_literal(self.db()).is_typed_dict(self.db())
//
// Dict-literal positional args (e.g., `TD({"a": 1})`) are excluded here because the
// synthesized `__new__` mapping overload already handles them via normal callable checking.
if let Some(typed_dict) = typed_dict_constructor
&& typed_dict_constructor_shape_supported
&& typed_dict_constructor_call_kind
!= TypedDictConstructorCallKind::PositionalDictLiteralOnly
{
validate_typed_dict_constructor(
&self.context,
TypedDictType::new(class),
typed_dict,
arguments,
func.as_ref().into(),
|expr| self.expression_type(expr),
Expand All @@ -7088,6 +7109,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {

let mut bindings = match bindings_result {
Ok(()) => bindings,
// For TypedDict constructors with supported call shapes (keyword-only or single
// positional mapping), suppress binding errors from the synthesized `__new__` — the
// TypedDict-specific validator above produces more precise diagnostics.
Err(CallErrorKind::BindingError) if typed_dict_constructor_shape_supported => {
return bindings.return_type(self.db());
}
Err(_) => {
bindings.report_diagnostics(&self.context, call_expression.into());
return bindings.return_type(self.db());
Expand Down
Loading
Loading