1import unittest 2 3GLOBAL_VAR = None 4 5class NamedExpressionInvalidTest(unittest.TestCase): 6 7 def test_named_expression_invalid_01(self): 8 code = """x := 0""" 9 10 with self.assertRaisesRegex(SyntaxError, "invalid syntax"): 11 exec(code, {}, {}) 12 13 def test_named_expression_invalid_02(self): 14 code = """x = y := 0""" 15 16 with self.assertRaisesRegex(SyntaxError, "invalid syntax"): 17 exec(code, {}, {}) 18 19 def test_named_expression_invalid_03(self): 20 code = """y := f(x)""" 21 22 with self.assertRaisesRegex(SyntaxError, "invalid syntax"): 23 exec(code, {}, {}) 24 25 def test_named_expression_invalid_04(self): 26 code = """y0 = y1 := f(x)""" 27 28 with self.assertRaisesRegex(SyntaxError, "invalid syntax"): 29 exec(code, {}, {}) 30 31 def test_named_expression_invalid_06(self): 32 code = """((a, b) := (1, 2))""" 33 34 with self.assertRaisesRegex(SyntaxError, "cannot use assignment expressions with tuple"): 35 exec(code, {}, {}) 36 37 def test_named_expression_invalid_07(self): 38 code = """def spam(a = b := 42): pass""" 39 40 with self.assertRaisesRegex(SyntaxError, "invalid syntax"): 41 exec(code, {}, {}) 42 43 def test_named_expression_invalid_08(self): 44 code = """def spam(a: b := 42 = 5): pass""" 45 46 with self.assertRaisesRegex(SyntaxError, "invalid syntax"): 47 exec(code, {}, {}) 48 49 def test_named_expression_invalid_09(self): 50 code = """spam(a=b := 'c')""" 51 52 with self.assertRaisesRegex(SyntaxError, "invalid syntax"): 53 exec(code, {}, {}) 54 55 def test_named_expression_invalid_10(self): 56 code = """spam(x = y := f(x))""" 57 58 with self.assertRaisesRegex(SyntaxError, "invalid syntax"): 59 exec(code, {}, {}) 60 61 def test_named_expression_invalid_11(self): 62 code = """spam(a=1, b := 2)""" 63 64 with self.assertRaisesRegex(SyntaxError, 65 "positional argument follows keyword argument"): 66 exec(code, {}, {}) 67 68 def test_named_expression_invalid_12(self): 69 code = """spam(a=1, (b := 2))""" 70 71 with self.assertRaisesRegex(SyntaxError, 72 "positional argument follows keyword argument"): 73 exec(code, {}, {}) 74 75 def test_named_expression_invalid_13(self): 76 code = """spam(a=1, (b := 2))""" 77 78 with self.assertRaisesRegex(SyntaxError, 79 "positional argument follows keyword argument"): 80 exec(code, {}, {}) 81 82 def test_named_expression_invalid_14(self): 83 code = """(x := lambda: y := 1)""" 84 85 with self.assertRaisesRegex(SyntaxError, "invalid syntax"): 86 exec(code, {}, {}) 87 88 def test_named_expression_invalid_15(self): 89 code = """(lambda: x := 1)""" 90 91 with self.assertRaisesRegex(SyntaxError, 92 "cannot use assignment expressions with lambda"): 93 exec(code, {}, {}) 94 95 def test_named_expression_invalid_16(self): 96 code = "[i + 1 for i in i := [1,2]]" 97 98 with self.assertRaisesRegex(SyntaxError, "invalid syntax"): 99 exec(code, {}, {}) 100 101 def test_named_expression_invalid_17(self): 102 code = "[i := 0, j := 1 for i, j in [(1, 2), (3, 4)]]" 103 104 with self.assertRaisesRegex(SyntaxError, 105 "did you forget parentheses around the comprehension target?"): 106 exec(code, {}, {}) 107 108 def test_named_expression_invalid_in_class_body(self): 109 code = """class Foo(): 110 [(42, 1 + ((( j := i )))) for i in range(5)] 111 """ 112 113 with self.assertRaisesRegex(SyntaxError, 114 "assignment expression within a comprehension cannot be used in a class body"): 115 exec(code, {}, {}) 116 117 def test_named_expression_invalid_rebinding_list_comprehension_iteration_variable(self): 118 cases = [ 119 ("Local reuse", 'i', "[i := 0 for i in range(5)]"), 120 ("Nested reuse", 'j', "[[(j := 0) for i in range(5)] for j in range(5)]"), 121 ("Reuse inner loop target", 'j', "[(j := 0) for i in range(5) for j in range(5)]"), 122 ("Unpacking reuse", 'i', "[i := 0 for i, j in [(0, 1)]]"), 123 ("Reuse in loop condition", 'i', "[i+1 for i in range(5) if (i := 0)]"), 124 ("Unreachable reuse", 'i', "[False or (i:=0) for i in range(5)]"), 125 ("Unreachable nested reuse", 'i', 126 "[(i, j) for i in range(5) for j in range(5) if True or (i:=10)]"), 127 ] 128 for case, target, code in cases: 129 msg = f"assignment expression cannot rebind comprehension iteration variable '{target}'" 130 with self.subTest(case=case): 131 with self.assertRaisesRegex(SyntaxError, msg): 132 exec(code, {}, {}) 133 134 def test_named_expression_invalid_rebinding_list_comprehension_inner_loop(self): 135 cases = [ 136 ("Inner reuse", 'j', "[i for i in range(5) if (j := 0) for j in range(5)]"), 137 ("Inner unpacking reuse", 'j', "[i for i in range(5) if (j := 0) for j, k in [(0, 1)]]"), 138 ] 139 for case, target, code in cases: 140 msg = f"comprehension inner loop cannot rebind assignment expression target '{target}'" 141 with self.subTest(case=case): 142 with self.assertRaisesRegex(SyntaxError, msg): 143 exec(code, {}) # Module scope 144 with self.assertRaisesRegex(SyntaxError, msg): 145 exec(code, {}, {}) # Class scope 146 with self.assertRaisesRegex(SyntaxError, msg): 147 exec(f"lambda: {code}", {}) # Function scope 148 149 def test_named_expression_invalid_list_comprehension_iterable_expression(self): 150 cases = [ 151 ("Top level", "[i for i in (i := range(5))]"), 152 ("Inside tuple", "[i for i in (2, 3, i := range(5))]"), 153 ("Inside list", "[i for i in [2, 3, i := range(5)]]"), 154 ("Different name", "[i for i in (j := range(5))]"), 155 ("Lambda expression", "[i for i in (lambda:(j := range(5)))()]"), 156 ("Inner loop", "[i for i in range(5) for j in (i := range(5))]"), 157 ("Nested comprehension", "[i for i in [j for j in (k := range(5))]]"), 158 ("Nested comprehension condition", "[i for i in [j for j in range(5) if (j := True)]]"), 159 ("Nested comprehension body", "[i for i in [(j := True) for j in range(5)]]"), 160 ] 161 msg = "assignment expression cannot be used in a comprehension iterable expression" 162 for case, code in cases: 163 with self.subTest(case=case): 164 with self.assertRaisesRegex(SyntaxError, msg): 165 exec(code, {}) # Module scope 166 with self.assertRaisesRegex(SyntaxError, msg): 167 exec(code, {}, {}) # Class scope 168 with self.assertRaisesRegex(SyntaxError, msg): 169 exec(f"lambda: {code}", {}) # Function scope 170 171 def test_named_expression_invalid_rebinding_set_comprehension_iteration_variable(self): 172 cases = [ 173 ("Local reuse", 'i', "{i := 0 for i in range(5)}"), 174 ("Nested reuse", 'j', "{{(j := 0) for i in range(5)} for j in range(5)}"), 175 ("Reuse inner loop target", 'j', "{(j := 0) for i in range(5) for j in range(5)}"), 176 ("Unpacking reuse", 'i', "{i := 0 for i, j in {(0, 1)}}"), 177 ("Reuse in loop condition", 'i', "{i+1 for i in range(5) if (i := 0)}"), 178 ("Unreachable reuse", 'i', "{False or (i:=0) for i in range(5)}"), 179 ("Unreachable nested reuse", 'i', 180 "{(i, j) for i in range(5) for j in range(5) if True or (i:=10)}"), 181 ] 182 for case, target, code in cases: 183 msg = f"assignment expression cannot rebind comprehension iteration variable '{target}'" 184 with self.subTest(case=case): 185 with self.assertRaisesRegex(SyntaxError, msg): 186 exec(code, {}, {}) 187 188 def test_named_expression_invalid_rebinding_set_comprehension_inner_loop(self): 189 cases = [ 190 ("Inner reuse", 'j', "{i for i in range(5) if (j := 0) for j in range(5)}"), 191 ("Inner unpacking reuse", 'j', "{i for i in range(5) if (j := 0) for j, k in {(0, 1)}}"), 192 ] 193 for case, target, code in cases: 194 msg = f"comprehension inner loop cannot rebind assignment expression target '{target}'" 195 with self.subTest(case=case): 196 with self.assertRaisesRegex(SyntaxError, msg): 197 exec(code, {}) # Module scope 198 with self.assertRaisesRegex(SyntaxError, msg): 199 exec(code, {}, {}) # Class scope 200 with self.assertRaisesRegex(SyntaxError, msg): 201 exec(f"lambda: {code}", {}) # Function scope 202 203 def test_named_expression_invalid_set_comprehension_iterable_expression(self): 204 cases = [ 205 ("Top level", "{i for i in (i := range(5))}"), 206 ("Inside tuple", "{i for i in (2, 3, i := range(5))}"), 207 ("Inside list", "{i for i in {2, 3, i := range(5)}}"), 208 ("Different name", "{i for i in (j := range(5))}"), 209 ("Lambda expression", "{i for i in (lambda:(j := range(5)))()}"), 210 ("Inner loop", "{i for i in range(5) for j in (i := range(5))}"), 211 ("Nested comprehension", "{i for i in {j for j in (k := range(5))}}"), 212 ("Nested comprehension condition", "{i for i in {j for j in range(5) if (j := True)}}"), 213 ("Nested comprehension body", "{i for i in {(j := True) for j in range(5)}}"), 214 ] 215 msg = "assignment expression cannot be used in a comprehension iterable expression" 216 for case, code in cases: 217 with self.subTest(case=case): 218 with self.assertRaisesRegex(SyntaxError, msg): 219 exec(code, {}) # Module scope 220 with self.assertRaisesRegex(SyntaxError, msg): 221 exec(code, {}, {}) # Class scope 222 with self.assertRaisesRegex(SyntaxError, msg): 223 exec(f"lambda: {code}", {}) # Function scope 224 225 226class NamedExpressionAssignmentTest(unittest.TestCase): 227 228 def test_named_expression_assignment_01(self): 229 (a := 10) 230 231 self.assertEqual(a, 10) 232 233 def test_named_expression_assignment_02(self): 234 a = 20 235 (a := a) 236 237 self.assertEqual(a, 20) 238 239 def test_named_expression_assignment_03(self): 240 (total := 1 + 2) 241 242 self.assertEqual(total, 3) 243 244 def test_named_expression_assignment_04(self): 245 (info := (1, 2, 3)) 246 247 self.assertEqual(info, (1, 2, 3)) 248 249 def test_named_expression_assignment_05(self): 250 (x := 1, 2) 251 252 self.assertEqual(x, 1) 253 254 def test_named_expression_assignment_06(self): 255 (z := (y := (x := 0))) 256 257 self.assertEqual(x, 0) 258 self.assertEqual(y, 0) 259 self.assertEqual(z, 0) 260 261 def test_named_expression_assignment_07(self): 262 (loc := (1, 2)) 263 264 self.assertEqual(loc, (1, 2)) 265 266 def test_named_expression_assignment_08(self): 267 if spam := "eggs": 268 self.assertEqual(spam, "eggs") 269 else: self.fail("variable was not assigned using named expression") 270 271 def test_named_expression_assignment_09(self): 272 if True and (spam := True): 273 self.assertTrue(spam) 274 else: self.fail("variable was not assigned using named expression") 275 276 def test_named_expression_assignment_10(self): 277 if (match := 10) == 10: 278 pass 279 else: self.fail("variable was not assigned using named expression") 280 281 def test_named_expression_assignment_11(self): 282 def spam(a): 283 return a 284 input_data = [1, 2, 3] 285 res = [(x, y, x/y) for x in input_data if (y := spam(x)) > 0] 286 287 self.assertEqual(res, [(1, 1, 1.0), (2, 2, 1.0), (3, 3, 1.0)]) 288 289 def test_named_expression_assignment_12(self): 290 def spam(a): 291 return a 292 res = [[y := spam(x), x/y] for x in range(1, 5)] 293 294 self.assertEqual(res, [[1, 1.0], [2, 1.0], [3, 1.0], [4, 1.0]]) 295 296 def test_named_expression_assignment_13(self): 297 length = len(lines := [1, 2]) 298 299 self.assertEqual(length, 2) 300 self.assertEqual(lines, [1,2]) 301 302 def test_named_expression_assignment_14(self): 303 """ 304 Where all variables are positive integers, and a is at least as large 305 as the n'th root of x, this algorithm returns the floor of the n'th 306 root of x (and roughly doubling the number of accurate bits per 307 iteration): 308 """ 309 a = 9 310 n = 2 311 x = 3 312 313 while a > (d := x // a**(n-1)): 314 a = ((n-1)*a + d) // n 315 316 self.assertEqual(a, 1) 317 318 def test_named_expression_assignment_15(self): 319 while a := False: 320 pass # This will not run 321 322 self.assertEqual(a, False) 323 324 def test_named_expression_assignment_16(self): 325 a, b = 1, 2 326 fib = {(c := a): (a := b) + (b := a + c) - b for __ in range(6)} 327 self.assertEqual(fib, {1: 2, 2: 3, 3: 5, 5: 8, 8: 13, 13: 21}) 328 329 def test_named_expression_assignment_17(self): 330 a = [1] 331 element = a[b:=0] 332 self.assertEqual(b, 0) 333 self.assertEqual(element, a[0]) 334 335 def test_named_expression_assignment_18(self): 336 class TwoDimensionalList: 337 def __init__(self, two_dimensional_list): 338 self.two_dimensional_list = two_dimensional_list 339 340 def __getitem__(self, index): 341 return self.two_dimensional_list[index[0]][index[1]] 342 343 a = TwoDimensionalList([[1], [2]]) 344 element = a[b:=0, c:=0] 345 self.assertEqual(b, 0) 346 self.assertEqual(c, 0) 347 self.assertEqual(element, a.two_dimensional_list[b][c]) 348 349 350 351class NamedExpressionScopeTest(unittest.TestCase): 352 353 def test_named_expression_scope_01(self): 354 code = """def spam(): 355 (a := 5) 356print(a)""" 357 358 with self.assertRaisesRegex(NameError, "name 'a' is not defined"): 359 exec(code, {}, {}) 360 361 def test_named_expression_scope_02(self): 362 total = 0 363 partial_sums = [total := total + v for v in range(5)] 364 365 self.assertEqual(partial_sums, [0, 1, 3, 6, 10]) 366 self.assertEqual(total, 10) 367 368 def test_named_expression_scope_03(self): 369 containsOne = any((lastNum := num) == 1 for num in [1, 2, 3]) 370 371 self.assertTrue(containsOne) 372 self.assertEqual(lastNum, 1) 373 374 def test_named_expression_scope_04(self): 375 def spam(a): 376 return a 377 res = [[y := spam(x), x/y] for x in range(1, 5)] 378 379 self.assertEqual(y, 4) 380 381 def test_named_expression_scope_05(self): 382 def spam(a): 383 return a 384 input_data = [1, 2, 3] 385 res = [(x, y, x/y) for x in input_data if (y := spam(x)) > 0] 386 387 self.assertEqual(res, [(1, 1, 1.0), (2, 2, 1.0), (3, 3, 1.0)]) 388 self.assertEqual(y, 3) 389 390 def test_named_expression_scope_06(self): 391 res = [[spam := i for i in range(3)] for j in range(2)] 392 393 self.assertEqual(res, [[0, 1, 2], [0, 1, 2]]) 394 self.assertEqual(spam, 2) 395 396 def test_named_expression_scope_07(self): 397 len(lines := [1, 2]) 398 399 self.assertEqual(lines, [1, 2]) 400 401 def test_named_expression_scope_08(self): 402 def spam(a): 403 return a 404 405 def eggs(b): 406 return b * 2 407 408 res = [spam(a := eggs(b := h)) for h in range(2)] 409 410 self.assertEqual(res, [0, 2]) 411 self.assertEqual(a, 2) 412 self.assertEqual(b, 1) 413 414 def test_named_expression_scope_09(self): 415 def spam(a): 416 return a 417 418 def eggs(b): 419 return b * 2 420 421 res = [spam(a := eggs(a := h)) for h in range(2)] 422 423 self.assertEqual(res, [0, 2]) 424 self.assertEqual(a, 2) 425 426 def test_named_expression_scope_10(self): 427 res = [b := [a := 1 for i in range(2)] for j in range(2)] 428 429 self.assertEqual(res, [[1, 1], [1, 1]]) 430 self.assertEqual(a, 1) 431 self.assertEqual(b, [1, 1]) 432 433 def test_named_expression_scope_11(self): 434 res = [j := i for i in range(5)] 435 436 self.assertEqual(res, [0, 1, 2, 3, 4]) 437 self.assertEqual(j, 4) 438 439 def test_named_expression_scope_17(self): 440 b = 0 441 res = [b := i + b for i in range(5)] 442 443 self.assertEqual(res, [0, 1, 3, 6, 10]) 444 self.assertEqual(b, 10) 445 446 def test_named_expression_scope_18(self): 447 def spam(a): 448 return a 449 450 res = spam(b := 2) 451 452 self.assertEqual(res, 2) 453 self.assertEqual(b, 2) 454 455 def test_named_expression_scope_19(self): 456 def spam(a): 457 return a 458 459 res = spam((b := 2)) 460 461 self.assertEqual(res, 2) 462 self.assertEqual(b, 2) 463 464 def test_named_expression_scope_20(self): 465 def spam(a): 466 return a 467 468 res = spam(a=(b := 2)) 469 470 self.assertEqual(res, 2) 471 self.assertEqual(b, 2) 472 473 def test_named_expression_scope_21(self): 474 def spam(a, b): 475 return a + b 476 477 res = spam(c := 2, b=1) 478 479 self.assertEqual(res, 3) 480 self.assertEqual(c, 2) 481 482 def test_named_expression_scope_22(self): 483 def spam(a, b): 484 return a + b 485 486 res = spam((c := 2), b=1) 487 488 self.assertEqual(res, 3) 489 self.assertEqual(c, 2) 490 491 def test_named_expression_scope_23(self): 492 def spam(a, b): 493 return a + b 494 495 res = spam(b=(c := 2), a=1) 496 497 self.assertEqual(res, 3) 498 self.assertEqual(c, 2) 499 500 def test_named_expression_scope_24(self): 501 a = 10 502 def spam(): 503 nonlocal a 504 (a := 20) 505 spam() 506 507 self.assertEqual(a, 20) 508 509 def test_named_expression_scope_25(self): 510 ns = {} 511 code = """a = 10 512def spam(): 513 global a 514 (a := 20) 515spam()""" 516 517 exec(code, ns, {}) 518 519 self.assertEqual(ns["a"], 20) 520 521 def test_named_expression_variable_reuse_in_comprehensions(self): 522 # The compiler is expected to raise syntax error for comprehension 523 # iteration variables, but should be fine with rebinding of other 524 # names (e.g. globals, nonlocals, other assignment expressions) 525 526 # The cases are all defined to produce the same expected result 527 # Each comprehension is checked at both function scope and module scope 528 rebinding = "[x := i for i in range(3) if (x := i) or not x]" 529 filter_ref = "[x := i for i in range(3) if x or not x]" 530 body_ref = "[x for i in range(3) if (x := i) or not x]" 531 nested_ref = "[j for i in range(3) if x or not x for j in range(3) if (x := i)][:-3]" 532 cases = [ 533 ("Rebind global", f"x = 1; result = {rebinding}"), 534 ("Rebind nonlocal", f"result, x = (lambda x=1: ({rebinding}, x))()"), 535 ("Filter global", f"x = 1; result = {filter_ref}"), 536 ("Filter nonlocal", f"result, x = (lambda x=1: ({filter_ref}, x))()"), 537 ("Body global", f"x = 1; result = {body_ref}"), 538 ("Body nonlocal", f"result, x = (lambda x=1: ({body_ref}, x))()"), 539 ("Nested global", f"x = 1; result = {nested_ref}"), 540 ("Nested nonlocal", f"result, x = (lambda x=1: ({nested_ref}, x))()"), 541 ] 542 for case, code in cases: 543 with self.subTest(case=case): 544 ns = {} 545 exec(code, ns) 546 self.assertEqual(ns["x"], 2) 547 self.assertEqual(ns["result"], [0, 1, 2]) 548 549 def test_named_expression_global_scope(self): 550 sentinel = object() 551 global GLOBAL_VAR 552 def f(): 553 global GLOBAL_VAR 554 [GLOBAL_VAR := sentinel for _ in range(1)] 555 self.assertEqual(GLOBAL_VAR, sentinel) 556 try: 557 f() 558 self.assertEqual(GLOBAL_VAR, sentinel) 559 finally: 560 GLOBAL_VAR = None 561 562 def test_named_expression_global_scope_no_global_keyword(self): 563 sentinel = object() 564 def f(): 565 GLOBAL_VAR = None 566 [GLOBAL_VAR := sentinel for _ in range(1)] 567 self.assertEqual(GLOBAL_VAR, sentinel) 568 f() 569 self.assertEqual(GLOBAL_VAR, None) 570 571 def test_named_expression_nonlocal_scope(self): 572 sentinel = object() 573 def f(): 574 nonlocal_var = None 575 def g(): 576 nonlocal nonlocal_var 577 [nonlocal_var := sentinel for _ in range(1)] 578 g() 579 self.assertEqual(nonlocal_var, sentinel) 580 f() 581 582 def test_named_expression_nonlocal_scope_no_nonlocal_keyword(self): 583 sentinel = object() 584 def f(): 585 nonlocal_var = None 586 def g(): 587 [nonlocal_var := sentinel for _ in range(1)] 588 g() 589 self.assertEqual(nonlocal_var, None) 590 f() 591 592 def test_named_expression_scope_in_genexp(self): 593 a = 1 594 b = [1, 2, 3, 4] 595 genexp = (c := i + a for i in b) 596 597 self.assertNotIn("c", locals()) 598 for idx, elem in enumerate(genexp): 599 self.assertEqual(elem, b[idx] + a) 600 601 602if __name__ == "__main__": 603 unittest.main() 604