Skip to content

Commit f9329cb

Browse files
committed
Support TD(mapping, **kwargs)
1 parent 09b7f4b commit f9329cb

File tree

3 files changed

+188
-45
lines changed

3 files changed

+188
-45
lines changed

crates/ty_python_semantic/resources/mdtest/typed_dict.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,31 @@ def _(ATD: Intersection[type[CtorRequired], type[CtorOptional]]):
419419
ok_dict = ATD({"a": 1})
420420
```
421421

422+
TypedDict constructors also support the `dict(mapping, **kwargs)`-style merge form. Keyword
423+
arguments should override the positional mapping when validating the final shape:
424+
425+
```py
426+
class BaseKwargs(TypedDict, total=False):
427+
name: str
428+
429+
class ChildKwargs(BaseKwargs, total=False):
430+
count: int
431+
432+
class OverrideCountKwargs(TypedDict, total=False):
433+
count: str
434+
435+
def _(base: BaseKwargs, override: OverrideCountKwargs):
436+
ok = ChildKwargs(base, count=1)
437+
overridden = ChildKwargs(override, count=1)
438+
overridden_literal = ChildKwargs({"count": "wrong"}, count=1)
439+
440+
# error: [invalid-argument-type]
441+
bad_value = ChildKwargs({"name": 1}, count=1)
442+
443+
# error: [invalid-argument-type]
444+
bad_mapping = ChildKwargs(1, count=1)
445+
```
446+
422447
All of these have an invalid type for the `name` field:
423448

424449
```py

crates/ty_python_semantic/src/types/infer/builder.rs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,10 @@ use crate::types::set_theoretic::RecursivelyDefined;
9696
use crate::types::subclass_of::SubclassOfInner;
9797
use crate::types::tuple::{Tuple, TupleLength, TupleSpecBuilder, TupleType};
9898
use crate::types::type_alias::{ManualPEP695TypeAliasType, PEP695TypeAliasType};
99-
use crate::types::typed_dict::{validate_typed_dict_constructor, validate_typed_dict_dict_literal};
99+
use crate::types::typed_dict::{
100+
TypedDictConstructorCallKind, typed_dict_constructor_call_kind,
101+
validate_typed_dict_constructor, validate_typed_dict_dict_literal,
102+
};
100103
use crate::types::typevar::{BoundTypeVarIdentity, TypeVarConstraints, TypeVarIdentity};
101104
use crate::types::{
102105
CallDunderError, CallableBinding, CallableType, ClassType, DynamicType, EvaluationMode,
@@ -7064,12 +7067,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
70647067
.then_some(TypedDictType::new(class))
70657068
});
70667069

7067-
let typed_dict_constructor_shape_supported = typed_dict_constructor.is_some()
7068-
&& (arguments.args.is_empty()
7069-
|| (arguments.args.len() == 1 && arguments.keywords.is_empty()));
7070-
let has_positional_dict_literal = arguments.args.len() == 1
7071-
&& arguments.keywords.is_empty()
7072-
&& arguments.args[0].is_dict_expr();
7070+
let typed_dict_constructor_call_kind = typed_dict_constructor
7071+
.map(|_| typed_dict_constructor_call_kind(arguments))
7072+
.unwrap_or(TypedDictConstructorCallKind::Unsupported);
7073+
let typed_dict_constructor_shape_supported =
7074+
typed_dict_constructor_call_kind != TypedDictConstructorCallKind::Unsupported;
70737075

