@@ -46,21 +46,22 @@ class subtype(abc.ABCMeta):
4646 __args__ : tuple
4747
4848 def __new__ (cls , tp , * args ):
49- if tp is Any :
50- return object
51- if isinstance (tp , cls ): # If already a subtype, return it directly
52- return tp
53- if isinstance (tp , typing .NewType ):
54- return cls (tp .__supertype__ , * args )
49+ match tp :
50+ case typing .Any :
51+ return object
52+ case subtype (): # If already a subtype, return it directly
53+ return tp
54+ case typing .NewType ():
55+ return cls (tp .__supertype__ , * args )
56+ case TypeVar ():
57+ return cls (Union [tp .__constraints__ ], * args ) if tp .__constraints__ else object
58+ case typing ._AnnotatedAlias ():
59+ return cls (tp .__origin__ , * args )
5560 if hasattr (typing , 'TypeAliasType' ) and isinstance (tp , typing .TypeAliasType ):
5661 return cls (tp .__value__ , * args )
57- if isinstance (tp , TypeVar ):
58- return cls (Union [tp .__constraints__ ], * args ) if tp .__constraints__ else object
59- if isinstance (tp , typing ._AnnotatedAlias ):
60- return cls (tp .__origin__ , * args )
6162 origin = get_origin (tp ) or tp
6263 args = tuple (map (cls , get_args (tp ) or args ))
63- if set (args ) <= {object } and not (origin is tuple and args ):
64+ if set (args ) <= {object } and (origin is not tuple or tp is tuple ):
6465 return origin
6566 bases = (origin ,) if type (origin ) in (type , abc .ABCMeta ) else ()
6667 if origin is Literal :
@@ -87,22 +88,27 @@ def __hash__(self) -> int:
8788 return hash (self .key ())
8889
8990 def __subclasscheck__ (self , subclass ):
90- origin = get_origin (subclass ) or subclass
9191 args = get_args (subclass )
92- if origin is Literal :
93- return all (isinstance (arg , self ) for arg in args )
94- if origin in (Union , types .UnionType ):
95- return all (issubclass (cls , self ) for cls in args )
96- if self .__origin__ is Literal :
97- return False
98- if self .__origin__ is types .UnionType :
99- return issubclass (subclass , self .__args__ )
100- if self .__origin__ is Callable :
101- return (
102- origin is Callable
103- and signature (self .__args__ [- 1 :]) <= signature (args [- 1 :]) # covariant return
104- and signature (args [:- 1 ]) <= signature (self .__args__ [:- 1 ]) # contravariant args
105- )
92+ match origin := get_origin (subclass ) or subclass :
93+ case typing .Literal :
94+ return all (isinstance (arg , self ) for arg in args )
95+ case typing .Union | types .UnionType :
96+ return all (issubclass (cls , self ) for cls in args )
97+ match self .__origin__ :
98+ case typing .Literal :
99+ return False
100+ case types .UnionType :
101+ return issubclass (subclass , self .__args__ )
102+ case builtins .tuple :
103+ if issubclass (origin , tuple ) and ... in self .__args__ :
104+ param = self .__args__ [0 ]
105+ return all (arg is ... or issubclass (arg , param ) for arg in args )
106+ case collections .abc .Callable :
107+ return (
108+ origin is Callable
109+ and signature (self .__args__ [- 1 :]) <= signature (args [- 1 :]) # covariant return
110+ and signature (args [:- 1 ]) <= signature (self .__args__ [:- 1 ]) # contravariant args
111+ )
106112 return ( # check args first to avoid recursion error: python/cpython#73407
107113 len (args ) == len (self .__args__ )
108114 and issubclass (origin , self .__origin__ )
0 commit comments