• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import unittest
2from test import support
3
4
5class TestMROEntry(unittest.TestCase):
6    def test_mro_entry_signature(self):
7        tested = []
8        class B: ...
9        class C:
10            def __mro_entries__(self, *args, **kwargs):
11                tested.extend([args, kwargs])
12                return (C,)
13        c = C()
14        self.assertEqual(tested, [])
15        class D(B, c): ...
16        self.assertEqual(tested[0], ((B, c),))
17        self.assertEqual(tested[1], {})
18
19    def test_mro_entry(self):
20        tested = []
21        class A: ...
22        class B: ...
23        class C:
24            def __mro_entries__(self, bases):
25                tested.append(bases)
26                return (self.__class__,)
27        c = C()
28        self.assertEqual(tested, [])
29        class D(A, c, B): ...
30        self.assertEqual(tested[-1], (A, c, B))
31        self.assertEqual(D.__bases__, (A, C, B))
32        self.assertEqual(D.__orig_bases__, (A, c, B))
33        self.assertEqual(D.__mro__, (D, A, C, B, object))
34        d = D()
35        class E(d): ...
36        self.assertEqual(tested[-1], (d,))
37        self.assertEqual(E.__bases__, (D,))
38
39    def test_mro_entry_none(self):
40        tested = []
41        class A: ...
42        class B: ...
43        class C:
44            def __mro_entries__(self, bases):
45                tested.append(bases)
46                return ()
47        c = C()
48        self.assertEqual(tested, [])
49        class D(A, c, B): ...
50        self.assertEqual(tested[-1], (A, c, B))
51        self.assertEqual(D.__bases__, (A, B))
52        self.assertEqual(D.__orig_bases__, (A, c, B))
53        self.assertEqual(D.__mro__, (D, A, B, object))
54        class E(c): ...
55        self.assertEqual(tested[-1], (c,))
56        self.assertEqual(E.__bases__, (object,))
57        self.assertEqual(E.__orig_bases__, (c,))
58        self.assertEqual(E.__mro__, (E, object))
59
60    def test_mro_entry_with_builtins(self):
61        tested = []
62        class A: ...
63        class C:
64            def __mro_entries__(self, bases):
65                tested.append(bases)
66                return (dict,)
67        c = C()
68        self.assertEqual(tested, [])
69        class D(A, c): ...
70        self.assertEqual(tested[-1], (A, c))
71        self.assertEqual(D.__bases__, (A, dict))
72        self.assertEqual(D.__orig_bases__, (A, c))
73        self.assertEqual(D.__mro__, (D, A, dict, object))
74
75    def test_mro_entry_with_builtins_2(self):
76        tested = []
77        class C:
78            def __mro_entries__(self, bases):
79                tested.append(bases)
80                return (C,)
81        c = C()
82        self.assertEqual(tested, [])
83        class D(c, dict): ...
84        self.assertEqual(tested[-1], (c, dict))
85        self.assertEqual(D.__bases__, (C, dict))
86        self.assertEqual(D.__orig_bases__, (c, dict))
87        self.assertEqual(D.__mro__, (D, C, dict, object))
88
89    def test_mro_entry_errors(self):
90        class C_too_many:
91            def __mro_entries__(self, bases, something, other):
92                return ()
93        c = C_too_many()
94        with self.assertRaises(TypeError):
95            class D(c): ...
96        class C_too_few:
97            def __mro_entries__(self):
98                return ()
99        d = C_too_few()
100        with self.assertRaises(TypeError):
101            class D(d): ...
102
103    def test_mro_entry_errors_2(self):
104        class C_not_callable:
105            __mro_entries__ = "Surprise!"
106        c = C_not_callable()
107        with self.assertRaises(TypeError):
108            class D(c): ...
109        class C_not_tuple:
110            def __mro_entries__(self):
111                return object
112        c = C_not_tuple()
113        with self.assertRaises(TypeError):
114            class D(c): ...
115
116    def test_mro_entry_metaclass(self):
117        meta_args = []
118        class Meta(type):
119            def __new__(mcls, name, bases, ns):
120                meta_args.extend([mcls, name, bases, ns])
121                return super().__new__(mcls, name, bases, ns)
122        class A: ...
123        class C:
124            def __mro_entries__(self, bases):
125                return (A,)
126        c = C()
127        class D(c, metaclass=Meta):
128            x = 1
129        self.assertEqual(meta_args[0], Meta)
130        self.assertEqual(meta_args[1], 'D')
131        self.assertEqual(meta_args[2], (A,))
132        self.assertEqual(meta_args[3]['x'], 1)
133        self.assertEqual(D.__bases__, (A,))
134        self.assertEqual(D.__orig_bases__, (c,))
135        self.assertEqual(D.__mro__, (D, A, object))
136        self.assertEqual(D.__class__, Meta)
137
138    def test_mro_entry_type_call(self):
139        # Substitution should _not_ happen in direct type call
140        class C:
141            def __mro_entries__(self, bases):
142                return ()
143        c = C()
144        with self.assertRaisesRegex(TypeError,
145                                    "MRO entry resolution; "
146                                    "use types.new_class()"):
147            type('Bad', (c,), {})
148
149
150class TestClassGetitem(unittest.TestCase):
151    def test_class_getitem(self):
152        getitem_args = []
153        class C:
154            def __class_getitem__(*args, **kwargs):
155                getitem_args.extend([args, kwargs])
156                return None
157        C[int, str]
158        self.assertEqual(getitem_args[0], (C, (int, str)))
159        self.assertEqual(getitem_args[1], {})
160
161    def test_class_getitem(self):
162        class C:
163            def __class_getitem__(cls, item):
164                return f'C[{item.__name__}]'
165        self.assertEqual(C[int], 'C[int]')
166        self.assertEqual(C[C], 'C[C]')
167
168    def test_class_getitem_inheritance(self):
169        class C:
170            def __class_getitem__(cls, item):
171                return f'{cls.__name__}[{item.__name__}]'
172        class D(C): ...
173        self.assertEqual(D[int], 'D[int]')
174        self.assertEqual(D[D], 'D[D]')
175
176    def test_class_getitem_inheritance_2(self):
177        class C:
178            def __class_getitem__(cls, item):
179                return 'Should not see this'
180        class D(C):
181            def __class_getitem__(cls, item):
182                return f'{cls.__name__}[{item.__name__}]'
183        self.assertEqual(D[int], 'D[int]')
184        self.assertEqual(D[D], 'D[D]')
185
186    def test_class_getitem_classmethod(self):
187        class C:
188            @classmethod
189            def __class_getitem__(cls, item):
190                return f'{cls.__name__}[{item.__name__}]'
191        class D(C): ...
192        self.assertEqual(D[int], 'D[int]')
193        self.assertEqual(D[D], 'D[D]')
194
195    def test_class_getitem_patched(self):
196        class C:
197            def __init_subclass__(cls):
198                def __class_getitem__(cls, item):
199                    return f'{cls.__name__}[{item.__name__}]'
200                cls.__class_getitem__ = classmethod(__class_getitem__)
201        class D(C): ...
202        self.assertEqual(D[int], 'D[int]')
203        self.assertEqual(D[D], 'D[D]')
204
205    def test_class_getitem_with_builtins(self):
206        class A(dict):
207            called_with = None
208
209            def __class_getitem__(cls, item):
210                cls.called_with = item
211        class B(A):
212            pass
213        self.assertIs(B.called_with, None)
214        B[int]
215        self.assertIs(B.called_with, int)
216
217    def test_class_getitem_errors(self):
218        class C_too_few:
219            def __class_getitem__(cls):
220                return None
221        with self.assertRaises(TypeError):
222            C_too_few[int]
223        class C_too_many:
224            def __class_getitem__(cls, one, two):
225                return None
226        with self.assertRaises(TypeError):
227            C_too_many[int]
228
229    def test_class_getitem_errors_2(self):
230        class C:
231            def __class_getitem__(cls, item):
232                return None
233        with self.assertRaises(TypeError):
234            C()[int]
235        class E: ...
236        e = E()
237        e.__class_getitem__ = lambda cls, item: 'This will not work'
238        with self.assertRaises(TypeError):
239            e[int]
240        class C_not_callable:
241            __class_getitem__ = "Surprise!"
242        with self.assertRaises(TypeError):
243            C_not_callable[int]
244
245    def test_class_getitem_metaclass(self):
246        class Meta(type):
247            def __class_getitem__(cls, item):
248                return f'{cls.__name__}[{item.__name__}]'
249        self.assertEqual(Meta[int], 'Meta[int]')
250
251    def test_class_getitem_with_metaclass(self):
252        class Meta(type): pass
253        class C(metaclass=Meta):
254            def __class_getitem__(cls, item):
255                return f'{cls.__name__}[{item.__name__}]'
256        self.assertEqual(C[int], 'C[int]')
257
258    def test_class_getitem_metaclass_first(self):
259        class Meta(type):
260            def __getitem__(cls, item):
261                return 'from metaclass'
262        class C(metaclass=Meta):
263            def __class_getitem__(cls, item):
264                return 'from __class_getitem__'
265        self.assertEqual(C[int], 'from metaclass')
266
267
268@support.cpython_only
269class CAPITest(unittest.TestCase):
270
271    def test_c_class(self):
272        from _testcapi import Generic, GenericAlias
273        self.assertIsInstance(Generic.__class_getitem__(int), GenericAlias)
274
275        IntGeneric = Generic[int]
276        self.assertIs(type(IntGeneric), GenericAlias)
277        self.assertEqual(IntGeneric.__mro_entries__(()), (int,))
278        class C(IntGeneric):
279            pass
280        self.assertEqual(C.__bases__, (int,))
281        self.assertEqual(C.__orig_bases__, (IntGeneric,))
282        self.assertEqual(C.__mro__, (C, int, object))
283
284
285if __name__ == "__main__":
286    unittest.main()
287