• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Test iterators.
2
3import sys
4import unittest
5from test.support import cpython_only
6from test.support.os_helper import TESTFN, unlink
7from test.support import check_free_after_iterating, ALWAYS_EQ, NEVER_EQ
8import pickle
9import collections.abc
10
11# Test result of triple loop (too big to inline)
12TRIPLETS = [(0, 0, 0), (0, 0, 1), (0, 0, 2),
13            (0, 1, 0), (0, 1, 1), (0, 1, 2),
14            (0, 2, 0), (0, 2, 1), (0, 2, 2),
15
16            (1, 0, 0), (1, 0, 1), (1, 0, 2),
17            (1, 1, 0), (1, 1, 1), (1, 1, 2),
18            (1, 2, 0), (1, 2, 1), (1, 2, 2),
19
20            (2, 0, 0), (2, 0, 1), (2, 0, 2),
21            (2, 1, 0), (2, 1, 1), (2, 1, 2),
22            (2, 2, 0), (2, 2, 1), (2, 2, 2)]
23
24# Helper classes
25
26class BasicIterClass:
27    def __init__(self, n):
28        self.n = n
29        self.i = 0
30    def __next__(self):
31        res = self.i
32        if res >= self.n:
33            raise StopIteration
34        self.i = res + 1
35        return res
36    def __iter__(self):
37        return self
38
39class IteratingSequenceClass:
40    def __init__(self, n):
41        self.n = n
42    def __iter__(self):
43        return BasicIterClass(self.n)
44
45class IteratorProxyClass:
46    def __init__(self, i):
47        self.i = i
48    def __next__(self):
49        return next(self.i)
50    def __iter__(self):
51        return self
52
53class SequenceClass:
54    def __init__(self, n):
55        self.n = n
56    def __getitem__(self, i):
57        if 0 <= i < self.n:
58            return i
59        else:
60            raise IndexError
61
62class SequenceProxyClass:
63    def __init__(self, s):
64        self.s = s
65    def __getitem__(self, i):
66        return self.s[i]
67
68class UnlimitedSequenceClass:
69    def __getitem__(self, i):
70        return i
71
72class DefaultIterClass:
73    pass
74
75class NoIterClass:
76    def __getitem__(self, i):
77        return i
78    __iter__ = None
79
80class BadIterableClass:
81    def __iter__(self):
82        raise ZeroDivisionError
83
84# Main test suite
85
86class TestCase(unittest.TestCase):
87
88    # Helper to check that an iterator returns a given sequence
89    def check_iterator(self, it, seq, pickle=True):
90        if pickle:
91            self.check_pickle(it, seq)
92        res = []
93        while 1:
94            try:
95                val = next(it)
96            except StopIteration:
97                break
98            res.append(val)
99        self.assertEqual(res, seq)
100
101    # Helper to check that a for loop generates a given sequence
102    def check_for_loop(self, expr, seq, pickle=True):
103        if pickle:
104            self.check_pickle(iter(expr), seq)
105        res = []
106        for val in expr:
107            res.append(val)
108        self.assertEqual(res, seq)
109
110    # Helper to check picklability
111    def check_pickle(self, itorg, seq):
112        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
113            d = pickle.dumps(itorg, proto)
114            it = pickle.loads(d)
115            # Cannot assert type equality because dict iterators unpickle as list
116            # iterators.
117            # self.assertEqual(type(itorg), type(it))
118            self.assertTrue(isinstance(it, collections.abc.Iterator))
119            self.assertEqual(list(it), seq)
120
121            it = pickle.loads(d)
122            try:
123                next(it)
124            except StopIteration:
125                continue
126            d = pickle.dumps(it, proto)
127            it = pickle.loads(d)
128            self.assertEqual(list(it), seq[1:])
129
130    # Test basic use of iter() function
131    def test_iter_basic(self):
132        self.check_iterator(iter(range(10)), list(range(10)))
133
134    # Test that iter(iter(x)) is the same as iter(x)
135    def test_iter_idempotency(self):
136        seq = list(range(10))
137        it = iter(seq)
138        it2 = iter(it)
139        self.assertTrue(it is it2)
140
141    # Test that for loops over iterators work
142    def test_iter_for_loop(self):
143        self.check_for_loop(iter(range(10)), list(range(10)))
144
145    # Test several independent iterators over the same list
146    def test_iter_independence(self):
147        seq = range(3)
148        res = []
149        for i in iter(seq):
150            for j in iter(seq):
151                for k in iter(seq):
152                    res.append((i, j, k))
153        self.assertEqual(res, TRIPLETS)
154
155    # Test triple list comprehension using iterators
156    def test_nested_comprehensions_iter(self):
157        seq = range(3)
158        res = [(i, j, k)
159               for i in iter(seq) for j in iter(seq) for k in iter(seq)]
160        self.assertEqual(res, TRIPLETS)
161
162    # Test triple list comprehension without iterators
163    def test_nested_comprehensions_for(self):
164        seq = range(3)
165        res = [(i, j, k) for i in seq for j in seq for k in seq]
166        self.assertEqual(res, TRIPLETS)
167
168    # Test a class with __iter__ in a for loop
169    def test_iter_class_for(self):
170        self.check_for_loop(IteratingSequenceClass(10), list(range(10)))
171
172    # Test a class with __iter__ with explicit iter()
173    def test_iter_class_iter(self):
174        self.check_iterator(iter(IteratingSequenceClass(10)), list(range(10)))
175
176    # Test for loop on a sequence class without __iter__
177    def test_seq_class_for(self):
178        self.check_for_loop(SequenceClass(10), list(range(10)))
179
180    # Test iter() on a sequence class without __iter__
181    def test_seq_class_iter(self):
182        self.check_iterator(iter(SequenceClass(10)), list(range(10)))
183
184    def test_mutating_seq_class_iter_pickle(self):
185        orig = SequenceClass(5)
186        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
187            # initial iterator
188            itorig = iter(orig)
189            d = pickle.dumps((itorig, orig), proto)
190            it, seq = pickle.loads(d)
191            seq.n = 7
192            self.assertIs(type(it), type(itorig))
193            self.assertEqual(list(it), list(range(7)))
194
195            # running iterator
196            next(itorig)
197            d = pickle.dumps((itorig, orig), proto)
198            it, seq = pickle.loads(d)
199            seq.n = 7
200            self.assertIs(type(it), type(itorig))
201            self.assertEqual(list(it), list(range(1, 7)))
202
203            # empty iterator
204            for i in range(1, 5):
205                next(itorig)
206            d = pickle.dumps((itorig, orig), proto)
207            it, seq = pickle.loads(d)
208            seq.n = 7
209            self.assertIs(type(it), type(itorig))
210            self.assertEqual(list(it), list(range(5, 7)))
211
212            # exhausted iterator
213            self.assertRaises(StopIteration, next, itorig)
214            d = pickle.dumps((itorig, orig), proto)
215            it, seq = pickle.loads(d)
216            seq.n = 7
217            self.assertTrue(isinstance(it, collections.abc.Iterator))
218            self.assertEqual(list(it), [])
219
220    def test_mutating_seq_class_exhausted_iter(self):
221        a = SequenceClass(5)
222        exhit = iter(a)
223        empit = iter(a)
224        for x in exhit:  # exhaust the iterator
225            next(empit)  # not exhausted
226        a.n = 7
227        self.assertEqual(list(exhit), [])
228        self.assertEqual(list(empit), [5, 6])
229        self.assertEqual(list(a), [0, 1, 2, 3, 4, 5, 6])
230
231    # Test a new_style class with __iter__ but no next() method
232    def test_new_style_iter_class(self):
233        class IterClass(object):
234            def __iter__(self):
235                return self
236        self.assertRaises(TypeError, iter, IterClass())
237
238    # Test two-argument iter() with callable instance
239    def test_iter_callable(self):
240        class C:
241            def __init__(self):
242                self.i = 0
243            def __call__(self):
244                i = self.i
245                self.i = i + 1
246                if i > 100:
247                    raise IndexError # Emergency stop
248                return i
249        self.check_iterator(iter(C(), 10), list(range(10)), pickle=False)
250
251    # Test two-argument iter() with function
252    def test_iter_function(self):
253        def spam(state=[0]):
254            i = state[0]
255            state[0] = i+1
256            return i
257        self.check_iterator(iter(spam, 10), list(range(10)), pickle=False)
258
259    # Test two-argument iter() with function that raises StopIteration
260    def test_iter_function_stop(self):
261        def spam(state=[0]):
262            i = state[0]
263            if i == 10:
264                raise StopIteration
265            state[0] = i+1
266            return i
267        self.check_iterator(iter(spam, 20), list(range(10)), pickle=False)
268
269    # Test exception propagation through function iterator
270    def test_exception_function(self):
271        def spam(state=[0]):
272            i = state[0]
273            state[0] = i+1
274            if i == 10:
275                raise RuntimeError
276            return i
277        res = []
278        try:
279            for x in iter(spam, 20):
280                res.append(x)
281        except RuntimeError:
282            self.assertEqual(res, list(range(10)))
283        else:
284            self.fail("should have raised RuntimeError")
285
286    # Test exception propagation through sequence iterator
287    def test_exception_sequence(self):
288        class MySequenceClass(SequenceClass):
289            def __getitem__(self, i):
290                if i == 10:
291                    raise RuntimeError
292                return SequenceClass.__getitem__(self, i)
293        res = []
294        try:
295            for x in MySequenceClass(20):
296                res.append(x)
297        except RuntimeError:
298            self.assertEqual(res, list(range(10)))
299        else:
300            self.fail("should have raised RuntimeError")
301
302    # Test for StopIteration from __getitem__
303    def test_stop_sequence(self):
304        class MySequenceClass(SequenceClass):
305            def __getitem__(self, i):
306                if i == 10:
307                    raise StopIteration
308                return SequenceClass.__getitem__(self, i)
309        self.check_for_loop(MySequenceClass(20), list(range(10)), pickle=False)
310
311    # Test a big range
312    def test_iter_big_range(self):
313        self.check_for_loop(iter(range(10000)), list(range(10000)))
314
315    # Test an empty list
316    def test_iter_empty(self):
317        self.check_for_loop(iter([]), [])
318
319    # Test a tuple
320    def test_iter_tuple(self):
321        self.check_for_loop(iter((0,1,2,3,4,5,6,7,8,9)), list(range(10)))
322
323    # Test a range
324    def test_iter_range(self):
325        self.check_for_loop(iter(range(10)), list(range(10)))
326
327    # Test a string
328    def test_iter_string(self):
329        self.check_for_loop(iter("abcde"), ["a", "b", "c", "d", "e"])
330
331    # Test a directory
332    def test_iter_dict(self):
333        dict = {}
334        for i in range(10):
335            dict[i] = None
336        self.check_for_loop(dict, list(dict.keys()))
337
338    # Test a file
339    def test_iter_file(self):
340        f = open(TESTFN, "w", encoding="utf-8")
341        try:
342            for i in range(5):
343                f.write("%d\n" % i)
344        finally:
345            f.close()
346        f = open(TESTFN, "r", encoding="utf-8")
347        try:
348            self.check_for_loop(f, ["0\n", "1\n", "2\n", "3\n", "4\n"], pickle=False)
349            self.check_for_loop(f, [], pickle=False)
350        finally:
351            f.close()
352            try:
353                unlink(TESTFN)
354            except OSError:
355                pass
356
357    # Test list()'s use of iterators.
358    def test_builtin_list(self):
359        self.assertEqual(list(SequenceClass(5)), list(range(5)))
360        self.assertEqual(list(SequenceClass(0)), [])
361        self.assertEqual(list(()), [])
362
363        d = {"one": 1, "two": 2, "three": 3}
364        self.assertEqual(list(d), list(d.keys()))
365
366        self.assertRaises(TypeError, list, list)
367        self.assertRaises(TypeError, list, 42)
368
369        f = open(TESTFN, "w", encoding="utf-8")
370        try:
371            for i in range(5):
372                f.write("%d\n" % i)
373        finally:
374            f.close()
375        f = open(TESTFN, "r", encoding="utf-8")
376        try:
377            self.assertEqual(list(f), ["0\n", "1\n", "2\n", "3\n", "4\n"])
378            f.seek(0, 0)
379            self.assertEqual(list(f),
380                             ["0\n", "1\n", "2\n", "3\n", "4\n"])
381        finally:
382            f.close()
383            try:
384                unlink(TESTFN)
385            except OSError:
386                pass
387
388    # Test tuples()'s use of iterators.
389    def test_builtin_tuple(self):
390        self.assertEqual(tuple(SequenceClass(5)), (0, 1, 2, 3, 4))
391        self.assertEqual(tuple(SequenceClass(0)), ())
392        self.assertEqual(tuple([]), ())
393        self.assertEqual(tuple(()), ())
394        self.assertEqual(tuple("abc"), ("a", "b", "c"))
395
396        d = {"one": 1, "two": 2, "three": 3}
397        self.assertEqual(tuple(d), tuple(d.keys()))
398
399        self.assertRaises(TypeError, tuple, list)
400        self.assertRaises(TypeError, tuple, 42)
401
402        f = open(TESTFN, "w", encoding="utf-8")
403        try:
404            for i in range(5):
405                f.write("%d\n" % i)
406        finally:
407            f.close()
408        f = open(TESTFN, "r", encoding="utf-8")
409        try:
410            self.assertEqual(tuple(f), ("0\n", "1\n", "2\n", "3\n", "4\n"))
411            f.seek(0, 0)
412            self.assertEqual(tuple(f),
413                             ("0\n", "1\n", "2\n", "3\n", "4\n"))
414        finally:
415            f.close()
416            try:
417                unlink(TESTFN)
418            except OSError:
419                pass
420
421    # Test filter()'s use of iterators.
422    def test_builtin_filter(self):
423        self.assertEqual(list(filter(None, SequenceClass(5))),
424                         list(range(1, 5)))
425        self.assertEqual(list(filter(None, SequenceClass(0))), [])
426        self.assertEqual(list(filter(None, ())), [])
427        self.assertEqual(list(filter(None, "abc")), ["a", "b", "c"])
428
429        d = {"one": 1, "two": 2, "three": 3}
430        self.assertEqual(list(filter(None, d)), list(d.keys()))
431
432        self.assertRaises(TypeError, filter, None, list)
433        self.assertRaises(TypeError, filter, None, 42)
434
435        class Boolean:
436            def __init__(self, truth):
437                self.truth = truth
438            def __bool__(self):
439                return self.truth
440        bTrue = Boolean(True)
441        bFalse = Boolean(False)
442
443        class Seq:
444            def __init__(self, *args):
445                self.vals = args
446            def __iter__(self):
447                class SeqIter:
448                    def __init__(self, vals):
449                        self.vals = vals
450                        self.i = 0
451                    def __iter__(self):
452                        return self
453                    def __next__(self):
454                        i = self.i
455                        self.i = i + 1
456                        if i < len(self.vals):
457                            return self.vals[i]
458                        else:
459                            raise StopIteration
460                return SeqIter(self.vals)
461
462        seq = Seq(*([bTrue, bFalse] * 25))
463        self.assertEqual(list(filter(lambda x: not x, seq)), [bFalse]*25)
464        self.assertEqual(list(filter(lambda x: not x, iter(seq))), [bFalse]*25)
465
466    # Test max() and min()'s use of iterators.
467    def test_builtin_max_min(self):
468        self.assertEqual(max(SequenceClass(5)), 4)
469        self.assertEqual(min(SequenceClass(5)), 0)
470        self.assertEqual(max(8, -1), 8)
471        self.assertEqual(min(8, -1), -1)
472
473        d = {"one": 1, "two": 2, "three": 3}
474        self.assertEqual(max(d), "two")
475        self.assertEqual(min(d), "one")
476        self.assertEqual(max(d.values()), 3)
477        self.assertEqual(min(iter(d.values())), 1)
478
479        f = open(TESTFN, "w", encoding="utf-8")
480        try:
481            f.write("medium line\n")
482            f.write("xtra large line\n")
483            f.write("itty-bitty line\n")
484        finally:
485            f.close()
486        f = open(TESTFN, "r", encoding="utf-8")
487        try:
488            self.assertEqual(min(f), "itty-bitty line\n")
489            f.seek(0, 0)
490            self.assertEqual(max(f), "xtra large line\n")
491        finally:
492            f.close()
493            try:
494                unlink(TESTFN)
495            except OSError:
496                pass
497
498    # Test map()'s use of iterators.
499    def test_builtin_map(self):
500        self.assertEqual(list(map(lambda x: x+1, SequenceClass(5))),
501                         list(range(1, 6)))
502
503        d = {"one": 1, "two": 2, "three": 3}
504        self.assertEqual(list(map(lambda k, d=d: (k, d[k]), d)),
505                         list(d.items()))
506        dkeys = list(d.keys())
507        expected = [(i < len(d) and dkeys[i] or None,
508                     i,
509                     i < len(d) and dkeys[i] or None)
510                    for i in range(3)]
511
512        f = open(TESTFN, "w", encoding="utf-8")
513        try:
514            for i in range(10):
515                f.write("xy" * i + "\n") # line i has len 2*i+1
516        finally:
517            f.close()
518        f = open(TESTFN, "r", encoding="utf-8")
519        try:
520            self.assertEqual(list(map(len, f)), list(range(1, 21, 2)))
521        finally:
522            f.close()
523            try:
524                unlink(TESTFN)
525            except OSError:
526                pass
527
528    # Test zip()'s use of iterators.
529    def test_builtin_zip(self):
530        self.assertEqual(list(zip()), [])
531        self.assertEqual(list(zip(*[])), [])
532        self.assertEqual(list(zip(*[(1, 2), 'ab'])), [(1, 'a'), (2, 'b')])
533
534        self.assertRaises(TypeError, zip, None)
535        self.assertRaises(TypeError, zip, range(10), 42)
536        self.assertRaises(TypeError, zip, range(10), zip)
537
538        self.assertEqual(list(zip(IteratingSequenceClass(3))),
539                         [(0,), (1,), (2,)])
540        self.assertEqual(list(zip(SequenceClass(3))),
541                         [(0,), (1,), (2,)])
542
543        d = {"one": 1, "two": 2, "three": 3}
544        self.assertEqual(list(d.items()), list(zip(d, d.values())))
545
546        # Generate all ints starting at constructor arg.
547        class IntsFrom:
548            def __init__(self, start):
549                self.i = start
550
551            def __iter__(self):
552                return self
553
554            def __next__(self):
555                i = self.i
556                self.i = i+1
557                return i
558
559        f = open(TESTFN, "w", encoding="utf-8")
560        try:
561            f.write("a\n" "bbb\n" "cc\n")
562        finally:
563            f.close()
564        f = open(TESTFN, "r", encoding="utf-8")
565        try:
566            self.assertEqual(list(zip(IntsFrom(0), f, IntsFrom(-100))),
567                             [(0, "a\n", -100),
568                              (1, "bbb\n", -99),
569                              (2, "cc\n", -98)])
570        finally:
571            f.close()
572            try:
573                unlink(TESTFN)
574            except OSError:
575                pass
576
577        self.assertEqual(list(zip(range(5))), [(i,) for i in range(5)])
578
579        # Classes that lie about their lengths.
580        class NoGuessLen5:
581            def __getitem__(self, i):
582                if i >= 5:
583                    raise IndexError
584                return i
585
586        class Guess3Len5(NoGuessLen5):
587            def __len__(self):
588                return 3
589
590        class Guess30Len5(NoGuessLen5):
591            def __len__(self):
592                return 30
593
594        def lzip(*args):
595            return list(zip(*args))
596
597        self.assertEqual(len(Guess3Len5()), 3)
598        self.assertEqual(len(Guess30Len5()), 30)
599        self.assertEqual(lzip(NoGuessLen5()), lzip(range(5)))
600        self.assertEqual(lzip(Guess3Len5()), lzip(range(5)))
601        self.assertEqual(lzip(Guess30Len5()), lzip(range(5)))
602
603        expected = [(i, i) for i in range(5)]
604        for x in NoGuessLen5(), Guess3Len5(), Guess30Len5():
605            for y in NoGuessLen5(), Guess3Len5(), Guess30Len5():
606                self.assertEqual(lzip(x, y), expected)
607
608    def test_unicode_join_endcase(self):
609
610        # This class inserts a Unicode object into its argument's natural
611        # iteration, in the 3rd position.
612        class OhPhooey:
613            def __init__(self, seq):
614                self.it = iter(seq)
615                self.i = 0
616
617            def __iter__(self):
618                return self
619
620            def __next__(self):
621                i = self.i
622                self.i = i+1
623                if i == 2:
624                    return "fooled you!"
625                return next(self.it)
626
627        f = open(TESTFN, "w", encoding="utf-8")
628        try:
629            f.write("a\n" + "b\n" + "c\n")
630        finally:
631            f.close()
632
633        f = open(TESTFN, "r", encoding="utf-8")
634        # Nasty:  string.join(s) can't know whether unicode.join() is needed
635        # until it's seen all of s's elements.  But in this case, f's
636        # iterator cannot be restarted.  So what we're testing here is
637        # whether string.join() can manage to remember everything it's seen
638        # and pass that on to unicode.join().
639        try:
640            got = " - ".join(OhPhooey(f))
641            self.assertEqual(got, "a\n - b\n - fooled you! - c\n")
642        finally:
643            f.close()
644            try:
645                unlink(TESTFN)
646            except OSError:
647                pass
648
649    # Test iterators with 'x in y' and 'x not in y'.
650    def test_in_and_not_in(self):
651        for sc5 in IteratingSequenceClass(5), SequenceClass(5):
652            for i in range(5):
653                self.assertIn(i, sc5)
654            for i in "abc", -1, 5, 42.42, (3, 4), [], {1: 1}, 3-12j, sc5:
655                self.assertNotIn(i, sc5)
656
657        self.assertIn(ALWAYS_EQ, IteratorProxyClass(iter([1])))
658        self.assertIn(ALWAYS_EQ, SequenceProxyClass([1]))
659        self.assertNotIn(ALWAYS_EQ, IteratorProxyClass(iter([NEVER_EQ])))
660        self.assertNotIn(ALWAYS_EQ, SequenceProxyClass([NEVER_EQ]))
661        self.assertIn(NEVER_EQ, IteratorProxyClass(iter([ALWAYS_EQ])))
662        self.assertIn(NEVER_EQ, SequenceProxyClass([ALWAYS_EQ]))
663
664        self.assertRaises(TypeError, lambda: 3 in 12)
665        self.assertRaises(TypeError, lambda: 3 not in map)
666        self.assertRaises(ZeroDivisionError, lambda: 3 in BadIterableClass())
667
668        d = {"one": 1, "two": 2, "three": 3, 1j: 2j}
669        for k in d:
670            self.assertIn(k, d)
671            self.assertNotIn(k, d.values())
672        for v in d.values():
673            self.assertIn(v, d.values())
674            self.assertNotIn(v, d)
675        for k, v in d.items():
676            self.assertIn((k, v), d.items())
677            self.assertNotIn((v, k), d.items())
678
679        f = open(TESTFN, "w", encoding="utf-8")
680        try:
681            f.write("a\n" "b\n" "c\n")
682        finally:
683            f.close()
684        f = open(TESTFN, "r", encoding="utf-8")
685        try:
686            for chunk in "abc":
687                f.seek(0, 0)
688                self.assertNotIn(chunk, f)
689                f.seek(0, 0)
690                self.assertIn((chunk + "\n"), f)
691        finally:
692            f.close()
693            try:
694                unlink(TESTFN)
695            except OSError:
696                pass
697
698    # Test iterators with operator.countOf (PySequence_Count).
699    def test_countOf(self):
700        from operator import countOf
701        self.assertEqual(countOf([1,2,2,3,2,5], 2), 3)
702        self.assertEqual(countOf((1,2,2,3,2,5), 2), 3)
703        self.assertEqual(countOf("122325", "2"), 3)
704        self.assertEqual(countOf("122325", "6"), 0)
705
706        self.assertRaises(TypeError, countOf, 42, 1)
707        self.assertRaises(TypeError, countOf, countOf, countOf)
708
709        d = {"one": 3, "two": 3, "three": 3, 1j: 2j}
710        for k in d:
711            self.assertEqual(countOf(d, k), 1)
712        self.assertEqual(countOf(d.values(), 3), 3)
713        self.assertEqual(countOf(d.values(), 2j), 1)
714        self.assertEqual(countOf(d.values(), 1j), 0)
715
716        f = open(TESTFN, "w", encoding="utf-8")
717        try:
718            f.write("a\n" "b\n" "c\n" "b\n")
719        finally:
720            f.close()
721        f = open(TESTFN, "r", encoding="utf-8")
722        try:
723            for letter, count in ("a", 1), ("b", 2), ("c", 1), ("d", 0):
724                f.seek(0, 0)
725                self.assertEqual(countOf(f, letter + "\n"), count)
726        finally:
727            f.close()
728            try:
729                unlink(TESTFN)
730            except OSError:
731                pass
732
733    # Test iterators with operator.indexOf (PySequence_Index).
734    def test_indexOf(self):
735        from operator import indexOf
736        self.assertEqual(indexOf([1,2,2,3,2,5], 1), 0)
737        self.assertEqual(indexOf((1,2,2,3,2,5), 2), 1)
738        self.assertEqual(indexOf((1,2,2,3,2,5), 3), 3)
739        self.assertEqual(indexOf((1,2,2,3,2,5), 5), 5)
740        self.assertRaises(ValueError, indexOf, (1,2,2,3,2,5), 0)
741        self.assertRaises(ValueError, indexOf, (1,2,2,3,2,5), 6)
742
743        self.assertEqual(indexOf("122325", "2"), 1)
744        self.assertEqual(indexOf("122325", "5"), 5)
745        self.assertRaises(ValueError, indexOf, "122325", "6")
746
747        self.assertRaises(TypeError, indexOf, 42, 1)
748        self.assertRaises(TypeError, indexOf, indexOf, indexOf)
749        self.assertRaises(ZeroDivisionError, indexOf, BadIterableClass(), 1)
750
751        f = open(TESTFN, "w", encoding="utf-8")
752        try:
753            f.write("a\n" "b\n" "c\n" "d\n" "e\n")
754        finally:
755            f.close()
756        f = open(TESTFN, "r", encoding="utf-8")
757        try:
758            fiter = iter(f)
759            self.assertEqual(indexOf(fiter, "b\n"), 1)
760            self.assertEqual(indexOf(fiter, "d\n"), 1)
761            self.assertEqual(indexOf(fiter, "e\n"), 0)
762            self.assertRaises(ValueError, indexOf, fiter, "a\n")
763        finally:
764            f.close()
765            try:
766                unlink(TESTFN)
767            except OSError:
768                pass
769
770        iclass = IteratingSequenceClass(3)
771        for i in range(3):
772            self.assertEqual(indexOf(iclass, i), i)
773        self.assertRaises(ValueError, indexOf, iclass, -1)
774
775    # Test iterators with file.writelines().
776    def test_writelines(self):
777        f = open(TESTFN, "w", encoding="utf-8")
778
779        try:
780            self.assertRaises(TypeError, f.writelines, None)
781            self.assertRaises(TypeError, f.writelines, 42)
782
783            f.writelines(["1\n", "2\n"])
784            f.writelines(("3\n", "4\n"))
785            f.writelines({'5\n': None})
786            f.writelines({})
787
788            # Try a big chunk too.
789            class Iterator:
790                def __init__(self, start, finish):
791                    self.start = start
792                    self.finish = finish
793                    self.i = self.start
794
795                def __next__(self):
796                    if self.i >= self.finish:
797                        raise StopIteration
798                    result = str(self.i) + '\n'
799                    self.i += 1
800                    return result
801
802                def __iter__(self):
803                    return self
804
805            class Whatever:
806                def __init__(self, start, finish):
807                    self.start = start
808                    self.finish = finish
809
810                def __iter__(self):
811                    return Iterator(self.start, self.finish)
812
813            f.writelines(Whatever(6, 6+2000))
814            f.close()
815
816            f = open(TESTFN, encoding="utf-8")
817            expected = [str(i) + "\n" for i in range(1, 2006)]
818            self.assertEqual(list(f), expected)
819
820        finally:
821            f.close()
822            try:
823                unlink(TESTFN)
824            except OSError:
825                pass
826
827
828    # Test iterators on RHS of unpacking assignments.
829    def test_unpack_iter(self):
830        a, b = 1, 2
831        self.assertEqual((a, b), (1, 2))
832
833        a, b, c = IteratingSequenceClass(3)
834        self.assertEqual((a, b, c), (0, 1, 2))
835
836        try:    # too many values
837            a, b = IteratingSequenceClass(3)
838        except ValueError:
839            pass
840        else:
841            self.fail("should have raised ValueError")
842
843        try:    # not enough values
844            a, b, c = IteratingSequenceClass(2)
845        except ValueError:
846            pass
847        else:
848            self.fail("should have raised ValueError")
849
850        try:    # not iterable
851            a, b, c = len
852        except TypeError:
853            pass
854        else:
855            self.fail("should have raised TypeError")
856
857        a, b, c = {1: 42, 2: 42, 3: 42}.values()
858        self.assertEqual((a, b, c), (42, 42, 42))
859
860        f = open(TESTFN, "w", encoding="utf-8")
861        lines = ("a\n", "bb\n", "ccc\n")
862        try:
863            for line in lines:
864                f.write(line)
865        finally:
866            f.close()
867        f = open(TESTFN, "r", encoding="utf-8")
868        try:
869            a, b, c = f
870            self.assertEqual((a, b, c), lines)
871        finally:
872            f.close()
873            try:
874                unlink(TESTFN)
875            except OSError:
876                pass
877
878        (a, b), (c,) = IteratingSequenceClass(2), {42: 24}
879        self.assertEqual((a, b, c), (0, 1, 42))
880
881
882    @cpython_only
883    def test_ref_counting_behavior(self):
884        class C(object):
885            count = 0
886            def __new__(cls):
887                cls.count += 1
888                return object.__new__(cls)
889            def __del__(self):
890                cls = self.__class__
891                assert cls.count > 0
892                cls.count -= 1
893        x = C()
894        self.assertEqual(C.count, 1)
895        del x
896        self.assertEqual(C.count, 0)
897        l = [C(), C(), C()]
898        self.assertEqual(C.count, 3)
899        try:
900            a, b = iter(l)
901        except ValueError:
902            pass
903        del l
904        self.assertEqual(C.count, 0)
905
906
907    # Make sure StopIteration is a "sink state".
908    # This tests various things that weren't sink states in Python 2.2.1,
909    # plus various things that always were fine.
910
911    def test_sinkstate_list(self):
912        # This used to fail
913        a = list(range(5))
914        b = iter(a)
915        self.assertEqual(list(b), list(range(5)))
916        a.extend(range(5, 10))
917        self.assertEqual(list(b), [])
918
919    def test_sinkstate_tuple(self):
920        a = (0, 1, 2, 3, 4)
921        b = iter(a)
922        self.assertEqual(list(b), list(range(5)))
923        self.assertEqual(list(b), [])
924
925    def test_sinkstate_string(self):
926        a = "abcde"
927        b = iter(a)
928        self.assertEqual(list(b), ['a', 'b', 'c', 'd', 'e'])
929        self.assertEqual(list(b), [])
930
931    def test_sinkstate_sequence(self):
932        # This used to fail
933        a = SequenceClass(5)
934        b = iter(a)
935        self.assertEqual(list(b), list(range(5)))
936        a.n = 10
937        self.assertEqual(list(b), [])
938
939    def test_sinkstate_callable(self):
940        # This used to fail
941        def spam(state=[0]):
942            i = state[0]
943            state[0] = i+1
944            if i == 10:
945                raise AssertionError("shouldn't have gotten this far")
946            return i
947        b = iter(spam, 5)
948        self.assertEqual(list(b), list(range(5)))
949        self.assertEqual(list(b), [])
950
951    def test_sinkstate_dict(self):
952        # XXX For a more thorough test, see towards the end of:
953        # http://mail.python.org/pipermail/python-dev/2002-July/026512.html
954        a = {1:1, 2:2, 0:0, 4:4, 3:3}
955        for b in iter(a), a.keys(), a.items(), a.values():
956            b = iter(a)
957            self.assertEqual(len(list(b)), 5)
958            self.assertEqual(list(b), [])
959
960    def test_sinkstate_yield(self):
961        def gen():
962            for i in range(5):
963                yield i
964        b = gen()
965        self.assertEqual(list(b), list(range(5)))
966        self.assertEqual(list(b), [])
967
968    def test_sinkstate_range(self):
969        a = range(5)
970        b = iter(a)
971        self.assertEqual(list(b), list(range(5)))
972        self.assertEqual(list(b), [])
973
974    def test_sinkstate_enumerate(self):
975        a = range(5)
976        e = enumerate(a)
977        b = iter(e)
978        self.assertEqual(list(b), list(zip(range(5), range(5))))
979        self.assertEqual(list(b), [])
980
981    def test_3720(self):
982        # Avoid a crash, when an iterator deletes its next() method.
983        class BadIterator(object):
984            def __iter__(self):
985                return self
986            def __next__(self):
987                del BadIterator.__next__
988                return 1
989
990        try:
991            for i in BadIterator() :
992                pass
993        except TypeError:
994            pass
995
996    def test_extending_list_with_iterator_does_not_segfault(self):
997        # The code to extend a list with an iterator has a fair
998        # amount of nontrivial logic in terms of guessing how
999        # much memory to allocate in advance, "stealing" refs,
1000        # and then shrinking at the end.  This is a basic smoke
1001        # test for that scenario.
1002        def gen():
1003            for i in range(500):
1004                yield i
1005        lst = [0] * 500
1006        for i in range(240):
1007            lst.pop(0)
1008        lst.extend(gen())
1009        self.assertEqual(len(lst), 760)
1010
1011    @cpython_only
1012    def test_iter_overflow(self):
1013        # Test for the issue 22939
1014        it = iter(UnlimitedSequenceClass())
1015        # Manually set `it_index` to PY_SSIZE_T_MAX-2 without a loop
1016        it.__setstate__(sys.maxsize - 2)
1017        self.assertEqual(next(it), sys.maxsize - 2)
1018        self.assertEqual(next(it), sys.maxsize - 1)
1019        with self.assertRaises(OverflowError):
1020            next(it)
1021        # Check that Overflow error is always raised
1022        with self.assertRaises(OverflowError):
1023            next(it)
1024
1025    def test_iter_neg_setstate(self):
1026        it = iter(UnlimitedSequenceClass())
1027        it.__setstate__(-42)
1028        self.assertEqual(next(it), 0)
1029        self.assertEqual(next(it), 1)
1030
1031    def test_free_after_iterating(self):
1032        check_free_after_iterating(self, iter, SequenceClass, (0,))
1033
1034    def test_error_iter(self):
1035        for typ in (DefaultIterClass, NoIterClass):
1036            self.assertRaises(TypeError, iter, typ())
1037        self.assertRaises(ZeroDivisionError, iter, BadIterableClass())
1038
1039
1040if __name__ == "__main__":
1041    unittest.main()
1042