1# Test iterators. 2 3import sys 4import unittest 5from test.support import cpython_only 6from test.support.os_helper import TESTFN, unlink 7from test.support import check_free_after_iterating, ALWAYS_EQ, NEVER_EQ 8from test.support import BrokenIter 9import pickle 10import collections.abc 11import functools 12import contextlib 13import builtins 14import traceback 15 16# Test result of triple loop (too big to inline) 17TRIPLETS = [(0, 0, 0), (0, 0, 1), (0, 0, 2), 18 (0, 1, 0), (0, 1, 1), (0, 1, 2), 19 (0, 2, 0), (0, 2, 1), (0, 2, 2), 20 21 (1, 0, 0), (1, 0, 1), (1, 0, 2), 22 (1, 1, 0), (1, 1, 1), (1, 1, 2), 23 (1, 2, 0), (1, 2, 1), (1, 2, 2), 24 25 (2, 0, 0), (2, 0, 1), (2, 0, 2), 26 (2, 1, 0), (2, 1, 1), (2, 1, 2), 27 (2, 2, 0), (2, 2, 1), (2, 2, 2)] 28 29# Helper classes 30 31class BasicIterClass: 32 def __init__(self, n): 33 self.n = n 34 self.i = 0 35 def __next__(self): 36 res = self.i 37 if res >= self.n: 38 raise StopIteration 39 self.i = res + 1 40 return res 41 def __iter__(self): 42 return self 43 44class IteratingSequenceClass: 45 def __init__(self, n): 46 self.n = n 47 def __iter__(self): 48 return BasicIterClass(self.n) 49 50class IteratorProxyClass: 51 def __init__(self, i): 52 self.i = i 53 def __next__(self): 54 return next(self.i) 55 def __iter__(self): 56 return self 57 58class SequenceClass: 59 def __init__(self, n): 60 self.n = n 61 def __getitem__(self, i): 62 if 0 <= i < self.n: 63 return i 64 else: 65 raise IndexError 66 67class SequenceProxyClass: 68 def __init__(self, s): 69 self.s = s 70 def __getitem__(self, i): 71 return self.s[i] 72 73class UnlimitedSequenceClass: 74 def __getitem__(self, i): 75 return i 76 77class DefaultIterClass: 78 pass 79 80class NoIterClass: 81 def __getitem__(self, i): 82 return i 83 __iter__ = None 84 85class BadIterableClass: 86 def __iter__(self): 87 raise ZeroDivisionError 88 89class CallableIterClass: 90 def __init__(self): 91 self.i = 0 92 def __call__(self): 93 i = self.i 94 self.i = i + 1 95 if i > 100: 96 raise IndexError # Emergency stop 97 return i 98 99class EmptyIterClass: 100 def __len__(self): 101 return 0 102 def __getitem__(self, i): 103 raise StopIteration 104 105# Main test suite 106 107class TestCase(unittest.TestCase): 108 109 # Helper to check that an iterator returns a given sequence 110 def check_iterator(self, it, seq, pickle=True): 111 if pickle: 112 self.check_pickle(it, seq) 113 res = [] 114 while 1: 115 try: 116 val = next(it) 117 except StopIteration: 118 break 119 res.append(val) 120 self.assertEqual(res, seq) 121 122 # Helper to check that a for loop generates a given sequence 123 def check_for_loop(self, expr, seq, pickle=True): 124 if pickle: 125 self.check_pickle(iter(expr), seq) 126 res = [] 127 for val in expr: 128 res.append(val) 129 self.assertEqual(res, seq) 130 131 # Helper to check picklability 132 def check_pickle(self, itorg, seq): 133 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 134 d = pickle.dumps(itorg, proto) 135 it = pickle.loads(d) 136 # Cannot assert type equality because dict iterators unpickle as list 137 # iterators. 138 # self.assertEqual(type(itorg), type(it)) 139 self.assertTrue(isinstance(it, collections.abc.Iterator)) 140 self.assertEqual(list(it), seq) 141 142 it = pickle.loads(d) 143 try: 144 next(it) 145 except StopIteration: 146 continue 147 d = pickle.dumps(it, proto) 148 it = pickle.loads(d) 149 self.assertEqual(list(it), seq[1:]) 150 151 # Test basic use of iter() function 152 def test_iter_basic(self): 153 self.check_iterator(iter(range(10)), list(range(10))) 154 155 # Test that iter(iter(x)) is the same as iter(x) 156 def test_iter_idempotency(self): 157 seq = list(range(10)) 158 it = iter(seq) 159 it2 = iter(it) 160 self.assertTrue(it is it2) 161 162 # Test that for loops over iterators work 163 def test_iter_for_loop(self): 164 self.check_for_loop(iter(range(10)), list(range(10))) 165 166 # Test several independent iterators over the same list 167 def test_iter_independence(self): 168 seq = range(3) 169 res = [] 170 for i in iter(seq): 171 for j in iter(seq): 172 for k in iter(seq): 173 res.append((i, j, k)) 174 self.assertEqual(res, TRIPLETS) 175 176 # Test triple list comprehension using iterators 177 def test_nested_comprehensions_iter(self): 178 seq = range(3) 179 res = [(i, j, k) 180 for i in iter(seq) for j in iter(seq) for k in iter(seq)] 181 self.assertEqual(res, TRIPLETS) 182 183 # Test triple list comprehension without iterators 184 def test_nested_comprehensions_for(self): 185 seq = range(3) 186 res = [(i, j, k) for i in seq for j in seq for k in seq] 187 self.assertEqual(res, TRIPLETS) 188 189 # Test a class with __iter__ in a for loop 190 def test_iter_class_for(self): 191 self.check_for_loop(IteratingSequenceClass(10), list(range(10))) 192 193 # Test a class with __iter__ with explicit iter() 194 def test_iter_class_iter(self): 195 self.check_iterator(iter(IteratingSequenceClass(10)), list(range(10))) 196 197 # Test for loop on a sequence class without __iter__ 198 def test_seq_class_for(self): 199 self.check_for_loop(SequenceClass(10), list(range(10))) 200 201 # Test iter() on a sequence class without __iter__ 202 def test_seq_class_iter(self): 203 self.check_iterator(iter(SequenceClass(10)), list(range(10))) 204 205 def test_mutating_seq_class_iter_pickle(self): 206 orig = SequenceClass(5) 207 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 208 # initial iterator 209 itorig = iter(orig) 210 d = pickle.dumps((itorig, orig), proto) 211 it, seq = pickle.loads(d) 212 seq.n = 7 213 self.assertIs(type(it), type(itorig)) 214 self.assertEqual(list(it), list(range(7))) 215 216 # running iterator 217 next(itorig) 218 d = pickle.dumps((itorig, orig), proto) 219 it, seq = pickle.loads(d) 220 seq.n = 7 221 self.assertIs(type(it), type(itorig)) 222 self.assertEqual(list(it), list(range(1, 7))) 223 224 # empty iterator 225 for i in range(1, 5): 226 next(itorig) 227 d = pickle.dumps((itorig, orig), proto) 228 it, seq = pickle.loads(d) 229 seq.n = 7 230 self.assertIs(type(it), type(itorig)) 231 self.assertEqual(list(it), list(range(5, 7))) 232 233 # exhausted iterator 234 self.assertRaises(StopIteration, next, itorig) 235 d = pickle.dumps((itorig, orig), proto) 236 it, seq = pickle.loads(d) 237 seq.n = 7 238 self.assertTrue(isinstance(it, collections.abc.Iterator)) 239 self.assertEqual(list(it), []) 240 241 def test_mutating_seq_class_exhausted_iter(self): 242 a = SequenceClass(5) 243 exhit = iter(a) 244 empit = iter(a) 245 for x in exhit: # exhaust the iterator 246 next(empit) # not exhausted 247 a.n = 7 248 self.assertEqual(list(exhit), []) 249 self.assertEqual(list(empit), [5, 6]) 250 self.assertEqual(list(a), [0, 1, 2, 3, 4, 5, 6]) 251 252 def test_reduce_mutating_builtins_iter(self): 253 # This is a reproducer of issue #101765 254 # where iter `__reduce__` calls could lead to a segfault or SystemError 255 # depending on the order of C argument evaluation, which is undefined 256 257 # Backup builtins 258 builtins_dict = builtins.__dict__ 259 orig = {"iter": iter, "reversed": reversed} 260 261 def run(builtin_name, item, sentinel=None): 262 it = iter(item) if sentinel is None else iter(item, sentinel) 263 264 class CustomStr: 265 def __init__(self, name, iterator): 266 self.name = name 267 self.iterator = iterator 268 def __hash__(self): 269 return hash(self.name) 270 def __eq__(self, other): 271 # Here we exhaust our iterator, possibly changing 272 # its `it_seq` pointer to NULL 273 # The `__reduce__` call should correctly get 274 # the pointers after this call 275 list(self.iterator) 276 return other == self.name 277 278 # del is required here 279 # to not prematurely call __eq__ from 280 # the hash collision with the old key 281 del builtins_dict[builtin_name] 282 builtins_dict[CustomStr(builtin_name, it)] = orig[builtin_name] 283 284 return it.__reduce__() 285 286 types = [ 287 (EmptyIterClass(),), 288 (bytes(8),), 289 (bytearray(8),), 290 ((1, 2, 3),), 291 (lambda: 0, 0), 292 (tuple[int],) # GenericAlias 293 ] 294 295 try: 296 run_iter = functools.partial(run, "iter") 297 # The returned value of `__reduce__` should not only be valid 298 # but also *empty*, as `it` was exhausted during `__eq__` 299 # i.e "xyz" returns (iter, ("",)) 300 self.assertEqual(run_iter("xyz"), (orig["iter"], ("",))) 301 self.assertEqual(run_iter([1, 2, 3]), (orig["iter"], ([],))) 302 303 # _PyEval_GetBuiltin is also called for `reversed` in a branch of 304 # listiter_reduce_general 305 self.assertEqual( 306 run("reversed", orig["reversed"](list(range(8)))), 307 (reversed, ([],)) 308 ) 309 310 for case in types: 311 self.assertEqual(run_iter(*case), (orig["iter"], ((),))) 312 finally: 313 # Restore original builtins 314 for key, func in orig.items(): 315 # need to suppress KeyErrors in case 316 # a failed test deletes the key without setting anything 317 with contextlib.suppress(KeyError): 318 # del is required here 319 # to not invoke our custom __eq__ from 320 # the hash collision with the old key 321 del builtins_dict[key] 322 builtins_dict[key] = func 323 324 # Test a new_style class with __iter__ but no next() method 325 def test_new_style_iter_class(self): 326 class IterClass(object): 327 def __iter__(self): 328 return self 329 self.assertRaises(TypeError, iter, IterClass()) 330 331 # Test two-argument iter() with callable instance 332 def test_iter_callable(self): 333 self.check_iterator(iter(CallableIterClass(), 10), list(range(10)), pickle=True) 334 335 # Test two-argument iter() with function 336 def test_iter_function(self): 337 def spam(state=[0]): 338 i = state[0] 339 state[0] = i+1 340 return i 341 self.check_iterator(iter(spam, 10), list(range(10)), pickle=False) 342 343 # Test two-argument iter() with function that raises StopIteration 344 def test_iter_function_stop(self): 345 def spam(state=[0]): 346 i = state[0] 347 if i == 10: 348 raise StopIteration 349 state[0] = i+1 350 return i 351 self.check_iterator(iter(spam, 20), list(range(10)), pickle=False) 352 353 def test_iter_function_concealing_reentrant_exhaustion(self): 354 # gh-101892: Test two-argument iter() with a function that 355 # exhausts its associated iterator but forgets to either return 356 # a sentinel value or raise StopIteration. 357 HAS_MORE = 1 358 NO_MORE = 2 359 360 def exhaust(iterator): 361 """Exhaust an iterator without raising StopIteration.""" 362 list(iterator) 363 364 def spam(): 365 # Touching the iterator with exhaust() below will call 366 # spam() once again so protect against recursion. 367 if spam.is_recursive_call: 368 return NO_MORE 369 spam.is_recursive_call = True 370 exhaust(spam.iterator) 371 return HAS_MORE 372 373 spam.is_recursive_call = False 374 spam.iterator = iter(spam, NO_MORE) 375 with self.assertRaises(StopIteration): 376 next(spam.iterator) 377 378 # Test exception propagation through function iterator 379 def test_exception_function(self): 380 def spam(state=[0]): 381 i = state[0] 382 state[0] = i+1 383 if i == 10: 384 raise RuntimeError 385 return i 386 res = [] 387 try: 388 for x in iter(spam, 20): 389 res.append(x) 390 except RuntimeError: 391 self.assertEqual(res, list(range(10))) 392 else: 393 self.fail("should have raised RuntimeError") 394 395 # Test exception propagation through sequence iterator 396 def test_exception_sequence(self): 397 class MySequenceClass(SequenceClass): 398 def __getitem__(self, i): 399 if i == 10: 400 raise RuntimeError 401 return SequenceClass.__getitem__(self, i) 402 res = [] 403 try: 404 for x in MySequenceClass(20): 405 res.append(x) 406 except RuntimeError: 407 self.assertEqual(res, list(range(10))) 408 else: 409 self.fail("should have raised RuntimeError") 410 411 # Test for StopIteration from __getitem__ 412 def test_stop_sequence(self): 413 class MySequenceClass(SequenceClass): 414 def __getitem__(self, i): 415 if i == 10: 416 raise StopIteration 417 return SequenceClass.__getitem__(self, i) 418 self.check_for_loop(MySequenceClass(20), list(range(10)), pickle=False) 419 420 # Test a big range 421 def test_iter_big_range(self): 422 self.check_for_loop(iter(range(10000)), list(range(10000))) 423 424 # Test an empty list 425 def test_iter_empty(self): 426 self.check_for_loop(iter([]), []) 427 428 # Test a tuple 429 def test_iter_tuple(self): 430 self.check_for_loop(iter((0,1,2,3,4,5,6,7,8,9)), list(range(10))) 431 432 # Test a range 433 def test_iter_range(self): 434 self.check_for_loop(iter(range(10)), list(range(10))) 435 436 # Test a string 437 def test_iter_string(self): 438 self.check_for_loop(iter("abcde"), ["a", "b", "c", "d", "e"]) 439 440 # Test a directory 441 def test_iter_dict(self): 442 dict = {} 443 for i in range(10): 444 dict[i] = None 445 self.check_for_loop(dict, list(dict.keys())) 446 447 # Test a file 448 def test_iter_file(self): 449 f = open(TESTFN, "w", encoding="utf-8") 450 try: 451 for i in range(5): 452 f.write("%d\n" % i) 453 finally: 454 f.close() 455 f = open(TESTFN, "r", encoding="utf-8") 456 try: 457 self.check_for_loop(f, ["0\n", "1\n", "2\n", "3\n", "4\n"], pickle=False) 458 self.check_for_loop(f, [], pickle=False) 459 finally: 460 f.close() 461 try: 462 unlink(TESTFN) 463 except OSError: 464 pass 465 466 # Test list()'s use of iterators. 467 def test_builtin_list(self): 468 self.assertEqual(list(SequenceClass(5)), list(range(5))) 469 self.assertEqual(list(SequenceClass(0)), []) 470 self.assertEqual(list(()), []) 471 472 d = {"one": 1, "two": 2, "three": 3} 473 self.assertEqual(list(d), list(d.keys())) 474 475 self.assertRaises(TypeError, list, list) 476 self.assertRaises(TypeError, list, 42) 477 478 f = open(TESTFN, "w", encoding="utf-8") 479 try: 480 for i in range(5): 481 f.write("%d\n" % i) 482 finally: 483 f.close() 484 f = open(TESTFN, "r", encoding="utf-8") 485 try: 486 self.assertEqual(list(f), ["0\n", "1\n", "2\n", "3\n", "4\n"]) 487 f.seek(0, 0) 488 self.assertEqual(list(f), 489 ["0\n", "1\n", "2\n", "3\n", "4\n"]) 490 finally: 491 f.close() 492 try: 493 unlink(TESTFN) 494 except OSError: 495 pass 496 497 # Test tuples()'s use of iterators. 498 def test_builtin_tuple(self): 499 self.assertEqual(tuple(SequenceClass(5)), (0, 1, 2, 3, 4)) 500 self.assertEqual(tuple(SequenceClass(0)), ()) 501 self.assertEqual(tuple([]), ()) 502 self.assertEqual(tuple(()), ()) 503 self.assertEqual(tuple("abc"), ("a", "b", "c")) 504 505 d = {"one": 1, "two": 2, "three": 3} 506 self.assertEqual(tuple(d), tuple(d.keys())) 507 508 self.assertRaises(TypeError, tuple, list) 509 self.assertRaises(TypeError, tuple, 42) 510 511 f = open(TESTFN, "w", encoding="utf-8") 512 try: 513 for i in range(5): 514 f.write("%d\n" % i) 515 finally: 516 f.close() 517 f = open(TESTFN, "r", encoding="utf-8") 518 try: 519 self.assertEqual(tuple(f), ("0\n", "1\n", "2\n", "3\n", "4\n")) 520 f.seek(0, 0) 521 self.assertEqual(tuple(f), 522 ("0\n", "1\n", "2\n", "3\n", "4\n")) 523 finally: 524 f.close() 525 try: 526 unlink(TESTFN) 527 except OSError: 528 pass 529 530 # Test filter()'s use of iterators. 531 def test_builtin_filter(self): 532 self.assertEqual(list(filter(None, SequenceClass(5))), 533 list(range(1, 5))) 534 self.assertEqual(list(filter(None, SequenceClass(0))), []) 535 self.assertEqual(list(filter(None, ())), []) 536 self.assertEqual(list(filter(None, "abc")), ["a", "b", "c"]) 537 538 d = {"one": 1, "two": 2, "three": 3} 539 self.assertEqual(list(filter(None, d)), list(d.keys())) 540 541 self.assertRaises(TypeError, filter, None, list) 542 self.assertRaises(TypeError, filter, None, 42) 543 544 class Boolean: 545 def __init__(self, truth): 546 self.truth = truth 547 def __bool__(self): 548 return self.truth 549 bTrue = Boolean(True) 550 bFalse = Boolean(False) 551 552 class Seq: 553 def __init__(self, *args): 554 self.vals = args 555 def __iter__(self): 556 class SeqIter: 557 def __init__(self, vals): 558 self.vals = vals 559 self.i = 0 560 def __iter__(self): 561 return self 562 def __next__(self): 563 i = self.i 564 self.i = i + 1 565 if i < len(self.vals): 566 return self.vals[i] 567 else: 568 raise StopIteration 569 return SeqIter(self.vals) 570 571 seq = Seq(*([bTrue, bFalse] * 25)) 572 self.assertEqual(list(filter(lambda x: not x, seq)), [bFalse]*25) 573 self.assertEqual(list(filter(lambda x: not x, iter(seq))), [bFalse]*25) 574 575 # Test max() and min()'s use of iterators. 576 def test_builtin_max_min(self): 577 self.assertEqual(max(SequenceClass(5)), 4) 578 self.assertEqual(min(SequenceClass(5)), 0) 579 self.assertEqual(max(8, -1), 8) 580 self.assertEqual(min(8, -1), -1) 581 582 d = {"one": 1, "two": 2, "three": 3} 583 self.assertEqual(max(d), "two") 584 self.assertEqual(min(d), "one") 585 self.assertEqual(max(d.values()), 3) 586 self.assertEqual(min(iter(d.values())), 1) 587 588 f = open(TESTFN, "w", encoding="utf-8") 589 try: 590 f.write("medium line\n") 591 f.write("xtra large line\n") 592 f.write("itty-bitty line\n") 593 finally: 594 f.close() 595 f = open(TESTFN, "r", encoding="utf-8") 596 try: 597 self.assertEqual(min(f), "itty-bitty line\n") 598 f.seek(0, 0) 599 self.assertEqual(max(f), "xtra large line\n") 600 finally: 601 f.close() 602 try: 603 unlink(TESTFN) 604 except OSError: 605 pass 606 607 # Test map()'s use of iterators. 608 def test_builtin_map(self): 609 self.assertEqual(list(map(lambda x: x+1, SequenceClass(5))), 610 list(range(1, 6))) 611 612 d = {"one": 1, "two": 2, "three": 3} 613 self.assertEqual(list(map(lambda k, d=d: (k, d[k]), d)), 614 list(d.items())) 615 dkeys = list(d.keys()) 616 expected = [(i < len(d) and dkeys[i] or None, 617 i, 618 i < len(d) and dkeys[i] or None) 619 for i in range(3)] 620 621 f = open(TESTFN, "w", encoding="utf-8") 622 try: 623 for i in range(10): 624 f.write("xy" * i + "\n") # line i has len 2*i+1 625 finally: 626 f.close() 627 f = open(TESTFN, "r", encoding="utf-8") 628 try: 629 self.assertEqual(list(map(len, f)), list(range(1, 21, 2))) 630 finally: 631 f.close() 632 try: 633 unlink(TESTFN) 634 except OSError: 635 pass 636 637 # Test zip()'s use of iterators. 638 def test_builtin_zip(self): 639 self.assertEqual(list(zip()), []) 640 self.assertEqual(list(zip(*[])), []) 641 self.assertEqual(list(zip(*[(1, 2), 'ab'])), [(1, 'a'), (2, 'b')]) 642 643 self.assertRaises(TypeError, zip, None) 644 self.assertRaises(TypeError, zip, range(10), 42) 645 self.assertRaises(TypeError, zip, range(10), zip) 646 647 self.assertEqual(list(zip(IteratingSequenceClass(3))), 648 [(0,), (1,), (2,)]) 649 self.assertEqual(list(zip(SequenceClass(3))), 650 [(0,), (1,), (2,)]) 651 652 d = {"one": 1, "two": 2, "three": 3} 653 self.assertEqual(list(d.items()), list(zip(d, d.values()))) 654 655 # Generate all ints starting at constructor arg. 656 class IntsFrom: 657 def __init__(self, start): 658 self.i = start 659 660 def __iter__(self): 661 return self 662 663 def __next__(self): 664 i = self.i 665 self.i = i+1 666 return i 667 668 f = open(TESTFN, "w", encoding="utf-8") 669 try: 670 f.write("a\n" "bbb\n" "cc\n") 671 finally: 672 f.close() 673 f = open(TESTFN, "r", encoding="utf-8") 674 try: 675 self.assertEqual(list(zip(IntsFrom(0), f, IntsFrom(-100))), 676 [(0, "a\n", -100), 677 (1, "bbb\n", -99), 678 (2, "cc\n", -98)]) 679 finally: 680 f.close() 681 try: 682 unlink(TESTFN) 683 except OSError: 684 pass 685 686 self.assertEqual(list(zip(range(5))), [(i,) for i in range(5)]) 687 688 # Classes that lie about their lengths. 689 class NoGuessLen5: 690 def __getitem__(self, i): 691 if i >= 5: 692 raise IndexError 693 return i 694 695 class Guess3Len5(NoGuessLen5): 696 def __len__(self): 697 return 3 698 699 class Guess30Len5(NoGuessLen5): 700 def __len__(self): 701 return 30 702 703 def lzip(*args): 704 return list(zip(*args)) 705 706 self.assertEqual(len(Guess3Len5()), 3) 707 self.assertEqual(len(Guess30Len5()), 30) 708 self.assertEqual(lzip(NoGuessLen5()), lzip(range(5))) 709 self.assertEqual(lzip(Guess3Len5()), lzip(range(5))) 710 self.assertEqual(lzip(Guess30Len5()), lzip(range(5))) 711 712 expected = [(i, i) for i in range(5)] 713 for x in NoGuessLen5(), Guess3Len5(), Guess30Len5(): 714 for y in NoGuessLen5(), Guess3Len5(), Guess30Len5(): 715 self.assertEqual(lzip(x, y), expected) 716 717 def test_unicode_join_endcase(self): 718 719 # This class inserts a Unicode object into its argument's natural 720 # iteration, in the 3rd position. 721 class OhPhooey: 722 def __init__(self, seq): 723 self.it = iter(seq) 724 self.i = 0 725 726 def __iter__(self): 727 return self 728 729 def __next__(self): 730 i = self.i 731 self.i = i+1 732 if i == 2: 733 return "fooled you!" 734 return next(self.it) 735 736 f = open(TESTFN, "w", encoding="utf-8") 737 try: 738 f.write("a\n" + "b\n" + "c\n") 739 finally: 740 f.close() 741 742 f = open(TESTFN, "r", encoding="utf-8") 743 # Nasty: string.join(s) can't know whether unicode.join() is needed 744 # until it's seen all of s's elements. But in this case, f's 745 # iterator cannot be restarted. So what we're testing here is 746 # whether string.join() can manage to remember everything it's seen 747 # and pass that on to unicode.join(). 748 try: 749 got = " - ".join(OhPhooey(f)) 750 self.assertEqual(got, "a\n - b\n - fooled you! - c\n") 751 finally: 752 f.close() 753 try: 754 unlink(TESTFN) 755 except OSError: 756 pass 757 758 # Test iterators with 'x in y' and 'x not in y'. 759 def test_in_and_not_in(self): 760 for sc5 in IteratingSequenceClass(5), SequenceClass(5): 761 for i in range(5): 762 self.assertIn(i, sc5) 763 for i in "abc", -1, 5, 42.42, (3, 4), [], {1: 1}, 3-12j, sc5: 764 self.assertNotIn(i, sc5) 765 766 self.assertIn(ALWAYS_EQ, IteratorProxyClass(iter([1]))) 767 self.assertIn(ALWAYS_EQ, SequenceProxyClass([1])) 768 self.assertNotIn(ALWAYS_EQ, IteratorProxyClass(iter([NEVER_EQ]))) 769 self.assertNotIn(ALWAYS_EQ, SequenceProxyClass([NEVER_EQ])) 770 self.assertIn(NEVER_EQ, IteratorProxyClass(iter([ALWAYS_EQ]))) 771 self.assertIn(NEVER_EQ, SequenceProxyClass([ALWAYS_EQ])) 772 773 self.assertRaises(TypeError, lambda: 3 in 12) 774 self.assertRaises(TypeError, lambda: 3 not in map) 775 self.assertRaises(ZeroDivisionError, lambda: 3 in BadIterableClass()) 776 777 d = {"one": 1, "two": 2, "three": 3, 1j: 2j} 778 for k in d: 779 self.assertIn(k, d) 780 self.assertNotIn(k, d.values()) 781 for v in d.values(): 782 self.assertIn(v, d.values()) 783 self.assertNotIn(v, d) 784 for k, v in d.items(): 785 self.assertIn((k, v), d.items()) 786 self.assertNotIn((v, k), d.items()) 787 788 f = open(TESTFN, "w", encoding="utf-8") 789 try: 790 f.write("a\n" "b\n" "c\n") 791 finally: 792 f.close() 793 f = open(TESTFN, "r", encoding="utf-8") 794 try: 795 for chunk in "abc": 796 f.seek(0, 0) 797 self.assertNotIn(chunk, f) 798 f.seek(0, 0) 799 self.assertIn((chunk + "\n"), f) 800 finally: 801 f.close() 802 try: 803 unlink(TESTFN) 804 except OSError: 805 pass 806 807 # Test iterators with operator.countOf (PySequence_Count). 808 def test_countOf(self): 809 from operator import countOf 810 self.assertEqual(countOf([1,2,2,3,2,5], 2), 3) 811 self.assertEqual(countOf((1,2,2,3,2,5), 2), 3) 812 self.assertEqual(countOf("122325", "2"), 3) 813 self.assertEqual(countOf("122325", "6"), 0) 814 815 self.assertRaises(TypeError, countOf, 42, 1) 816 self.assertRaises(TypeError, countOf, countOf, countOf) 817 818 d = {"one": 3, "two": 3, "three": 3, 1j: 2j} 819 for k in d: 820 self.assertEqual(countOf(d, k), 1) 821 self.assertEqual(countOf(d.values(), 3), 3) 822 self.assertEqual(countOf(d.values(), 2j), 1) 823 self.assertEqual(countOf(d.values(), 1j), 0) 824 825 f = open(TESTFN, "w", encoding="utf-8") 826 try: 827 f.write("a\n" "b\n" "c\n" "b\n") 828 finally: 829 f.close() 830 f = open(TESTFN, "r", encoding="utf-8") 831 try: 832 for letter, count in ("a", 1), ("b", 2), ("c", 1), ("d", 0): 833 f.seek(0, 0) 834 self.assertEqual(countOf(f, letter + "\n"), count) 835 finally: 836 f.close() 837 try: 838 unlink(TESTFN) 839 except OSError: 840 pass 841 842 # Test iterators with operator.indexOf (PySequence_Index). 843 def test_indexOf(self): 844 from operator import indexOf 845 self.assertEqual(indexOf([1,2,2,3,2,5], 1), 0) 846 self.assertEqual(indexOf((1,2,2,3,2,5), 2), 1) 847 self.assertEqual(indexOf((1,2,2,3,2,5), 3), 3) 848 self.assertEqual(indexOf((1,2,2,3,2,5), 5), 5) 849 self.assertRaises(ValueError, indexOf, (1,2,2,3,2,5), 0) 850 self.assertRaises(ValueError, indexOf, (1,2,2,3,2,5), 6) 851 852 self.assertEqual(indexOf("122325", "2"), 1) 853 self.assertEqual(indexOf("122325", "5"), 5) 854 self.assertRaises(ValueError, indexOf, "122325", "6") 855 856 self.assertRaises(TypeError, indexOf, 42, 1) 857 self.assertRaises(TypeError, indexOf, indexOf, indexOf) 858 self.assertRaises(ZeroDivisionError, indexOf, BadIterableClass(), 1) 859 860 f = open(TESTFN, "w", encoding="utf-8") 861 try: 862 f.write("a\n" "b\n" "c\n" "d\n" "e\n") 863 finally: 864 f.close() 865 f = open(TESTFN, "r", encoding="utf-8") 866 try: 867 fiter = iter(f) 868 self.assertEqual(indexOf(fiter, "b\n"), 1) 869 self.assertEqual(indexOf(fiter, "d\n"), 1) 870 self.assertEqual(indexOf(fiter, "e\n"), 0) 871 self.assertRaises(ValueError, indexOf, fiter, "a\n") 872 finally: 873 f.close() 874 try: 875 unlink(TESTFN) 876 except OSError: 877 pass 878 879 iclass = IteratingSequenceClass(3) 880 for i in range(3): 881 self.assertEqual(indexOf(iclass, i), i) 882 self.assertRaises(ValueError, indexOf, iclass, -1) 883 884 # Test iterators with file.writelines(). 885 def test_writelines(self): 886 f = open(TESTFN, "w", encoding="utf-8") 887 888 try: 889 self.assertRaises(TypeError, f.writelines, None) 890 self.assertRaises(TypeError, f.writelines, 42) 891 892 f.writelines(["1\n", "2\n"]) 893 f.writelines(("3\n", "4\n")) 894 f.writelines({'5\n': None}) 895 f.writelines({}) 896 897 # Try a big chunk too. 898 class Iterator: 899 def __init__(self, start, finish): 900 self.start = start 901 self.finish = finish 902 self.i = self.start 903 904 def __next__(self): 905 if self.i >= self.finish: 906 raise StopIteration 907 result = str(self.i) + '\n' 908 self.i += 1 909 return result 910 911 def __iter__(self): 912 return self 913 914 class Whatever: 915 def __init__(self, start, finish): 916 self.start = start 917 self.finish = finish 918 919 def __iter__(self): 920 return Iterator(self.start, self.finish) 921 922 f.writelines(Whatever(6, 6+2000)) 923 f.close() 924 925 f = open(TESTFN, encoding="utf-8") 926 expected = [str(i) + "\n" for i in range(1, 2006)] 927 self.assertEqual(list(f), expected) 928 929 finally: 930 f.close() 931 try: 932 unlink(TESTFN) 933 except OSError: 934 pass 935 936 937 # Test iterators on RHS of unpacking assignments. 938 def test_unpack_iter(self): 939 a, b = 1, 2 940 self.assertEqual((a, b), (1, 2)) 941 942 a, b, c = IteratingSequenceClass(3) 943 self.assertEqual((a, b, c), (0, 1, 2)) 944 945 try: # too many values 946 a, b = IteratingSequenceClass(3) 947 except ValueError: 948 pass 949 else: 950 self.fail("should have raised ValueError") 951 952 try: # not enough values 953 a, b, c = IteratingSequenceClass(2) 954 except ValueError: 955 pass 956 else: 957 self.fail("should have raised ValueError") 958 959 try: # not iterable 960 a, b, c = len 961 except TypeError: 962 pass 963 else: 964 self.fail("should have raised TypeError") 965 966 a, b, c = {1: 42, 2: 42, 3: 42}.values() 967 self.assertEqual((a, b, c), (42, 42, 42)) 968 969 f = open(TESTFN, "w", encoding="utf-8") 970 lines = ("a\n", "bb\n", "ccc\n") 971 try: 972 for line in lines: 973 f.write(line) 974 finally: 975 f.close() 976 f = open(TESTFN, "r", encoding="utf-8") 977 try: 978 a, b, c = f 979 self.assertEqual((a, b, c), lines) 980 finally: 981 f.close() 982 try: 983 unlink(TESTFN) 984 except OSError: 985 pass 986 987 (a, b), (c,) = IteratingSequenceClass(2), {42: 24} 988 self.assertEqual((a, b, c), (0, 1, 42)) 989 990 991 @cpython_only 992 def test_ref_counting_behavior(self): 993 class C(object): 994 count = 0 995 def __new__(cls): 996 cls.count += 1 997 return object.__new__(cls) 998 def __del__(self): 999 cls = self.__class__ 1000 assert cls.count > 0 1001 cls.count -= 1 1002 x = C() 1003 self.assertEqual(C.count, 1) 1004 del x 1005 self.assertEqual(C.count, 0) 1006 l = [C(), C(), C()] 1007 self.assertEqual(C.count, 3) 1008 try: 1009 a, b = iter(l) 1010 except ValueError: 1011 pass 1012 del l 1013 self.assertEqual(C.count, 0) 1014 1015 1016 # Make sure StopIteration is a "sink state". 1017 # This tests various things that weren't sink states in Python 2.2.1, 1018 # plus various things that always were fine. 1019 1020 def test_sinkstate_list(self): 1021 # This used to fail 1022 a = list(range(5)) 1023 b = iter(a) 1024 self.assertEqual(list(b), list(range(5))) 1025 a.extend(range(5, 10)) 1026 self.assertEqual(list(b), []) 1027 1028 def test_sinkstate_tuple(self): 1029 a = (0, 1, 2, 3, 4) 1030 b = iter(a) 1031 self.assertEqual(list(b), list(range(5))) 1032 self.assertEqual(list(b), []) 1033 1034 def test_sinkstate_string(self): 1035 a = "abcde" 1036 b = iter(a) 1037 self.assertEqual(list(b), ['a', 'b', 'c', 'd', 'e']) 1038 self.assertEqual(list(b), []) 1039 1040 def test_sinkstate_sequence(self): 1041 # This used to fail 1042 a = SequenceClass(5) 1043 b = iter(a) 1044 self.assertEqual(list(b), list(range(5))) 1045 a.n = 10 1046 self.assertEqual(list(b), []) 1047 1048 def test_sinkstate_callable(self): 1049 # This used to fail 1050 def spam(state=[0]): 1051 i = state[0] 1052 state[0] = i+1 1053 if i == 10: 1054 raise AssertionError("shouldn't have gotten this far") 1055 return i 1056 b = iter(spam, 5) 1057 self.assertEqual(list(b), list(range(5))) 1058 self.assertEqual(list(b), []) 1059 1060 def test_sinkstate_dict(self): 1061 # XXX For a more thorough test, see towards the end of: 1062 # http://mail.python.org/pipermail/python-dev/2002-July/026512.html 1063 a = {1:1, 2:2, 0:0, 4:4, 3:3} 1064 for b in iter(a), a.keys(), a.items(), a.values(): 1065 b = iter(a) 1066 self.assertEqual(len(list(b)), 5) 1067 self.assertEqual(list(b), []) 1068 1069 def test_sinkstate_yield(self): 1070 def gen(): 1071 for i in range(5): 1072 yield i 1073 b = gen() 1074 self.assertEqual(list(b), list(range(5))) 1075 self.assertEqual(list(b), []) 1076 1077 def test_sinkstate_range(self): 1078 a = range(5) 1079 b = iter(a) 1080 self.assertEqual(list(b), list(range(5))) 1081 self.assertEqual(list(b), []) 1082 1083 def test_sinkstate_enumerate(self): 1084 a = range(5) 1085 e = enumerate(a) 1086 b = iter(e) 1087 self.assertEqual(list(b), list(zip(range(5), range(5)))) 1088 self.assertEqual(list(b), []) 1089 1090 def test_3720(self): 1091 # Avoid a crash, when an iterator deletes its next() method. 1092 class BadIterator(object): 1093 def __iter__(self): 1094 return self 1095 def __next__(self): 1096 del BadIterator.__next__ 1097 return 1 1098 1099 try: 1100 for i in BadIterator() : 1101 pass 1102 except TypeError: 1103 pass 1104 1105 def test_extending_list_with_iterator_does_not_segfault(self): 1106 # The code to extend a list with an iterator has a fair 1107 # amount of nontrivial logic in terms of guessing how 1108 # much memory to allocate in advance, "stealing" refs, 1109 # and then shrinking at the end. This is a basic smoke 1110 # test for that scenario. 1111 def gen(): 1112 for i in range(500): 1113 yield i 1114 lst = [0] * 500 1115 for i in range(240): 1116 lst.pop(0) 1117 lst.extend(gen()) 1118 self.assertEqual(len(lst), 760) 1119 1120 @cpython_only 1121 def test_iter_overflow(self): 1122 # Test for the issue 22939 1123 it = iter(UnlimitedSequenceClass()) 1124 # Manually set `it_index` to PY_SSIZE_T_MAX-2 without a loop 1125 it.__setstate__(sys.maxsize - 2) 1126 self.assertEqual(next(it), sys.maxsize - 2) 1127 self.assertEqual(next(it), sys.maxsize - 1) 1128 with self.assertRaises(OverflowError): 1129 next(it) 1130 # Check that Overflow error is always raised 1131 with self.assertRaises(OverflowError): 1132 next(it) 1133 1134 def test_iter_neg_setstate(self): 1135 it = iter(UnlimitedSequenceClass()) 1136 it.__setstate__(-42) 1137 self.assertEqual(next(it), 0) 1138 self.assertEqual(next(it), 1) 1139 1140 def test_free_after_iterating(self): 1141 check_free_after_iterating(self, iter, SequenceClass, (0,)) 1142 1143 def test_error_iter(self): 1144 for typ in (DefaultIterClass, NoIterClass): 1145 self.assertRaises(TypeError, iter, typ()) 1146 self.assertRaises(ZeroDivisionError, iter, BadIterableClass()) 1147 1148 def test_exception_locations(self): 1149 # The location of an exception raised from __init__ or 1150 # __next__ should should be the iterator expression 1151 1152 def init_raises(): 1153 try: 1154 for x in BrokenIter(init_raises=True): 1155 pass 1156 except Exception as e: 1157 return e 1158 1159 def next_raises(): 1160 try: 1161 for x in BrokenIter(next_raises=True): 1162 pass 1163 except Exception as e: 1164 return e 1165 1166 def iter_raises(): 1167 try: 1168 for x in BrokenIter(iter_raises=True): 1169 pass 1170 except Exception as e: 1171 return e 1172 1173 for func, expected in [(init_raises, "BrokenIter(init_raises=True)"), 1174 (next_raises, "BrokenIter(next_raises=True)"), 1175 (iter_raises, "BrokenIter(iter_raises=True)"), 1176 ]: 1177 with self.subTest(func): 1178 exc = func() 1179 f = traceback.extract_tb(exc.__traceback__)[0] 1180 indent = 16 1181 co = func.__code__ 1182 self.assertEqual(f.lineno, co.co_firstlineno + 2) 1183 self.assertEqual(f.end_lineno, co.co_firstlineno + 2) 1184 self.assertEqual(f.line[f.colno - indent : f.end_colno - indent], 1185 expected) 1186 1187 1188 1189if __name__ == "__main__": 1190 unittest.main() 1191