• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import ast
2import sys
3import unittest
4from test import support
5
6
7funcdef = """\
8def foo():
9    # type: () -> int
10    pass
11
12def bar():  # type: () -> None
13    pass
14"""
15
16asyncdef = """\
17async def foo():
18    # type: () -> int
19    return await bar()
20
21async def bar():  # type: () -> int
22    return await bar()
23"""
24
25asyncvar = """\
26async = 12
27await = 13
28"""
29
30asynccomp = """\
31async def foo(xs):
32    [x async for x in xs]
33"""
34
35matmul = """\
36a = b @ c
37"""
38
39fstring = """\
40a = 42
41f"{a}"
42"""
43
44underscorednumber = """\
45a = 42_42_42
46"""
47
48redundantdef = """\
49def foo():  # type: () -> int
50    # type: () -> str
51    return ''
52"""
53
54nonasciidef = """\
55def foo():
56    # type: () -> àçčéñt
57    pass
58"""
59
60forstmt = """\
61for a in []:  # type: int
62    pass
63"""
64
65withstmt = """\
66with context() as a:  # type: int
67    pass
68"""
69
70vardecl = """\
71a = 0  # type: int
72"""
73
74ignores = """\
75def foo():
76    pass  # type: ignore
77
78def bar():
79    x = 1  # type: ignore
80
81def baz():
82    pass  # type: ignore[excuse]
83    pass  # type: ignore=excuse
84    pass  # type: ignore [excuse]
85    x = 1  # type: ignore whatever
86"""
87
88# Test for long-form type-comments in arguments.  A test function
89# named 'fabvk' would have two positional args, a and b, plus a
90# var-arg *v, plus a kw-arg **k.  It is verified in test_longargs()
91# that it has exactly these arguments, no more, no fewer.
92longargs = """\
93def fa(
94    a = 1,  # type: A
95):
96    pass
97
98def fa(
99    a = 1  # type: A
100):
101    pass
102
103def fa(
104    a = 1,  # type: A
105    /
106):
107    pass
108
109def fab(
110    a,  # type: A
111    b,  # type: B
112):
113    pass
114
115def fab(
116    a,  # type: A
117    /,
118    b,  # type: B
119):
120    pass
121
122def fab(
123    a,  # type: A
124    b   # type: B
125):
126    pass
127
128def fv(
129    *v,  # type: V
130):
131    pass
132
133def fv(
134    *v  # type: V
135):
136    pass
137
138def fk(
139    **k,  # type: K
140):
141    pass
142
143def fk(
144    **k  # type: K
145):
146    pass
147
148def fvk(
149    *v,  # type: V
150    **k,  # type: K
151):
152    pass
153
154def fvk(
155    *v,  # type: V
156    **k  # type: K
157):
158    pass
159
160def fav(
161    a,  # type: A
162    *v,  # type: V
163):
164    pass
165
166def fav(
167    a,  # type: A
168    /,
169    *v,  # type: V
170):
171    pass
172
173def fav(
174    a,  # type: A
175    *v  # type: V
176):
177    pass
178
179def fak(
180    a,  # type: A
181    **k,  # type: K
182):
183    pass
184
185def fak(
186    a,  # type: A
187    /,
188    **k,  # type: K
189):
190    pass
191
192def fak(
193    a,  # type: A
194    **k  # type: K
195):
196    pass
197
198def favk(
199    a,  # type: A
200    *v,  # type: V
201    **k,  # type: K
202):
203    pass
204
205def favk(
206    a,  # type: A
207    /,
208    *v,  # type: V
209    **k,  # type: K
210):
211    pass
212
213def favk(
214    a,  # type: A
215    *v,  # type: V
216    **k  # type: K
217):
218    pass
219"""
220
221
222class TypeCommentTests(unittest.TestCase):
223
224    lowest = 4  # Lowest minor version supported
225    highest = sys.version_info[1]  # Highest minor version
226
227    def parse(self, source, feature_version=highest):
228        return ast.parse(source, type_comments=True,
229                         feature_version=feature_version)
230
231    def parse_all(self, source, minver=lowest, maxver=highest, expected_regex=""):
232        for version in range(self.lowest, self.highest + 1):
233            feature_version = (3, version)
234            if minver <= version <= maxver:
235                try:
236                    yield self.parse(source, feature_version)
237                except SyntaxError as err:
238                    raise SyntaxError(str(err) + f" feature_version={feature_version}")
239            else:
240                with self.assertRaisesRegex(SyntaxError, expected_regex,
241                                            msg=f"feature_version={feature_version}"):
242                    self.parse(source, feature_version)
243
244    def classic_parse(self, source):
245        return ast.parse(source)
246
247    def test_funcdef(self):
248        for tree in self.parse_all(funcdef):
249            self.assertEqual(tree.body[0].type_comment, "() -> int")
250            self.assertEqual(tree.body[1].type_comment, "() -> None")
251        tree = self.classic_parse(funcdef)
252        self.assertEqual(tree.body[0].type_comment, None)
253        self.assertEqual(tree.body[1].type_comment, None)
254
255    def test_asyncdef(self):
256        for tree in self.parse_all(asyncdef, minver=5):
257            self.assertEqual(tree.body[0].type_comment, "() -> int")
258            self.assertEqual(tree.body[1].type_comment, "() -> int")
259        tree = self.classic_parse(asyncdef)
260        self.assertEqual(tree.body[0].type_comment, None)
261        self.assertEqual(tree.body[1].type_comment, None)
262
263    def test_asyncvar(self):
264        for tree in self.parse_all(asyncvar, maxver=6):
265            pass
266
267    def test_asynccomp(self):
268        for tree in self.parse_all(asynccomp, minver=6):
269            pass
270
271    def test_matmul(self):
272        for tree in self.parse_all(matmul, minver=5):
273            pass
274
275    def test_fstring(self):
276        for tree in self.parse_all(fstring, minver=6):
277            pass
278
279    def test_underscorednumber(self):
280        for tree in self.parse_all(underscorednumber, minver=6):
281            pass
282
283    def test_redundantdef(self):
284        for tree in self.parse_all(redundantdef, maxver=0,
285                                expected_regex="^Cannot have two type comments on def"):
286            pass
287
288    def test_nonasciidef(self):
289        for tree in self.parse_all(nonasciidef):
290            self.assertEqual(tree.body[0].type_comment, "() -> àçčéñt")
291
292    def test_forstmt(self):
293        for tree in self.parse_all(forstmt):
294            self.assertEqual(tree.body[0].type_comment, "int")
295        tree = self.classic_parse(forstmt)
296        self.assertEqual(tree.body[0].type_comment, None)
297
298    def test_withstmt(self):
299        for tree in self.parse_all(withstmt):
300            self.assertEqual(tree.body[0].type_comment, "int")
301        tree = self.classic_parse(withstmt)
302        self.assertEqual(tree.body[0].type_comment, None)
303
304    def test_vardecl(self):
305        for tree in self.parse_all(vardecl):
306            self.assertEqual(tree.body[0].type_comment, "int")
307        tree = self.classic_parse(vardecl)
308        self.assertEqual(tree.body[0].type_comment, None)
309
310    def test_ignores(self):
311        for tree in self.parse_all(ignores):
312            self.assertEqual(
313                [(ti.lineno, ti.tag) for ti in tree.type_ignores],
314                [
315                    (2, ''),
316                    (5, ''),
317                    (8, '[excuse]'),
318                    (9, '=excuse'),
319                    (10, ' [excuse]'),
320                    (11, ' whatever'),
321                ])
322        tree = self.classic_parse(ignores)
323        self.assertEqual(tree.type_ignores, [])
324
325    def test_longargs(self):
326        for tree in self.parse_all(longargs):
327            for t in tree.body:
328                # The expected args are encoded in the function name
329                todo = set(t.name[1:])
330                self.assertEqual(len(t.args.args) + len(t.args.posonlyargs),
331                                 len(todo) - bool(t.args.vararg) - bool(t.args.kwarg))
332                self.assertTrue(t.name.startswith('f'), t.name)
333                for index, c in enumerate(t.name[1:]):
334                    todo.remove(c)
335                    if c == 'v':
336                        arg = t.args.vararg
337                    elif c == 'k':
338                        arg = t.args.kwarg
339                    else:
340                        assert 0 <= ord(c) - ord('a') < len(t.args.posonlyargs + t.args.args)
341                        if index < len(t.args.posonlyargs):
342                            arg = t.args.posonlyargs[ord(c) - ord('a')]
343                        else:
344                            arg = t.args.args[ord(c) - ord('a') - len(t.args.posonlyargs)]
345                    self.assertEqual(arg.arg, c)  # That's the argument name
346                    self.assertEqual(arg.type_comment, arg.arg.upper())
347                assert not todo
348        tree = self.classic_parse(longargs)
349        for t in tree.body:
350            for arg in t.args.args + [t.args.vararg, t.args.kwarg]:
351                if arg is not None:
352                    self.assertIsNone(arg.type_comment, "%s(%s:%r)" %
353                                      (t.name, arg.arg, arg.type_comment))
354
355    def test_inappropriate_type_comments(self):
356        """Tests for inappropriately-placed type comments.
357
358        These should be silently ignored with type comments off,
359        but raise SyntaxError with type comments on.
360
361        This is not meant to be exhaustive.
362        """
363
364        def check_both_ways(source):
365            ast.parse(source, type_comments=False)
366            for tree in self.parse_all(source, maxver=0):
367                pass
368
369        check_both_ways("pass  # type: int\n")
370        check_both_ways("foo()  # type: int\n")
371        check_both_ways("x += 1  # type: int\n")
372        check_both_ways("while True:  # type: int\n  continue\n")
373        check_both_ways("while True:\n  continue  # type: int\n")
374        check_both_ways("try:  # type: int\n  pass\nfinally:\n  pass\n")
375        check_both_ways("try:\n  pass\nfinally:  # type: int\n  pass\n")
376        check_both_ways("pass  # type: ignorewhatever\n")
377        check_both_ways("pass  # type: ignoreé\n")
378
379    def test_func_type_input(self):
380
381        def parse_func_type_input(source):
382            return ast.parse(source, "<unknown>", "func_type")
383
384        # Some checks below will crash if the returned structure is wrong
385        tree = parse_func_type_input("() -> int")
386        self.assertEqual(tree.argtypes, [])
387        self.assertEqual(tree.returns.id, "int")
388
389        tree = parse_func_type_input("(int) -> List[str]")
390        self.assertEqual(len(tree.argtypes), 1)
391        arg = tree.argtypes[0]
392        self.assertEqual(arg.id, "int")
393        self.assertEqual(tree.returns.value.id, "List")
394        self.assertEqual(tree.returns.slice.id, "str")
395
396        tree = parse_func_type_input("(int, *str, **Any) -> float")
397        self.assertEqual(tree.argtypes[0].id, "int")
398        self.assertEqual(tree.argtypes[1].id, "str")
399        self.assertEqual(tree.argtypes[2].id, "Any")
400        self.assertEqual(tree.returns.id, "float")
401
402        tree = parse_func_type_input("(*int) -> None")
403        self.assertEqual(tree.argtypes[0].id, "int")
404        tree = parse_func_type_input("(**int) -> None")
405        self.assertEqual(tree.argtypes[0].id, "int")
406        tree = parse_func_type_input("(*int, **str) -> None")
407        self.assertEqual(tree.argtypes[0].id, "int")
408        self.assertEqual(tree.argtypes[1].id, "str")
409
410        with self.assertRaises(SyntaxError):
411            tree = parse_func_type_input("(int, *str, *Any) -> float")
412
413        with self.assertRaises(SyntaxError):
414            tree = parse_func_type_input("(int, **str, Any) -> float")
415
416        with self.assertRaises(SyntaxError):
417            tree = parse_func_type_input("(**int, **str) -> float")
418
419
420if __name__ == '__main__':
421    unittest.main()
422