• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import unittest
2
3GLOBAL_VAR = None
4
5class NamedExpressionInvalidTest(unittest.TestCase):
6
7    def test_named_expression_invalid_01(self):
8        code = """x := 0"""
9
10        with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
11            exec(code, {}, {})
12
13    def test_named_expression_invalid_02(self):
14        code = """x = y := 0"""
15
16        with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
17            exec(code, {}, {})
18
19    def test_named_expression_invalid_03(self):
20        code = """y := f(x)"""
21
22        with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
23            exec(code, {}, {})
24
25    def test_named_expression_invalid_04(self):
26        code = """y0 = y1 := f(x)"""
27
28        with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
29            exec(code, {}, {})
30
31    def test_named_expression_invalid_06(self):
32        code = """((a, b) := (1, 2))"""
33
34        with self.assertRaisesRegex(SyntaxError, "cannot use assignment expressions with tuple"):
35            exec(code, {}, {})
36
37    def test_named_expression_invalid_07(self):
38        code = """def spam(a = b := 42): pass"""
39
40        with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
41            exec(code, {}, {})
42
43    def test_named_expression_invalid_08(self):
44        code = """def spam(a: b := 42 = 5): pass"""
45
46        with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
47            exec(code, {}, {})
48
49    def test_named_expression_invalid_09(self):
50        code = """spam(a=b := 'c')"""
51
52        with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
53            exec(code, {}, {})
54
55    def test_named_expression_invalid_10(self):
56        code = """spam(x = y := f(x))"""
57
58        with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
59            exec(code, {}, {})
60
61    def test_named_expression_invalid_11(self):
62        code = """spam(a=1, b := 2)"""
63
64        with self.assertRaisesRegex(SyntaxError,
65            "positional argument follows keyword argument"):
66            exec(code, {}, {})
67
68    def test_named_expression_invalid_12(self):
69        code = """spam(a=1, (b := 2))"""
70
71        with self.assertRaisesRegex(SyntaxError,
72            "positional argument follows keyword argument"):
73            exec(code, {}, {})
74
75    def test_named_expression_invalid_13(self):
76        code = """spam(a=1, (b := 2))"""
77
78        with self.assertRaisesRegex(SyntaxError,
79            "positional argument follows keyword argument"):
80            exec(code, {}, {})
81
82    def test_named_expression_invalid_14(self):
83        code = """(x := lambda: y := 1)"""
84
85        with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
86            exec(code, {}, {})
87
88    def test_named_expression_invalid_15(self):
89        code = """(lambda: x := 1)"""
90
91        with self.assertRaisesRegex(SyntaxError,
92            "cannot use assignment expressions with lambda"):
93            exec(code, {}, {})
94
95    def test_named_expression_invalid_16(self):
96        code = "[i + 1 for i in i := [1,2]]"
97
98        with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
99            exec(code, {}, {})
100
101    def test_named_expression_invalid_17(self):
102        code = "[i := 0, j := 1 for i, j in [(1, 2), (3, 4)]]"
103
104        with self.assertRaisesRegex(SyntaxError,
105                "did you forget parentheses around the comprehension target?"):
106            exec(code, {}, {})
107
108    def test_named_expression_invalid_in_class_body(self):
109        code = """class Foo():
110            [(42, 1 + ((( j := i )))) for i in range(5)]
111        """
112
113        with self.assertRaisesRegex(SyntaxError,
114            "assignment expression within a comprehension cannot be used in a class body"):
115            exec(code, {}, {})
116
117    def test_named_expression_invalid_rebinding_list_comprehension_iteration_variable(self):
118        cases = [
119            ("Local reuse", 'i', "[i := 0 for i in range(5)]"),
120            ("Nested reuse", 'j', "[[(j := 0) for i in range(5)] for j in range(5)]"),
121            ("Reuse inner loop target", 'j', "[(j := 0) for i in range(5) for j in range(5)]"),
122            ("Unpacking reuse", 'i', "[i := 0 for i, j in [(0, 1)]]"),
123            ("Reuse in loop condition", 'i', "[i+1 for i in range(5) if (i := 0)]"),
124            ("Unreachable reuse", 'i', "[False or (i:=0) for i in range(5)]"),
125            ("Unreachable nested reuse", 'i',
126                "[(i, j) for i in range(5) for j in range(5) if True or (i:=10)]"),
127        ]
128        for case, target, code in cases:
129            msg = f"assignment expression cannot rebind comprehension iteration variable '{target}'"
130            with self.subTest(case=case):
131                with self.assertRaisesRegex(SyntaxError, msg):
132                    exec(code, {}, {})
133
134    def test_named_expression_invalid_rebinding_list_comprehension_inner_loop(self):
135        cases = [
136            ("Inner reuse", 'j', "[i for i in range(5) if (j := 0) for j in range(5)]"),
137            ("Inner unpacking reuse", 'j', "[i for i in range(5) if (j := 0) for j, k in [(0, 1)]]"),
138        ]
139        for case, target, code in cases:
140            msg = f"comprehension inner loop cannot rebind assignment expression target '{target}'"
141            with self.subTest(case=case):
142                with self.assertRaisesRegex(SyntaxError, msg):
143                    exec(code, {}) # Module scope
144                with self.assertRaisesRegex(SyntaxError, msg):
145                    exec(code, {}, {}) # Class scope
146                with self.assertRaisesRegex(SyntaxError, msg):
147                    exec(f"lambda: {code}", {}) # Function scope
148
149    def test_named_expression_invalid_list_comprehension_iterable_expression(self):
150        cases = [
151            ("Top level", "[i for i in (i := range(5))]"),
152            ("Inside tuple", "[i for i in (2, 3, i := range(5))]"),
153            ("Inside list", "[i for i in [2, 3, i := range(5)]]"),
154            ("Different name", "[i for i in (j := range(5))]"),
155            ("Lambda expression", "[i for i in (lambda:(j := range(5)))()]"),
156            ("Inner loop", "[i for i in range(5) for j in (i := range(5))]"),
157            ("Nested comprehension", "[i for i in [j for j in (k := range(5))]]"),
158            ("Nested comprehension condition", "[i for i in [j for j in range(5) if (j := True)]]"),
159            ("Nested comprehension body", "[i for i in [(j := True) for j in range(5)]]"),
160        ]
161        msg = "assignment expression cannot be used in a comprehension iterable expression"
162        for case, code in cases:
163            with self.subTest(case=case):
164                with self.assertRaisesRegex(SyntaxError, msg):
165                    exec(code, {}) # Module scope
166                with self.assertRaisesRegex(SyntaxError, msg):
167                    exec(code, {}, {}) # Class scope
168                with self.assertRaisesRegex(SyntaxError, msg):
169                    exec(f"lambda: {code}", {}) # Function scope
170
171    def test_named_expression_invalid_rebinding_set_comprehension_iteration_variable(self):
172        cases = [
173            ("Local reuse", 'i', "{i := 0 for i in range(5)}"),
174            ("Nested reuse", 'j', "{{(j := 0) for i in range(5)} for j in range(5)}"),
175            ("Reuse inner loop target", 'j', "{(j := 0) for i in range(5) for j in range(5)}"),
176            ("Unpacking reuse", 'i', "{i := 0 for i, j in {(0, 1)}}"),
177            ("Reuse in loop condition", 'i', "{i+1 for i in range(5) if (i := 0)}"),
178            ("Unreachable reuse", 'i', "{False or (i:=0) for i in range(5)}"),
179            ("Unreachable nested reuse", 'i',
180                "{(i, j) for i in range(5) for j in range(5) if True or (i:=10)}"),
181        ]
182        for case, target, code in cases:
183            msg = f"assignment expression cannot rebind comprehension iteration variable '{target}'"
184            with self.subTest(case=case):
185                with self.assertRaisesRegex(SyntaxError, msg):
186                    exec(code, {}, {})
187
188    def test_named_expression_invalid_rebinding_set_comprehension_inner_loop(self):
189        cases = [
190            ("Inner reuse", 'j', "{i for i in range(5) if (j := 0) for j in range(5)}"),
191            ("Inner unpacking reuse", 'j', "{i for i in range(5) if (j := 0) for j, k in {(0, 1)}}"),
192        ]
193        for case, target, code in cases:
194            msg = f"comprehension inner loop cannot rebind assignment expression target '{target}'"
195            with self.subTest(case=case):
196                with self.assertRaisesRegex(SyntaxError, msg):
197                    exec(code, {}) # Module scope
198                with self.assertRaisesRegex(SyntaxError, msg):
199                    exec(code, {}, {}) # Class scope
200                with self.assertRaisesRegex(SyntaxError, msg):
201                    exec(f"lambda: {code}", {}) # Function scope
202
203    def test_named_expression_invalid_set_comprehension_iterable_expression(self):
204        cases = [
205            ("Top level", "{i for i in (i := range(5))}"),
206            ("Inside tuple", "{i for i in (2, 3, i := range(5))}"),
207            ("Inside list", "{i for i in {2, 3, i := range(5)}}"),
208            ("Different name", "{i for i in (j := range(5))}"),
209            ("Lambda expression", "{i for i in (lambda:(j := range(5)))()}"),
210            ("Inner loop", "{i for i in range(5) for j in (i := range(5))}"),
211            ("Nested comprehension", "{i for i in {j for j in (k := range(5))}}"),
212            ("Nested comprehension condition", "{i for i in {j for j in range(5) if (j := True)}}"),
213            ("Nested comprehension body", "{i for i in {(j := True) for j in range(5)}}"),
214        ]
215        msg = "assignment expression cannot be used in a comprehension iterable expression"
216        for case, code in cases:
217            with self.subTest(case=case):
218                with self.assertRaisesRegex(SyntaxError, msg):
219                    exec(code, {}) # Module scope
220                with self.assertRaisesRegex(SyntaxError, msg):
221                    exec(code, {}, {}) # Class scope
222                with self.assertRaisesRegex(SyntaxError, msg):
223                    exec(f"lambda: {code}", {}) # Function scope
224
225
226class NamedExpressionAssignmentTest(unittest.TestCase):
227
228    def test_named_expression_assignment_01(self):
229        (a := 10)
230
231        self.assertEqual(a, 10)
232
233    def test_named_expression_assignment_02(self):
234        a = 20
235        (a := a)
236
237        self.assertEqual(a, 20)
238
239    def test_named_expression_assignment_03(self):
240        (total := 1 + 2)
241
242        self.assertEqual(total, 3)
243
244    def test_named_expression_assignment_04(self):
245        (info := (1, 2, 3))
246
247        self.assertEqual(info, (1, 2, 3))
248
249    def test_named_expression_assignment_05(self):
250        (x := 1, 2)
251
252        self.assertEqual(x, 1)
253
254    def test_named_expression_assignment_06(self):
255        (z := (y := (x := 0)))
256
257        self.assertEqual(x, 0)
258        self.assertEqual(y, 0)
259        self.assertEqual(z, 0)
260
261    def test_named_expression_assignment_07(self):
262        (loc := (1, 2))
263
264        self.assertEqual(loc, (1, 2))
265
266    def test_named_expression_assignment_08(self):
267        if spam := "eggs":
268            self.assertEqual(spam, "eggs")
269        else: self.fail("variable was not assigned using named expression")
270
271    def test_named_expression_assignment_09(self):
272        if True and (spam := True):
273            self.assertTrue(spam)
274        else: self.fail("variable was not assigned using named expression")
275
276    def test_named_expression_assignment_10(self):
277        if (match := 10) == 10:
278            pass
279        else: self.fail("variable was not assigned using named expression")
280
281    def test_named_expression_assignment_11(self):
282        def spam(a):
283            return a
284        input_data = [1, 2, 3]
285        res = [(x, y, x/y) for x in input_data if (y := spam(x)) > 0]
286
287        self.assertEqual(res, [(1, 1, 1.0), (2, 2, 1.0), (3, 3, 1.0)])
288
289    def test_named_expression_assignment_12(self):
290        def spam(a):
291            return a
292        res = [[y := spam(x), x/y] for x in range(1, 5)]
293
294        self.assertEqual(res, [[1, 1.0], [2, 1.0], [3, 1.0], [4, 1.0]])
295
296    def test_named_expression_assignment_13(self):
297        length = len(lines := [1, 2])
298
299        self.assertEqual(length, 2)
300        self.assertEqual(lines, [1,2])
301
302    def test_named_expression_assignment_14(self):
303        """
304        Where all variables are positive integers, and a is at least as large
305        as the n'th root of x, this algorithm returns the floor of the n'th
306        root of x (and roughly doubling the number of accurate bits per
307        iteration):
308        """
309        a = 9
310        n = 2
311        x = 3
312
313        while a > (d := x // a**(n-1)):
314            a = ((n-1)*a + d) // n
315
316        self.assertEqual(a, 1)
317
318    def test_named_expression_assignment_15(self):
319        while a := False:
320            pass  # This will not run
321
322        self.assertEqual(a, False)
323
324    def test_named_expression_assignment_16(self):
325        a, b = 1, 2
326        fib = {(c := a): (a := b) + (b := a + c) - b for __ in range(6)}
327        self.assertEqual(fib, {1: 2, 2: 3, 3: 5, 5: 8, 8: 13, 13: 21})
328
329    def test_named_expression_assignment_17(self):
330        a = [1]
331        element = a[b:=0]
332        self.assertEqual(b, 0)
333        self.assertEqual(element, a[0])
334
335    def test_named_expression_assignment_18(self):
336        class TwoDimensionalList:
337            def __init__(self, two_dimensional_list):
338                self.two_dimensional_list = two_dimensional_list
339
340            def __getitem__(self, index):
341                return self.two_dimensional_list[index[0]][index[1]]
342
343        a = TwoDimensionalList([[1], [2]])
344        element = a[b:=0, c:=0]
345        self.assertEqual(b, 0)
346        self.assertEqual(c, 0)
347        self.assertEqual(element, a.two_dimensional_list[b][c])
348
349
350
351class NamedExpressionScopeTest(unittest.TestCase):
352
353    def test_named_expression_scope_01(self):
354        code = """def spam():
355    (a := 5)
356print(a)"""
357
358        with self.assertRaisesRegex(NameError, "name 'a' is not defined"):
359            exec(code, {}, {})
360
361    def test_named_expression_scope_02(self):
362        total = 0
363        partial_sums = [total := total + v for v in range(5)]
364
365        self.assertEqual(partial_sums, [0, 1, 3, 6, 10])
366        self.assertEqual(total, 10)
367
368    def test_named_expression_scope_03(self):
369        containsOne = any((lastNum := num) == 1 for num in [1, 2, 3])
370
371        self.assertTrue(containsOne)
372        self.assertEqual(lastNum, 1)
373
374    def test_named_expression_scope_04(self):
375        def spam(a):
376            return a
377        res = [[y := spam(x), x/y] for x in range(1, 5)]
378
379        self.assertEqual(y, 4)
380
381    def test_named_expression_scope_05(self):
382        def spam(a):
383            return a
384        input_data = [1, 2, 3]
385        res = [(x, y, x/y) for x in input_data if (y := spam(x)) > 0]
386
387        self.assertEqual(res, [(1, 1, 1.0), (2, 2, 1.0), (3, 3, 1.0)])
388        self.assertEqual(y, 3)
389
390    def test_named_expression_scope_06(self):
391        res = [[spam := i for i in range(3)] for j in range(2)]
392
393        self.assertEqual(res, [[0, 1, 2], [0, 1, 2]])
394        self.assertEqual(spam, 2)
395
396    def test_named_expression_scope_07(self):
397        len(lines := [1, 2])
398
399        self.assertEqual(lines, [1, 2])
400
401    def test_named_expression_scope_08(self):
402        def spam(a):
403            return a
404
405        def eggs(b):
406            return b * 2
407
408        res = [spam(a := eggs(b := h)) for h in range(2)]
409
410        self.assertEqual(res, [0, 2])
411        self.assertEqual(a, 2)
412        self.assertEqual(b, 1)
413
414    def test_named_expression_scope_09(self):
415        def spam(a):
416            return a
417
418        def eggs(b):
419            return b * 2
420
421        res = [spam(a := eggs(a := h)) for h in range(2)]
422
423        self.assertEqual(res, [0, 2])
424        self.assertEqual(a, 2)
425
426    def test_named_expression_scope_10(self):
427        res = [b := [a := 1 for i in range(2)] for j in range(2)]
428
429        self.assertEqual(res, [[1, 1], [1, 1]])
430        self.assertEqual(a, 1)
431        self.assertEqual(b, [1, 1])
432
433    def test_named_expression_scope_11(self):
434        res = [j := i for i in range(5)]
435
436        self.assertEqual(res, [0, 1, 2, 3, 4])
437        self.assertEqual(j, 4)
438
439    def test_named_expression_scope_17(self):
440        b = 0
441        res = [b := i + b for i in range(5)]
442
443        self.assertEqual(res, [0, 1, 3, 6, 10])
444        self.assertEqual(b, 10)
445
446    def test_named_expression_scope_18(self):
447        def spam(a):
448            return a
449
450        res = spam(b := 2)
451
452        self.assertEqual(res, 2)
453        self.assertEqual(b, 2)
454
455    def test_named_expression_scope_19(self):
456        def spam(a):
457            return a
458
459        res = spam((b := 2))
460
461        self.assertEqual(res, 2)
462        self.assertEqual(b, 2)
463
464    def test_named_expression_scope_20(self):
465        def spam(a):
466            return a
467
468        res = spam(a=(b := 2))
469
470        self.assertEqual(res, 2)
471        self.assertEqual(b, 2)
472
473    def test_named_expression_scope_21(self):
474        def spam(a, b):
475            return a + b
476
477        res = spam(c := 2, b=1)
478
479        self.assertEqual(res, 3)
480        self.assertEqual(c, 2)
481
482    def test_named_expression_scope_22(self):
483        def spam(a, b):
484            return a + b
485
486        res = spam((c := 2), b=1)
487
488        self.assertEqual(res, 3)
489        self.assertEqual(c, 2)
490
491    def test_named_expression_scope_23(self):
492        def spam(a, b):
493            return a + b
494
495        res = spam(b=(c := 2), a=1)
496
497        self.assertEqual(res, 3)
498        self.assertEqual(c, 2)
499
500    def test_named_expression_scope_24(self):
501        a = 10
502        def spam():
503            nonlocal a
504            (a := 20)
505        spam()
506
507        self.assertEqual(a, 20)
508
509    def test_named_expression_scope_25(self):
510        ns = {}
511        code = """a = 10
512def spam():
513    global a
514    (a := 20)
515spam()"""
516
517        exec(code, ns, {})
518
519        self.assertEqual(ns["a"], 20)
520
521    def test_named_expression_variable_reuse_in_comprehensions(self):
522        # The compiler is expected to raise syntax error for comprehension
523        # iteration variables, but should be fine with rebinding of other
524        # names (e.g. globals, nonlocals, other assignment expressions)
525
526        # The cases are all defined to produce the same expected result
527        # Each comprehension is checked at both function scope and module scope
528        rebinding = "[x := i for i in range(3) if (x := i) or not x]"
529        filter_ref = "[x := i for i in range(3) if x or not x]"
530        body_ref = "[x for i in range(3) if (x := i) or not x]"
531        nested_ref = "[j for i in range(3) if x or not x for j in range(3) if (x := i)][:-3]"
532        cases = [
533            ("Rebind global", f"x = 1; result = {rebinding}"),
534            ("Rebind nonlocal", f"result, x = (lambda x=1: ({rebinding}, x))()"),
535            ("Filter global", f"x = 1; result = {filter_ref}"),
536            ("Filter nonlocal", f"result, x = (lambda x=1: ({filter_ref}, x))()"),
537            ("Body global", f"x = 1; result = {body_ref}"),
538            ("Body nonlocal", f"result, x = (lambda x=1: ({body_ref}, x))()"),
539            ("Nested global", f"x = 1; result = {nested_ref}"),
540            ("Nested nonlocal", f"result, x = (lambda x=1: ({nested_ref}, x))()"),
541        ]
542        for case, code in cases:
543            with self.subTest(case=case):
544                ns = {}
545                exec(code, ns)
546                self.assertEqual(ns["x"], 2)
547                self.assertEqual(ns["result"], [0, 1, 2])
548
549    def test_named_expression_global_scope(self):
550        sentinel = object()
551        global GLOBAL_VAR
552        def f():
553            global GLOBAL_VAR
554            [GLOBAL_VAR := sentinel for _ in range(1)]
555            self.assertEqual(GLOBAL_VAR, sentinel)
556        try:
557            f()
558            self.assertEqual(GLOBAL_VAR, sentinel)
559        finally:
560            GLOBAL_VAR = None
561
562    def test_named_expression_global_scope_no_global_keyword(self):
563        sentinel = object()
564        def f():
565            GLOBAL_VAR = None
566            [GLOBAL_VAR := sentinel for _ in range(1)]
567            self.assertEqual(GLOBAL_VAR, sentinel)
568        f()
569        self.assertEqual(GLOBAL_VAR, None)
570
571    def test_named_expression_nonlocal_scope(self):
572        sentinel = object()
573        def f():
574            nonlocal_var = None
575            def g():
576                nonlocal nonlocal_var
577                [nonlocal_var := sentinel for _ in range(1)]
578            g()
579            self.assertEqual(nonlocal_var, sentinel)
580        f()
581
582    def test_named_expression_nonlocal_scope_no_nonlocal_keyword(self):
583        sentinel = object()
584        def f():
585            nonlocal_var = None
586            def g():
587                [nonlocal_var := sentinel for _ in range(1)]
588            g()
589            self.assertEqual(nonlocal_var, None)
590        f()
591
592    def test_named_expression_scope_in_genexp(self):
593        a = 1
594        b = [1, 2, 3, 4]
595        genexp = (c := i + a for i in b)
596
597        self.assertNotIn("c", locals())
598        for idx, elem in enumerate(genexp):
599            self.assertEqual(elem, b[idx] + a)
600
601
602if __name__ == "__main__":
603    unittest.main()
604