1# Test iterators. 2 3import sys 4import unittest 5from test.support import cpython_only 6from test.support.os_helper import TESTFN, unlink 7from test.support import check_free_after_iterating, ALWAYS_EQ, NEVER_EQ 8import pickle 9import collections.abc 10 11# Test result of triple loop (too big to inline) 12TRIPLETS = [(0, 0, 0), (0, 0, 1), (0, 0, 2), 13 (0, 1, 0), (0, 1, 1), (0, 1, 2), 14 (0, 2, 0), (0, 2, 1), (0, 2, 2), 15 16 (1, 0, 0), (1, 0, 1), (1, 0, 2), 17 (1, 1, 0), (1, 1, 1), (1, 1, 2), 18 (1, 2, 0), (1, 2, 1), (1, 2, 2), 19 20 (2, 0, 0), (2, 0, 1), (2, 0, 2), 21 (2, 1, 0), (2, 1, 1), (2, 1, 2), 22 (2, 2, 0), (2, 2, 1), (2, 2, 2)] 23 24# Helper classes 25 26class BasicIterClass: 27 def __init__(self, n): 28 self.n = n 29 self.i = 0 30 def __next__(self): 31 res = self.i 32 if res >= self.n: 33 raise StopIteration 34 self.i = res + 1 35 return res 36 def __iter__(self): 37 return self 38 39class IteratingSequenceClass: 40 def __init__(self, n): 41 self.n = n 42 def __iter__(self): 43 return BasicIterClass(self.n) 44 45class IteratorProxyClass: 46 def __init__(self, i): 47 self.i = i 48 def __next__(self): 49 return next(self.i) 50 def __iter__(self): 51 return self 52 53class SequenceClass: 54 def __init__(self, n): 55 self.n = n 56 def __getitem__(self, i): 57 if 0 <= i < self.n: 58 return i 59 else: 60 raise IndexError 61 62class SequenceProxyClass: 63 def __init__(self, s): 64 self.s = s 65 def __getitem__(self, i): 66 return self.s[i] 67 68class UnlimitedSequenceClass: 69 def __getitem__(self, i): 70 return i 71 72class DefaultIterClass: 73 pass 74 75class NoIterClass: 76 def __getitem__(self, i): 77 return i 78 __iter__ = None 79 80class BadIterableClass: 81 def __iter__(self): 82 raise ZeroDivisionError 83 84# Main test suite 85 86class TestCase(unittest.TestCase): 87 88 # Helper to check that an iterator returns a given sequence 89 def check_iterator(self, it, seq, pickle=True): 90 if pickle: 91 self.check_pickle(it, seq) 92 res = [] 93 while 1: 94 try: 95 val = next(it) 96 except StopIteration: 97 break 98 res.append(val) 99 self.assertEqual(res, seq) 100 101 # Helper to check that a for loop generates a given sequence 102 def check_for_loop(self, expr, seq, pickle=True): 103 if pickle: 104 self.check_pickle(iter(expr), seq) 105 res = [] 106 for val in expr: 107 res.append(val) 108 self.assertEqual(res, seq) 109 110 # Helper to check picklability 111 def check_pickle(self, itorg, seq): 112 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 113 d = pickle.dumps(itorg, proto) 114 it = pickle.loads(d) 115 # Cannot assert type equality because dict iterators unpickle as list 116 # iterators. 117 # self.assertEqual(type(itorg), type(it)) 118 self.assertTrue(isinstance(it, collections.abc.Iterator)) 119 self.assertEqual(list(it), seq) 120 121 it = pickle.loads(d) 122 try: 123 next(it) 124 except StopIteration: 125 continue 126 d = pickle.dumps(it, proto) 127 it = pickle.loads(d) 128 self.assertEqual(list(it), seq[1:]) 129 130 # Test basic use of iter() function 131 def test_iter_basic(self): 132 self.check_iterator(iter(range(10)), list(range(10))) 133 134 # Test that iter(iter(x)) is the same as iter(x) 135 def test_iter_idempotency(self): 136 seq = list(range(10)) 137 it = iter(seq) 138 it2 = iter(it) 139 self.assertTrue(it is it2) 140 141 # Test that for loops over iterators work 142 def test_iter_for_loop(self): 143 self.check_for_loop(iter(range(10)), list(range(10))) 144 145 # Test several independent iterators over the same list 146 def test_iter_independence(self): 147 seq = range(3) 148 res = [] 149 for i in iter(seq): 150 for j in iter(seq): 151 for k in iter(seq): 152 res.append((i, j, k)) 153 self.assertEqual(res, TRIPLETS) 154 155 # Test triple list comprehension using iterators 156 def test_nested_comprehensions_iter(self): 157 seq = range(3) 158 res = [(i, j, k) 159 for i in iter(seq) for j in iter(seq) for k in iter(seq)] 160 self.assertEqual(res, TRIPLETS) 161 162 # Test triple list comprehension without iterators 163 def test_nested_comprehensions_for(self): 164 seq = range(3) 165 res = [(i, j, k) for i in seq for j in seq for k in seq] 166 self.assertEqual(res, TRIPLETS) 167 168 # Test a class with __iter__ in a for loop 169 def test_iter_class_for(self): 170 self.check_for_loop(IteratingSequenceClass(10), list(range(10))) 171 172 # Test a class with __iter__ with explicit iter() 173 def test_iter_class_iter(self): 174 self.check_iterator(iter(IteratingSequenceClass(10)), list(range(10))) 175 176 # Test for loop on a sequence class without __iter__ 177 def test_seq_class_for(self): 178 self.check_for_loop(SequenceClass(10), list(range(10))) 179 180 # Test iter() on a sequence class without __iter__ 181 def test_seq_class_iter(self): 182 self.check_iterator(iter(SequenceClass(10)), list(range(10))) 183 184 def test_mutating_seq_class_iter_pickle(self): 185 orig = SequenceClass(5) 186 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 187 # initial iterator 188 itorig = iter(orig) 189 d = pickle.dumps((itorig, orig), proto) 190 it, seq = pickle.loads(d) 191 seq.n = 7 192 self.assertIs(type(it), type(itorig)) 193 self.assertEqual(list(it), list(range(7))) 194 195 # running iterator 196 next(itorig) 197 d = pickle.dumps((itorig, orig), proto) 198 it, seq = pickle.loads(d) 199 seq.n = 7 200 self.assertIs(type(it), type(itorig)) 201 self.assertEqual(list(it), list(range(1, 7))) 202 203 # empty iterator 204 for i in range(1, 5): 205 next(itorig) 206 d = pickle.dumps((itorig, orig), proto) 207 it, seq = pickle.loads(d) 208 seq.n = 7 209 self.assertIs(type(it), type(itorig)) 210 self.assertEqual(list(it), list(range(5, 7))) 211 212 # exhausted iterator 213 self.assertRaises(StopIteration, next, itorig) 214 d = pickle.dumps((itorig, orig), proto) 215 it, seq = pickle.loads(d) 216 seq.n = 7 217 self.assertTrue(isinstance(it, collections.abc.Iterator)) 218 self.assertEqual(list(it), []) 219 220 def test_mutating_seq_class_exhausted_iter(self): 221 a = SequenceClass(5) 222 exhit = iter(a) 223 empit = iter(a) 224 for x in exhit: # exhaust the iterator 225 next(empit) # not exhausted 226 a.n = 7 227 self.assertEqual(list(exhit), []) 228 self.assertEqual(list(empit), [5, 6]) 229 self.assertEqual(list(a), [0, 1, 2, 3, 4, 5, 6]) 230 231 # Test a new_style class with __iter__ but no next() method 232 def test_new_style_iter_class(self): 233 class IterClass(object): 234 def __iter__(self): 235 return self 236 self.assertRaises(TypeError, iter, IterClass()) 237 238 # Test two-argument iter() with callable instance 239 def test_iter_callable(self): 240 class C: 241 def __init__(self): 242 self.i = 0 243 def __call__(self): 244 i = self.i 245 self.i = i + 1 246 if i > 100: 247 raise IndexError # Emergency stop 248 return i 249 self.check_iterator(iter(C(), 10), list(range(10)), pickle=False) 250 251 # Test two-argument iter() with function 252 def test_iter_function(self): 253 def spam(state=[0]): 254 i = state[0] 255 state[0] = i+1 256 return i 257 self.check_iterator(iter(spam, 10), list(range(10)), pickle=False) 258 259 # Test two-argument iter() with function that raises StopIteration 260 def test_iter_function_stop(self): 261 def spam(state=[0]): 262 i = state[0] 263 if i == 10: 264 raise StopIteration 265 state[0] = i+1 266 return i 267 self.check_iterator(iter(spam, 20), list(range(10)), pickle=False) 268 269 # Test exception propagation through function iterator 270 def test_exception_function(self): 271 def spam(state=[0]): 272 i = state[0] 273 state[0] = i+1 274 if i == 10: 275 raise RuntimeError 276 return i 277 res = [] 278 try: 279 for x in iter(spam, 20): 280 res.append(x) 281 except RuntimeError: 282 self.assertEqual(res, list(range(10))) 283 else: 284 self.fail("should have raised RuntimeError") 285 286 # Test exception propagation through sequence iterator 287 def test_exception_sequence(self): 288 class MySequenceClass(SequenceClass): 289 def __getitem__(self, i): 290 if i == 10: 291 raise RuntimeError 292 return SequenceClass.__getitem__(self, i) 293 res = [] 294 try: 295 for x in MySequenceClass(20): 296 res.append(x) 297 except RuntimeError: 298 self.assertEqual(res, list(range(10))) 299 else: 300 self.fail("should have raised RuntimeError") 301 302 # Test for StopIteration from __getitem__ 303 def test_stop_sequence(self): 304 class MySequenceClass(SequenceClass): 305 def __getitem__(self, i): 306 if i == 10: 307 raise StopIteration 308 return SequenceClass.__getitem__(self, i) 309 self.check_for_loop(MySequenceClass(20), list(range(10)), pickle=False) 310 311 # Test a big range 312 def test_iter_big_range(self): 313 self.check_for_loop(iter(range(10000)), list(range(10000))) 314 315 # Test an empty list 316 def test_iter_empty(self): 317 self.check_for_loop(iter([]), []) 318 319 # Test a tuple 320 def test_iter_tuple(self): 321 self.check_for_loop(iter((0,1,2,3,4,5,6,7,8,9)), list(range(10))) 322 323 # Test a range 324 def test_iter_range(self): 325 self.check_for_loop(iter(range(10)), list(range(10))) 326 327 # Test a string 328 def test_iter_string(self): 329 self.check_for_loop(iter("abcde"), ["a", "b", "c", "d", "e"]) 330 331 # Test a directory 332 def test_iter_dict(self): 333 dict = {} 334 for i in range(10): 335 dict[i] = None 336 self.check_for_loop(dict, list(dict.keys())) 337 338 # Test a file 339 def test_iter_file(self): 340 f = open(TESTFN, "w", encoding="utf-8") 341 try: 342 for i in range(5): 343 f.write("%d\n" % i) 344 finally: 345 f.close() 346 f = open(TESTFN, "r", encoding="utf-8") 347 try: 348 self.check_for_loop(f, ["0\n", "1\n", "2\n", "3\n", "4\n"], pickle=False) 349 self.check_for_loop(f, [], pickle=False) 350 finally: 351 f.close() 352 try: 353 unlink(TESTFN) 354 except OSError: 355 pass 356 357 # Test list()'s use of iterators. 358 def test_builtin_list(self): 359 self.assertEqual(list(SequenceClass(5)), list(range(5))) 360 self.assertEqual(list(SequenceClass(0)), []) 361 self.assertEqual(list(()), []) 362 363 d = {"one": 1, "two": 2, "three": 3} 364 self.assertEqual(list(d), list(d.keys())) 365 366 self.assertRaises(TypeError, list, list) 367 self.assertRaises(TypeError, list, 42) 368 369 f = open(TESTFN, "w", encoding="utf-8") 370 try: 371 for i in range(5): 372 f.write("%d\n" % i) 373 finally: 374 f.close() 375 f = open(TESTFN, "r", encoding="utf-8") 376 try: 377 self.assertEqual(list(f), ["0\n", "1\n", "2\n", "3\n", "4\n"]) 378 f.seek(0, 0) 379 self.assertEqual(list(f), 380 ["0\n", "1\n", "2\n", "3\n", "4\n"]) 381 finally: 382 f.close() 383 try: 384 unlink(TESTFN) 385 except OSError: 386 pass 387 388 # Test tuples()'s use of iterators. 389 def test_builtin_tuple(self): 390 self.assertEqual(tuple(SequenceClass(5)), (0, 1, 2, 3, 4)) 391 self.assertEqual(tuple(SequenceClass(0)), ()) 392 self.assertEqual(tuple([]), ()) 393 self.assertEqual(tuple(()), ()) 394 self.assertEqual(tuple("abc"), ("a", "b", "c")) 395 396 d = {"one": 1, "two": 2, "three": 3} 397 self.assertEqual(tuple(d), tuple(d.keys())) 398 399 self.assertRaises(TypeError, tuple, list) 400 self.assertRaises(TypeError, tuple, 42) 401 402 f = open(TESTFN, "w", encoding="utf-8") 403 try: 404 for i in range(5): 405 f.write("%d\n" % i) 406 finally: 407 f.close() 408 f = open(TESTFN, "r", encoding="utf-8") 409 try: 410 self.assertEqual(tuple(f), ("0\n", "1\n", "2\n", "3\n", "4\n")) 411 f.seek(0, 0) 412 self.assertEqual(tuple(f), 413 ("0\n", "1\n", "2\n", "3\n", "4\n")) 414 finally: 415 f.close() 416 try: 417 unlink(TESTFN) 418 except OSError: 419 pass 420 421 # Test filter()'s use of iterators. 422 def test_builtin_filter(self): 423 self.assertEqual(list(filter(None, SequenceClass(5))), 424 list(range(1, 5))) 425 self.assertEqual(list(filter(None, SequenceClass(0))), []) 426 self.assertEqual(list(filter(None, ())), []) 427 self.assertEqual(list(filter(None, "abc")), ["a", "b", "c"]) 428 429 d = {"one": 1, "two": 2, "three": 3} 430 self.assertEqual(list(filter(None, d)), list(d.keys())) 431 432 self.assertRaises(TypeError, filter, None, list) 433 self.assertRaises(TypeError, filter, None, 42) 434 435 class Boolean: 436 def __init__(self, truth): 437 self.truth = truth 438 def __bool__(self): 439 return self.truth 440 bTrue = Boolean(True) 441 bFalse = Boolean(False) 442 443 class Seq: 444 def __init__(self, *args): 445 self.vals = args 446 def __iter__(self): 447 class SeqIter: 448 def __init__(self, vals): 449 self.vals = vals 450 self.i = 0 451 def __iter__(self): 452 return self 453 def __next__(self): 454 i = self.i 455 self.i = i + 1 456 if i < len(self.vals): 457 return self.vals[i] 458 else: 459 raise StopIteration 460 return SeqIter(self.vals) 461 462 seq = Seq(*([bTrue, bFalse] * 25)) 463 self.assertEqual(list(filter(lambda x: not x, seq)), [bFalse]*25) 464 self.assertEqual(list(filter(lambda x: not x, iter(seq))), [bFalse]*25) 465 466 # Test max() and min()'s use of iterators. 467 def test_builtin_max_min(self): 468 self.assertEqual(max(SequenceClass(5)), 4) 469 self.assertEqual(min(SequenceClass(5)), 0) 470 self.assertEqual(max(8, -1), 8) 471 self.assertEqual(min(8, -1), -1) 472 473 d = {"one": 1, "two": 2, "three": 3} 474 self.assertEqual(max(d), "two") 475 self.assertEqual(min(d), "one") 476 self.assertEqual(max(d.values()), 3) 477 self.assertEqual(min(iter(d.values())), 1) 478 479 f = open(TESTFN, "w", encoding="utf-8") 480 try: 481 f.write("medium line\n") 482 f.write("xtra large line\n") 483 f.write("itty-bitty line\n") 484 finally: 485 f.close() 486 f = open(TESTFN, "r", encoding="utf-8") 487 try: 488 self.assertEqual(min(f), "itty-bitty line\n") 489 f.seek(0, 0) 490 self.assertEqual(max(f), "xtra large line\n") 491 finally: 492 f.close() 493 try: 494 unlink(TESTFN) 495 except OSError: 496 pass 497 498 # Test map()'s use of iterators. 499 def test_builtin_map(self): 500 self.assertEqual(list(map(lambda x: x+1, SequenceClass(5))), 501 list(range(1, 6))) 502 503 d = {"one": 1, "two": 2, "three": 3} 504 self.assertEqual(list(map(lambda k, d=d: (k, d[k]), d)), 505 list(d.items())) 506 dkeys = list(d.keys()) 507 expected = [(i < len(d) and dkeys[i] or None, 508 i, 509 i < len(d) and dkeys[i] or None) 510 for i in range(3)] 511 512 f = open(TESTFN, "w", encoding="utf-8") 513 try: 514 for i in range(10): 515 f.write("xy" * i + "\n") # line i has len 2*i+1 516 finally: 517 f.close() 518 f = open(TESTFN, "r", encoding="utf-8") 519 try: 520 self.assertEqual(list(map(len, f)), list(range(1, 21, 2))) 521 finally: 522 f.close() 523 try: 524 unlink(TESTFN) 525 except OSError: 526 pass 527 528 # Test zip()'s use of iterators. 529 def test_builtin_zip(self): 530 self.assertEqual(list(zip()), []) 531 self.assertEqual(list(zip(*[])), []) 532 self.assertEqual(list(zip(*[(1, 2), 'ab'])), [(1, 'a'), (2, 'b')]) 533 534 self.assertRaises(TypeError, zip, None) 535 self.assertRaises(TypeError, zip, range(10), 42) 536 self.assertRaises(TypeError, zip, range(10), zip) 537 538 self.assertEqual(list(zip(IteratingSequenceClass(3))), 539 [(0,), (1,), (2,)]) 540 self.assertEqual(list(zip(SequenceClass(3))), 541 [(0,), (1,), (2,)]) 542 543 d = {"one": 1, "two": 2, "three": 3} 544 self.assertEqual(list(d.items()), list(zip(d, d.values()))) 545 546 # Generate all ints starting at constructor arg. 547 class IntsFrom: 548 def __init__(self, start): 549 self.i = start 550 551 def __iter__(self): 552 return self 553 554 def __next__(self): 555 i = self.i 556 self.i = i+1 557 return i 558 559 f = open(TESTFN, "w", encoding="utf-8") 560 try: 561 f.write("a\n" "bbb\n" "cc\n") 562 finally: 563 f.close() 564 f = open(TESTFN, "r", encoding="utf-8") 565 try: 566 self.assertEqual(list(zip(IntsFrom(0), f, IntsFrom(-100))), 567 [(0, "a\n", -100), 568 (1, "bbb\n", -99), 569 (2, "cc\n", -98)]) 570 finally: 571 f.close() 572 try: 573 unlink(TESTFN) 574 except OSError: 575 pass 576 577 self.assertEqual(list(zip(range(5))), [(i,) for i in range(5)]) 578 579 # Classes that lie about their lengths. 580 class NoGuessLen5: 581 def __getitem__(self, i): 582 if i >= 5: 583 raise IndexError 584 return i 585 586 class Guess3Len5(NoGuessLen5): 587 def __len__(self): 588 return 3 589 590 class Guess30Len5(NoGuessLen5): 591 def __len__(self): 592 return 30 593 594 def lzip(*args): 595 return list(zip(*args)) 596 597 self.assertEqual(len(Guess3Len5()), 3) 598 self.assertEqual(len(Guess30Len5()), 30) 599 self.assertEqual(lzip(NoGuessLen5()), lzip(range(5))) 600 self.assertEqual(lzip(Guess3Len5()), lzip(range(5))) 601 self.assertEqual(lzip(Guess30Len5()), lzip(range(5))) 602 603 expected = [(i, i) for i in range(5)] 604 for x in NoGuessLen5(), Guess3Len5(), Guess30Len5(): 605 for y in NoGuessLen5(), Guess3Len5(), Guess30Len5(): 606 self.assertEqual(lzip(x, y), expected) 607 608 def test_unicode_join_endcase(self): 609 610 # This class inserts a Unicode object into its argument's natural 611 # iteration, in the 3rd position. 612 class OhPhooey: 613 def __init__(self, seq): 614 self.it = iter(seq) 615 self.i = 0 616 617 def __iter__(self): 618 return self 619 620 def __next__(self): 621 i = self.i 622 self.i = i+1 623 if i == 2: 624 return "fooled you!" 625 return next(self.it) 626 627 f = open(TESTFN, "w", encoding="utf-8") 628 try: 629 f.write("a\n" + "b\n" + "c\n") 630 finally: 631 f.close() 632 633 f = open(TESTFN, "r", encoding="utf-8") 634 # Nasty: string.join(s) can't know whether unicode.join() is needed 635 # until it's seen all of s's elements. But in this case, f's 636 # iterator cannot be restarted. So what we're testing here is 637 # whether string.join() can manage to remember everything it's seen 638 # and pass that on to unicode.join(). 639 try: 640 got = " - ".join(OhPhooey(f)) 641 self.assertEqual(got, "a\n - b\n - fooled you! - c\n") 642 finally: 643 f.close() 644 try: 645 unlink(TESTFN) 646 except OSError: 647 pass 648 649 # Test iterators with 'x in y' and 'x not in y'. 650 def test_in_and_not_in(self): 651 for sc5 in IteratingSequenceClass(5), SequenceClass(5): 652 for i in range(5): 653 self.assertIn(i, sc5) 654 for i in "abc", -1, 5, 42.42, (3, 4), [], {1: 1}, 3-12j, sc5: 655 self.assertNotIn(i, sc5) 656 657 self.assertIn(ALWAYS_EQ, IteratorProxyClass(iter([1]))) 658 self.assertIn(ALWAYS_EQ, SequenceProxyClass([1])) 659 self.assertNotIn(ALWAYS_EQ, IteratorProxyClass(iter([NEVER_EQ]))) 660 self.assertNotIn(ALWAYS_EQ, SequenceProxyClass([NEVER_EQ])) 661 self.assertIn(NEVER_EQ, IteratorProxyClass(iter([ALWAYS_EQ]))) 662 self.assertIn(NEVER_EQ, SequenceProxyClass([ALWAYS_EQ])) 663 664 self.assertRaises(TypeError, lambda: 3 in 12) 665 self.assertRaises(TypeError, lambda: 3 not in map) 666 self.assertRaises(ZeroDivisionError, lambda: 3 in BadIterableClass()) 667 668 d = {"one": 1, "two": 2, "three": 3, 1j: 2j} 669 for k in d: 670 self.assertIn(k, d) 671 self.assertNotIn(k, d.values()) 672 for v in d.values(): 673 self.assertIn(v, d.values()) 674 self.assertNotIn(v, d) 675 for k, v in d.items(): 676 self.assertIn((k, v), d.items()) 677 self.assertNotIn((v, k), d.items()) 678 679 f = open(TESTFN, "w", encoding="utf-8") 680 try: 681 f.write("a\n" "b\n" "c\n") 682 finally: 683 f.close() 684 f = open(TESTFN, "r", encoding="utf-8") 685 try: 686 for chunk in "abc": 687 f.seek(0, 0) 688 self.assertNotIn(chunk, f) 689 f.seek(0, 0) 690 self.assertIn((chunk + "\n"), f) 691 finally: 692 f.close() 693 try: 694 unlink(TESTFN) 695 except OSError: 696 pass 697 698 # Test iterators with operator.countOf (PySequence_Count). 699 def test_countOf(self): 700 from operator import countOf 701 self.assertEqual(countOf([1,2,2,3,2,5], 2), 3) 702 self.assertEqual(countOf((1,2,2,3,2,5), 2), 3) 703 self.assertEqual(countOf("122325", "2"), 3) 704 self.assertEqual(countOf("122325", "6"), 0) 705 706 self.assertRaises(TypeError, countOf, 42, 1) 707 self.assertRaises(TypeError, countOf, countOf, countOf) 708 709 d = {"one": 3, "two": 3, "three": 3, 1j: 2j} 710 for k in d: 711 self.assertEqual(countOf(d, k), 1) 712 self.assertEqual(countOf(d.values(), 3), 3) 713 self.assertEqual(countOf(d.values(), 2j), 1) 714 self.assertEqual(countOf(d.values(), 1j), 0) 715 716 f = open(TESTFN, "w", encoding="utf-8") 717 try: 718 f.write("a\n" "b\n" "c\n" "b\n") 719 finally: 720 f.close() 721 f = open(TESTFN, "r", encoding="utf-8") 722 try: 723 for letter, count in ("a", 1), ("b", 2), ("c", 1), ("d", 0): 724 f.seek(0, 0) 725 self.assertEqual(countOf(f, letter + "\n"), count) 726 finally: 727 f.close() 728 try: 729 unlink(TESTFN) 730 except OSError: 731 pass 732 733 # Test iterators with operator.indexOf (PySequence_Index). 734 def test_indexOf(self): 735 from operator import indexOf 736 self.assertEqual(indexOf([1,2,2,3,2,5], 1), 0) 737 self.assertEqual(indexOf((1,2,2,3,2,5), 2), 1) 738 self.assertEqual(indexOf((1,2,2,3,2,5), 3), 3) 739 self.assertEqual(indexOf((1,2,2,3,2,5), 5), 5) 740 self.assertRaises(ValueError, indexOf, (1,2,2,3,2,5), 0) 741 self.assertRaises(ValueError, indexOf, (1,2,2,3,2,5), 6) 742 743 self.assertEqual(indexOf("122325", "2"), 1) 744 self.assertEqual(indexOf("122325", "5"), 5) 745 self.assertRaises(ValueError, indexOf, "122325", "6") 746 747 self.assertRaises(TypeError, indexOf, 42, 1) 748 self.assertRaises(TypeError, indexOf, indexOf, indexOf) 749 self.assertRaises(ZeroDivisionError, indexOf, BadIterableClass(), 1) 750 751 f = open(TESTFN, "w", encoding="utf-8") 752 try: 753 f.write("a\n" "b\n" "c\n" "d\n" "e\n") 754 finally: 755 f.close() 756 f = open(TESTFN, "r", encoding="utf-8") 757 try: 758 fiter = iter(f) 759 self.assertEqual(indexOf(fiter, "b\n"), 1) 760 self.assertEqual(indexOf(fiter, "d\n"), 1) 761 self.assertEqual(indexOf(fiter, "e\n"), 0) 762 self.assertRaises(ValueError, indexOf, fiter, "a\n") 763 finally: 764 f.close() 765 try: 766 unlink(TESTFN) 767 except OSError: 768 pass 769 770 iclass = IteratingSequenceClass(3) 771 for i in range(3): 772 self.assertEqual(indexOf(iclass, i), i) 773 self.assertRaises(ValueError, indexOf, iclass, -1) 774 775 # Test iterators with file.writelines(). 776 def test_writelines(self): 777 f = open(TESTFN, "w", encoding="utf-8") 778 779 try: 780 self.assertRaises(TypeError, f.writelines, None) 781 self.assertRaises(TypeError, f.writelines, 42) 782 783 f.writelines(["1\n", "2\n"]) 784 f.writelines(("3\n", "4\n")) 785 f.writelines({'5\n': None}) 786 f.writelines({}) 787 788 # Try a big chunk too. 789 class Iterator: 790 def __init__(self, start, finish): 791 self.start = start 792 self.finish = finish 793 self.i = self.start 794 795 def __next__(self): 796 if self.i >= self.finish: 797 raise StopIteration 798 result = str(self.i) + '\n' 799 self.i += 1 800 return result 801 802 def __iter__(self): 803 return self 804 805 class Whatever: 806 def __init__(self, start, finish): 807 self.start = start 808 self.finish = finish 809 810 def __iter__(self): 811 return Iterator(self.start, self.finish) 812 813 f.writelines(Whatever(6, 6+2000)) 814 f.close() 815 816 f = open(TESTFN, encoding="utf-8") 817 expected = [str(i) + "\n" for i in range(1, 2006)] 818 self.assertEqual(list(f), expected) 819 820 finally: 821 f.close() 822 try: 823 unlink(TESTFN) 824 except OSError: 825 pass 826 827 828 # Test iterators on RHS of unpacking assignments. 829 def test_unpack_iter(self): 830 a, b = 1, 2 831 self.assertEqual((a, b), (1, 2)) 832 833 a, b, c = IteratingSequenceClass(3) 834 self.assertEqual((a, b, c), (0, 1, 2)) 835 836 try: # too many values 837 a, b = IteratingSequenceClass(3) 838 except ValueError: 839 pass 840 else: 841 self.fail("should have raised ValueError") 842 843 try: # not enough values 844 a, b, c = IteratingSequenceClass(2) 845 except ValueError: 846 pass 847 else: 848 self.fail("should have raised ValueError") 849 850 try: # not iterable 851 a, b, c = len 852 except TypeError: 853 pass 854 else: 855 self.fail("should have raised TypeError") 856 857 a, b, c = {1: 42, 2: 42, 3: 42}.values() 858 self.assertEqual((a, b, c), (42, 42, 42)) 859 860 f = open(TESTFN, "w", encoding="utf-8") 861 lines = ("a\n", "bb\n", "ccc\n") 862 try: 863 for line in lines: 864 f.write(line) 865 finally: 866 f.close() 867 f = open(TESTFN, "r", encoding="utf-8") 868 try: 869 a, b, c = f 870 self.assertEqual((a, b, c), lines) 871 finally: 872 f.close() 873 try: 874 unlink(TESTFN) 875 except OSError: 876 pass 877 878 (a, b), (c,) = IteratingSequenceClass(2), {42: 24} 879 self.assertEqual((a, b, c), (0, 1, 42)) 880 881 882 @cpython_only 883 def test_ref_counting_behavior(self): 884 class C(object): 885 count = 0 886 def __new__(cls): 887 cls.count += 1 888 return object.__new__(cls) 889 def __del__(self): 890 cls = self.__class__ 891 assert cls.count > 0 892 cls.count -= 1 893 x = C() 894 self.assertEqual(C.count, 1) 895 del x 896 self.assertEqual(C.count, 0) 897 l = [C(), C(), C()] 898 self.assertEqual(C.count, 3) 899 try: 900 a, b = iter(l) 901 except ValueError: 902 pass 903 del l 904 self.assertEqual(C.count, 0) 905 906 907 # Make sure StopIteration is a "sink state". 908 # This tests various things that weren't sink states in Python 2.2.1, 909 # plus various things that always were fine. 910 911 def test_sinkstate_list(self): 912 # This used to fail 913 a = list(range(5)) 914 b = iter(a) 915 self.assertEqual(list(b), list(range(5))) 916 a.extend(range(5, 10)) 917 self.assertEqual(list(b), []) 918 919 def test_sinkstate_tuple(self): 920 a = (0, 1, 2, 3, 4) 921 b = iter(a) 922 self.assertEqual(list(b), list(range(5))) 923 self.assertEqual(list(b), []) 924 925 def test_sinkstate_string(self): 926 a = "abcde" 927 b = iter(a) 928 self.assertEqual(list(b), ['a', 'b', 'c', 'd', 'e']) 929 self.assertEqual(list(b), []) 930 931 def test_sinkstate_sequence(self): 932 # This used to fail 933 a = SequenceClass(5) 934 b = iter(a) 935 self.assertEqual(list(b), list(range(5))) 936 a.n = 10 937 self.assertEqual(list(b), []) 938 939 def test_sinkstate_callable(self): 940 # This used to fail 941 def spam(state=[0]): 942 i = state[0] 943 state[0] = i+1 944 if i == 10: 945 raise AssertionError("shouldn't have gotten this far") 946 return i 947 b = iter(spam, 5) 948 self.assertEqual(list(b), list(range(5))) 949 self.assertEqual(list(b), []) 950 951 def test_sinkstate_dict(self): 952 # XXX For a more thorough test, see towards the end of: 953 # http://mail.python.org/pipermail/python-dev/2002-July/026512.html 954 a = {1:1, 2:2, 0:0, 4:4, 3:3} 955 for b in iter(a), a.keys(), a.items(), a.values(): 956 b = iter(a) 957 self.assertEqual(len(list(b)), 5) 958 self.assertEqual(list(b), []) 959 960 def test_sinkstate_yield(self): 961 def gen(): 962 for i in range(5): 963 yield i 964 b = gen() 965 self.assertEqual(list(b), list(range(5))) 966 self.assertEqual(list(b), []) 967 968 def test_sinkstate_range(self): 969 a = range(5) 970 b = iter(a) 971 self.assertEqual(list(b), list(range(5))) 972 self.assertEqual(list(b), []) 973 974 def test_sinkstate_enumerate(self): 975 a = range(5) 976 e = enumerate(a) 977 b = iter(e) 978 self.assertEqual(list(b), list(zip(range(5), range(5)))) 979 self.assertEqual(list(b), []) 980 981 def test_3720(self): 982 # Avoid a crash, when an iterator deletes its next() method. 983 class BadIterator(object): 984 def __iter__(self): 985 return self 986 def __next__(self): 987 del BadIterator.__next__ 988 return 1 989 990 try: 991 for i in BadIterator() : 992 pass 993 except TypeError: 994 pass 995 996 def test_extending_list_with_iterator_does_not_segfault(self): 997 # The code to extend a list with an iterator has a fair 998 # amount of nontrivial logic in terms of guessing how 999 # much memory to allocate in advance, "stealing" refs, 1000 # and then shrinking at the end. This is a basic smoke 1001 # test for that scenario. 1002 def gen(): 1003 for i in range(500): 1004 yield i 1005 lst = [0] * 500 1006 for i in range(240): 1007 lst.pop(0) 1008 lst.extend(gen()) 1009 self.assertEqual(len(lst), 760) 1010 1011 @cpython_only 1012 def test_iter_overflow(self): 1013 # Test for the issue 22939 1014 it = iter(UnlimitedSequenceClass()) 1015 # Manually set `it_index` to PY_SSIZE_T_MAX-2 without a loop 1016 it.__setstate__(sys.maxsize - 2) 1017 self.assertEqual(next(it), sys.maxsize - 2) 1018 self.assertEqual(next(it), sys.maxsize - 1) 1019 with self.assertRaises(OverflowError): 1020 next(it) 1021 # Check that Overflow error is always raised 1022 with self.assertRaises(OverflowError): 1023 next(it) 1024 1025 def test_iter_neg_setstate(self): 1026 it = iter(UnlimitedSequenceClass()) 1027 it.__setstate__(-42) 1028 self.assertEqual(next(it), 0) 1029 self.assertEqual(next(it), 1) 1030 1031 def test_free_after_iterating(self): 1032 check_free_after_iterating(self, iter, SequenceClass, (0,)) 1033 1034 def test_error_iter(self): 1035 for typ in (DefaultIterClass, NoIterClass): 1036 self.assertRaises(TypeError, iter, typ()) 1037 self.assertRaises(ZeroDivisionError, iter, BadIterableClass()) 1038 1039 1040if __name__ == "__main__": 1041 unittest.main() 1042