• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import os
2import unittest
3
4GLOBAL_VAR = None
5
6class NamedExpressionInvalidTest(unittest.TestCase):
7
8    def test_named_expression_invalid_01(self):
9        code = """x := 0"""
10
11        with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
12            exec(code, {}, {})
13
14    def test_named_expression_invalid_02(self):
15        code = """x = y := 0"""
16
17        with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
18            exec(code, {}, {})
19
20    def test_named_expression_invalid_03(self):
21        code = """y := f(x)"""
22
23        with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
24            exec(code, {}, {})
25
26    def test_named_expression_invalid_04(self):
27        code = """y0 = y1 := f(x)"""
28
29        with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
30            exec(code, {}, {})
31
32    def test_named_expression_invalid_06(self):
33        code = """((a, b) := (1, 2))"""
34
35        with self.assertRaisesRegex(SyntaxError, "cannot use named assignment with tuple"):
36            exec(code, {}, {})
37
38    def test_named_expression_invalid_07(self):
39        code = """def spam(a = b := 42): pass"""
40
41        with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
42            exec(code, {}, {})
43
44    def test_named_expression_invalid_08(self):
45        code = """def spam(a: b := 42 = 5): pass"""
46
47        with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
48            exec(code, {}, {})
49
50    def test_named_expression_invalid_09(self):
51        code = """spam(a=b := 'c')"""
52
53        with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
54            exec(code, {}, {})
55
56    def test_named_expression_invalid_10(self):
57        code = """spam(x = y := f(x))"""
58
59        with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
60            exec(code, {}, {})
61
62    def test_named_expression_invalid_11(self):
63        code = """spam(a=1, b := 2)"""
64
65        with self.assertRaisesRegex(SyntaxError,
66            "positional argument follows keyword argument"):
67            exec(code, {}, {})
68
69    def test_named_expression_invalid_12(self):
70        code = """spam(a=1, (b := 2))"""
71
72        with self.assertRaisesRegex(SyntaxError,
73            "positional argument follows keyword argument"):
74            exec(code, {}, {})
75
76    def test_named_expression_invalid_13(self):
77        code = """spam(a=1, (b := 2))"""
78
79        with self.assertRaisesRegex(SyntaxError,
80            "positional argument follows keyword argument"):
81            exec(code, {}, {})
82
83    def test_named_expression_invalid_14(self):
84        code = """(x := lambda: y := 1)"""
85
86        with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
87            exec(code, {}, {})
88
89    def test_named_expression_invalid_15(self):
90        code = """(lambda: x := 1)"""
91
92        with self.assertRaisesRegex(SyntaxError,
93            "cannot use named assignment with lambda"):
94            exec(code, {}, {})
95
96    def test_named_expression_invalid_16(self):
97        code = "[i + 1 for i in i := [1,2]]"
98
99        with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
100            exec(code, {}, {})
101
102    def test_named_expression_invalid_17(self):
103        code = "[i := 0, j := 1 for i, j in [(1, 2), (3, 4)]]"
104
105        with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
106            exec(code, {}, {})
107
108    def test_named_expression_invalid_in_class_body(self):
109        code = """class Foo():
110            [(42, 1 + ((( j := i )))) for i in range(5)]
111        """
112
113        with self.assertRaisesRegex(SyntaxError,
114            "assignment expression within a comprehension cannot be used in a class body"):
115            exec(code, {}, {})
116
117    def test_named_expression_invalid_rebinding_comprehension_iteration_variable(self):
118        cases = [
119            ("Local reuse", 'i', "[i := 0 for i in range(5)]"),
120            ("Nested reuse", 'j', "[[(j := 0) for i in range(5)] for j in range(5)]"),
121            ("Reuse inner loop target", 'j', "[(j := 0) for i in range(5) for j in range(5)]"),
122            ("Unpacking reuse", 'i', "[i := 0 for i, j in [(0, 1)]]"),
123            ("Reuse in loop condition", 'i', "[i+1 for i in range(5) if (i := 0)]"),
124            ("Unreachable reuse", 'i', "[False or (i:=0) for i in range(5)]"),
125            ("Unreachable nested reuse", 'i',
126                "[(i, j) for i in range(5) for j in range(5) if True or (i:=10)]"),
127        ]
128        for case, target, code in cases:
129            msg = f"assignment expression cannot rebind comprehension iteration variable '{target}'"
130            with self.subTest(case=case):
131                with self.assertRaisesRegex(SyntaxError, msg):
132                    exec(code, {}, {})
133
134    def test_named_expression_invalid_rebinding_comprehension_inner_loop(self):
135        cases = [
136            ("Inner reuse", 'j', "[i for i in range(5) if (j := 0) for j in range(5)]"),
137            ("Inner unpacking reuse", 'j', "[i for i in range(5) if (j := 0) for j, k in [(0, 1)]]"),
138        ]
139        for case, target, code in cases:
140            msg = f"comprehension inner loop cannot rebind assignment expression target '{target}'"
141            with self.subTest(case=case):
142                with self.assertRaisesRegex(SyntaxError, msg):
143                    exec(code, {}) # Module scope
144                with self.assertRaisesRegex(SyntaxError, msg):
145                    exec(code, {}, {}) # Class scope
146                with self.assertRaisesRegex(SyntaxError, msg):
147                    exec(f"lambda: {code}", {}) # Function scope
148
149    def test_named_expression_invalid_comprehension_iterable_expression(self):
150        cases = [
151            ("Top level", "[i for i in (i := range(5))]"),
152            ("Inside tuple", "[i for i in (2, 3, i := range(5))]"),
153            ("Inside list", "[i for i in [2, 3, i := range(5)]]"),
154            ("Different name", "[i for i in (j := range(5))]"),
155            ("Lambda expression", "[i for i in (lambda:(j := range(5)))()]"),
156            ("Inner loop", "[i for i in range(5) for j in (i := range(5))]"),
157            ("Nested comprehension", "[i for i in [j for j in (k := range(5))]]"),
158            ("Nested comprehension condition", "[i for i in [j for j in range(5) if (j := True)]]"),
159            ("Nested comprehension body", "[i for i in [(j := True) for j in range(5)]]"),
160        ]
161        msg = "assignment expression cannot be used in a comprehension iterable expression"
162        for case, code in cases:
163            with self.subTest(case=case):
164                with self.assertRaisesRegex(SyntaxError, msg):
165                    exec(code, {}) # Module scope
166                with self.assertRaisesRegex(SyntaxError, msg):
167                    exec(code, {}, {}) # Class scope
168                with self.assertRaisesRegex(SyntaxError, msg):
169                    exec(f"lambda: {code}", {}) # Function scope
170
171
172class NamedExpressionAssignmentTest(unittest.TestCase):
173
174    def test_named_expression_assignment_01(self):
175        (a := 10)
176
177        self.assertEqual(a, 10)
178
179    def test_named_expression_assignment_02(self):
180        a = 20
181        (a := a)
182
183        self.assertEqual(a, 20)
184
185    def test_named_expression_assignment_03(self):
186        (total := 1 + 2)
187
188        self.assertEqual(total, 3)
189
190    def test_named_expression_assignment_04(self):
191        (info := (1, 2, 3))
192
193        self.assertEqual(info, (1, 2, 3))
194
195    def test_named_expression_assignment_05(self):
196        (x := 1, 2)
197
198        self.assertEqual(x, 1)
199
200    def test_named_expression_assignment_06(self):
201        (z := (y := (x := 0)))
202
203        self.assertEqual(x, 0)
204        self.assertEqual(y, 0)
205        self.assertEqual(z, 0)
206
207    def test_named_expression_assignment_07(self):
208        (loc := (1, 2))
209
210        self.assertEqual(loc, (1, 2))
211
212    def test_named_expression_assignment_08(self):
213        if spam := "eggs":
214            self.assertEqual(spam, "eggs")
215        else: self.fail("variable was not assigned using named expression")
216
217    def test_named_expression_assignment_09(self):
218        if True and (spam := True):
219            self.assertTrue(spam)
220        else: self.fail("variable was not assigned using named expression")
221
222    def test_named_expression_assignment_10(self):
223        if (match := 10) == 10:
224            pass
225        else: self.fail("variable was not assigned using named expression")
226
227    def test_named_expression_assignment_11(self):
228        def spam(a):
229            return a
230        input_data = [1, 2, 3]
231        res = [(x, y, x/y) for x in input_data if (y := spam(x)) > 0]
232
233        self.assertEqual(res, [(1, 1, 1.0), (2, 2, 1.0), (3, 3, 1.0)])
234
235    def test_named_expression_assignment_12(self):
236        def spam(a):
237            return a
238        res = [[y := spam(x), x/y] for x in range(1, 5)]
239
240        self.assertEqual(res, [[1, 1.0], [2, 1.0], [3, 1.0], [4, 1.0]])
241
242    def test_named_expression_assignment_13(self):
243        length = len(lines := [1, 2])
244
245        self.assertEqual(length, 2)
246        self.assertEqual(lines, [1,2])
247
248    def test_named_expression_assignment_14(self):
249        """
250        Where all variables are positive integers, and a is at least as large
251        as the n'th root of x, this algorithm returns the floor of the n'th
252        root of x (and roughly doubling the number of accurate bits per
253        iteration):
254        """
255        a = 9
256        n = 2
257        x = 3
258
259        while a > (d := x // a**(n-1)):
260            a = ((n-1)*a + d) // n
261
262        self.assertEqual(a, 1)
263
264    def test_named_expression_assignment_15(self):
265        while a := False:
266            pass  # This will not run
267
268        self.assertEqual(a, False)
269
270    def test_named_expression_assignment_16(self):
271        a, b = 1, 2
272        fib = {(c := a): (a := b) + (b := a + c) - b for __ in range(6)}
273        self.assertEqual(fib, {1: 2, 2: 3, 3: 5, 5: 8, 8: 13, 13: 21})
274
275
276class NamedExpressionScopeTest(unittest.TestCase):
277
278    def test_named_expression_scope_01(self):
279        code = """def spam():
280    (a := 5)
281print(a)"""
282
283        with self.assertRaisesRegex(NameError, "name 'a' is not defined"):
284            exec(code, {}, {})
285
286    def test_named_expression_scope_02(self):
287        total = 0
288        partial_sums = [total := total + v for v in range(5)]
289
290        self.assertEqual(partial_sums, [0, 1, 3, 6, 10])
291        self.assertEqual(total, 10)
292
293    def test_named_expression_scope_03(self):
294        containsOne = any((lastNum := num) == 1 for num in [1, 2, 3])
295
296        self.assertTrue(containsOne)
297        self.assertEqual(lastNum, 1)
298
299    def test_named_expression_scope_04(self):
300        def spam(a):
301            return a
302        res = [[y := spam(x), x/y] for x in range(1, 5)]
303
304        self.assertEqual(y, 4)
305
306    def test_named_expression_scope_05(self):
307        def spam(a):
308            return a
309        input_data = [1, 2, 3]
310        res = [(x, y, x/y) for x in input_data if (y := spam(x)) > 0]
311
312        self.assertEqual(res, [(1, 1, 1.0), (2, 2, 1.0), (3, 3, 1.0)])
313        self.assertEqual(y, 3)
314
315    def test_named_expression_scope_06(self):
316        res = [[spam := i for i in range(3)] for j in range(2)]
317
318        self.assertEqual(res, [[0, 1, 2], [0, 1, 2]])
319        self.assertEqual(spam, 2)
320
321    def test_named_expression_scope_07(self):
322        len(lines := [1, 2])
323
324        self.assertEqual(lines, [1, 2])
325
326    def test_named_expression_scope_08(self):
327        def spam(a):
328            return a
329
330        def eggs(b):
331            return b * 2
332
333        res = [spam(a := eggs(b := h)) for h in range(2)]
334
335        self.assertEqual(res, [0, 2])
336        self.assertEqual(a, 2)
337        self.assertEqual(b, 1)
338
339    def test_named_expression_scope_09(self):
340        def spam(a):
341            return a
342
343        def eggs(b):
344            return b * 2
345
346        res = [spam(a := eggs(a := h)) for h in range(2)]
347
348        self.assertEqual(res, [0, 2])
349        self.assertEqual(a, 2)
350
351    def test_named_expression_scope_10(self):
352        res = [b := [a := 1 for i in range(2)] for j in range(2)]
353
354        self.assertEqual(res, [[1, 1], [1, 1]])
355        self.assertEqual(a, 1)
356        self.assertEqual(b, [1, 1])
357
358    def test_named_expression_scope_11(self):
359        res = [j := i for i in range(5)]
360
361        self.assertEqual(res, [0, 1, 2, 3, 4])
362        self.assertEqual(j, 4)
363
364    def test_named_expression_scope_17(self):
365        b = 0
366        res = [b := i + b for i in range(5)]
367
368        self.assertEqual(res, [0, 1, 3, 6, 10])
369        self.assertEqual(b, 10)
370
371    def test_named_expression_scope_18(self):
372        def spam(a):
373            return a
374
375        res = spam(b := 2)
376
377        self.assertEqual(res, 2)
378        self.assertEqual(b, 2)
379
380    def test_named_expression_scope_19(self):
381        def spam(a):
382            return a
383
384        res = spam((b := 2))
385
386        self.assertEqual(res, 2)
387        self.assertEqual(b, 2)
388
389    def test_named_expression_scope_20(self):
390        def spam(a):
391            return a
392
393        res = spam(a=(b := 2))
394
395        self.assertEqual(res, 2)
396        self.assertEqual(b, 2)
397
398    def test_named_expression_scope_21(self):
399        def spam(a, b):
400            return a + b
401
402        res = spam(c := 2, b=1)
403
404        self.assertEqual(res, 3)
405        self.assertEqual(c, 2)
406
407    def test_named_expression_scope_22(self):
408        def spam(a, b):
409            return a + b
410
411        res = spam((c := 2), b=1)
412
413        self.assertEqual(res, 3)
414        self.assertEqual(c, 2)
415
416    def test_named_expression_scope_23(self):
417        def spam(a, b):
418            return a + b
419
420        res = spam(b=(c := 2), a=1)
421
422        self.assertEqual(res, 3)
423        self.assertEqual(c, 2)
424
425    def test_named_expression_scope_24(self):
426        a = 10
427        def spam():
428            nonlocal a
429            (a := 20)
430        spam()
431
432        self.assertEqual(a, 20)
433
434    def test_named_expression_scope_25(self):
435        ns = {}
436        code = """a = 10
437def spam():
438    global a
439    (a := 20)
440spam()"""
441
442        exec(code, ns, {})
443
444        self.assertEqual(ns["a"], 20)
445
446    def test_named_expression_variable_reuse_in_comprehensions(self):
447        # The compiler is expected to raise syntax error for comprehension
448        # iteration variables, but should be fine with rebinding of other
449        # names (e.g. globals, nonlocals, other assignment expressions)
450
451        # The cases are all defined to produce the same expected result
452        # Each comprehension is checked at both function scope and module scope
453        rebinding = "[x := i for i in range(3) if (x := i) or not x]"
454        filter_ref = "[x := i for i in range(3) if x or not x]"
455        body_ref = "[x for i in range(3) if (x := i) or not x]"
456        nested_ref = "[j for i in range(3) if x or not x for j in range(3) if (x := i)][:-3]"
457        cases = [
458            ("Rebind global", f"x = 1; result = {rebinding}"),
459            ("Rebind nonlocal", f"result, x = (lambda x=1: ({rebinding}, x))()"),
460            ("Filter global", f"x = 1; result = {filter_ref}"),
461            ("Filter nonlocal", f"result, x = (lambda x=1: ({filter_ref}, x))()"),
462            ("Body global", f"x = 1; result = {body_ref}"),
463            ("Body nonlocal", f"result, x = (lambda x=1: ({body_ref}, x))()"),
464            ("Nested global", f"x = 1; result = {nested_ref}"),
465            ("Nested nonlocal", f"result, x = (lambda x=1: ({nested_ref}, x))()"),
466        ]
467        for case, code in cases:
468            with self.subTest(case=case):
469                ns = {}
470                exec(code, ns)
471                self.assertEqual(ns["x"], 2)
472                self.assertEqual(ns["result"], [0, 1, 2])
473
474    def test_named_expression_global_scope(self):
475        sentinel = object()
476        global GLOBAL_VAR
477        def f():
478            global GLOBAL_VAR
479            [GLOBAL_VAR := sentinel for _ in range(1)]
480            self.assertEqual(GLOBAL_VAR, sentinel)
481        try:
482            f()
483            self.assertEqual(GLOBAL_VAR, sentinel)
484        finally:
485            GLOBAL_VAR = None
486
487    def test_named_expression_global_scope_no_global_keyword(self):
488        sentinel = object()
489        def f():
490            GLOBAL_VAR = None
491            [GLOBAL_VAR := sentinel for _ in range(1)]
492            self.assertEqual(GLOBAL_VAR, sentinel)
493        f()
494        self.assertEqual(GLOBAL_VAR, None)
495
496    def test_named_expression_nonlocal_scope(self):
497        sentinel = object()
498        def f():
499            nonlocal_var = None
500            def g():
501                nonlocal nonlocal_var
502                [nonlocal_var := sentinel for _ in range(1)]
503            g()
504            self.assertEqual(nonlocal_var, sentinel)
505        f()
506
507    def test_named_expression_nonlocal_scope_no_nonlocal_keyword(self):
508        sentinel = object()
509        def f():
510            nonlocal_var = None
511            def g():
512                [nonlocal_var := sentinel for _ in range(1)]
513            g()
514            self.assertEqual(nonlocal_var, None)
515        f()
516
517
518if __name__ == "__main__":
519    unittest.main()
520