• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import sys
2import os
3import abc
4import contextlib
5import collections
6import collections.abc
7from functools import lru_cache
8import inspect
9import pickle
10import subprocess
11import types
12from unittest import TestCase, main, skipUnless, skipIf
13from test import ann_module, ann_module2, ann_module3
14import typing
15from typing import TypeVar, Optional, Union, Any, AnyStr
16from typing import T, KT, VT  # Not in __all__.
17from typing import Tuple, List, Dict, Iterable, Iterator, Callable
18from typing import Generic, NamedTuple
19from typing import no_type_check
20import typing_extensions
21from typing_extensions import NoReturn, ClassVar, Final, IntVar, Literal, Type, NewType, TypedDict, Self
22from typing_extensions import TypeAlias, ParamSpec, Concatenate, ParamSpecArgs, ParamSpecKwargs, TypeGuard
23from typing_extensions import Awaitable, AsyncIterator, AsyncContextManager, Required, NotRequired
24from typing_extensions import Protocol, runtime, runtime_checkable, Annotated, overload, final, is_typeddict
25from typing_extensions import TypeVarTuple, Unpack, dataclass_transform, reveal_type, Never, assert_never, LiteralString
26from typing_extensions import assert_type, get_type_hints, get_origin, get_args
27
28# Flags used to mark tests that only apply after a specific
29# version of the typing module.
30TYPING_3_8_0 = sys.version_info[:3] >= (3, 8, 0)
31TYPING_3_10_0 = sys.version_info[:3] >= (3, 10, 0)
32TYPING_3_11_0 = sys.version_info[:3] >= (3, 11, 0)
33
34
35class BaseTestCase(TestCase):
36    def assertIsSubclass(self, cls, class_or_tuple, msg=None):
37        if not issubclass(cls, class_or_tuple):
38            message = f'{cls!r} is not a subclass of {repr(class_or_tuple)}'
39            if msg is not None:
40                message += f' : {msg}'
41            raise self.failureException(message)
42
43    def assertNotIsSubclass(self, cls, class_or_tuple, msg=None):
44        if issubclass(cls, class_or_tuple):
45            message = f'{cls!r} is a subclass of {repr(class_or_tuple)}'
46            if msg is not None:
47                message += f' : {msg}'
48            raise self.failureException(message)
49
50
51class Employee:
52    pass
53
54
55class BottomTypeTestsMixin:
56    bottom_type: ClassVar[Any]
57
58    def test_equality(self):
59        self.assertEqual(self.bottom_type, self.bottom_type)
60        self.assertIs(self.bottom_type, self.bottom_type)
61        self.assertNotEqual(self.bottom_type, None)
62
63    def test_get_origin(self):
64        self.assertIs(get_origin(self.bottom_type), None)
65
66    def test_instance_type_error(self):
67        with self.assertRaises(TypeError):
68            isinstance(42, self.bottom_type)
69
70    def test_subclass_type_error(self):
71        with self.assertRaises(TypeError):
72            issubclass(Employee, self.bottom_type)
73        with self.assertRaises(TypeError):
74            issubclass(NoReturn, self.bottom_type)
75
76    def test_not_generic(self):
77        with self.assertRaises(TypeError):
78            self.bottom_type[int]
79
80    def test_cannot_subclass(self):
81        with self.assertRaises(TypeError):
82            class A(self.bottom_type):
83                pass
84        with self.assertRaises(TypeError):
85            class A(type(self.bottom_type)):
86                pass
87
88    def test_cannot_instantiate(self):
89        with self.assertRaises(TypeError):
90            self.bottom_type()
91        with self.assertRaises(TypeError):
92            type(self.bottom_type)()
93
94    def test_pickle(self):
95        for proto in range(pickle.HIGHEST_PROTOCOL):
96            pickled = pickle.dumps(self.bottom_type, protocol=proto)
97            self.assertIs(self.bottom_type, pickle.loads(pickled))
98
99
100class NoReturnTests(BottomTypeTestsMixin, BaseTestCase):
101    bottom_type = NoReturn
102
103    def test_repr(self):
104        if hasattr(typing, 'NoReturn'):
105            self.assertEqual(repr(NoReturn), 'typing.NoReturn')
106        else:
107            self.assertEqual(repr(NoReturn), 'typing_extensions.NoReturn')
108
109    def test_get_type_hints(self):
110        def some(arg: NoReturn) -> NoReturn: ...
111        def some_str(arg: 'NoReturn') -> 'typing.NoReturn': ...
112
113        expected = {'arg': NoReturn, 'return': NoReturn}
114        targets = [some]
115
116        # On 3.7.0 and 3.7.1, https://github.com/python/cpython/pull/10772
117        # wasn't applied yet and NoReturn fails _type_check.
118        if not ((3, 7, 0) <= sys.version_info < (3, 7, 2)):
119            targets.append(some_str)
120        for target in targets:
121            with self.subTest(target=target):
122                self.assertEqual(gth(target), expected)
123
124    def test_not_equality(self):
125        self.assertNotEqual(NoReturn, Never)
126        self.assertNotEqual(Never, NoReturn)
127
128
129class NeverTests(BottomTypeTestsMixin, BaseTestCase):
130    bottom_type = Never
131
132    def test_repr(self):
133        if hasattr(typing, 'Never'):
134            self.assertEqual(repr(Never), 'typing.Never')
135        else:
136            self.assertEqual(repr(Never), 'typing_extensions.Never')
137
138    def test_get_type_hints(self):
139        def some(arg: Never) -> Never: ...
140        def some_str(arg: 'Never') -> 'typing_extensions.Never': ...
141
142        expected = {'arg': Never, 'return': Never}
143        for target in [some, some_str]:
144            with self.subTest(target=target):
145                self.assertEqual(gth(target), expected)
146
147
148class AssertNeverTests(BaseTestCase):
149    def test_exception(self):
150        with self.assertRaises(AssertionError):
151            assert_never(None)
152
153
154class ClassVarTests(BaseTestCase):
155
156    def test_basics(self):
157        with self.assertRaises(TypeError):
158            ClassVar[1]
159        with self.assertRaises(TypeError):
160            ClassVar[int, str]
161        with self.assertRaises(TypeError):
162            ClassVar[int][str]
163
164    def test_repr(self):
165        if hasattr(typing, 'ClassVar'):
166            mod_name = 'typing'
167        else:
168            mod_name = 'typing_extensions'
169        self.assertEqual(repr(ClassVar), mod_name + '.ClassVar')
170        cv = ClassVar[int]
171        self.assertEqual(repr(cv), mod_name + '.ClassVar[int]')
172        cv = ClassVar[Employee]
173        self.assertEqual(repr(cv), mod_name + f'.ClassVar[{__name__}.Employee]')
174
175    def test_cannot_subclass(self):
176        with self.assertRaises(TypeError):
177            class C(type(ClassVar)):
178                pass
179        with self.assertRaises(TypeError):
180            class C(type(ClassVar[int])):
181                pass
182
183    def test_cannot_init(self):
184        with self.assertRaises(TypeError):
185            ClassVar()
186        with self.assertRaises(TypeError):
187            type(ClassVar)()
188        with self.assertRaises(TypeError):
189            type(ClassVar[Optional[int]])()
190
191    def test_no_isinstance(self):
192        with self.assertRaises(TypeError):
193            isinstance(1, ClassVar[int])
194        with self.assertRaises(TypeError):
195            issubclass(int, ClassVar)
196
197
198class FinalTests(BaseTestCase):
199
200    def test_basics(self):
201        with self.assertRaises(TypeError):
202            Final[1]
203        with self.assertRaises(TypeError):
204            Final[int, str]
205        with self.assertRaises(TypeError):
206            Final[int][str]
207
208    def test_repr(self):
209        if hasattr(typing, 'Final') and sys.version_info[:2] >= (3, 7):
210            mod_name = 'typing'
211        else:
212            mod_name = 'typing_extensions'
213        self.assertEqual(repr(Final), mod_name + '.Final')
214        cv = Final[int]
215        self.assertEqual(repr(cv), mod_name + '.Final[int]')
216        cv = Final[Employee]
217        self.assertEqual(repr(cv), mod_name + f'.Final[{__name__}.Employee]')
218
219    def test_cannot_subclass(self):
220        with self.assertRaises(TypeError):
221            class C(type(Final)):
222                pass
223        with self.assertRaises(TypeError):
224            class C(type(Final[int])):
225                pass
226
227    def test_cannot_init(self):
228        with self.assertRaises(TypeError):
229            Final()
230        with self.assertRaises(TypeError):
231            type(Final)()
232        with self.assertRaises(TypeError):
233            type(Final[Optional[int]])()
234
235    def test_no_isinstance(self):
236        with self.assertRaises(TypeError):
237            isinstance(1, Final[int])
238        with self.assertRaises(TypeError):
239            issubclass(int, Final)
240
241
242class RequiredTests(BaseTestCase):
243
244    def test_basics(self):
245        with self.assertRaises(TypeError):
246            Required[1]
247        with self.assertRaises(TypeError):
248            Required[int, str]
249        with self.assertRaises(TypeError):
250            Required[int][str]
251
252    def test_repr(self):
253        if hasattr(typing, 'Required'):
254            mod_name = 'typing'
255        else:
256            mod_name = 'typing_extensions'
257        self.assertEqual(repr(Required), mod_name + '.Required')
258        cv = Required[int]
259        self.assertEqual(repr(cv), mod_name + '.Required[int]')
260        cv = Required[Employee]
261        self.assertEqual(repr(cv), mod_name + '.Required[%s.Employee]' % __name__)
262
263    def test_cannot_subclass(self):
264        with self.assertRaises(TypeError):
265            class C(type(Required)):
266                pass
267        with self.assertRaises(TypeError):
268            class C(type(Required[int])):
269                pass
270
271    def test_cannot_init(self):
272        with self.assertRaises(TypeError):
273            Required()
274        with self.assertRaises(TypeError):
275            type(Required)()
276        with self.assertRaises(TypeError):
277            type(Required[Optional[int]])()
278
279    def test_no_isinstance(self):
280        with self.assertRaises(TypeError):
281            isinstance(1, Required[int])
282        with self.assertRaises(TypeError):
283            issubclass(int, Required)
284
285
286class NotRequiredTests(BaseTestCase):
287
288    def test_basics(self):
289        with self.assertRaises(TypeError):
290            NotRequired[1]
291        with self.assertRaises(TypeError):
292            NotRequired[int, str]
293        with self.assertRaises(TypeError):
294            NotRequired[int][str]
295
296    def test_repr(self):
297        if hasattr(typing, 'NotRequired'):
298            mod_name = 'typing'
299        else:
300            mod_name = 'typing_extensions'
301        self.assertEqual(repr(NotRequired), mod_name + '.NotRequired')
302        cv = NotRequired[int]
303        self.assertEqual(repr(cv), mod_name + '.NotRequired[int]')
304        cv = NotRequired[Employee]
305        self.assertEqual(repr(cv), mod_name + '.NotRequired[%s.Employee]' % __name__)
306
307    def test_cannot_subclass(self):
308        with self.assertRaises(TypeError):
309            class C(type(NotRequired)):
310                pass
311        with self.assertRaises(TypeError):
312            class C(type(NotRequired[int])):
313                pass
314
315    def test_cannot_init(self):
316        with self.assertRaises(TypeError):
317            NotRequired()
318        with self.assertRaises(TypeError):
319            type(NotRequired)()
320        with self.assertRaises(TypeError):
321            type(NotRequired[Optional[int]])()
322
323    def test_no_isinstance(self):
324        with self.assertRaises(TypeError):
325            isinstance(1, NotRequired[int])
326        with self.assertRaises(TypeError):
327            issubclass(int, NotRequired)
328
329
330class IntVarTests(BaseTestCase):
331    def test_valid(self):
332        T_ints = IntVar("T_ints")  # noqa
333
334    def test_invalid(self):
335        with self.assertRaises(TypeError):
336            T_ints = IntVar("T_ints", int)
337        with self.assertRaises(TypeError):
338            T_ints = IntVar("T_ints", bound=int)
339        with self.assertRaises(TypeError):
340            T_ints = IntVar("T_ints", covariant=True)  # noqa
341
342
343class LiteralTests(BaseTestCase):
344    def test_basics(self):
345        Literal[1]
346        Literal[1, 2, 3]
347        Literal["x", "y", "z"]
348        Literal[None]
349
350    def test_illegal_parameters_do_not_raise_runtime_errors(self):
351        # Type checkers should reject these types, but we do not
352        # raise errors at runtime to maintain maximum flexibility
353        Literal[int]
354        Literal[Literal[1, 2], Literal[4, 5]]
355        Literal[3j + 2, ..., ()]
356        Literal[b"foo", u"bar"]
357        Literal[{"foo": 3, "bar": 4}]
358        Literal[T]
359
360    def test_literals_inside_other_types(self):
361        List[Literal[1, 2, 3]]
362        List[Literal[("foo", "bar", "baz")]]
363
364    def test_repr(self):
365        if hasattr(typing, 'Literal'):
366            mod_name = 'typing'
367        else:
368            mod_name = 'typing_extensions'
369        self.assertEqual(repr(Literal[1]), mod_name + ".Literal[1]")
370        self.assertEqual(repr(Literal[1, True, "foo"]), mod_name + ".Literal[1, True, 'foo']")
371        self.assertEqual(repr(Literal[int]), mod_name + ".Literal[int]")
372        self.assertEqual(repr(Literal), mod_name + ".Literal")
373        self.assertEqual(repr(Literal[None]), mod_name + ".Literal[None]")
374
375    def test_cannot_init(self):
376        with self.assertRaises(TypeError):
377            Literal()
378        with self.assertRaises(TypeError):
379            Literal[1]()
380        with self.assertRaises(TypeError):
381            type(Literal)()
382        with self.assertRaises(TypeError):
383            type(Literal[1])()
384
385    def test_no_isinstance_or_issubclass(self):
386        with self.assertRaises(TypeError):
387            isinstance(1, Literal[1])
388        with self.assertRaises(TypeError):
389            isinstance(int, Literal[1])
390        with self.assertRaises(TypeError):
391            issubclass(1, Literal[1])
392        with self.assertRaises(TypeError):
393            issubclass(int, Literal[1])
394
395    def test_no_subclassing(self):
396        with self.assertRaises(TypeError):
397            class Foo(Literal[1]): pass
398        with self.assertRaises(TypeError):
399            class Bar(Literal): pass
400
401    def test_no_multiple_subscripts(self):
402        with self.assertRaises(TypeError):
403            Literal[1][1]
404
405
406class OverloadTests(BaseTestCase):
407
408    def test_overload_fails(self):
409        with self.assertRaises(RuntimeError):
410
411            @overload
412            def blah():
413                pass
414
415            blah()
416
417    def test_overload_succeeds(self):
418        @overload
419        def blah():
420            pass
421
422        def blah():
423            pass
424
425        blah()
426
427
428class AssertTypeTests(BaseTestCase):
429
430    def test_basics(self):
431        arg = 42
432        self.assertIs(assert_type(arg, int), arg)
433        self.assertIs(assert_type(arg, Union[str, float]), arg)
434        self.assertIs(assert_type(arg, AnyStr), arg)
435        self.assertIs(assert_type(arg, None), arg)
436
437    def test_errors(self):
438        # Bogus calls are not expected to fail.
439        arg = 42
440        self.assertIs(assert_type(arg, 42), arg)
441        self.assertIs(assert_type(arg, 'hello'), arg)
442
443
444T_a = TypeVar('T_a')
445
446class AwaitableWrapper(Awaitable[T_a]):
447
448    def __init__(self, value):
449        self.value = value
450
451    def __await__(self) -> typing.Iterator[T_a]:
452        yield
453        return self.value
454
455class AsyncIteratorWrapper(AsyncIterator[T_a]):
456
457    def __init__(self, value: Iterable[T_a]):
458        self.value = value
459
460    def __aiter__(self) -> AsyncIterator[T_a]:
461        return self
462
463    async def __anext__(self) -> T_a:
464        data = await self.value
465        if data:
466            return data
467        else:
468            raise StopAsyncIteration
469
470class ACM:
471    async def __aenter__(self) -> int:
472        return 42
473
474    async def __aexit__(self, etype, eval, tb):
475        return None
476
477
478
479class A:
480    y: float
481class B(A):
482    x: ClassVar[Optional['B']] = None
483    y: int
484    b: int
485class CSub(B):
486    z: ClassVar['CSub'] = B()
487class G(Generic[T]):
488    lst: ClassVar[List[T]] = []
489
490class Loop:
491    attr: Final['Loop']
492
493class NoneAndForward:
494    parent: 'NoneAndForward'
495    meaning: None
496
497class XRepr(NamedTuple):
498    x: int
499    y: int = 1
500
501    def __str__(self):
502        return f'{self.x} -> {self.y}'
503
504    def __add__(self, other):
505        return 0
506
507@runtime
508class HasCallProtocol(Protocol):
509    __call__: typing.Callable
510
511
512async def g_with(am: AsyncContextManager[int]):
513    x: int
514    async with am as x:
515        return x
516
517try:
518    g_with(ACM()).send(None)
519except StopIteration as e:
520    assert e.args[0] == 42
521
522Label = TypedDict('Label', [('label', str)])
523
524class Point2D(TypedDict):
525    x: int
526    y: int
527
528class Point2Dor3D(Point2D, total=False):
529    z: int
530
531class LabelPoint2D(Point2D, Label): ...
532
533class Options(TypedDict, total=False):
534    log_level: int
535    log_path: str
536
537class BaseAnimal(TypedDict):
538    name: str
539
540class Animal(BaseAnimal, total=False):
541    voice: str
542    tail: bool
543
544class Cat(Animal):
545    fur_color: str
546
547class TotalMovie(TypedDict):
548    title: str
549    year: NotRequired[int]
550
551class NontotalMovie(TypedDict, total=False):
552    title: Required[str]
553    year: int
554
555class AnnotatedMovie(TypedDict):
556    title: Annotated[Required[str], "foobar"]
557    year: NotRequired[Annotated[int, 2000]]
558
559
560gth = get_type_hints
561
562
563class GetTypeHintTests(BaseTestCase):
564    def test_get_type_hints_modules(self):
565        ann_module_type_hints = {1: 2, 'f': Tuple[int, int], 'x': int, 'y': str}
566        if (TYPING_3_11_0
567                or (TYPING_3_10_0 and sys.version_info.releaselevel in {'candidate', 'final'})):
568            # More tests were added in 3.10rc1.
569            ann_module_type_hints['u'] = int | float
570        self.assertEqual(gth(ann_module), ann_module_type_hints)
571        self.assertEqual(gth(ann_module2), {})
572        self.assertEqual(gth(ann_module3), {})
573
574    def test_get_type_hints_classes(self):
575        self.assertEqual(gth(ann_module.C, ann_module.__dict__),
576                         {'y': Optional[ann_module.C]})
577        self.assertIsInstance(gth(ann_module.j_class), dict)
578        self.assertEqual(gth(ann_module.M), {'123': 123, 'o': type})
579        self.assertEqual(gth(ann_module.D),
580                         {'j': str, 'k': str, 'y': Optional[ann_module.C]})
581        self.assertEqual(gth(ann_module.Y), {'z': int})
582        self.assertEqual(gth(ann_module.h_class),
583                         {'y': Optional[ann_module.C]})
584        self.assertEqual(gth(ann_module.S), {'x': str, 'y': str})
585        self.assertEqual(gth(ann_module.foo), {'x': int})
586        self.assertEqual(gth(NoneAndForward, globals()),
587                         {'parent': NoneAndForward, 'meaning': type(None)})
588
589    def test_respect_no_type_check(self):
590        @no_type_check
591        class NoTpCheck:
592            class Inn:
593                def __init__(self, x: 'not a type'): ...  # noqa
594        self.assertTrue(NoTpCheck.__no_type_check__)
595        self.assertTrue(NoTpCheck.Inn.__init__.__no_type_check__)
596        self.assertEqual(gth(ann_module2.NTC.meth), {})
597        class ABase(Generic[T]):
598            def meth(x: int): ...
599        @no_type_check
600        class Der(ABase): ...
601        self.assertEqual(gth(ABase.meth), {'x': int})
602
603    def test_get_type_hints_ClassVar(self):
604        self.assertEqual(gth(ann_module2.CV, ann_module2.__dict__),
605                         {'var': ClassVar[ann_module2.CV]})
606        self.assertEqual(gth(B, globals()),
607                         {'y': int, 'x': ClassVar[Optional[B]], 'b': int})
608        self.assertEqual(gth(CSub, globals()),
609                         {'z': ClassVar[CSub], 'y': int, 'b': int,
610                          'x': ClassVar[Optional[B]]})
611        self.assertEqual(gth(G), {'lst': ClassVar[List[T]]})
612
613    def test_final_forward_ref(self):
614        self.assertEqual(gth(Loop, globals())['attr'], Final[Loop])
615        self.assertNotEqual(gth(Loop, globals())['attr'], Final[int])
616        self.assertNotEqual(gth(Loop, globals())['attr'], Final)
617
618
619class GetUtilitiesTestCase(TestCase):
620    def test_get_origin(self):
621        T = TypeVar('T')
622        P = ParamSpec('P')
623        Ts = TypeVarTuple('Ts')
624        class C(Generic[T]): pass
625        self.assertIs(get_origin(C[int]), C)
626        self.assertIs(get_origin(C[T]), C)
627        self.assertIs(get_origin(int), None)
628        self.assertIs(get_origin(ClassVar[int]), ClassVar)
629        self.assertIs(get_origin(Union[int, str]), Union)
630        self.assertIs(get_origin(Literal[42, 43]), Literal)
631        self.assertIs(get_origin(Final[List[int]]), Final)
632        self.assertIs(get_origin(Generic), Generic)
633        self.assertIs(get_origin(Generic[T]), Generic)
634        self.assertIs(get_origin(List[Tuple[T, T]][int]), list)
635        self.assertIs(get_origin(Annotated[T, 'thing']), Annotated)
636        self.assertIs(get_origin(List), list)
637        self.assertIs(get_origin(Tuple), tuple)
638        self.assertIs(get_origin(Callable), collections.abc.Callable)
639        if sys.version_info >= (3, 9):
640            self.assertIs(get_origin(list[int]), list)
641        self.assertIs(get_origin(list), None)
642        self.assertIs(get_origin(P.args), P)
643        self.assertIs(get_origin(P.kwargs), P)
644        self.assertIs(get_origin(Required[int]), Required)
645        self.assertIs(get_origin(NotRequired[int]), NotRequired)
646        self.assertIs(get_origin(Unpack[Ts]), Unpack)
647        self.assertIs(get_origin(Unpack), None)
648
649    def test_get_args(self):
650        T = TypeVar('T')
651        Ts = TypeVarTuple('Ts')
652        class C(Generic[T]): pass
653        self.assertEqual(get_args(C[int]), (int,))
654        self.assertEqual(get_args(C[T]), (T,))
655        self.assertEqual(get_args(int), ())
656        self.assertEqual(get_args(ClassVar[int]), (int,))
657        self.assertEqual(get_args(Union[int, str]), (int, str))
658        self.assertEqual(get_args(Literal[42, 43]), (42, 43))
659        self.assertEqual(get_args(Final[List[int]]), (List[int],))
660        self.assertEqual(get_args(Union[int, Tuple[T, int]][str]),
661                         (int, Tuple[str, int]))
662        self.assertEqual(get_args(typing.Dict[int, Tuple[T, T]][Optional[int]]),
663                         (int, Tuple[Optional[int], Optional[int]]))
664        self.assertEqual(get_args(Callable[[], T][int]), ([], int))
665        self.assertEqual(get_args(Callable[..., int]), (..., int))
666        self.assertEqual(get_args(Union[int, Callable[[Tuple[T, ...]], str]]),
667                         (int, Callable[[Tuple[T, ...]], str]))
668        self.assertEqual(get_args(Tuple[int, ...]), (int, ...))
669        self.assertEqual(get_args(Tuple[()]), ((),))
670        self.assertEqual(get_args(Annotated[T, 'one', 2, ['three']]), (T, 'one', 2, ['three']))
671        self.assertEqual(get_args(List), ())
672        self.assertEqual(get_args(Tuple), ())
673        self.assertEqual(get_args(Callable), ())
674        if sys.version_info >= (3, 9):
675            self.assertEqual(get_args(list[int]), (int,))
676        self.assertEqual(get_args(list), ())
677        if sys.version_info >= (3, 9):
678            # Support Python versions with and without the fix for
679            # https://bugs.python.org/issue42195
680            # The first variant is for 3.9.2+, the second for 3.9.0 and 1
681            self.assertIn(get_args(collections.abc.Callable[[int], str]),
682                          (([int], str), ([[int]], str)))
683            self.assertIn(get_args(collections.abc.Callable[[], str]),
684                          (([], str), ([[]], str)))
685            self.assertEqual(get_args(collections.abc.Callable[..., str]), (..., str))
686        P = ParamSpec('P')
687        # In 3.9 and lower we use typing_extensions's hacky implementation
688        # of ParamSpec, which gets incorrectly wrapped in a list
689        self.assertIn(get_args(Callable[P, int]), [(P, int), ([P], int)])
690        self.assertEqual(get_args(Callable[Concatenate[int, P], int]),
691                         (Concatenate[int, P], int))
692        self.assertEqual(get_args(Required[int]), (int,))
693        self.assertEqual(get_args(NotRequired[int]), (int,))
694        self.assertEqual(get_args(Unpack[Ts]), (Ts,))
695        self.assertEqual(get_args(Unpack), ())
696
697
698class CollectionsAbcTests(BaseTestCase):
699
700    def test_isinstance_collections(self):
701        self.assertNotIsInstance(1, collections.abc.Mapping)
702        self.assertNotIsInstance(1, collections.abc.Iterable)
703        self.assertNotIsInstance(1, collections.abc.Container)
704        self.assertNotIsInstance(1, collections.abc.Sized)
705        with self.assertRaises(TypeError):
706            isinstance(collections.deque(), typing_extensions.Deque[int])
707        with self.assertRaises(TypeError):
708            issubclass(collections.Counter, typing_extensions.Counter[str])
709
710    def test_awaitable(self):
711        ns = {}
712        exec(
713            "async def foo() -> typing_extensions.Awaitable[int]:\n"
714            "    return await AwaitableWrapper(42)\n",
715            globals(), ns)
716        foo = ns['foo']
717        g = foo()
718        self.assertIsInstance(g, typing_extensions.Awaitable)
719        self.assertNotIsInstance(foo, typing_extensions.Awaitable)
720        g.send(None)  # Run foo() till completion, to avoid warning.
721
722    def test_coroutine(self):
723        ns = {}
724        exec(
725            "async def foo():\n"
726            "    return\n",
727            globals(), ns)
728        foo = ns['foo']
729        g = foo()
730        self.assertIsInstance(g, typing_extensions.Coroutine)
731        with self.assertRaises(TypeError):
732            isinstance(g, typing_extensions.Coroutine[int])
733        self.assertNotIsInstance(foo, typing_extensions.Coroutine)
734        try:
735            g.send(None)
736        except StopIteration:
737            pass
738
739    def test_async_iterable(self):
740        base_it = range(10)  # type: Iterator[int]
741        it = AsyncIteratorWrapper(base_it)
742        self.assertIsInstance(it, typing_extensions.AsyncIterable)
743        self.assertIsInstance(it, typing_extensions.AsyncIterable)
744        self.assertNotIsInstance(42, typing_extensions.AsyncIterable)
745
746    def test_async_iterator(self):
747        base_it = range(10)  # type: Iterator[int]
748        it = AsyncIteratorWrapper(base_it)
749        self.assertIsInstance(it, typing_extensions.AsyncIterator)
750        self.assertNotIsInstance(42, typing_extensions.AsyncIterator)
751
752    def test_deque(self):
753        self.assertIsSubclass(collections.deque, typing_extensions.Deque)
754        class MyDeque(typing_extensions.Deque[int]): ...
755        self.assertIsInstance(MyDeque(), collections.deque)
756
757    def test_counter(self):
758        self.assertIsSubclass(collections.Counter, typing_extensions.Counter)
759
760    def test_defaultdict_instantiation(self):
761        self.assertIs(
762            type(typing_extensions.DefaultDict()),
763            collections.defaultdict)
764        self.assertIs(
765            type(typing_extensions.DefaultDict[KT, VT]()),
766            collections.defaultdict)
767        self.assertIs(
768            type(typing_extensions.DefaultDict[str, int]()),
769            collections.defaultdict)
770
771    def test_defaultdict_subclass(self):
772
773        class MyDefDict(typing_extensions.DefaultDict[str, int]):
774            pass
775
776        dd = MyDefDict()
777        self.assertIsInstance(dd, MyDefDict)
778
779        self.assertIsSubclass(MyDefDict, collections.defaultdict)
780        self.assertNotIsSubclass(collections.defaultdict, MyDefDict)
781
782    def test_ordereddict_instantiation(self):
783        self.assertIs(
784            type(typing_extensions.OrderedDict()),
785            collections.OrderedDict)
786        self.assertIs(
787            type(typing_extensions.OrderedDict[KT, VT]()),
788            collections.OrderedDict)
789        self.assertIs(
790            type(typing_extensions.OrderedDict[str, int]()),
791            collections.OrderedDict)
792
793    def test_ordereddict_subclass(self):
794
795        class MyOrdDict(typing_extensions.OrderedDict[str, int]):
796            pass
797
798        od = MyOrdDict()
799        self.assertIsInstance(od, MyOrdDict)
800
801        self.assertIsSubclass(MyOrdDict, collections.OrderedDict)
802        self.assertNotIsSubclass(collections.OrderedDict, MyOrdDict)
803
804    def test_chainmap_instantiation(self):
805        self.assertIs(type(typing_extensions.ChainMap()), collections.ChainMap)
806        self.assertIs(type(typing_extensions.ChainMap[KT, VT]()), collections.ChainMap)
807        self.assertIs(type(typing_extensions.ChainMap[str, int]()), collections.ChainMap)
808        class CM(typing_extensions.ChainMap[KT, VT]): ...
809        self.assertIs(type(CM[int, str]()), CM)
810
811    def test_chainmap_subclass(self):
812
813        class MyChainMap(typing_extensions.ChainMap[str, int]):
814            pass
815
816        cm = MyChainMap()
817        self.assertIsInstance(cm, MyChainMap)
818
819        self.assertIsSubclass(MyChainMap, collections.ChainMap)
820        self.assertNotIsSubclass(collections.ChainMap, MyChainMap)
821
822    def test_deque_instantiation(self):
823        self.assertIs(type(typing_extensions.Deque()), collections.deque)
824        self.assertIs(type(typing_extensions.Deque[T]()), collections.deque)
825        self.assertIs(type(typing_extensions.Deque[int]()), collections.deque)
826        class D(typing_extensions.Deque[T]): ...
827        self.assertIs(type(D[int]()), D)
828
829    def test_counter_instantiation(self):
830        self.assertIs(type(typing_extensions.Counter()), collections.Counter)
831        self.assertIs(type(typing_extensions.Counter[T]()), collections.Counter)
832        self.assertIs(type(typing_extensions.Counter[int]()), collections.Counter)
833        class C(typing_extensions.Counter[T]): ...
834        self.assertIs(type(C[int]()), C)
835        self.assertEqual(C.__bases__, (collections.Counter, typing.Generic))
836
837    def test_counter_subclass_instantiation(self):
838
839        class MyCounter(typing_extensions.Counter[int]):
840            pass
841
842        d = MyCounter()
843        self.assertIsInstance(d, MyCounter)
844        self.assertIsInstance(d, collections.Counter)
845        self.assertIsInstance(d, typing_extensions.Counter)
846
847    def test_async_generator(self):
848        ns = {}
849        exec("async def f():\n"
850             "    yield 42\n", globals(), ns)
851        g = ns['f']()
852        self.assertIsSubclass(type(g), typing_extensions.AsyncGenerator)
853
854    def test_no_async_generator_instantiation(self):
855        with self.assertRaises(TypeError):
856            typing_extensions.AsyncGenerator()
857        with self.assertRaises(TypeError):
858            typing_extensions.AsyncGenerator[T, T]()
859        with self.assertRaises(TypeError):
860            typing_extensions.AsyncGenerator[int, int]()
861
862    def test_subclassing_async_generator(self):
863        class G(typing_extensions.AsyncGenerator[int, int]):
864            def asend(self, value):
865                pass
866            def athrow(self, typ, val=None, tb=None):
867                pass
868
869        ns = {}
870        exec('async def g(): yield 0', globals(), ns)
871        g = ns['g']
872        self.assertIsSubclass(G, typing_extensions.AsyncGenerator)
873        self.assertIsSubclass(G, typing_extensions.AsyncIterable)
874        self.assertIsSubclass(G, collections.abc.AsyncGenerator)
875        self.assertIsSubclass(G, collections.abc.AsyncIterable)
876        self.assertNotIsSubclass(type(g), G)
877
878        instance = G()
879        self.assertIsInstance(instance, typing_extensions.AsyncGenerator)
880        self.assertIsInstance(instance, typing_extensions.AsyncIterable)
881        self.assertIsInstance(instance, collections.abc.AsyncGenerator)
882        self.assertIsInstance(instance, collections.abc.AsyncIterable)
883        self.assertNotIsInstance(type(g), G)
884        self.assertNotIsInstance(g, G)
885
886
887class OtherABCTests(BaseTestCase):
888
889    def test_contextmanager(self):
890        @contextlib.contextmanager
891        def manager():
892            yield 42
893
894        cm = manager()
895        self.assertIsInstance(cm, typing_extensions.ContextManager)
896        self.assertNotIsInstance(42, typing_extensions.ContextManager)
897
898    def test_async_contextmanager(self):
899        class NotACM:
900            pass
901        self.assertIsInstance(ACM(), typing_extensions.AsyncContextManager)
902        self.assertNotIsInstance(NotACM(), typing_extensions.AsyncContextManager)
903        @contextlib.contextmanager
904        def manager():
905            yield 42
906
907        cm = manager()
908        self.assertNotIsInstance(cm, typing_extensions.AsyncContextManager)
909        self.assertEqual(typing_extensions.AsyncContextManager[int].__args__, (int,))
910        with self.assertRaises(TypeError):
911            isinstance(42, typing_extensions.AsyncContextManager[int])
912        with self.assertRaises(TypeError):
913            typing_extensions.AsyncContextManager[int, str]
914
915
916class TypeTests(BaseTestCase):
917
918    def test_type_basic(self):
919
920        class User: pass
921        class BasicUser(User): pass
922        class ProUser(User): pass
923
924        def new_user(user_class: Type[User]) -> User:
925            return user_class()
926
927        new_user(BasicUser)
928
929    def test_type_typevar(self):
930
931        class User: pass
932        class BasicUser(User): pass
933        class ProUser(User): pass
934
935        U = TypeVar('U', bound=User)
936
937        def new_user(user_class: Type[U]) -> U:
938            return user_class()
939
940        new_user(BasicUser)
941
942    def test_type_optional(self):
943        A = Optional[Type[BaseException]]
944
945        def foo(a: A) -> Optional[BaseException]:
946            if a is None:
947                return None
948            else:
949                return a()
950
951        assert isinstance(foo(KeyboardInterrupt), KeyboardInterrupt)
952        assert foo(None) is None
953
954
955class NewTypeTests(BaseTestCase):
956
957    def test_basic(self):
958        UserId = NewType('UserId', int)
959        UserName = NewType('UserName', str)
960        self.assertIsInstance(UserId(5), int)
961        self.assertIsInstance(UserName('Joe'), str)
962        self.assertEqual(UserId(5) + 1, 6)
963
964    def test_errors(self):
965        UserId = NewType('UserId', int)
966        UserName = NewType('UserName', str)
967        with self.assertRaises(TypeError):
968            issubclass(UserId, int)
969        with self.assertRaises(TypeError):
970            class D(UserName):
971                pass
972
973
974class Coordinate(Protocol):
975    x: int
976    y: int
977
978@runtime
979class Point(Coordinate, Protocol):
980    label: str
981
982class MyPoint:
983    x: int
984    y: int
985    label: str
986
987class XAxis(Protocol):
988    x: int
989
990class YAxis(Protocol):
991    y: int
992
993@runtime
994class Position(XAxis, YAxis, Protocol):
995    pass
996
997@runtime
998class Proto(Protocol):
999    attr: int
1000
1001    def meth(self, arg: str) -> int:
1002        ...
1003
1004class Concrete(Proto):
1005    pass
1006
1007class Other:
1008    attr: int = 1
1009
1010    def meth(self, arg: str) -> int:
1011        if arg == 'this':
1012            return 1
1013        return 0
1014
1015class NT(NamedTuple):
1016    x: int
1017    y: int
1018
1019
1020class ProtocolTests(BaseTestCase):
1021
1022    def test_basic_protocol(self):
1023        @runtime
1024        class P(Protocol):
1025            def meth(self):
1026                pass
1027        class C: pass
1028        class D:
1029            def meth(self):
1030                pass
1031        def f():
1032            pass
1033        self.assertIsSubclass(D, P)
1034        self.assertIsInstance(D(), P)
1035        self.assertNotIsSubclass(C, P)
1036        self.assertNotIsInstance(C(), P)
1037        self.assertNotIsSubclass(types.FunctionType, P)
1038        self.assertNotIsInstance(f, P)
1039
1040    def test_everything_implements_empty_protocol(self):
1041        @runtime
1042        class Empty(Protocol): pass
1043        class C: pass
1044        def f():
1045            pass
1046        for thing in (object, type, tuple, C, types.FunctionType):
1047            self.assertIsSubclass(thing, Empty)
1048        for thing in (object(), 1, (), typing, f):
1049            self.assertIsInstance(thing, Empty)
1050
1051    def test_function_implements_protocol(self):
1052        def f():
1053            pass
1054        self.assertIsInstance(f, HasCallProtocol)
1055
1056    def test_no_inheritance_from_nominal(self):
1057        class C: pass
1058        class BP(Protocol): pass
1059        with self.assertRaises(TypeError):
1060            class P(C, Protocol):
1061                pass
1062        with self.assertRaises(TypeError):
1063            class P(Protocol, C):
1064                pass
1065        with self.assertRaises(TypeError):
1066            class P(BP, C, Protocol):
1067                pass
1068        class D(BP, C): pass
1069        class E(C, BP): pass
1070        self.assertNotIsInstance(D(), E)
1071        self.assertNotIsInstance(E(), D)
1072
1073    def test_no_instantiation(self):
1074        class P(Protocol): pass
1075        with self.assertRaises(TypeError):
1076            P()
1077        class C(P): pass
1078        self.assertIsInstance(C(), C)
1079        T = TypeVar('T')
1080        class PG(Protocol[T]): pass
1081        with self.assertRaises(TypeError):
1082            PG()
1083        with self.assertRaises(TypeError):
1084            PG[int]()
1085        with self.assertRaises(TypeError):
1086            PG[T]()
1087        class CG(PG[T]): pass
1088        self.assertIsInstance(CG[int](), CG)
1089
1090    def test_cannot_instantiate_abstract(self):
1091        @runtime
1092        class P(Protocol):
1093            @abc.abstractmethod
1094            def ameth(self) -> int:
1095                raise NotImplementedError
1096        class B(P):
1097            pass
1098        class C(B):
1099            def ameth(self) -> int:
1100                return 26
1101        with self.assertRaises(TypeError):
1102            B()
1103        self.assertIsInstance(C(), P)
1104
1105    def test_subprotocols_extending(self):
1106        class P1(Protocol):
1107            def meth1(self):
1108                pass
1109        @runtime
1110        class P2(P1, Protocol):
1111            def meth2(self):
1112                pass
1113        class C:
1114            def meth1(self):
1115                pass
1116            def meth2(self):
1117                pass
1118        class C1:
1119            def meth1(self):
1120                pass
1121        class C2:
1122            def meth2(self):
1123                pass
1124        self.assertNotIsInstance(C1(), P2)
1125        self.assertNotIsInstance(C2(), P2)
1126        self.assertNotIsSubclass(C1, P2)
1127        self.assertNotIsSubclass(C2, P2)
1128        self.assertIsInstance(C(), P2)
1129        self.assertIsSubclass(C, P2)
1130
1131    def test_subprotocols_merging(self):
1132        class P1(Protocol):
1133            def meth1(self):
1134                pass
1135        class P2(Protocol):
1136            def meth2(self):
1137                pass
1138        @runtime
1139        class P(P1, P2, Protocol):
1140            pass
1141        class C:
1142            def meth1(self):
1143                pass
1144            def meth2(self):
1145                pass
1146        class C1:
1147            def meth1(self):
1148                pass
1149        class C2:
1150            def meth2(self):
1151                pass
1152        self.assertNotIsInstance(C1(), P)
1153        self.assertNotIsInstance(C2(), P)
1154        self.assertNotIsSubclass(C1, P)
1155        self.assertNotIsSubclass(C2, P)
1156        self.assertIsInstance(C(), P)
1157        self.assertIsSubclass(C, P)
1158
1159    def test_protocols_issubclass(self):
1160        T = TypeVar('T')
1161        @runtime
1162        class P(Protocol):
1163            def x(self): ...
1164        @runtime
1165        class PG(Protocol[T]):
1166            def x(self): ...
1167        class BadP(Protocol):
1168            def x(self): ...
1169        class BadPG(Protocol[T]):
1170            def x(self): ...
1171        class C:
1172            def x(self): ...
1173        self.assertIsSubclass(C, P)
1174        self.assertIsSubclass(C, PG)
1175        self.assertIsSubclass(BadP, PG)
1176        with self.assertRaises(TypeError):
1177            issubclass(C, PG[T])
1178        with self.assertRaises(TypeError):
1179            issubclass(C, PG[C])
1180        with self.assertRaises(TypeError):
1181            issubclass(C, BadP)
1182        with self.assertRaises(TypeError):
1183            issubclass(C, BadPG)
1184        with self.assertRaises(TypeError):
1185            issubclass(P, PG[T])
1186        with self.assertRaises(TypeError):
1187            issubclass(PG, PG[int])
1188
1189    def test_protocols_issubclass_non_callable(self):
1190        class C:
1191            x = 1
1192        @runtime
1193        class PNonCall(Protocol):
1194            x = 1
1195        with self.assertRaises(TypeError):
1196            issubclass(C, PNonCall)
1197        self.assertIsInstance(C(), PNonCall)
1198        PNonCall.register(C)
1199        with self.assertRaises(TypeError):
1200            issubclass(C, PNonCall)
1201        self.assertIsInstance(C(), PNonCall)
1202        # check that non-protocol subclasses are not affected
1203        class D(PNonCall): ...
1204        self.assertNotIsSubclass(C, D)
1205        self.assertNotIsInstance(C(), D)
1206        D.register(C)
1207        self.assertIsSubclass(C, D)
1208        self.assertIsInstance(C(), D)
1209        with self.assertRaises(TypeError):
1210            issubclass(D, PNonCall)
1211
1212    def test_protocols_isinstance(self):
1213        T = TypeVar('T')
1214        @runtime
1215        class P(Protocol):
1216            def meth(x): ...
1217        @runtime
1218        class PG(Protocol[T]):
1219            def meth(x): ...
1220        class BadP(Protocol):
1221            def meth(x): ...
1222        class BadPG(Protocol[T]):
1223            def meth(x): ...
1224        class C:
1225            def meth(x): ...
1226        self.assertIsInstance(C(), P)
1227        self.assertIsInstance(C(), PG)
1228        with self.assertRaises(TypeError):
1229            isinstance(C(), PG[T])
1230        with self.assertRaises(TypeError):
1231            isinstance(C(), PG[C])
1232        with self.assertRaises(TypeError):
1233            isinstance(C(), BadP)
1234        with self.assertRaises(TypeError):
1235            isinstance(C(), BadPG)
1236
1237    def test_protocols_isinstance_py36(self):
1238        class APoint:
1239            def __init__(self, x, y, label):
1240                self.x = x
1241                self.y = y
1242                self.label = label
1243        class BPoint:
1244            label = 'B'
1245            def __init__(self, x, y):
1246                self.x = x
1247                self.y = y
1248        class C:
1249            def __init__(self, attr):
1250                self.attr = attr
1251            def meth(self, arg):
1252                return 0
1253        class Bad: pass
1254        self.assertIsInstance(APoint(1, 2, 'A'), Point)
1255        self.assertIsInstance(BPoint(1, 2), Point)
1256        self.assertNotIsInstance(MyPoint(), Point)
1257        self.assertIsInstance(BPoint(1, 2), Position)
1258        self.assertIsInstance(Other(), Proto)
1259        self.assertIsInstance(Concrete(), Proto)
1260        self.assertIsInstance(C(42), Proto)
1261        self.assertNotIsInstance(Bad(), Proto)
1262        self.assertNotIsInstance(Bad(), Point)
1263        self.assertNotIsInstance(Bad(), Position)
1264        self.assertNotIsInstance(Bad(), Concrete)
1265        self.assertNotIsInstance(Other(), Concrete)
1266        self.assertIsInstance(NT(1, 2), Position)
1267
1268    def test_protocols_isinstance_init(self):
1269        T = TypeVar('T')
1270        @runtime
1271        class P(Protocol):
1272            x = 1
1273        @runtime
1274        class PG(Protocol[T]):
1275            x = 1
1276        class C:
1277            def __init__(self, x):
1278                self.x = x
1279        self.assertIsInstance(C(1), P)
1280        self.assertIsInstance(C(1), PG)
1281
1282    def test_protocols_support_register(self):
1283        @runtime
1284        class P(Protocol):
1285            x = 1
1286        class PM(Protocol):
1287            def meth(self): pass
1288        class D(PM): pass
1289        class C: pass
1290        D.register(C)
1291        P.register(C)
1292        self.assertIsInstance(C(), P)
1293        self.assertIsInstance(C(), D)
1294
1295    def test_none_on_non_callable_doesnt_block_implementation(self):
1296        @runtime
1297        class P(Protocol):
1298            x = 1
1299        class A:
1300            x = 1
1301        class B(A):
1302            x = None
1303        class C:
1304            def __init__(self):
1305                self.x = None
1306        self.assertIsInstance(B(), P)
1307        self.assertIsInstance(C(), P)
1308
1309    def test_none_on_callable_blocks_implementation(self):
1310        @runtime
1311        class P(Protocol):
1312            def x(self): ...
1313        class A:
1314            def x(self): ...
1315        class B(A):
1316            x = None
1317        class C:
1318            def __init__(self):
1319                self.x = None
1320        self.assertNotIsInstance(B(), P)
1321        self.assertNotIsInstance(C(), P)
1322
1323    def test_non_protocol_subclasses(self):
1324        class P(Protocol):
1325            x = 1
1326        @runtime
1327        class PR(Protocol):
1328            def meth(self): pass
1329        class NonP(P):
1330            x = 1
1331        class NonPR(PR): pass
1332        class C:
1333            x = 1
1334        class D:
1335            def meth(self): pass
1336        self.assertNotIsInstance(C(), NonP)
1337        self.assertNotIsInstance(D(), NonPR)
1338        self.assertNotIsSubclass(C, NonP)
1339        self.assertNotIsSubclass(D, NonPR)
1340        self.assertIsInstance(NonPR(), PR)
1341        self.assertIsSubclass(NonPR, PR)
1342
1343    def test_custom_subclasshook(self):
1344        class P(Protocol):
1345            x = 1
1346        class OKClass: pass
1347        class BadClass:
1348            x = 1
1349        class C(P):
1350            @classmethod
1351            def __subclasshook__(cls, other):
1352                return other.__name__.startswith("OK")
1353        self.assertIsInstance(OKClass(), C)
1354        self.assertNotIsInstance(BadClass(), C)
1355        self.assertIsSubclass(OKClass, C)
1356        self.assertNotIsSubclass(BadClass, C)
1357
1358    def test_issubclass_fails_correctly(self):
1359        @runtime
1360        class P(Protocol):
1361            x = 1
1362        class C: pass
1363        with self.assertRaises(TypeError):
1364            issubclass(C(), P)
1365
1366    def test_defining_generic_protocols(self):
1367        T = TypeVar('T')
1368        S = TypeVar('S')
1369        @runtime
1370        class PR(Protocol[T, S]):
1371            def meth(self): pass
1372        class P(PR[int, T], Protocol[T]):
1373            y = 1
1374        with self.assertRaises(TypeError):
1375            issubclass(PR[int, T], PR)
1376        with self.assertRaises(TypeError):
1377            issubclass(P[str], PR)
1378        with self.assertRaises(TypeError):
1379            PR[int]
1380        with self.assertRaises(TypeError):
1381            P[int, str]
1382        if not TYPING_3_10_0:
1383            with self.assertRaises(TypeError):
1384                PR[int, 1]
1385            with self.assertRaises(TypeError):
1386                PR[int, ClassVar]
1387        class C(PR[int, T]): pass
1388        self.assertIsInstance(C[str](), C)
1389
1390    def test_defining_generic_protocols_old_style(self):
1391        T = TypeVar('T')
1392        S = TypeVar('S')
1393        @runtime
1394        class PR(Protocol, Generic[T, S]):
1395            def meth(self): pass
1396        class P(PR[int, str], Protocol):
1397            y = 1
1398        with self.assertRaises(TypeError):
1399            self.assertIsSubclass(PR[int, str], PR)
1400        self.assertIsSubclass(P, PR)
1401        with self.assertRaises(TypeError):
1402            PR[int]
1403        if not TYPING_3_10_0:
1404            with self.assertRaises(TypeError):
1405                PR[int, 1]
1406        class P1(Protocol, Generic[T]):
1407            def bar(self, x: T) -> str: ...
1408        class P2(Generic[T], Protocol):
1409            def bar(self, x: T) -> str: ...
1410        @runtime
1411        class PSub(P1[str], Protocol):
1412            x = 1
1413        class Test:
1414            x = 1
1415            def bar(self, x: str) -> str:
1416                return x
1417        self.assertIsInstance(Test(), PSub)
1418        if not TYPING_3_10_0:
1419            with self.assertRaises(TypeError):
1420                PR[int, ClassVar]
1421
1422    def test_init_called(self):
1423        T = TypeVar('T')
1424        class P(Protocol[T]): pass
1425        class C(P[T]):
1426            def __init__(self):
1427                self.test = 'OK'
1428        self.assertEqual(C[int]().test, 'OK')
1429
1430    def test_protocols_bad_subscripts(self):
1431        T = TypeVar('T')
1432        S = TypeVar('S')
1433        with self.assertRaises(TypeError):
1434            class P(Protocol[T, T]): pass
1435        with self.assertRaises(TypeError):
1436            class P(Protocol[int]): pass
1437        with self.assertRaises(TypeError):
1438            class P(Protocol[T], Protocol[S]): pass
1439        with self.assertRaises(TypeError):
1440            class P(typing.Mapping[T, S], Protocol[T]): pass
1441
1442    def test_generic_protocols_repr(self):
1443        T = TypeVar('T')
1444        S = TypeVar('S')
1445        class P(Protocol[T, S]): pass
1446        self.assertTrue(repr(P[T, S]).endswith('P[~T, ~S]'))
1447        self.assertTrue(repr(P[int, str]).endswith('P[int, str]'))
1448
1449    def test_generic_protocols_eq(self):
1450        T = TypeVar('T')
1451        S = TypeVar('S')
1452        class P(Protocol[T, S]): pass
1453        self.assertEqual(P, P)
1454        self.assertEqual(P[int, T], P[int, T])
1455        self.assertEqual(P[T, T][Tuple[T, S]][int, str],
1456                         P[Tuple[int, str], Tuple[int, str]])
1457
1458    def test_generic_protocols_special_from_generic(self):
1459        T = TypeVar('T')
1460        class P(Protocol[T]): pass
1461        self.assertEqual(P.__parameters__, (T,))
1462        self.assertEqual(P[int].__parameters__, ())
1463        self.assertEqual(P[int].__args__, (int,))
1464        self.assertIs(P[int].__origin__, P)
1465
1466    def test_generic_protocols_special_from_protocol(self):
1467        @runtime
1468        class PR(Protocol):
1469            x = 1
1470        class P(Protocol):
1471            def meth(self):
1472                pass
1473        T = TypeVar('T')
1474        class PG(Protocol[T]):
1475            x = 1
1476            def meth(self):
1477                pass
1478        self.assertTrue(P._is_protocol)
1479        self.assertTrue(PR._is_protocol)
1480        self.assertTrue(PG._is_protocol)
1481        if hasattr(typing, 'Protocol'):
1482            self.assertFalse(P._is_runtime_protocol)
1483        else:
1484            with self.assertRaises(AttributeError):
1485                self.assertFalse(P._is_runtime_protocol)
1486        self.assertTrue(PR._is_runtime_protocol)
1487        self.assertTrue(PG[int]._is_protocol)
1488        self.assertEqual(typing_extensions._get_protocol_attrs(P), {'meth'})
1489        self.assertEqual(typing_extensions._get_protocol_attrs(PR), {'x'})
1490        self.assertEqual(frozenset(typing_extensions._get_protocol_attrs(PG)),
1491                         frozenset({'x', 'meth'}))
1492
1493    def test_no_runtime_deco_on_nominal(self):
1494        with self.assertRaises(TypeError):
1495            @runtime
1496            class C: pass
1497        class Proto(Protocol):
1498            x = 1
1499        with self.assertRaises(TypeError):
1500            @runtime
1501            class Concrete(Proto):
1502                pass
1503
1504    def test_none_treated_correctly(self):
1505        @runtime
1506        class P(Protocol):
1507            x = None  # type: int
1508        class B(object): pass
1509        self.assertNotIsInstance(B(), P)
1510        class C:
1511            x = 1
1512        class D:
1513            x = None
1514        self.assertIsInstance(C(), P)
1515        self.assertIsInstance(D(), P)
1516        class CI:
1517            def __init__(self):
1518                self.x = 1
1519        class DI:
1520            def __init__(self):
1521                self.x = None
1522        self.assertIsInstance(C(), P)
1523        self.assertIsInstance(D(), P)
1524
1525    def test_protocols_in_unions(self):
1526        class P(Protocol):
1527            x = None  # type: int
1528        Alias = typing.Union[typing.Iterable, P]
1529        Alias2 = typing.Union[P, typing.Iterable]
1530        self.assertEqual(Alias, Alias2)
1531
1532    def test_protocols_pickleable(self):
1533        global P, CP  # pickle wants to reference the class by name
1534        T = TypeVar('T')
1535
1536        @runtime
1537        class P(Protocol[T]):
1538            x = 1
1539        class CP(P[int]):
1540            pass
1541
1542        c = CP()
1543        c.foo = 42
1544        c.bar = 'abc'
1545        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1546            z = pickle.dumps(c, proto)
1547            x = pickle.loads(z)
1548            self.assertEqual(x.foo, 42)
1549            self.assertEqual(x.bar, 'abc')
1550            self.assertEqual(x.x, 1)
1551            self.assertEqual(x.__dict__, {'foo': 42, 'bar': 'abc'})
1552            s = pickle.dumps(P)
1553            D = pickle.loads(s)
1554            class E:
1555                x = 1
1556            self.assertIsInstance(E(), D)
1557
1558    def test_collections_protocols_allowed(self):
1559        @runtime_checkable
1560        class Custom(collections.abc.Iterable, Protocol):
1561            def close(self): pass
1562
1563        class A: ...
1564        class B:
1565            def __iter__(self):
1566                return []
1567            def close(self):
1568                return 0
1569
1570        self.assertIsSubclass(B, Custom)
1571        self.assertNotIsSubclass(A, Custom)
1572
1573    def test_no_init_same_for_different_protocol_implementations(self):
1574        class CustomProtocolWithoutInitA(Protocol):
1575            pass
1576
1577        class CustomProtocolWithoutInitB(Protocol):
1578            pass
1579
1580        self.assertEqual(CustomProtocolWithoutInitA.__init__, CustomProtocolWithoutInitB.__init__)
1581
1582
1583class TypedDictTests(BaseTestCase):
1584
1585    def test_basics_iterable_syntax(self):
1586        Emp = TypedDict('Emp', {'name': str, 'id': int})
1587        self.assertIsSubclass(Emp, dict)
1588        self.assertIsSubclass(Emp, typing.MutableMapping)
1589        self.assertNotIsSubclass(Emp, collections.abc.Sequence)
1590        jim = Emp(name='Jim', id=1)
1591        self.assertIs(type(jim), dict)
1592        self.assertEqual(jim['name'], 'Jim')
1593        self.assertEqual(jim['id'], 1)
1594        self.assertEqual(Emp.__name__, 'Emp')
1595        self.assertEqual(Emp.__module__, __name__)
1596        self.assertEqual(Emp.__bases__, (dict,))
1597        self.assertEqual(Emp.__annotations__, {'name': str, 'id': int})
1598        self.assertEqual(Emp.__total__, True)
1599
1600    def test_basics_keywords_syntax(self):
1601        Emp = TypedDict('Emp', name=str, id=int)
1602        self.assertIsSubclass(Emp, dict)
1603        self.assertIsSubclass(Emp, typing.MutableMapping)
1604        self.assertNotIsSubclass(Emp, collections.abc.Sequence)
1605        jim = Emp(name='Jim', id=1)
1606        self.assertIs(type(jim), dict)
1607        self.assertEqual(jim['name'], 'Jim')
1608        self.assertEqual(jim['id'], 1)
1609        self.assertEqual(Emp.__name__, 'Emp')
1610        self.assertEqual(Emp.__module__, __name__)
1611        self.assertEqual(Emp.__bases__, (dict,))
1612        self.assertEqual(Emp.__annotations__, {'name': str, 'id': int})
1613        self.assertEqual(Emp.__total__, True)
1614
1615    def test_typeddict_special_keyword_names(self):
1616        TD = TypedDict("TD", cls=type, self=object, typename=str, _typename=int,
1617                       fields=list, _fields=dict)
1618        self.assertEqual(TD.__name__, 'TD')
1619        self.assertEqual(TD.__annotations__, {'cls': type, 'self': object, 'typename': str,
1620                                              '_typename': int, 'fields': list, '_fields': dict})
1621        a = TD(cls=str, self=42, typename='foo', _typename=53,
1622               fields=[('bar', tuple)], _fields={'baz', set})
1623        self.assertEqual(a['cls'], str)
1624        self.assertEqual(a['self'], 42)
1625        self.assertEqual(a['typename'], 'foo')
1626        self.assertEqual(a['_typename'], 53)
1627        self.assertEqual(a['fields'], [('bar', tuple)])
1628        self.assertEqual(a['_fields'], {'baz', set})
1629
1630    @skipIf(hasattr(typing, 'TypedDict'), "Should be tested by upstream")
1631    def test_typeddict_create_errors(self):
1632        with self.assertRaises(TypeError):
1633            TypedDict.__new__()
1634        with self.assertRaises(TypeError):
1635            TypedDict()
1636        with self.assertRaises(TypeError):
1637            TypedDict('Emp', [('name', str)], None)
1638
1639        with self.assertWarns(DeprecationWarning):
1640            Emp = TypedDict(_typename='Emp', name=str, id=int)
1641        self.assertEqual(Emp.__name__, 'Emp')
1642        self.assertEqual(Emp.__annotations__, {'name': str, 'id': int})
1643
1644        with self.assertWarns(DeprecationWarning):
1645            Emp = TypedDict('Emp', _fields={'name': str, 'id': int})
1646        self.assertEqual(Emp.__name__, 'Emp')
1647        self.assertEqual(Emp.__annotations__, {'name': str, 'id': int})
1648
1649    def test_typeddict_errors(self):
1650        Emp = TypedDict('Emp', {'name': str, 'id': int})
1651        if hasattr(typing, "Required"):
1652            self.assertEqual(TypedDict.__module__, 'typing')
1653        else:
1654            self.assertEqual(TypedDict.__module__, 'typing_extensions')
1655        jim = Emp(name='Jim', id=1)
1656        with self.assertRaises(TypeError):
1657            isinstance({}, Emp)
1658        with self.assertRaises(TypeError):
1659            isinstance(jim, Emp)
1660        with self.assertRaises(TypeError):
1661            issubclass(dict, Emp)
1662        with self.assertRaises(TypeError):
1663            TypedDict('Hi', x=1)
1664        with self.assertRaises(TypeError):
1665            TypedDict('Hi', [('x', int), ('y', 1)])
1666        with self.assertRaises(TypeError):
1667            TypedDict('Hi', [('x', int)], y=int)
1668
1669    def test_py36_class_syntax_usage(self):
1670        self.assertEqual(LabelPoint2D.__name__, 'LabelPoint2D')
1671        self.assertEqual(LabelPoint2D.__module__, __name__)
1672        self.assertEqual(get_type_hints(LabelPoint2D), {'x': int, 'y': int, 'label': str})
1673        self.assertEqual(LabelPoint2D.__bases__, (dict,))
1674        self.assertEqual(LabelPoint2D.__total__, True)
1675        self.assertNotIsSubclass(LabelPoint2D, typing.Sequence)
1676        not_origin = Point2D(x=0, y=1)
1677        self.assertEqual(not_origin['x'], 0)
1678        self.assertEqual(not_origin['y'], 1)
1679        other = LabelPoint2D(x=0, y=1, label='hi')
1680        self.assertEqual(other['label'], 'hi')
1681
1682    def test_pickle(self):
1683        global EmpD  # pickle wants to reference the class by name
1684        EmpD = TypedDict('EmpD', name=str, id=int)
1685        jane = EmpD({'name': 'jane', 'id': 37})
1686        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1687            z = pickle.dumps(jane, proto)
1688            jane2 = pickle.loads(z)
1689            self.assertEqual(jane2, jane)
1690            self.assertEqual(jane2, {'name': 'jane', 'id': 37})
1691            ZZ = pickle.dumps(EmpD, proto)
1692            EmpDnew = pickle.loads(ZZ)
1693            self.assertEqual(EmpDnew({'name': 'jane', 'id': 37}), jane)
1694
1695    def test_optional(self):
1696        EmpD = TypedDict('EmpD', name=str, id=int)
1697
1698        self.assertEqual(typing.Optional[EmpD], typing.Union[None, EmpD])
1699        self.assertNotEqual(typing.List[EmpD], typing.Tuple[EmpD])
1700
1701    def test_total(self):
1702        D = TypedDict('D', {'x': int}, total=False)
1703        self.assertEqual(D(), {})
1704        self.assertEqual(D(x=1), {'x': 1})
1705        self.assertEqual(D.__total__, False)
1706        self.assertEqual(D.__required_keys__, frozenset())
1707        self.assertEqual(D.__optional_keys__, {'x'})
1708
1709        self.assertEqual(Options(), {})
1710        self.assertEqual(Options(log_level=2), {'log_level': 2})
1711        self.assertEqual(Options.__total__, False)
1712        self.assertEqual(Options.__required_keys__, frozenset())
1713        self.assertEqual(Options.__optional_keys__, {'log_level', 'log_path'})
1714
1715    def test_optional_keys(self):
1716        assert Point2Dor3D.__required_keys__ == frozenset(['x', 'y'])
1717        assert Point2Dor3D.__optional_keys__ == frozenset(['z'])
1718
1719    def test_required_notrequired_keys(self):
1720        assert NontotalMovie.__required_keys__ == frozenset({'title'})
1721        assert NontotalMovie.__optional_keys__ == frozenset({'year'})
1722
1723        assert TotalMovie.__required_keys__ == frozenset({'title'})
1724        assert TotalMovie.__optional_keys__ == frozenset({'year'})
1725
1726
1727    def test_keys_inheritance(self):
1728        assert BaseAnimal.__required_keys__ == frozenset(['name'])
1729        assert BaseAnimal.__optional_keys__ == frozenset([])
1730        assert get_type_hints(BaseAnimal) == {'name': str}
1731
1732        assert Animal.__required_keys__ == frozenset(['name'])
1733        assert Animal.__optional_keys__ == frozenset(['tail', 'voice'])
1734        assert get_type_hints(Animal) == {
1735            'name': str,
1736            'tail': bool,
1737            'voice': str,
1738        }
1739
1740        assert Cat.__required_keys__ == frozenset(['name', 'fur_color'])
1741        assert Cat.__optional_keys__ == frozenset(['tail', 'voice'])
1742        assert get_type_hints(Cat) == {
1743            'fur_color': str,
1744            'name': str,
1745            'tail': bool,
1746            'voice': str,
1747        }
1748
1749    def test_is_typeddict(self):
1750        assert is_typeddict(Point2D) is True
1751        assert is_typeddict(Point2Dor3D) is True
1752        assert is_typeddict(Union[str, int]) is False
1753        # classes, not instances
1754        assert is_typeddict(Point2D()) is False
1755
1756    @skipUnless(TYPING_3_8_0, "Python 3.8+ required")
1757    def test_is_typeddict_against_typeddict_from_typing(self):
1758        Point = typing.TypedDict('Point', {'x': int, 'y': int})
1759
1760        class PointDict2D(typing.TypedDict):
1761            x: int
1762            y: int
1763
1764        class PointDict3D(PointDict2D, total=False):
1765            z: int
1766
1767        assert is_typeddict(Point) is True
1768        assert is_typeddict(PointDict2D) is True
1769        assert is_typeddict(PointDict3D) is True
1770
1771
1772class AnnotatedTests(BaseTestCase):
1773
1774    def test_repr(self):
1775        if hasattr(typing, 'Annotated'):
1776            mod_name = 'typing'
1777        else:
1778            mod_name = "typing_extensions"
1779        self.assertEqual(
1780            repr(Annotated[int, 4, 5]),
1781            mod_name + ".Annotated[int, 4, 5]"
1782        )
1783        self.assertEqual(
1784            repr(Annotated[List[int], 4, 5]),
1785            mod_name + ".Annotated[typing.List[int], 4, 5]"
1786        )
1787
1788    def test_flatten(self):
1789        A = Annotated[Annotated[int, 4], 5]
1790        self.assertEqual(A, Annotated[int, 4, 5])
1791        self.assertEqual(A.__metadata__, (4, 5))
1792        self.assertEqual(A.__origin__, int)
1793
1794    def test_specialize(self):
1795        L = Annotated[List[T], "my decoration"]
1796        LI = Annotated[List[int], "my decoration"]
1797        self.assertEqual(L[int], Annotated[List[int], "my decoration"])
1798        self.assertEqual(L[int].__metadata__, ("my decoration",))
1799        self.assertEqual(L[int].__origin__, List[int])
1800        with self.assertRaises(TypeError):
1801            LI[int]
1802        with self.assertRaises(TypeError):
1803            L[int, float]
1804
1805    def test_hash_eq(self):
1806        self.assertEqual(len({Annotated[int, 4, 5], Annotated[int, 4, 5]}), 1)
1807        self.assertNotEqual(Annotated[int, 4, 5], Annotated[int, 5, 4])
1808        self.assertNotEqual(Annotated[int, 4, 5], Annotated[str, 4, 5])
1809        self.assertNotEqual(Annotated[int, 4], Annotated[int, 4, 4])
1810        self.assertEqual(
1811            {Annotated[int, 4, 5], Annotated[int, 4, 5], Annotated[T, 4, 5]},
1812            {Annotated[int, 4, 5], Annotated[T, 4, 5]}
1813        )
1814
1815    def test_instantiate(self):
1816        class C:
1817            classvar = 4
1818
1819            def __init__(self, x):
1820                self.x = x
1821
1822            def __eq__(self, other):
1823                if not isinstance(other, C):
1824                    return NotImplemented
1825                return other.x == self.x
1826
1827        A = Annotated[C, "a decoration"]
1828        a = A(5)
1829        c = C(5)
1830        self.assertEqual(a, c)
1831        self.assertEqual(a.x, c.x)
1832        self.assertEqual(a.classvar, c.classvar)
1833
1834    def test_instantiate_generic(self):
1835        MyCount = Annotated[typing_extensions.Counter[T], "my decoration"]
1836        self.assertEqual(MyCount([4, 4, 5]), {4: 2, 5: 1})
1837        self.assertEqual(MyCount[int]([4, 4, 5]), {4: 2, 5: 1})
1838
1839    def test_cannot_instantiate_forward(self):
1840        A = Annotated["int", (5, 6)]
1841        with self.assertRaises(TypeError):
1842            A(5)
1843
1844    def test_cannot_instantiate_type_var(self):
1845        A = Annotated[T, (5, 6)]
1846        with self.assertRaises(TypeError):
1847            A(5)
1848
1849    def test_cannot_getattr_typevar(self):
1850        with self.assertRaises(AttributeError):
1851            Annotated[T, (5, 7)].x
1852
1853    def test_attr_passthrough(self):
1854        class C:
1855            classvar = 4
1856
1857        A = Annotated[C, "a decoration"]
1858        self.assertEqual(A.classvar, 4)
1859        A.x = 5
1860        self.assertEqual(C.x, 5)
1861
1862    @skipIf(sys.version_info[:2] in ((3, 9), (3, 10)), "Waiting for bpo-46491 bugfix.")
1863    def test_special_form_containment(self):
1864        class C:
1865            classvar: Annotated[ClassVar[int], "a decoration"] = 4
1866            const: Annotated[Final[int], "Const"] = 4
1867
1868        if sys.version_info[:2] >= (3, 7):
1869            self.assertEqual(get_type_hints(C, globals())["classvar"], ClassVar[int])
1870            self.assertEqual(get_type_hints(C, globals())["const"], Final[int])
1871        else:
1872            self.assertEqual(
1873                get_type_hints(C, globals())["classvar"],
1874                Annotated[ClassVar[int], "a decoration"]
1875            )
1876            self.assertEqual(
1877                get_type_hints(C, globals())["const"], Annotated[Final[int], "Const"]
1878            )
1879
1880    def test_hash_eq(self):
1881        self.assertEqual(len({Annotated[int, 4, 5], Annotated[int, 4, 5]}), 1)
1882        self.assertNotEqual(Annotated[int, 4, 5], Annotated[int, 5, 4])
1883        self.assertNotEqual(Annotated[int, 4, 5], Annotated[str, 4, 5])
1884        self.assertNotEqual(Annotated[int, 4], Annotated[int, 4, 4])
1885        self.assertEqual(
1886            {Annotated[int, 4, 5], Annotated[int, 4, 5], Annotated[T, 4, 5]},
1887            {Annotated[int, 4, 5], Annotated[T, 4, 5]}
1888        )
1889
1890    def test_cannot_subclass(self):
1891        with self.assertRaisesRegex(TypeError, "Cannot subclass .*Annotated"):
1892            class C(Annotated):
1893                pass
1894
1895    def test_cannot_check_instance(self):
1896        with self.assertRaises(TypeError):
1897            isinstance(5, Annotated[int, "positive"])
1898
1899    def test_cannot_check_subclass(self):
1900        with self.assertRaises(TypeError):
1901            issubclass(int, Annotated[int, "positive"])
1902
1903    def test_pickle(self):
1904        samples = [typing.Any, typing.Union[int, str],
1905                   typing.Optional[str], Tuple[int, ...],
1906                   typing.Callable[[str], bytes],
1907                   Self, LiteralString, Never]
1908
1909        for t in samples:
1910            x = Annotated[t, "a"]
1911
1912            for prot in range(pickle.HIGHEST_PROTOCOL + 1):
1913                with self.subTest(protocol=prot, type=t):
1914                    pickled = pickle.dumps(x, prot)
1915                    restored = pickle.loads(pickled)
1916                    self.assertEqual(x, restored)
1917
1918        global _Annotated_test_G
1919
1920        class _Annotated_test_G(Generic[T]):
1921            x = 1
1922
1923        G = Annotated[_Annotated_test_G[int], "A decoration"]
1924        G.foo = 42
1925        G.bar = 'abc'
1926
1927        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1928            z = pickle.dumps(G, proto)
1929            x = pickle.loads(z)
1930            self.assertEqual(x.foo, 42)
1931            self.assertEqual(x.bar, 'abc')
1932            self.assertEqual(x.x, 1)
1933
1934    def test_subst(self):
1935        dec = "a decoration"
1936        dec2 = "another decoration"
1937
1938        S = Annotated[T, dec2]
1939        self.assertEqual(S[int], Annotated[int, dec2])
1940
1941        self.assertEqual(S[Annotated[int, dec]], Annotated[int, dec, dec2])
1942        L = Annotated[List[T], dec]
1943
1944        self.assertEqual(L[int], Annotated[List[int], dec])
1945        with self.assertRaises(TypeError):
1946            L[int, int]
1947
1948        self.assertEqual(S[L[int]], Annotated[List[int], dec, dec2])
1949
1950        D = Annotated[Dict[KT, VT], dec]
1951        self.assertEqual(D[str, int], Annotated[Dict[str, int], dec])
1952        with self.assertRaises(TypeError):
1953            D[int]
1954
1955        It = Annotated[int, dec]
1956        with self.assertRaises(TypeError):
1957            It[None]
1958
1959        LI = L[int]
1960        with self.assertRaises(TypeError):
1961            LI[None]
1962
1963    def test_annotated_in_other_types(self):
1964        X = List[Annotated[T, 5]]
1965        self.assertEqual(X[int], List[Annotated[int, 5]])
1966
1967
1968class GetTypeHintsTests(BaseTestCase):
1969    def test_get_type_hints(self):
1970        def foobar(x: List['X']): ...
1971        X = Annotated[int, (1, 10)]
1972        self.assertEqual(
1973            get_type_hints(foobar, globals(), locals()),
1974            {'x': List[int]}
1975        )
1976        self.assertEqual(
1977            get_type_hints(foobar, globals(), locals(), include_extras=True),
1978            {'x': List[Annotated[int, (1, 10)]]}
1979        )
1980        BA = Tuple[Annotated[T, (1, 0)], ...]
1981        def barfoo(x: BA): ...
1982        self.assertEqual(get_type_hints(barfoo, globals(), locals())['x'], Tuple[T, ...])
1983        self.assertIs(
1984            get_type_hints(barfoo, globals(), locals(), include_extras=True)['x'],
1985            BA
1986        )
1987        def barfoo2(x: typing.Callable[..., Annotated[List[T], "const"]],
1988                    y: typing.Union[int, Annotated[T, "mutable"]]): ...
1989        self.assertEqual(
1990            get_type_hints(barfoo2, globals(), locals()),
1991            {'x': typing.Callable[..., List[T]], 'y': typing.Union[int, T]}
1992        )
1993        BA2 = typing.Callable[..., List[T]]
1994        def barfoo3(x: BA2): ...
1995        self.assertIs(
1996            get_type_hints(barfoo3, globals(), locals(), include_extras=True)["x"],
1997            BA2
1998        )
1999
2000    def test_get_type_hints_refs(self):
2001
2002        Const = Annotated[T, "Const"]
2003
2004        class MySet(Generic[T]):
2005
2006            def __ior__(self, other: "Const[MySet[T]]") -> "MySet[T]":
2007                ...
2008
2009            def __iand__(self, other: Const["MySet[T]"]) -> "MySet[T]":
2010                ...
2011
2012        self.assertEqual(
2013            get_type_hints(MySet.__iand__, globals(), locals()),
2014            {'other': MySet[T], 'return': MySet[T]}
2015        )
2016
2017        self.assertEqual(
2018            get_type_hints(MySet.__iand__, globals(), locals(), include_extras=True),
2019            {'other': Const[MySet[T]], 'return': MySet[T]}
2020        )
2021
2022        self.assertEqual(
2023            get_type_hints(MySet.__ior__, globals(), locals()),
2024            {'other': MySet[T], 'return': MySet[T]}
2025        )
2026
2027    def test_get_type_hints_typeddict(self):
2028        assert get_type_hints(TotalMovie) == {'title': str, 'year': int}
2029        assert get_type_hints(TotalMovie, include_extras=True) == {
2030            'title': str,
2031            'year': NotRequired[int],
2032        }
2033
2034        assert get_type_hints(AnnotatedMovie) == {'title': str, 'year': int}
2035        assert get_type_hints(AnnotatedMovie, include_extras=True) == {
2036            'title': Annotated[Required[str], "foobar"],
2037            'year': NotRequired[Annotated[int, 2000]],
2038        }
2039
2040
2041class TypeAliasTests(BaseTestCase):
2042    def test_canonical_usage_with_variable_annotation(self):
2043        ns = {}
2044        exec('Alias: TypeAlias = Employee', globals(), ns)
2045
2046    def test_canonical_usage_with_type_comment(self):
2047        Alias = Employee  # type: TypeAlias
2048
2049    def test_cannot_instantiate(self):
2050        with self.assertRaises(TypeError):
2051            TypeAlias()
2052
2053    def test_no_isinstance(self):
2054        with self.assertRaises(TypeError):
2055            isinstance(42, TypeAlias)
2056
2057    def test_no_issubclass(self):
2058        with self.assertRaises(TypeError):
2059            issubclass(Employee, TypeAlias)
2060
2061        with self.assertRaises(TypeError):
2062            issubclass(TypeAlias, Employee)
2063
2064    def test_cannot_subclass(self):
2065        with self.assertRaises(TypeError):
2066            class C(TypeAlias):
2067                pass
2068
2069        with self.assertRaises(TypeError):
2070            class C(type(TypeAlias)):
2071                pass
2072
2073    def test_repr(self):
2074        if hasattr(typing, 'TypeAlias'):
2075            self.assertEqual(repr(TypeAlias), 'typing.TypeAlias')
2076        else:
2077            self.assertEqual(repr(TypeAlias), 'typing_extensions.TypeAlias')
2078
2079    def test_cannot_subscript(self):
2080        with self.assertRaises(TypeError):
2081            TypeAlias[int]
2082
2083class ParamSpecTests(BaseTestCase):
2084
2085    def test_basic_plain(self):
2086        P = ParamSpec('P')
2087        self.assertEqual(P, P)
2088        self.assertIsInstance(P, ParamSpec)
2089        # Should be hashable
2090        hash(P)
2091
2092    def test_repr(self):
2093        P = ParamSpec('P')
2094        P_co = ParamSpec('P_co', covariant=True)
2095        P_contra = ParamSpec('P_contra', contravariant=True)
2096        P_2 = ParamSpec('P_2')
2097        self.assertEqual(repr(P), '~P')
2098        self.assertEqual(repr(P_2), '~P_2')
2099
2100        # Note: PEP 612 doesn't require these to be repr-ed correctly, but
2101        # just follow CPython.
2102        self.assertEqual(repr(P_co), '+P_co')
2103        self.assertEqual(repr(P_contra), '-P_contra')
2104
2105    def test_valid_uses(self):
2106        P = ParamSpec('P')
2107        T = TypeVar('T')
2108        C1 = typing.Callable[P, int]
2109        self.assertEqual(C1.__args__, (P, int))
2110        self.assertEqual(C1.__parameters__, (P,))
2111        C2 = typing.Callable[P, T]
2112        self.assertEqual(C2.__args__, (P, T))
2113        self.assertEqual(C2.__parameters__, (P, T))
2114
2115
2116        # Test collections.abc.Callable too.
2117        if sys.version_info[:2] >= (3, 9):
2118            # Note: no tests for Callable.__parameters__ here
2119            # because types.GenericAlias Callable is hardcoded to search
2120            # for tp_name "TypeVar" in C.  This was changed in 3.10.
2121            C3 = collections.abc.Callable[P, int]
2122            self.assertEqual(C3.__args__, (P, int))
2123            C4 = collections.abc.Callable[P, T]
2124            self.assertEqual(C4.__args__, (P, T))
2125
2126        # ParamSpec instances should also have args and kwargs attributes.
2127        # Note: not in dir(P) because of __class__ hacks
2128        self.assertTrue(hasattr(P, 'args'))
2129        self.assertTrue(hasattr(P, 'kwargs'))
2130
2131    @skipIf((3, 10, 0) <= sys.version_info[:3] <= (3, 10, 2), "Needs bpo-46676.")
2132    def test_args_kwargs(self):
2133        P = ParamSpec('P')
2134        P_2 = ParamSpec('P_2')
2135        # Note: not in dir(P) because of __class__ hacks
2136        self.assertTrue(hasattr(P, 'args'))
2137        self.assertTrue(hasattr(P, 'kwargs'))
2138        self.assertIsInstance(P.args, ParamSpecArgs)
2139        self.assertIsInstance(P.kwargs, ParamSpecKwargs)
2140        self.assertIs(P.args.__origin__, P)
2141        self.assertIs(P.kwargs.__origin__, P)
2142        self.assertEqual(P.args, P.args)
2143        self.assertEqual(P.kwargs, P.kwargs)
2144        self.assertNotEqual(P.args, P_2.args)
2145        self.assertNotEqual(P.kwargs, P_2.kwargs)
2146        self.assertNotEqual(P.args, P.kwargs)
2147        self.assertNotEqual(P.kwargs, P.args)
2148        self.assertNotEqual(P.args, P_2.kwargs)
2149        self.assertEqual(repr(P.args), "P.args")
2150        self.assertEqual(repr(P.kwargs), "P.kwargs")
2151
2152    def test_user_generics(self):
2153        T = TypeVar("T")
2154        P = ParamSpec("P")
2155        P_2 = ParamSpec("P_2")
2156
2157        class X(Generic[T, P]):
2158            pass
2159
2160        G1 = X[int, P_2]
2161        self.assertEqual(G1.__args__, (int, P_2))
2162        self.assertEqual(G1.__parameters__, (P_2,))
2163
2164        G2 = X[int, Concatenate[int, P_2]]
2165        self.assertEqual(G2.__args__, (int, Concatenate[int, P_2]))
2166        self.assertEqual(G2.__parameters__, (P_2,))
2167
2168        # The following are some valid uses cases in PEP 612 that don't work:
2169        # These do not work in 3.9, _type_check blocks the list and ellipsis.
2170        # G3 = X[int, [int, bool]]
2171        # G4 = X[int, ...]
2172        # G5 = Z[[int, str, bool]]
2173        # Not working because this is special-cased in 3.10.
2174        # G6 = Z[int, str, bool]
2175
2176        class Z(Generic[P]):
2177            pass
2178
2179    def test_pickle(self):
2180        global P, P_co, P_contra
2181        P = ParamSpec('P')
2182        P_co = ParamSpec('P_co', covariant=True)
2183        P_contra = ParamSpec('P_contra', contravariant=True)
2184        for proto in range(pickle.HIGHEST_PROTOCOL):
2185            with self.subTest(f'Pickle protocol {proto}'):
2186                for paramspec in (P, P_co, P_contra):
2187                    z = pickle.loads(pickle.dumps(paramspec, proto))
2188                    self.assertEqual(z.__name__, paramspec.__name__)
2189                    self.assertEqual(z.__covariant__, paramspec.__covariant__)
2190                    self.assertEqual(z.__contravariant__, paramspec.__contravariant__)
2191                    self.assertEqual(z.__bound__, paramspec.__bound__)
2192
2193    def test_eq(self):
2194        P = ParamSpec('P')
2195        self.assertEqual(P, P)
2196        self.assertEqual(hash(P), hash(P))
2197        # ParamSpec should compare by id similar to TypeVar in CPython
2198        self.assertNotEqual(ParamSpec('P'), P)
2199        self.assertIsNot(ParamSpec('P'), P)
2200        # Note: normally you don't test this as it breaks when there's
2201        # a hash collision. However, ParamSpec *must* guarantee that
2202        # as long as two objects don't have the same ID, their hashes
2203        # won't be the same.
2204        self.assertNotEqual(hash(ParamSpec('P')), hash(P))
2205
2206
2207class ConcatenateTests(BaseTestCase):
2208    def test_basics(self):
2209        P = ParamSpec('P')
2210
2211        class MyClass: ...
2212
2213        c = Concatenate[MyClass, P]
2214        self.assertNotEqual(c, Concatenate)
2215
2216    def test_valid_uses(self):
2217        P = ParamSpec('P')
2218        T = TypeVar('T')
2219
2220        C1 = Callable[Concatenate[int, P], int]
2221        C2 = Callable[Concatenate[int, T, P], T]
2222
2223        # Test collections.abc.Callable too.
2224        if sys.version_info[:2] >= (3, 9):
2225            C3 = collections.abc.Callable[Concatenate[int, P], int]
2226            C4 = collections.abc.Callable[Concatenate[int, T, P], T]
2227
2228    def test_invalid_uses(self):
2229        P = ParamSpec('P')
2230        T = TypeVar('T')
2231
2232        with self.assertRaisesRegex(
2233            TypeError,
2234            'Cannot take a Concatenate of no types',
2235        ):
2236            Concatenate[()]
2237
2238        with self.assertRaisesRegex(
2239            TypeError,
2240            'The last parameter to Concatenate should be a ParamSpec variable',
2241        ):
2242            Concatenate[P, T]
2243
2244        with self.assertRaisesRegex(
2245            TypeError,
2246            'each arg must be a type',
2247        ):
2248            Concatenate[1, P]
2249
2250    def test_basic_introspection(self):
2251        P = ParamSpec('P')
2252        C1 = Concatenate[int, P]
2253        C2 = Concatenate[int, T, P]
2254        self.assertEqual(C1.__origin__, Concatenate)
2255        self.assertEqual(C1.__args__, (int, P))
2256        self.assertEqual(C2.__origin__, Concatenate)
2257        self.assertEqual(C2.__args__, (int, T, P))
2258
2259    def test_eq(self):
2260        P = ParamSpec('P')
2261        C1 = Concatenate[int, P]
2262        C2 = Concatenate[int, P]
2263        C3 = Concatenate[int, T, P]
2264        self.assertEqual(C1, C2)
2265        self.assertEqual(hash(C1), hash(C2))
2266        self.assertNotEqual(C1, C3)
2267
2268
2269class TypeGuardTests(BaseTestCase):
2270    def test_basics(self):
2271        TypeGuard[int]  # OK
2272        self.assertEqual(TypeGuard[int], TypeGuard[int])
2273
2274        def foo(arg) -> TypeGuard[int]: ...
2275        self.assertEqual(gth(foo), {'return': TypeGuard[int]})
2276
2277    def test_repr(self):
2278        if hasattr(typing, 'TypeGuard'):
2279            mod_name = 'typing'
2280        else:
2281            mod_name = 'typing_extensions'
2282        self.assertEqual(repr(TypeGuard), f'{mod_name}.TypeGuard')
2283        cv = TypeGuard[int]
2284        self.assertEqual(repr(cv), f'{mod_name}.TypeGuard[int]')
2285        cv = TypeGuard[Employee]
2286        self.assertEqual(repr(cv), f'{mod_name}.TypeGuard[{__name__}.Employee]')
2287        cv = TypeGuard[Tuple[int]]
2288        self.assertEqual(repr(cv), f'{mod_name}.TypeGuard[typing.Tuple[int]]')
2289
2290    def test_cannot_subclass(self):
2291        with self.assertRaises(TypeError):
2292            class C(type(TypeGuard)):
2293                pass
2294        with self.assertRaises(TypeError):
2295            class C(type(TypeGuard[int])):
2296                pass
2297
2298    def test_cannot_init(self):
2299        with self.assertRaises(TypeError):
2300            TypeGuard()
2301        with self.assertRaises(TypeError):
2302            type(TypeGuard)()
2303        with self.assertRaises(TypeError):
2304            type(TypeGuard[Optional[int]])()
2305
2306    def test_no_isinstance(self):
2307        with self.assertRaises(TypeError):
2308            isinstance(1, TypeGuard[int])
2309        with self.assertRaises(TypeError):
2310            issubclass(int, TypeGuard)
2311
2312
2313class LiteralStringTests(BaseTestCase):
2314    def test_basics(self):
2315        class Foo:
2316            def bar(self) -> LiteralString: ...
2317            def baz(self) -> "LiteralString": ...
2318
2319        self.assertEqual(gth(Foo.bar), {'return': LiteralString})
2320        self.assertEqual(gth(Foo.baz), {'return': LiteralString})
2321
2322    def test_get_origin(self):
2323        self.assertIsNone(get_origin(LiteralString))
2324
2325    def test_repr(self):
2326        if hasattr(typing, 'LiteralString'):
2327            mod_name = 'typing'
2328        else:
2329            mod_name = 'typing_extensions'
2330        self.assertEqual(repr(LiteralString), '{}.LiteralString'.format(mod_name))
2331
2332    def test_cannot_subscript(self):
2333        with self.assertRaises(TypeError):
2334            LiteralString[int]
2335
2336    def test_cannot_subclass(self):
2337        with self.assertRaises(TypeError):
2338            class C(type(LiteralString)):
2339                pass
2340        with self.assertRaises(TypeError):
2341            class C(LiteralString):
2342                pass
2343
2344    def test_cannot_init(self):
2345        with self.assertRaises(TypeError):
2346            LiteralString()
2347        with self.assertRaises(TypeError):
2348            type(LiteralString)()
2349
2350    def test_no_isinstance(self):
2351        with self.assertRaises(TypeError):
2352            isinstance(1, LiteralString)
2353        with self.assertRaises(TypeError):
2354            issubclass(int, LiteralString)
2355
2356    def test_alias(self):
2357        StringTuple = Tuple[LiteralString, LiteralString]
2358        class Alias:
2359            def return_tuple(self) -> StringTuple:
2360                return ("foo", "pep" + "675")
2361
2362    def test_typevar(self):
2363        StrT = TypeVar("StrT", bound=LiteralString)
2364        self.assertIs(StrT.__bound__, LiteralString)
2365
2366    def test_pickle(self):
2367        for proto in range(pickle.HIGHEST_PROTOCOL):
2368            pickled = pickle.dumps(LiteralString, protocol=proto)
2369            self.assertIs(LiteralString, pickle.loads(pickled))
2370
2371
2372class SelfTests(BaseTestCase):
2373    def test_basics(self):
2374        class Foo:
2375            def bar(self) -> Self: ...
2376
2377        self.assertEqual(gth(Foo.bar), {'return': Self})
2378
2379    def test_repr(self):
2380        if hasattr(typing, 'Self'):
2381            mod_name = 'typing'
2382        else:
2383            mod_name = 'typing_extensions'
2384        self.assertEqual(repr(Self), '{}.Self'.format(mod_name))
2385
2386    def test_cannot_subscript(self):
2387        with self.assertRaises(TypeError):
2388            Self[int]
2389
2390    def test_cannot_subclass(self):
2391        with self.assertRaises(TypeError):
2392            class C(type(Self)):
2393                pass
2394
2395    def test_cannot_init(self):
2396        with self.assertRaises(TypeError):
2397            Self()
2398        with self.assertRaises(TypeError):
2399            type(Self)()
2400
2401    def test_no_isinstance(self):
2402        with self.assertRaises(TypeError):
2403            isinstance(1, Self)
2404        with self.assertRaises(TypeError):
2405            issubclass(int, Self)
2406
2407    def test_alias(self):
2408        TupleSelf = Tuple[Self, Self]
2409        class Alias:
2410            def return_tuple(self) -> TupleSelf:
2411                return (self, self)
2412
2413    def test_pickle(self):
2414        for proto in range(pickle.HIGHEST_PROTOCOL):
2415            pickled = pickle.dumps(Self, protocol=proto)
2416            self.assertIs(Self, pickle.loads(pickled))
2417
2418
2419class UnpackTests(BaseTestCase):
2420    def test_basic_plain(self):
2421        Ts = TypeVarTuple('Ts')
2422        self.assertEqual(Unpack[Ts], Unpack[Ts])
2423        with self.assertRaises(TypeError):
2424            Unpack()
2425
2426    def test_repr(self):
2427        Ts = TypeVarTuple('Ts')
2428        self.assertEqual(repr(Unpack[Ts]), 'typing_extensions.Unpack[Ts]')
2429
2430    def test_cannot_subclass_vars(self):
2431        with self.assertRaises(TypeError):
2432            class V(Unpack[TypeVarTuple('Ts')]):
2433                pass
2434
2435    def test_tuple(self):
2436        Ts = TypeVarTuple('Ts')
2437        Tuple[Unpack[Ts]]
2438
2439    def test_union(self):
2440        Xs = TypeVarTuple('Xs')
2441        Ys = TypeVarTuple('Ys')
2442        self.assertEqual(
2443            Union[Unpack[Xs]],
2444            Unpack[Xs]
2445        )
2446        self.assertNotEqual(
2447            Union[Unpack[Xs]],
2448            Union[Unpack[Xs], Unpack[Ys]]
2449        )
2450        self.assertEqual(
2451            Union[Unpack[Xs], Unpack[Xs]],
2452            Unpack[Xs]
2453        )
2454        self.assertNotEqual(
2455            Union[Unpack[Xs], int],
2456            Union[Unpack[Xs]]
2457        )
2458        self.assertNotEqual(
2459            Union[Unpack[Xs], int],
2460            Union[int]
2461        )
2462        self.assertEqual(
2463            Union[Unpack[Xs], int].__args__,
2464            (Unpack[Xs], int)
2465        )
2466        self.assertEqual(
2467            Union[Unpack[Xs], int].__parameters__,
2468            (Xs,)
2469        )
2470        self.assertIs(
2471            Union[Unpack[Xs], int].__origin__,
2472            Union
2473        )
2474
2475    def test_concatenation(self):
2476        Xs = TypeVarTuple('Xs')
2477        self.assertEqual(Tuple[int, Unpack[Xs]].__args__, (int, Unpack[Xs]))
2478        self.assertEqual(Tuple[Unpack[Xs], int].__args__, (Unpack[Xs], int))
2479        self.assertEqual(Tuple[int, Unpack[Xs], str].__args__,
2480                         (int, Unpack[Xs], str))
2481        class C(Generic[Unpack[Xs]]): pass
2482        self.assertEqual(C[int, Unpack[Xs]].__args__, (int, Unpack[Xs]))
2483        self.assertEqual(C[Unpack[Xs], int].__args__, (Unpack[Xs], int))
2484        self.assertEqual(C[int, Unpack[Xs], str].__args__,
2485                         (int, Unpack[Xs], str))
2486
2487    def test_class(self):
2488        Ts = TypeVarTuple('Ts')
2489
2490        class C(Generic[Unpack[Ts]]): pass
2491        self.assertEqual(C[int].__args__, (int,))
2492        self.assertEqual(C[int, str].__args__, (int, str))
2493
2494        with self.assertRaises(TypeError):
2495            class C(Generic[Unpack[Ts], int]): pass
2496
2497        T1 = TypeVar('T')
2498        T2 = TypeVar('T')
2499        class C(Generic[T1, T2, Unpack[Ts]]): pass
2500        self.assertEqual(C[int, str].__args__, (int, str))
2501        self.assertEqual(C[int, str, float].__args__, (int, str, float))
2502        self.assertEqual(C[int, str, float, bool].__args__, (int, str, float, bool))
2503        with self.assertRaises(TypeError):
2504            C[int]
2505
2506
2507class TypeVarTupleTests(BaseTestCase):
2508
2509    def test_basic_plain(self):
2510        Ts = TypeVarTuple('Ts')
2511        self.assertEqual(Ts, Ts)
2512        self.assertIsInstance(Ts, TypeVarTuple)
2513        Xs = TypeVarTuple('Xs')
2514        Ys = TypeVarTuple('Ys')
2515        self.assertNotEqual(Xs, Ys)
2516
2517    def test_repr(self):
2518        Ts = TypeVarTuple('Ts')
2519        self.assertEqual(repr(Ts), 'Ts')
2520
2521    def test_no_redefinition(self):
2522        self.assertNotEqual(TypeVarTuple('Ts'), TypeVarTuple('Ts'))
2523
2524    def test_cannot_subclass_vars(self):
2525        with self.assertRaises(TypeError):
2526            class V(TypeVarTuple('Ts')):
2527                pass
2528
2529    def test_cannot_subclass_var_itself(self):
2530        with self.assertRaises(TypeError):
2531            class V(TypeVarTuple):
2532                pass
2533
2534    def test_cannot_instantiate_vars(self):
2535        Ts = TypeVarTuple('Ts')
2536        with self.assertRaises(TypeError):
2537            Ts()
2538
2539    def test_tuple(self):
2540        Ts = TypeVarTuple('Ts')
2541        # Not legal at type checking time but we can't really check against it.
2542        Tuple[Ts]
2543
2544    def test_args_and_parameters(self):
2545        Ts = TypeVarTuple('Ts')
2546
2547        t = Tuple[tuple(Ts)]
2548        self.assertEqual(t.__args__, (Ts.__unpacked__,))
2549        self.assertEqual(t.__parameters__, (Ts,))
2550
2551
2552class FinalDecoratorTests(BaseTestCase):
2553    def test_final_unmodified(self):
2554        def func(x): ...
2555        self.assertIs(func, final(func))
2556
2557    def test_dunder_final(self):
2558        @final
2559        def func(): ...
2560        @final
2561        class Cls: ...
2562        self.assertIs(True, func.__final__)
2563        self.assertIs(True, Cls.__final__)
2564
2565        class Wrapper:
2566            __slots__ = ("func",)
2567            def __init__(self, func):
2568                self.func = func
2569            def __call__(self, *args, **kwargs):
2570                return self.func(*args, **kwargs)
2571
2572        # Check that no error is thrown if the attribute
2573        # is not writable.
2574        @final
2575        @Wrapper
2576        def wrapped(): ...
2577        self.assertIsInstance(wrapped, Wrapper)
2578        self.assertIs(False, hasattr(wrapped, "__final__"))
2579
2580        class Meta(type):
2581            @property
2582            def __final__(self): return "can't set me"
2583        @final
2584        class WithMeta(metaclass=Meta): ...
2585        self.assertEqual(WithMeta.__final__, "can't set me")
2586
2587        # Builtin classes throw TypeError if you try to set an
2588        # attribute.
2589        final(int)
2590        self.assertIs(False, hasattr(int, "__final__"))
2591
2592        # Make sure it works with common builtin decorators
2593        class Methods:
2594            @final
2595            @classmethod
2596            def clsmethod(cls): ...
2597
2598            @final
2599            @staticmethod
2600            def stmethod(): ...
2601
2602            # The other order doesn't work because property objects
2603            # don't allow attribute assignment.
2604            @property
2605            @final
2606            def prop(self): ...
2607
2608            @final
2609            @lru_cache()  # noqa: B019
2610            def cached(self): ...
2611
2612        # Use getattr_static because the descriptor returns the
2613        # underlying function, which doesn't have __final__.
2614        self.assertIs(
2615            True,
2616            inspect.getattr_static(Methods, "clsmethod").__final__
2617        )
2618        self.assertIs(
2619            True,
2620            inspect.getattr_static(Methods, "stmethod").__final__
2621        )
2622        self.assertIs(True, Methods.prop.fget.__final__)
2623        self.assertIs(True, Methods.cached.__final__)
2624
2625
2626class RevealTypeTests(BaseTestCase):
2627    def test_reveal_type(self):
2628        obj = object()
2629        self.assertIs(obj, reveal_type(obj))
2630
2631
2632class DataclassTransformTests(BaseTestCase):
2633    def test_decorator(self):
2634        def create_model(*, frozen: bool = False, kw_only: bool = True):
2635            return lambda cls: cls
2636
2637        decorated = dataclass_transform(kw_only_default=True, order_default=False)(create_model)
2638
2639        class CustomerModel:
2640            id: int
2641
2642        self.assertIs(decorated, create_model)
2643        self.assertEqual(
2644            decorated.__dataclass_transform__,
2645            {
2646                "eq_default": True,
2647                "order_default": False,
2648                "kw_only_default": True,
2649                "field_descriptors": (),
2650            }
2651        )
2652        self.assertIs(
2653            decorated(frozen=True, kw_only=False)(CustomerModel),
2654            CustomerModel
2655        )
2656
2657    def test_base_class(self):
2658        class ModelBase:
2659            def __init_subclass__(cls, *, frozen: bool = False): ...
2660
2661        Decorated = dataclass_transform(eq_default=True, order_default=True)(ModelBase)
2662
2663        class CustomerModel(Decorated, frozen=True):
2664            id: int
2665
2666        self.assertIs(Decorated, ModelBase)
2667        self.assertEqual(
2668            Decorated.__dataclass_transform__,
2669            {
2670                "eq_default": True,
2671                "order_default": True,
2672                "kw_only_default": False,
2673                "field_descriptors": (),
2674            }
2675        )
2676        self.assertIsSubclass(CustomerModel, Decorated)
2677
2678    def test_metaclass(self):
2679        class Field: ...
2680
2681        class ModelMeta(type):
2682            def __new__(
2683                cls, name, bases, namespace, *, init: bool = True,
2684            ):
2685                return super().__new__(cls, name, bases, namespace)
2686
2687        Decorated = dataclass_transform(
2688            order_default=True, field_descriptors=(Field,)
2689        )(ModelMeta)
2690
2691        class ModelBase(metaclass=Decorated): ...
2692
2693        class CustomerModel(ModelBase, init=False):
2694            id: int
2695
2696        self.assertIs(Decorated, ModelMeta)
2697        self.assertEqual(
2698            Decorated.__dataclass_transform__,
2699            {
2700                "eq_default": True,
2701                "order_default": True,
2702                "kw_only_default": False,
2703                "field_descriptors": (Field,),
2704            }
2705        )
2706        self.assertIsInstance(CustomerModel, Decorated)
2707
2708
2709class AllTests(BaseTestCase):
2710
2711    def test_typing_extensions_includes_standard(self):
2712        a = typing_extensions.__all__
2713        self.assertIn('ClassVar', a)
2714        self.assertIn('Type', a)
2715        self.assertIn('ChainMap', a)
2716        self.assertIn('ContextManager', a)
2717        self.assertIn('Counter', a)
2718        self.assertIn('DefaultDict', a)
2719        self.assertIn('Deque', a)
2720        self.assertIn('NewType', a)
2721        self.assertIn('overload', a)
2722        self.assertIn('Text', a)
2723        self.assertIn('TYPE_CHECKING', a)
2724        self.assertIn('TypeAlias', a)
2725        self.assertIn('ParamSpec', a)
2726        self.assertIn("Concatenate", a)
2727
2728        self.assertIn('Annotated', a)
2729        self.assertIn('get_type_hints', a)
2730
2731        self.assertIn('Awaitable', a)
2732        self.assertIn('AsyncIterator', a)
2733        self.assertIn('AsyncIterable', a)
2734        self.assertIn('Coroutine', a)
2735        self.assertIn('AsyncContextManager', a)
2736
2737        self.assertIn('AsyncGenerator', a)
2738
2739        self.assertIn('Protocol', a)
2740        self.assertIn('runtime', a)
2741
2742        # Check that all objects in `__all__` are present in the module
2743        for name in a:
2744            self.assertTrue(hasattr(typing_extensions, name))
2745
2746    def test_typing_extensions_defers_when_possible(self):
2747        exclude = {
2748            'overload',
2749            'Text',
2750            'TypedDict',
2751            'TYPE_CHECKING',
2752            'Final',
2753            'get_type_hints',
2754            'is_typeddict',
2755        }
2756        if sys.version_info < (3, 10):
2757            exclude |= {'get_args', 'get_origin'}
2758        if sys.version_info < (3, 11):
2759            exclude.add('final')
2760        for item in typing_extensions.__all__:
2761            if item not in exclude and hasattr(typing, item):
2762                self.assertIs(
2763                    getattr(typing_extensions, item),
2764                    getattr(typing, item))
2765
2766    def test_typing_extensions_compiles_with_opt(self):
2767        file_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),
2768                                 'typing_extensions.py')
2769        try:
2770            subprocess.check_output(f'{sys.executable} -OO {file_path}',
2771                                    stderr=subprocess.STDOUT,
2772                                    shell=True)
2773        except subprocess.CalledProcessError:
2774            self.fail('Module does not compile with optimize=2 (-OO flag).')
2775
2776
2777
2778if __name__ == '__main__':
2779    main()
2780