70747076
report_missing_implicit_constructor_call(
70757077
&self.context,
@@ -7092,8 +7094,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
70927094
// Dict-literal positional args (e.g., `TD({"a": 1})`) are excluded here because the
70937095
// synthesized `__new__` mapping overload already handles them via normal callable checking.
70947096
if let Some(typed_dict) = typed_dict_constructor
7095-
&& !has_positional_dict_literal
70967097
&& typed_dict_constructor_shape_supported
7098+
&& typed_dict_constructor_call_kind
7099+
!= TypedDictConstructorCallKind::PositionalDictLiteralOnly
70977100
{
70987101
validate_typed_dict_constructor(
70997102
&self.context,

crates/ty_python_semantic/src/types/typed_dict.rs

Lines changed: 152 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -796,59 +796,153 @@ fn extract_typed_dict_keys<'db>(
796796
}
797797
}
798798

799+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
800+
pub(super) enum TypedDictConstructorCallKind {
801+
KeywordsOnly,
802+
PositionalDictLiteralOnly,
803+
PositionalMappingOnly,
804+
PositionalDictLiteralAndKeywords,
805+
PositionalMappingAndKeywords,
806+
Unsupported,
807+
}
808+
809+
pub(super) fn typed_dict_constructor_call_kind(
810+
arguments: &Arguments,
811+
) -> TypedDictConstructorCallKind {
812+
match (arguments.args.len(), arguments.keywords.is_empty()) {
813+
(0, _) => TypedDictConstructorCallKind::KeywordsOnly,
814+
(1, true) if arguments.args[0].is_dict_expr() => {
815+
TypedDictConstructorCallKind::PositionalDictLiteralOnly
816+
}
817+
(1, true) => TypedDictConstructorCallKind::PositionalMappingOnly,
818+
(1, false) if arguments.args[0].is_dict_expr() => {
819+
TypedDictConstructorCallKind::PositionalDictLiteralAndKeywords
820+
}
821+
(1, false) => TypedDictConstructorCallKind::PositionalMappingAndKeywords,
822+
_ => TypedDictConstructorCallKind::Unsupported,
823+
}
824+
}
825+
799826
pub(super) fn validate_typed_dict_constructor<'db, 'ast>(
800827
context: &InferContext<'db, 'ast>,
801828
typed_dict: TypedDictType<'db>,
802829
arguments: &'ast Arguments,
803830
error_node: AnyNodeRef<'ast>,
804831
expression_type_fn: impl Fn(&ast::Expr) -> Type<'db>,
805832
) {
806-
let db = context.db();
807-
808-
// Check for a single positional argument that is a dict literal
809-
let has_positional_dict_literal = arguments.args.len() == 1 && arguments.args[0].is_dict_expr();
810-
811-
// Check for a single positional argument (not a dict literal)
812-
let is_single_positional_arg =
813-
arguments.args.len() == 1 && arguments.keywords.is_empty() && !has_positional_dict_literal;
814-
815-
if has_positional_dict_literal {
816-
let provided_keys = validate_from_dict_literal(
817-
context,
818-
typed_dict,
819-
arguments,
820-
error_node,
821-
&expression_type_fn,
822-
);
823-
validate_typed_dict_required_keys(context, typed_dict, &provided_keys, error_node);
824-
} else if is_single_positional_arg {
825-
// Single positional argument: check if assignable to the target TypedDict.
826-
// This handles TypedDict, intersections, unions, and type aliases correctly.
827-
// Assignability already checks for required keys and type compatibility,
828-
// so we don't need separate validation.
829-
let arg = &arguments.args[0];
830-
let arg_ty = expression_type_fn(arg);
831-
let target_ty = Type::TypedDict(typed_dict);
832-
833-
if !arg_ty.is_assignable_to(db, target_ty) {
834-
if let Some(builder) = context.report_lint(&INVALID_ARGUMENT_TYPE, arg) {
833+
match typed_dict_constructor_call_kind(arguments) {
834+
TypedDictConstructorCallKind::PositionalDictLiteralOnly => {
835+
let provided_keys = validate_from_dict_literal(
836+
context,
837+
typed_dict,
838+
arguments,
839+
error_node,
840+
&expression_type_fn,
841+
None,
842+
);
843+
validate_typed_dict_required_keys(context, typed_dict, &provided_keys, error_node);
844+
}
845+
TypedDictConstructorCallKind::PositionalMappingOnly => {
846+
// Single positional argument: check if assignable to the target TypedDict.
847+
// This handles TypedDict, intersections, unions, and type aliases correctly.
848+
// Assignability already checks for required keys and type compatibility,
849+
// so we don't need separate validation.
850+
let arg = &arguments.args[0];
851+
let arg_ty = expression_type_fn(arg);
852+
let target_ty = Type::TypedDict(typed_dict);
853+
854+
if !arg_ty.is_assignable_to(context.db(), target_ty)
855+
&& let Some(builder) = context.report_lint(&INVALID_ARGUMENT_TYPE, arg)
856+
{
835857
builder.into_diagnostic(format_args!(
836858
"Argument of type `{}` is not assignable to `{}`",
837-
arg_ty.display(db),
838-
target_ty.display(db),
859+
arg_ty.display(context.db()),
860+
target_ty.display(context.db()),
839861
));
840862
}
841863
}
842-
} else {
843-
let provided_keys = validate_from_keywords(
864+
TypedDictConstructorCallKind::PositionalDictLiteralAndKeywords
865+
| TypedDictConstructorCallKind::PositionalMappingAndKeywords => {
866+
let provided_keys = validate_from_mapping_and_keywords(
867+
context,
868+
typed_dict,
869+
arguments,
870+
error_node,
871+
&expression_type_fn,
872+
);
873+
validate_typed_dict_required_keys(context, typed_dict, &provided_keys, error_node);
874+
}
875+
TypedDictConstructorCallKind::KeywordsOnly => {
876+
let provided_keys = validate_from_keywords(
877+
context,
878+
typed_dict,
879+
arguments,
880+
error_node,
881+
&expression_type_fn,
882+
);
883+
validate_typed_dict_required_keys(context, typed_dict, &provided_keys, error_node);
884+
}
885+
TypedDictConstructorCallKind::Unsupported => {}
886+
}
887+
}
888+
889+
fn validate_from_mapping_and_keywords<'db, 'ast>(
890+
context: &InferContext<'db, 'ast>,
891+
typed_dict: TypedDictType<'db>,
892+
arguments: &'ast Arguments,
893+
typed_dict_node: AnyNodeRef<'ast>,
894+
expression_type_fn: &impl Fn(&ast::Expr) -> Type<'db>,
895+
) -> OrderSet<Name> {
896+
let db = context.db();
897+
let mapping_arg = &arguments.args[0];
898+
let mut provided_keys = validate_from_keywords(
899+
context,
900+
typed_dict,
901+
arguments,
902+
typed_dict_node,
903+
expression_type_fn,
904+
);
905+
906+
if mapping_arg.is_dict_expr() {
907+
let mapping_keys = validate_from_dict_literal(
844908
context,
845909
typed_dict,
846910
arguments,
847-
error_node,
848-
&expression_type_fn,
911+
typed_dict_node,
912+
expression_type_fn,
913+
Some(&provided_keys),
849914
);
850-
validate_typed_dict_required_keys(context, typed_dict, &provided_keys, error_node);
915+
provided_keys.extend(mapping_keys);
916+
} else {
917+
let mapping_ty = expression_type_fn(mapping_arg);
918+
let remaining_target_ty = Type::TypedDict(typed_dict_patch_without_keys(
919+
db,
920+
typed_dict,
921+
&provided_keys,
922+
));
923+
924+
if !mapping_ty.is_assignable_to(db, remaining_target_ty)
925+
&& let Some(builder) = context.report_lint(&INVALID_ARGUMENT_TYPE, mapping_arg)
926+
{
927+
builder.into_diagnostic(format_args!(
928+
"Argument of type `{}` is not assignable to `{}`",
929+
mapping_ty.display(db),
930+
remaining_target_ty.display(db),
931+
));
932+
}
933+
934+
if mapping_ty.is_never() || mapping_ty.is_dynamic() {
935+
for (key_name, field) in typed_dict.items(db) {
936+
if field.is_required() {
937+
provided_keys.insert(key_name.clone());
938+
}
939+
}
940+
} else if let Some(mapping_keys) = extract_typed_dict_keys(db, mapping_ty) {
941+
provided_keys.extend(mapping_keys.into_keys());
942+
}
851943
}
944+
945+
provided_keys
852946
}
853947

854948
/// Validates a `TypedDict` constructor call with a single positional dictionary argument
@@ -859,6 +953,7 @@ fn validate_from_dict_literal<'db, 'ast>(
859953
arguments: &'ast Arguments,
860954
typed_dict_node: AnyNodeRef<'ast>,
861955
expression_type_fn: &impl Fn(&ast::Expr) -> Type<'db>,
956+
overridden_keys: Option<&OrderSet<Name>>,
862957
) -> OrderSet<Name> {
863958
let mut provided_keys = OrderSet::new();
864959

@@ -871,7 +966,12 @@ fn validate_from_dict_literal<'db, 'ast>(
871966
}) = key_expr
872967
{
873968
let key = key_value.to_str();
874-
provided_keys.insert(Name::new(key));
969+
let key_name = Name::new(key);
970+
provided_keys.insert(key_name.clone());
971+
972+
if overridden_keys.is_some_and(|keys| keys.contains(&key_name)) {
973+
continue;
974+
}
875975

876976
// Get the already-inferred argument type
877977
let value_ty = expression_type_fn(&dict_item.value);
@@ -895,6 +995,21 @@ fn validate_from_dict_literal<'db, 'ast>(
895995
provided_keys
896996
}
897997

998+
fn typed_dict_patch_without_keys<'db>(
999+
db: &'db dyn Db,
1000+
typed_dict: TypedDictType<'db>,
1001+
excluded_keys: &OrderSet<Name>,
1002+
) -> TypedDictType<'db> {
1003+
let items: TypedDictSchema<'db> = typed_dict
1004+
.items(db)
1005+
.iter()
1006+
.filter(|(name, _)| !excluded_keys.contains(*name))
1007+
.map(|(name, field)| (name.clone(), field.clone().with_required(false)))
1008+
.collect();
1009+
1010+
TypedDictType::from_patch_items(db, items)
1011+
}
1012+
8981013
/// Validates a `TypedDict` constructor call with keywords
8991014
/// e.g. `Person(name="Alice", age=30)` or `Person(**other_typed_dict)`
9001015
fn validate_from_keywords<'db, 'ast>(

0 commit comments

Comments
 (0)