1import os 2import unittest 3 4GLOBAL_VAR = None 5 6class NamedExpressionInvalidTest(unittest.TestCase): 7 8 def test_named_expression_invalid_01(self): 9 code = """x := 0""" 10 11 with self.assertRaisesRegex(SyntaxError, "invalid syntax"): 12 exec(code, {}, {}) 13 14 def test_named_expression_invalid_02(self): 15 code = """x = y := 0""" 16 17 with self.assertRaisesRegex(SyntaxError, "invalid syntax"): 18 exec(code, {}, {}) 19 20 def test_named_expression_invalid_03(self): 21 code = """y := f(x)""" 22 23 with self.assertRaisesRegex(SyntaxError, "invalid syntax"): 24 exec(code, {}, {}) 25 26 def test_named_expression_invalid_04(self): 27 code = """y0 = y1 := f(x)""" 28 29 with self.assertRaisesRegex(SyntaxError, "invalid syntax"): 30 exec(code, {}, {}) 31 32 def test_named_expression_invalid_06(self): 33 code = """((a, b) := (1, 2))""" 34 35 with self.assertRaisesRegex(SyntaxError, "cannot use named assignment with tuple"): 36 exec(code, {}, {}) 37 38 def test_named_expression_invalid_07(self): 39 code = """def spam(a = b := 42): pass""" 40 41 with self.assertRaisesRegex(SyntaxError, "invalid syntax"): 42 exec(code, {}, {}) 43 44 def test_named_expression_invalid_08(self): 45 code = """def spam(a: b := 42 = 5): pass""" 46 47 with self.assertRaisesRegex(SyntaxError, "invalid syntax"): 48 exec(code, {}, {}) 49 50 def test_named_expression_invalid_09(self): 51 code = """spam(a=b := 'c')""" 52 53 with self.assertRaisesRegex(SyntaxError, "invalid syntax"): 54 exec(code, {}, {}) 55 56 def test_named_expression_invalid_10(self): 57 code = """spam(x = y := f(x))""" 58 59 with self.assertRaisesRegex(SyntaxError, "invalid syntax"): 60 exec(code, {}, {}) 61 62 def test_named_expression_invalid_11(self): 63 code = """spam(a=1, b := 2)""" 64 65 with self.assertRaisesRegex(SyntaxError, 66 "positional argument follows keyword argument"): 67 exec(code, {}, {}) 68 69 def test_named_expression_invalid_12(self): 70 code = """spam(a=1, (b := 2))""" 71 72 with self.assertRaisesRegex(SyntaxError, 73 "positional argument follows keyword argument"): 74 exec(code, {}, {}) 75 76 def test_named_expression_invalid_13(self): 77 code = """spam(a=1, (b := 2))""" 78 79 with self.assertRaisesRegex(SyntaxError, 80 "positional argument follows keyword argument"): 81 exec(code, {}, {}) 82 83 def test_named_expression_invalid_14(self): 84 code = """(x := lambda: y := 1)""" 85 86 with self.assertRaisesRegex(SyntaxError, "invalid syntax"): 87 exec(code, {}, {}) 88 89 def test_named_expression_invalid_15(self): 90 code = """(lambda: x := 1)""" 91 92 with self.assertRaisesRegex(SyntaxError, 93 "cannot use named assignment with lambda"): 94 exec(code, {}, {}) 95 96 def test_named_expression_invalid_16(self): 97 code = "[i + 1 for i in i := [1,2]]" 98 99 with self.assertRaisesRegex(SyntaxError, "invalid syntax"): 100 exec(code, {}, {}) 101 102 def test_named_expression_invalid_17(self): 103 code = "[i := 0, j := 1 for i, j in [(1, 2), (3, 4)]]" 104 105 with self.assertRaisesRegex(SyntaxError, "invalid syntax"): 106 exec(code, {}, {}) 107 108 def test_named_expression_invalid_in_class_body(self): 109 code = """class Foo(): 110 [(42, 1 + ((( j := i )))) for i in range(5)] 111 """ 112 113 with self.assertRaisesRegex(SyntaxError, 114 "assignment expression within a comprehension cannot be used in a class body"): 115 exec(code, {}, {}) 116 117 def test_named_expression_invalid_rebinding_comprehension_iteration_variable(self): 118 cases = [ 119 ("Local reuse", 'i', "[i := 0 for i in range(5)]"), 120 ("Nested reuse", 'j', "[[(j := 0) for i in range(5)] for j in range(5)]"), 121 ("Reuse inner loop target", 'j', "[(j := 0) for i in range(5) for j in range(5)]"), 122 ("Unpacking reuse", 'i', "[i := 0 for i, j in [(0, 1)]]"), 123 ("Reuse in loop condition", 'i', "[i+1 for i in range(5) if (i := 0)]"), 124 ("Unreachable reuse", 'i', "[False or (i:=0) for i in range(5)]"), 125 ("Unreachable nested reuse", 'i', 126 "[(i, j) for i in range(5) for j in range(5) if True or (i:=10)]"), 127 ] 128 for case, target, code in cases: 129 msg = f"assignment expression cannot rebind comprehension iteration variable '{target}'" 130 with self.subTest(case=case): 131 with self.assertRaisesRegex(SyntaxError, msg): 132 exec(code, {}, {}) 133 134 def test_named_expression_invalid_rebinding_comprehension_inner_loop(self): 135 cases = [ 136 ("Inner reuse", 'j', "[i for i in range(5) if (j := 0) for j in range(5)]"), 137 ("Inner unpacking reuse", 'j', "[i for i in range(5) if (j := 0) for j, k in [(0, 1)]]"), 138 ] 139 for case, target, code in cases: 140 msg = f"comprehension inner loop cannot rebind assignment expression target '{target}'" 141 with self.subTest(case=case): 142 with self.assertRaisesRegex(SyntaxError, msg): 143 exec(code, {}) # Module scope 144 with self.assertRaisesRegex(SyntaxError, msg): 145 exec(code, {}, {}) # Class scope 146 with self.assertRaisesRegex(SyntaxError, msg): 147 exec(f"lambda: {code}", {}) # Function scope 148 149 def test_named_expression_invalid_comprehension_iterable_expression(self): 150 cases = [ 151 ("Top level", "[i for i in (i := range(5))]"), 152 ("Inside tuple", "[i for i in (2, 3, i := range(5))]"), 153 ("Inside list", "[i for i in [2, 3, i := range(5)]]"), 154 ("Different name", "[i for i in (j := range(5))]"), 155 ("Lambda expression", "[i for i in (lambda:(j := range(5)))()]"), 156 ("Inner loop", "[i for i in range(5) for j in (i := range(5))]"), 157 ("Nested comprehension", "[i for i in [j for j in (k := range(5))]]"), 158 ("Nested comprehension condition", "[i for i in [j for j in range(5) if (j := True)]]"), 159 ("Nested comprehension body", "[i for i in [(j := True) for j in range(5)]]"), 160 ] 161 msg = "assignment expression cannot be used in a comprehension iterable expression" 162 for case, code in cases: 163 with self.subTest(case=case): 164 with self.assertRaisesRegex(SyntaxError, msg): 165 exec(code, {}) # Module scope 166 with self.assertRaisesRegex(SyntaxError, msg): 167 exec(code, {}, {}) # Class scope 168 with self.assertRaisesRegex(SyntaxError, msg): 169 exec(f"lambda: {code}", {}) # Function scope 170 171 172class NamedExpressionAssignmentTest(unittest.TestCase): 173 174 def test_named_expression_assignment_01(self): 175 (a := 10) 176 177 self.assertEqual(a, 10) 178 179 def test_named_expression_assignment_02(self): 180 a = 20 181 (a := a) 182 183 self.assertEqual(a, 20) 184 185 def test_named_expression_assignment_03(self): 186 (total := 1 + 2) 187 188 self.assertEqual(total, 3) 189 190 def test_named_expression_assignment_04(self): 191 (info := (1, 2, 3)) 192 193 self.assertEqual(info, (1, 2, 3)) 194 195 def test_named_expression_assignment_05(self): 196 (x := 1, 2) 197 198 self.assertEqual(x, 1) 199 200 def test_named_expression_assignment_06(self): 201 (z := (y := (x := 0))) 202 203 self.assertEqual(x, 0) 204 self.assertEqual(y, 0) 205 self.assertEqual(z, 0) 206 207 def test_named_expression_assignment_07(self): 208 (loc := (1, 2)) 209 210 self.assertEqual(loc, (1, 2)) 211 212 def test_named_expression_assignment_08(self): 213 if spam := "eggs": 214 self.assertEqual(spam, "eggs") 215 else: self.fail("variable was not assigned using named expression") 216 217 def test_named_expression_assignment_09(self): 218 if True and (spam := True): 219 self.assertTrue(spam) 220 else: self.fail("variable was not assigned using named expression") 221 222 def test_named_expression_assignment_10(self): 223 if (match := 10) == 10: 224 pass 225 else: self.fail("variable was not assigned using named expression") 226 227 def test_named_expression_assignment_11(self): 228 def spam(a): 229 return a 230 input_data = [1, 2, 3] 231 res = [(x, y, x/y) for x in input_data if (y := spam(x)) > 0] 232 233 self.assertEqual(res, [(1, 1, 1.0), (2, 2, 1.0), (3, 3, 1.0)]) 234 235 def test_named_expression_assignment_12(self): 236 def spam(a): 237 return a 238 res = [[y := spam(x), x/y] for x in range(1, 5)] 239 240 self.assertEqual(res, [[1, 1.0], [2, 1.0], [3, 1.0], [4, 1.0]]) 241 242 def test_named_expression_assignment_13(self): 243 length = len(lines := [1, 2]) 244 245 self.assertEqual(length, 2) 246 self.assertEqual(lines, [1,2]) 247 248 def test_named_expression_assignment_14(self): 249 """ 250 Where all variables are positive integers, and a is at least as large 251 as the n'th root of x, this algorithm returns the floor of the n'th 252 root of x (and roughly doubling the number of accurate bits per 253 iteration): 254 """ 255 a = 9 256 n = 2 257 x = 3 258 259 while a > (d := x // a**(n-1)): 260 a = ((n-1)*a + d) // n 261 262 self.assertEqual(a, 1) 263 264 def test_named_expression_assignment_15(self): 265 while a := False: 266 pass # This will not run 267 268 self.assertEqual(a, False) 269 270 def test_named_expression_assignment_16(self): 271 a, b = 1, 2 272 fib = {(c := a): (a := b) + (b := a + c) - b for __ in range(6)} 273 self.assertEqual(fib, {1: 2, 2: 3, 3: 5, 5: 8, 8: 13, 13: 21}) 274 275 276class NamedExpressionScopeTest(unittest.TestCase): 277 278 def test_named_expression_scope_01(self): 279 code = """def spam(): 280 (a := 5) 281print(a)""" 282 283 with self.assertRaisesRegex(NameError, "name 'a' is not defined"): 284 exec(code, {}, {}) 285 286 def test_named_expression_scope_02(self): 287 total = 0 288 partial_sums = [total := total + v for v in range(5)] 289 290 self.assertEqual(partial_sums, [0, 1, 3, 6, 10]) 291 self.assertEqual(total, 10) 292 293 def test_named_expression_scope_03(self): 294 containsOne = any((lastNum := num) == 1 for num in [1, 2, 3]) 295 296 self.assertTrue(containsOne) 297 self.assertEqual(lastNum, 1) 298 299 def test_named_expression_scope_04(self): 300 def spam(a): 301 return a 302 res = [[y := spam(x), x/y] for x in range(1, 5)] 303 304 self.assertEqual(y, 4) 305 306 def test_named_expression_scope_05(self): 307 def spam(a): 308 return a 309 input_data = [1, 2, 3] 310 res = [(x, y, x/y) for x in input_data if (y := spam(x)) > 0] 311 312 self.assertEqual(res, [(1, 1, 1.0), (2, 2, 1.0), (3, 3, 1.0)]) 313 self.assertEqual(y, 3) 314 315 def test_named_expression_scope_06(self): 316 res = [[spam := i for i in range(3)] for j in range(2)] 317 318 self.assertEqual(res, [[0, 1, 2], [0, 1, 2]]) 319 self.assertEqual(spam, 2) 320 321 def test_named_expression_scope_07(self): 322 len(lines := [1, 2]) 323 324 self.assertEqual(lines, [1, 2]) 325 326 def test_named_expression_scope_08(self): 327 def spam(a): 328 return a 329 330 def eggs(b): 331 return b * 2 332 333 res = [spam(a := eggs(b := h)) for h in range(2)] 334 335 self.assertEqual(res, [0, 2]) 336 self.assertEqual(a, 2) 337 self.assertEqual(b, 1) 338 339 def test_named_expression_scope_09(self): 340 def spam(a): 341 return a 342 343 def eggs(b): 344 return b * 2 345 346 res = [spam(a := eggs(a := h)) for h in range(2)] 347 348 self.assertEqual(res, [0, 2]) 349 self.assertEqual(a, 2) 350 351 def test_named_expression_scope_10(self): 352 res = [b := [a := 1 for i in range(2)] for j in range(2)] 353 354 self.assertEqual(res, [[1, 1], [1, 1]]) 355 self.assertEqual(a, 1) 356 self.assertEqual(b, [1, 1]) 357 358 def test_named_expression_scope_11(self): 359 res = [j := i for i in range(5)] 360 361 self.assertEqual(res, [0, 1, 2, 3, 4]) 362 self.assertEqual(j, 4) 363 364 def test_named_expression_scope_17(self): 365 b = 0 366 res = [b := i + b for i in range(5)] 367 368 self.assertEqual(res, [0, 1, 3, 6, 10]) 369 self.assertEqual(b, 10) 370 371 def test_named_expression_scope_18(self): 372 def spam(a): 373 return a 374 375 res = spam(b := 2) 376 377 self.assertEqual(res, 2) 378 self.assertEqual(b, 2) 379 380 def test_named_expression_scope_19(self): 381 def spam(a): 382 return a 383 384 res = spam((b := 2)) 385 386 self.assertEqual(res, 2) 387 self.assertEqual(b, 2) 388 389 def test_named_expression_scope_20(self): 390 def spam(a): 391 return a 392 393 res = spam(a=(b := 2)) 394 395 self.assertEqual(res, 2) 396 self.assertEqual(b, 2) 397 398 def test_named_expression_scope_21(self): 399 def spam(a, b): 400 return a + b 401 402 res = spam(c := 2, b=1) 403 404 self.assertEqual(res, 3) 405 self.assertEqual(c, 2) 406 407 def test_named_expression_scope_22(self): 408 def spam(a, b): 409 return a + b 410 411 res = spam((c := 2), b=1) 412 413 self.assertEqual(res, 3) 414 self.assertEqual(c, 2) 415 416 def test_named_expression_scope_23(self): 417 def spam(a, b): 418 return a + b 419 420 res = spam(b=(c := 2), a=1) 421 422 self.assertEqual(res, 3) 423 self.assertEqual(c, 2) 424 425 def test_named_expression_scope_24(self): 426 a = 10 427 def spam(): 428 nonlocal a 429 (a := 20) 430 spam() 431 432 self.assertEqual(a, 20) 433 434 def test_named_expression_scope_25(self): 435 ns = {} 436 code = """a = 10 437def spam(): 438 global a 439 (a := 20) 440spam()""" 441 442 exec(code, ns, {}) 443 444 self.assertEqual(ns["a"], 20) 445 446 def test_named_expression_variable_reuse_in_comprehensions(self): 447 # The compiler is expected to raise syntax error for comprehension 448 # iteration variables, but should be fine with rebinding of other 449 # names (e.g. globals, nonlocals, other assignment expressions) 450 451 # The cases are all defined to produce the same expected result 452 # Each comprehension is checked at both function scope and module scope 453 rebinding = "[x := i for i in range(3) if (x := i) or not x]" 454 filter_ref = "[x := i for i in range(3) if x or not x]" 455 body_ref = "[x for i in range(3) if (x := i) or not x]" 456 nested_ref = "[j for i in range(3) if x or not x for j in range(3) if (x := i)][:-3]" 457 cases = [ 458 ("Rebind global", f"x = 1; result = {rebinding}"), 459 ("Rebind nonlocal", f"result, x = (lambda x=1: ({rebinding}, x))()"), 460 ("Filter global", f"x = 1; result = {filter_ref}"), 461 ("Filter nonlocal", f"result, x = (lambda x=1: ({filter_ref}, x))()"), 462 ("Body global", f"x = 1; result = {body_ref}"), 463 ("Body nonlocal", f"result, x = (lambda x=1: ({body_ref}, x))()"), 464 ("Nested global", f"x = 1; result = {nested_ref}"), 465 ("Nested nonlocal", f"result, x = (lambda x=1: ({nested_ref}, x))()"), 466 ] 467 for case, code in cases: 468 with self.subTest(case=case): 469 ns = {} 470 exec(code, ns) 471 self.assertEqual(ns["x"], 2) 472 self.assertEqual(ns["result"], [0, 1, 2]) 473 474 def test_named_expression_global_scope(self): 475 sentinel = object() 476 global GLOBAL_VAR 477 def f(): 478 global GLOBAL_VAR 479 [GLOBAL_VAR := sentinel for _ in range(1)] 480 self.assertEqual(GLOBAL_VAR, sentinel) 481 try: 482 f() 483 self.assertEqual(GLOBAL_VAR, sentinel) 484 finally: 485 GLOBAL_VAR = None 486 487 def test_named_expression_global_scope_no_global_keyword(self): 488 sentinel = object() 489 def f(): 490 GLOBAL_VAR = None 491 [GLOBAL_VAR := sentinel for _ in range(1)] 492 self.assertEqual(GLOBAL_VAR, sentinel) 493 f() 494 self.assertEqual(GLOBAL_VAR, None) 495 496 def test_named_expression_nonlocal_scope(self): 497 sentinel = object() 498 def f(): 499 nonlocal_var = None 500 def g(): 501 nonlocal nonlocal_var 502 [nonlocal_var := sentinel for _ in range(1)] 503 g() 504 self.assertEqual(nonlocal_var, sentinel) 505 f() 506 507 def test_named_expression_nonlocal_scope_no_nonlocal_keyword(self): 508 sentinel = object() 509 def f(): 510 nonlocal_var = None 511 def g(): 512 [nonlocal_var := sentinel for _ in range(1)] 513 g() 514 self.assertEqual(nonlocal_var, None) 515 f() 516 517 518if __name__ == "__main__": 519 unittest.main() 520