• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import ast
2import sys
3import unittest
4
5
6funcdef = """\
7def foo():
8    # type: () -> int
9    pass
10
11def bar():  # type: () -> None
12    pass
13"""
14
15asyncdef = """\
16async def foo():
17    # type: () -> int
18    return await bar()
19
20async def bar():  # type: () -> int
21    return await bar()
22"""
23
24asyncvar = """\
25async = 12
26await = 13
27"""
28
29asynccomp = """\
30async def foo(xs):
31    [x async for x in xs]
32"""
33
34matmul = """\
35a = b @ c
36"""
37
38fstring = """\
39a = 42
40f"{a}"
41"""
42
43underscorednumber = """\
44a = 42_42_42
45"""
46
47redundantdef = """\
48def foo():  # type: () -> int
49    # type: () -> str
50    return ''
51"""
52
53nonasciidef = """\
54def foo():
55    # type: () -> àçčéñt
56    pass
57"""
58
59forstmt = """\
60for a in []:  # type: int
61    pass
62"""
63
64withstmt = """\
65with context() as a:  # type: int
66    pass
67"""
68
69parenthesized_withstmt = """\
70with (a as b):  # type: int
71    pass
72
73with (a, b):  # type: int
74    pass
75"""
76
77vardecl = """\
78a = 0  # type: int
79"""
80
81ignores = """\
82def foo():
83    pass  # type: ignore
84
85def bar():
86    x = 1  # type: ignore
87
88def baz():
89    pass  # type: ignore[excuse]
90    pass  # type: ignore=excuse
91    pass  # type: ignore [excuse]
92    x = 1  # type: ignore whatever
93"""
94
95# Test for long-form type-comments in arguments.  A test function
96# named 'fabvk' would have two positional args, a and b, plus a
97# var-arg *v, plus a kw-arg **k.  It is verified in test_longargs()
98# that it has exactly these arguments, no more, no fewer.
99longargs = """\
100def fa(
101    a = 1,  # type: A
102):
103    pass
104
105def fa(
106    a = 1  # type: A
107):
108    pass
109
110def fa(
111    a = 1,  # type: A
112    /
113):
114    pass
115
116def fab(
117    a,  # type: A
118    b,  # type: B
119):
120    pass
121
122def fab(
123    a,  # type: A
124    /,
125    b,  # type: B
126):
127    pass
128
129def fab(
130    a,  # type: A
131    b   # type: B
132):
133    pass
134
135def fv(
136    *v,  # type: V
137):
138    pass
139
140def fv(
141    *v  # type: V
142):
143    pass
144
145def fk(
146    **k,  # type: K
147):
148    pass
149
150def fk(
151    **k  # type: K
152):
153    pass
154
155def fvk(
156    *v,  # type: V
157    **k,  # type: K
158):
159    pass
160
161def fvk(
162    *v,  # type: V
163    **k  # type: K
164):
165    pass
166
167def fav(
168    a,  # type: A
169    *v,  # type: V
170):
171    pass
172
173def fav(
174    a,  # type: A
175    /,
176    *v,  # type: V
177):
178    pass
179
180def fav(
181    a,  # type: A
182    *v  # type: V
183):
184    pass
185
186def fak(
187    a,  # type: A
188    **k,  # type: K
189):
190    pass
191
192def fak(
193    a,  # type: A
194    /,
195    **k,  # type: K
196):
197    pass
198
199def fak(
200    a,  # type: A
201    **k  # type: K
202):
203    pass
204
205def favk(
206    a,  # type: A
207    *v,  # type: V
208    **k,  # type: K
209):
210    pass
211
212def favk(
213    a,  # type: A
214    /,
215    *v,  # type: V
216    **k,  # type: K
217):
218    pass
219
220def favk(
221    a,  # type: A
222    *v,  # type: V
223    **k  # type: K
224):
225    pass
226"""
227
228
229class TypeCommentTests(unittest.TestCase):
230
231    lowest = 4  # Lowest minor version supported
232    highest = sys.version_info[1]  # Highest minor version
233
234    def parse(self, source, feature_version=highest):
235        return ast.parse(source, type_comments=True,
236                         feature_version=feature_version)
237
238    def parse_all(self, source, minver=lowest, maxver=highest, expected_regex=""):
239        for version in range(self.lowest, self.highest + 1):
240            feature_version = (3, version)
241            if minver <= version <= maxver:
242                try:
243                    yield self.parse(source, feature_version)
244                except SyntaxError as err:
245                    raise SyntaxError(str(err) + f" feature_version={feature_version}")
246            else:
247                with self.assertRaisesRegex(SyntaxError, expected_regex,
248                                            msg=f"feature_version={feature_version}"):
249                    self.parse(source, feature_version)
250
251    def classic_parse(self, source):
252        return ast.parse(source)
253
254    def test_funcdef(self):
255        for tree in self.parse_all(funcdef):
256            self.assertEqual(tree.body[0].type_comment, "() -> int")
257            self.assertEqual(tree.body[1].type_comment, "() -> None")
258        tree = self.classic_parse(funcdef)
259        self.assertEqual(tree.body[0].type_comment, None)
260        self.assertEqual(tree.body[1].type_comment, None)
261
262    def test_asyncdef(self):
263        for tree in self.parse_all(asyncdef, minver=5):
264            self.assertEqual(tree.body[0].type_comment, "() -> int")
265            self.assertEqual(tree.body[1].type_comment, "() -> int")
266        tree = self.classic_parse(asyncdef)
267        self.assertEqual(tree.body[0].type_comment, None)
268        self.assertEqual(tree.body[1].type_comment, None)
269
270    def test_asyncvar(self):
271        with self.assertRaises(SyntaxError):
272            self.classic_parse(asyncvar)
273
274    def test_asynccomp(self):
275        for tree in self.parse_all(asynccomp, minver=6):
276            pass
277
278    def test_matmul(self):
279        for tree in self.parse_all(matmul, minver=5):
280            pass
281
282    def test_fstring(self):
283        for tree in self.parse_all(fstring):
284            pass
285
286    def test_underscorednumber(self):
287        for tree in self.parse_all(underscorednumber, minver=6):
288            pass
289
290    def test_redundantdef(self):
291        for tree in self.parse_all(redundantdef, maxver=0,
292                                expected_regex="^Cannot have two type comments on def"):
293            pass
294
295    def test_nonasciidef(self):
296        for tree in self.parse_all(nonasciidef):
297            self.assertEqual(tree.body[0].type_comment, "() -> àçčéñt")
298
299    def test_forstmt(self):
300        for tree in self.parse_all(forstmt):
301            self.assertEqual(tree.body[0].type_comment, "int")
302        tree = self.classic_parse(forstmt)
303        self.assertEqual(tree.body[0].type_comment, None)
304
305    def test_withstmt(self):
306        for tree in self.parse_all(withstmt):
307            self.assertEqual(tree.body[0].type_comment, "int")
308        tree = self.classic_parse(withstmt)
309        self.assertEqual(tree.body[0].type_comment, None)
310
311    def test_parenthesized_withstmt(self):
312        for tree in self.parse_all(parenthesized_withstmt):
313            self.assertEqual(tree.body[0].type_comment, "int")
314            self.assertEqual(tree.body[1].type_comment, "int")
315        tree = self.classic_parse(parenthesized_withstmt)
316        self.assertEqual(tree.body[0].type_comment, None)
317        self.assertEqual(tree.body[1].type_comment, None)
318
319    def test_vardecl(self):
320        for tree in self.parse_all(vardecl):
321            self.assertEqual(tree.body[0].type_comment, "int")
322        tree = self.classic_parse(vardecl)
323        self.assertEqual(tree.body[0].type_comment, None)
324
325    def test_ignores(self):
326        for tree in self.parse_all(ignores):
327            self.assertEqual(
328                [(ti.lineno, ti.tag) for ti in tree.type_ignores],
329                [
330                    (2, ''),
331                    (5, ''),
332                    (8, '[excuse]'),
333                    (9, '=excuse'),
334                    (10, ' [excuse]'),
335                    (11, ' whatever'),
336                ])
337        tree = self.classic_parse(ignores)
338        self.assertEqual(tree.type_ignores, [])
339
340    def test_longargs(self):
341        for tree in self.parse_all(longargs, minver=8):
342            for t in tree.body:
343                # The expected args are encoded in the function name
344                todo = set(t.name[1:])
345                self.assertEqual(len(t.args.args) + len(t.args.posonlyargs),
346                                 len(todo) - bool(t.args.vararg) - bool(t.args.kwarg))
347                self.assertTrue(t.name.startswith('f'), t.name)
348                for index, c in enumerate(t.name[1:]):
349                    todo.remove(c)
350                    if c == 'v':
351                        arg = t.args.vararg
352                    elif c == 'k':
353                        arg = t.args.kwarg
354                    else:
355                        assert 0 <= ord(c) - ord('a') < len(t.args.posonlyargs + t.args.args)
356                        if index < len(t.args.posonlyargs):
357                            arg = t.args.posonlyargs[ord(c) - ord('a')]
358                        else:
359                            arg = t.args.args[ord(c) - ord('a') - len(t.args.posonlyargs)]
360                    self.assertEqual(arg.arg, c)  # That's the argument name
361                    self.assertEqual(arg.type_comment, arg.arg.upper())
362                assert not todo
363        tree = self.classic_parse(longargs)
364        for t in tree.body:
365            for arg in t.args.args + [t.args.vararg, t.args.kwarg]:
366                if arg is not None:
367                    self.assertIsNone(arg.type_comment, "%s(%s:%r)" %
368                                      (t.name, arg.arg, arg.type_comment))
369
370    def test_inappropriate_type_comments(self):
371        """Tests for inappropriately-placed type comments.
372
373        These should be silently ignored with type comments off,
374        but raise SyntaxError with type comments on.
375
376        This is not meant to be exhaustive.
377        """
378
379        def check_both_ways(source):
380            ast.parse(source, type_comments=False)
381            for tree in self.parse_all(source, maxver=0):
382                pass
383
384        check_both_ways("pass  # type: int\n")
385        check_both_ways("foo()  # type: int\n")
386        check_both_ways("x += 1  # type: int\n")
387        check_both_ways("while True:  # type: int\n  continue\n")
388        check_both_ways("while True:\n  continue  # type: int\n")
389        check_both_ways("try:  # type: int\n  pass\nfinally:\n  pass\n")
390        check_both_ways("try:\n  pass\nfinally:  # type: int\n  pass\n")
391        check_both_ways("pass  # type: ignorewhatever\n")
392        check_both_ways("pass  # type: ignoreé\n")
393
394    def test_func_type_input(self):
395
396        def parse_func_type_input(source):
397            return ast.parse(source, "<unknown>", "func_type")
398
399        # Some checks below will crash if the returned structure is wrong
400        tree = parse_func_type_input("() -> int")
401        self.assertEqual(tree.argtypes, [])
402        self.assertEqual(tree.returns.id, "int")
403
404        tree = parse_func_type_input("(int) -> List[str]")
405        self.assertEqual(len(tree.argtypes), 1)
406        arg = tree.argtypes[0]
407        self.assertEqual(arg.id, "int")
408        self.assertEqual(tree.returns.value.id, "List")
409        self.assertEqual(tree.returns.slice.id, "str")
410
411        tree = parse_func_type_input("(int, *str, **Any) -> float")
412        self.assertEqual(tree.argtypes[0].id, "int")
413        self.assertEqual(tree.argtypes[1].id, "str")
414        self.assertEqual(tree.argtypes[2].id, "Any")
415        self.assertEqual(tree.returns.id, "float")
416
417        tree = parse_func_type_input("(*int) -> None")
418        self.assertEqual(tree.argtypes[0].id, "int")
419        tree = parse_func_type_input("(**int) -> None")
420        self.assertEqual(tree.argtypes[0].id, "int")
421        tree = parse_func_type_input("(*int, **str) -> None")
422        self.assertEqual(tree.argtypes[0].id, "int")
423        self.assertEqual(tree.argtypes[1].id, "str")
424
425        with self.assertRaises(SyntaxError):
426            tree = parse_func_type_input("(int, *str, *Any) -> float")
427
428        with self.assertRaises(SyntaxError):
429            tree = parse_func_type_input("(int, **str, Any) -> float")
430
431        with self.assertRaises(SyntaxError):
432            tree = parse_func_type_input("(**int, **str) -> float")
433
434
435if __name__ == '__main__':
436    unittest.main()
437