• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Test iterators.
2
3import sys
4import unittest
5from test.support import run_unittest, TESTFN, unlink, cpython_only
6from test.support import check_free_after_iterating
7import pickle
8import collections.abc
9
10# Test result of triple loop (too big to inline)
11TRIPLETS = [(0, 0, 0), (0, 0, 1), (0, 0, 2),
12            (0, 1, 0), (0, 1, 1), (0, 1, 2),
13            (0, 2, 0), (0, 2, 1), (0, 2, 2),
14
15            (1, 0, 0), (1, 0, 1), (1, 0, 2),
16            (1, 1, 0), (1, 1, 1), (1, 1, 2),
17            (1, 2, 0), (1, 2, 1), (1, 2, 2),
18
19            (2, 0, 0), (2, 0, 1), (2, 0, 2),
20            (2, 1, 0), (2, 1, 1), (2, 1, 2),
21            (2, 2, 0), (2, 2, 1), (2, 2, 2)]
22
23# Helper classes
24
25class BasicIterClass:
26    def __init__(self, n):
27        self.n = n
28        self.i = 0
29    def __next__(self):
30        res = self.i
31        if res >= self.n:
32            raise StopIteration
33        self.i = res + 1
34        return res
35    def __iter__(self):
36        return self
37
38class IteratingSequenceClass:
39    def __init__(self, n):
40        self.n = n
41    def __iter__(self):
42        return BasicIterClass(self.n)
43
44class SequenceClass:
45    def __init__(self, n):
46        self.n = n
47    def __getitem__(self, i):
48        if 0 <= i < self.n:
49            return i
50        else:
51            raise IndexError
52
53class UnlimitedSequenceClass:
54    def __getitem__(self, i):
55        return i
56
57class DefaultIterClass:
58    pass
59
60class NoIterClass:
61    def __getitem__(self, i):
62        return i
63    __iter__ = None
64
65# Main test suite
66
67class TestCase(unittest.TestCase):
68
69    # Helper to check that an iterator returns a given sequence
70    def check_iterator(self, it, seq, pickle=True):
71        if pickle:
72            self.check_pickle(it, seq)
73        res = []
74        while 1:
75            try:
76                val = next(it)
77            except StopIteration:
78                break
79            res.append(val)
80        self.assertEqual(res, seq)
81
82    # Helper to check that a for loop generates a given sequence
83    def check_for_loop(self, expr, seq, pickle=True):
84        if pickle:
85            self.check_pickle(iter(expr), seq)
86        res = []
87        for val in expr:
88            res.append(val)
89        self.assertEqual(res, seq)
90
91    # Helper to check picklability
92    def check_pickle(self, itorg, seq):
93        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
94            d = pickle.dumps(itorg, proto)
95            it = pickle.loads(d)
96            # Cannot assert type equality because dict iterators unpickle as list
97            # iterators.
98            # self.assertEqual(type(itorg), type(it))
99            self.assertTrue(isinstance(it, collections.abc.Iterator))
100            self.assertEqual(list(it), seq)
101
102            it = pickle.loads(d)
103            try:
104                next(it)
105            except StopIteration:
106                continue
107            d = pickle.dumps(it, proto)
108            it = pickle.loads(d)
109            self.assertEqual(list(it), seq[1:])
110
111    # Test basic use of iter() function
112    def test_iter_basic(self):
113        self.check_iterator(iter(range(10)), list(range(10)))
114
115    # Test that iter(iter(x)) is the same as iter(x)
116    def test_iter_idempotency(self):
117        seq = list(range(10))
118        it = iter(seq)
119        it2 = iter(it)
120        self.assertTrue(it is it2)
121
122    # Test that for loops over iterators work
123    def test_iter_for_loop(self):
124        self.check_for_loop(iter(range(10)), list(range(10)))
125
126    # Test several independent iterators over the same list
127    def test_iter_independence(self):
128        seq = range(3)
129        res = []
130        for i in iter(seq):
131            for j in iter(seq):
132                for k in iter(seq):
133                    res.append((i, j, k))
134        self.assertEqual(res, TRIPLETS)
135
136    # Test triple list comprehension using iterators
137    def test_nested_comprehensions_iter(self):
138        seq = range(3)
139        res = [(i, j, k)
140               for i in iter(seq) for j in iter(seq) for k in iter(seq)]
141        self.assertEqual(res, TRIPLETS)
142
143    # Test triple list comprehension without iterators
144    def test_nested_comprehensions_for(self):
145        seq = range(3)
146        res = [(i, j, k) for i in seq for j in seq for k in seq]
147        self.assertEqual(res, TRIPLETS)
148
149    # Test a class with __iter__ in a for loop
150    def test_iter_class_for(self):
151        self.check_for_loop(IteratingSequenceClass(10), list(range(10)))
152
153    # Test a class with __iter__ with explicit iter()
154    def test_iter_class_iter(self):
155        self.check_iterator(iter(IteratingSequenceClass(10)), list(range(10)))
156
157    # Test for loop on a sequence class without __iter__
158    def test_seq_class_for(self):
159        self.check_for_loop(SequenceClass(10), list(range(10)))
160
161    # Test iter() on a sequence class without __iter__
162    def test_seq_class_iter(self):
163        self.check_iterator(iter(SequenceClass(10)), list(range(10)))
164
165    def test_mutating_seq_class_iter_pickle(self):
166        orig = SequenceClass(5)
167        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
168            # initial iterator
169            itorig = iter(orig)
170            d = pickle.dumps((itorig, orig), proto)
171            it, seq = pickle.loads(d)
172            seq.n = 7
173            self.assertIs(type(it), type(itorig))
174            self.assertEqual(list(it), list(range(7)))
175
176            # running iterator
177            next(itorig)
178            d = pickle.dumps((itorig, orig), proto)
179            it, seq = pickle.loads(d)
180            seq.n = 7
181            self.assertIs(type(it), type(itorig))
182            self.assertEqual(list(it), list(range(1, 7)))
183
184            # empty iterator
185            for i in range(1, 5):
186                next(itorig)
187            d = pickle.dumps((itorig, orig), proto)
188            it, seq = pickle.loads(d)
189            seq.n = 7
190            self.assertIs(type(it), type(itorig))
191            self.assertEqual(list(it), list(range(5, 7)))
192
193            # exhausted iterator
194            self.assertRaises(StopIteration, next, itorig)
195            d = pickle.dumps((itorig, orig), proto)
196            it, seq = pickle.loads(d)
197            seq.n = 7
198            self.assertTrue(isinstance(it, collections.abc.Iterator))
199            self.assertEqual(list(it), [])
200
201    def test_mutating_seq_class_exhausted_iter(self):
202        a = SequenceClass(5)
203        exhit = iter(a)
204        empit = iter(a)
205        for x in exhit:  # exhaust the iterator
206            next(empit)  # not exhausted
207        a.n = 7
208        self.assertEqual(list(exhit), [])
209        self.assertEqual(list(empit), [5, 6])
210        self.assertEqual(list(a), [0, 1, 2, 3, 4, 5, 6])
211
212    # Test a new_style class with __iter__ but no next() method
213    def test_new_style_iter_class(self):
214        class IterClass(object):
215            def __iter__(self):
216                return self
217        self.assertRaises(TypeError, iter, IterClass())
218
219    # Test two-argument iter() with callable instance
220    def test_iter_callable(self):
221        class C:
222            def __init__(self):
223                self.i = 0
224            def __call__(self):
225                i = self.i
226                self.i = i + 1
227                if i > 100:
228                    raise IndexError # Emergency stop
229                return i
230        self.check_iterator(iter(C(), 10), list(range(10)), pickle=False)
231
232    # Test two-argument iter() with function
233    def test_iter_function(self):
234        def spam(state=[0]):
235            i = state[0]
236            state[0] = i+1
237            return i
238        self.check_iterator(iter(spam, 10), list(range(10)), pickle=False)
239
240    # Test two-argument iter() with function that raises StopIteration
241    def test_iter_function_stop(self):
242        def spam(state=[0]):
243            i = state[0]
244            if i == 10:
245                raise StopIteration
246            state[0] = i+1
247            return i
248        self.check_iterator(iter(spam, 20), list(range(10)), pickle=False)
249
250    # Test exception propagation through function iterator
251    def test_exception_function(self):
252        def spam(state=[0]):
253            i = state[0]
254            state[0] = i+1
255            if i == 10:
256                raise RuntimeError
257            return i
258        res = []
259        try:
260            for x in iter(spam, 20):
261                res.append(x)
262        except RuntimeError:
263            self.assertEqual(res, list(range(10)))
264        else:
265            self.fail("should have raised RuntimeError")
266
267    # Test exception propagation through sequence iterator
268    def test_exception_sequence(self):
269        class MySequenceClass(SequenceClass):
270            def __getitem__(self, i):
271                if i == 10:
272                    raise RuntimeError
273                return SequenceClass.__getitem__(self, i)
274        res = []
275        try:
276            for x in MySequenceClass(20):
277                res.append(x)
278        except RuntimeError:
279            self.assertEqual(res, list(range(10)))
280        else:
281            self.fail("should have raised RuntimeError")
282
283    # Test for StopIteration from __getitem__
284    def test_stop_sequence(self):
285        class MySequenceClass(SequenceClass):
286            def __getitem__(self, i):
287                if i == 10:
288                    raise StopIteration
289                return SequenceClass.__getitem__(self, i)
290        self.check_for_loop(MySequenceClass(20), list(range(10)), pickle=False)
291
292    # Test a big range
293    def test_iter_big_range(self):
294        self.check_for_loop(iter(range(10000)), list(range(10000)))
295
296    # Test an empty list
297    def test_iter_empty(self):
298        self.check_for_loop(iter([]), [])
299
300    # Test a tuple
301    def test_iter_tuple(self):
302        self.check_for_loop(iter((0,1,2,3,4,5,6,7,8,9)), list(range(10)))
303
304    # Test a range
305    def test_iter_range(self):
306        self.check_for_loop(iter(range(10)), list(range(10)))
307
308    # Test a string
309    def test_iter_string(self):
310        self.check_for_loop(iter("abcde"), ["a", "b", "c", "d", "e"])
311
312    # Test a directory
313    def test_iter_dict(self):
314        dict = {}
315        for i in range(10):
316            dict[i] = None
317        self.check_for_loop(dict, list(dict.keys()))
318
319    # Test a file
320    def test_iter_file(self):
321        f = open(TESTFN, "w")
322        try:
323            for i in range(5):
324                f.write("%d\n" % i)
325        finally:
326            f.close()
327        f = open(TESTFN, "r")
328        try:
329            self.check_for_loop(f, ["0\n", "1\n", "2\n", "3\n", "4\n"], pickle=False)
330            self.check_for_loop(f, [], pickle=False)
331        finally:
332            f.close()
333            try:
334                unlink(TESTFN)
335            except OSError:
336                pass
337
338    # Test list()'s use of iterators.
339    def test_builtin_list(self):
340        self.assertEqual(list(SequenceClass(5)), list(range(5)))
341        self.assertEqual(list(SequenceClass(0)), [])
342        self.assertEqual(list(()), [])
343
344        d = {"one": 1, "two": 2, "three": 3}
345        self.assertEqual(list(d), list(d.keys()))
346
347        self.assertRaises(TypeError, list, list)
348        self.assertRaises(TypeError, list, 42)
349
350        f = open(TESTFN, "w")
351        try:
352            for i in range(5):
353                f.write("%d\n" % i)
354        finally:
355            f.close()
356        f = open(TESTFN, "r")
357        try:
358            self.assertEqual(list(f), ["0\n", "1\n", "2\n", "3\n", "4\n"])
359            f.seek(0, 0)
360            self.assertEqual(list(f),
361                             ["0\n", "1\n", "2\n", "3\n", "4\n"])
362        finally:
363            f.close()
364            try:
365                unlink(TESTFN)
366            except OSError:
367                pass
368
369    # Test tuples()'s use of iterators.
370    def test_builtin_tuple(self):
371        self.assertEqual(tuple(SequenceClass(5)), (0, 1, 2, 3, 4))
372        self.assertEqual(tuple(SequenceClass(0)), ())
373        self.assertEqual(tuple([]), ())
374        self.assertEqual(tuple(()), ())
375        self.assertEqual(tuple("abc"), ("a", "b", "c"))
376
377        d = {"one": 1, "two": 2, "three": 3}
378        self.assertEqual(tuple(d), tuple(d.keys()))
379
380        self.assertRaises(TypeError, tuple, list)
381        self.assertRaises(TypeError, tuple, 42)
382
383        f = open(TESTFN, "w")
384        try:
385            for i in range(5):
386                f.write("%d\n" % i)
387        finally:
388            f.close()
389        f = open(TESTFN, "r")
390        try:
391            self.assertEqual(tuple(f), ("0\n", "1\n", "2\n", "3\n", "4\n"))
392            f.seek(0, 0)
393            self.assertEqual(tuple(f),
394                             ("0\n", "1\n", "2\n", "3\n", "4\n"))
395        finally:
396            f.close()
397            try:
398                unlink(TESTFN)
399            except OSError:
400                pass
401
402    # Test filter()'s use of iterators.
403    def test_builtin_filter(self):
404        self.assertEqual(list(filter(None, SequenceClass(5))),
405                         list(range(1, 5)))
406        self.assertEqual(list(filter(None, SequenceClass(0))), [])
407        self.assertEqual(list(filter(None, ())), [])
408        self.assertEqual(list(filter(None, "abc")), ["a", "b", "c"])
409
410        d = {"one": 1, "two": 2, "three": 3}
411        self.assertEqual(list(filter(None, d)), list(d.keys()))
412
413        self.assertRaises(TypeError, filter, None, list)
414        self.assertRaises(TypeError, filter, None, 42)
415
416        class Boolean:
417            def __init__(self, truth):
418                self.truth = truth
419            def __bool__(self):
420                return self.truth
421        bTrue = Boolean(True)
422        bFalse = Boolean(False)
423
424        class Seq:
425            def __init__(self, *args):
426                self.vals = args
427            def __iter__(self):
428                class SeqIter:
429                    def __init__(self, vals):
430                        self.vals = vals
431                        self.i = 0
432                    def __iter__(self):
433                        return self
434                    def __next__(self):
435                        i = self.i
436                        self.i = i + 1
437                        if i < len(self.vals):
438                            return self.vals[i]
439                        else:
440                            raise StopIteration
441                return SeqIter(self.vals)
442
443        seq = Seq(*([bTrue, bFalse] * 25))
444        self.assertEqual(list(filter(lambda x: not x, seq)), [bFalse]*25)
445        self.assertEqual(list(filter(lambda x: not x, iter(seq))), [bFalse]*25)
446
447    # Test max() and min()'s use of iterators.
448    def test_builtin_max_min(self):
449        self.assertEqual(max(SequenceClass(5)), 4)
450        self.assertEqual(min(SequenceClass(5)), 0)
451        self.assertEqual(max(8, -1), 8)
452        self.assertEqual(min(8, -1), -1)
453
454        d = {"one": 1, "two": 2, "three": 3}
455        self.assertEqual(max(d), "two")
456        self.assertEqual(min(d), "one")
457        self.assertEqual(max(d.values()), 3)
458        self.assertEqual(min(iter(d.values())), 1)
459
460        f = open(TESTFN, "w")
461        try:
462            f.write("medium line\n")
463            f.write("xtra large line\n")
464            f.write("itty-bitty line\n")
465        finally:
466            f.close()
467        f = open(TESTFN, "r")
468        try:
469            self.assertEqual(min(f), "itty-bitty line\n")
470            f.seek(0, 0)
471            self.assertEqual(max(f), "xtra large line\n")
472        finally:
473            f.close()
474            try:
475                unlink(TESTFN)
476            except OSError:
477                pass
478
479    # Test map()'s use of iterators.
480    def test_builtin_map(self):
481        self.assertEqual(list(map(lambda x: x+1, SequenceClass(5))),
482                         list(range(1, 6)))
483
484        d = {"one": 1, "two": 2, "three": 3}
485        self.assertEqual(list(map(lambda k, d=d: (k, d[k]), d)),
486                         list(d.items()))
487        dkeys = list(d.keys())
488        expected = [(i < len(d) and dkeys[i] or None,
489                     i,
490                     i < len(d) and dkeys[i] or None)
491                    for i in range(3)]
492
493        f = open(TESTFN, "w")
494        try:
495            for i in range(10):
496                f.write("xy" * i + "\n") # line i has len 2*i+1
497        finally:
498            f.close()
499        f = open(TESTFN, "r")
500        try:
501            self.assertEqual(list(map(len, f)), list(range(1, 21, 2)))
502        finally:
503            f.close()
504            try:
505                unlink(TESTFN)
506            except OSError:
507                pass
508
509    # Test zip()'s use of iterators.
510    def test_builtin_zip(self):
511        self.assertEqual(list(zip()), [])
512        self.assertEqual(list(zip(*[])), [])
513        self.assertEqual(list(zip(*[(1, 2), 'ab'])), [(1, 'a'), (2, 'b')])
514
515        self.assertRaises(TypeError, zip, None)
516        self.assertRaises(TypeError, zip, range(10), 42)
517        self.assertRaises(TypeError, zip, range(10), zip)
518
519        self.assertEqual(list(zip(IteratingSequenceClass(3))),
520                         [(0,), (1,), (2,)])
521        self.assertEqual(list(zip(SequenceClass(3))),
522                         [(0,), (1,), (2,)])
523
524        d = {"one": 1, "two": 2, "three": 3}
525        self.assertEqual(list(d.items()), list(zip(d, d.values())))
526
527        # Generate all ints starting at constructor arg.
528        class IntsFrom:
529            def __init__(self, start):
530                self.i = start
531
532            def __iter__(self):
533                return self
534
535            def __next__(self):
536                i = self.i
537                self.i = i+1
538                return i
539
540        f = open(TESTFN, "w")
541        try:
542            f.write("a\n" "bbb\n" "cc\n")
543        finally:
544            f.close()
545        f = open(TESTFN, "r")
546        try:
547            self.assertEqual(list(zip(IntsFrom(0), f, IntsFrom(-100))),
548                             [(0, "a\n", -100),
549                              (1, "bbb\n", -99),
550                              (2, "cc\n", -98)])
551        finally:
552            f.close()
553            try:
554                unlink(TESTFN)
555            except OSError:
556                pass
557
558        self.assertEqual(list(zip(range(5))), [(i,) for i in range(5)])
559
560        # Classes that lie about their lengths.
561        class NoGuessLen5:
562            def __getitem__(self, i):
563                if i >= 5:
564                    raise IndexError
565                return i
566
567        class Guess3Len5(NoGuessLen5):
568            def __len__(self):
569                return 3
570
571        class Guess30Len5(NoGuessLen5):
572            def __len__(self):
573                return 30
574
575        def lzip(*args):
576            return list(zip(*args))
577
578        self.assertEqual(len(Guess3Len5()), 3)
579        self.assertEqual(len(Guess30Len5()), 30)
580        self.assertEqual(lzip(NoGuessLen5()), lzip(range(5)))
581        self.assertEqual(lzip(Guess3Len5()), lzip(range(5)))
582        self.assertEqual(lzip(Guess30Len5()), lzip(range(5)))
583
584        expected = [(i, i) for i in range(5)]
585        for x in NoGuessLen5(), Guess3Len5(), Guess30Len5():
586            for y in NoGuessLen5(), Guess3Len5(), Guess30Len5():
587                self.assertEqual(lzip(x, y), expected)
588
589    def test_unicode_join_endcase(self):
590
591        # This class inserts a Unicode object into its argument's natural
592        # iteration, in the 3rd position.
593        class OhPhooey:
594            def __init__(self, seq):
595                self.it = iter(seq)
596                self.i = 0
597
598            def __iter__(self):
599                return self
600
601            def __next__(self):
602                i = self.i
603                self.i = i+1
604                if i == 2:
605                    return "fooled you!"
606                return next(self.it)
607
608        f = open(TESTFN, "w")
609        try:
610            f.write("a\n" + "b\n" + "c\n")
611        finally:
612            f.close()
613
614        f = open(TESTFN, "r")
615        # Nasty:  string.join(s) can't know whether unicode.join() is needed
616        # until it's seen all of s's elements.  But in this case, f's
617        # iterator cannot be restarted.  So what we're testing here is
618        # whether string.join() can manage to remember everything it's seen
619        # and pass that on to unicode.join().
620        try:
621            got = " - ".join(OhPhooey(f))
622            self.assertEqual(got, "a\n - b\n - fooled you! - c\n")
623        finally:
624            f.close()
625            try:
626                unlink(TESTFN)
627            except OSError:
628                pass
629
630    # Test iterators with 'x in y' and 'x not in y'.
631    def test_in_and_not_in(self):
632        for sc5 in IteratingSequenceClass(5), SequenceClass(5):
633            for i in range(5):
634                self.assertIn(i, sc5)
635            for i in "abc", -1, 5, 42.42, (3, 4), [], {1: 1}, 3-12j, sc5:
636                self.assertNotIn(i, sc5)
637
638        self.assertRaises(TypeError, lambda: 3 in 12)
639        self.assertRaises(TypeError, lambda: 3 not in map)
640
641        d = {"one": 1, "two": 2, "three": 3, 1j: 2j}
642        for k in d:
643            self.assertIn(k, d)
644            self.assertNotIn(k, d.values())
645        for v in d.values():
646            self.assertIn(v, d.values())
647            self.assertNotIn(v, d)
648        for k, v in d.items():
649            self.assertIn((k, v), d.items())
650            self.assertNotIn((v, k), d.items())
651
652        f = open(TESTFN, "w")
653        try:
654            f.write("a\n" "b\n" "c\n")
655        finally:
656            f.close()
657        f = open(TESTFN, "r")
658        try:
659            for chunk in "abc":
660                f.seek(0, 0)
661                self.assertNotIn(chunk, f)
662                f.seek(0, 0)
663                self.assertIn((chunk + "\n"), f)
664        finally:
665            f.close()
666            try:
667                unlink(TESTFN)
668            except OSError:
669                pass
670
671    # Test iterators with operator.countOf (PySequence_Count).
672    def test_countOf(self):
673        from operator import countOf
674        self.assertEqual(countOf([1,2,2,3,2,5], 2), 3)
675        self.assertEqual(countOf((1,2,2,3,2,5), 2), 3)
676        self.assertEqual(countOf("122325", "2"), 3)
677        self.assertEqual(countOf("122325", "6"), 0)
678
679        self.assertRaises(TypeError, countOf, 42, 1)
680        self.assertRaises(TypeError, countOf, countOf, countOf)
681
682        d = {"one": 3, "two": 3, "three": 3, 1j: 2j}
683        for k in d:
684            self.assertEqual(countOf(d, k), 1)
685        self.assertEqual(countOf(d.values(), 3), 3)
686        self.assertEqual(countOf(d.values(), 2j), 1)
687        self.assertEqual(countOf(d.values(), 1j), 0)
688
689        f = open(TESTFN, "w")
690        try:
691            f.write("a\n" "b\n" "c\n" "b\n")
692        finally:
693            f.close()
694        f = open(TESTFN, "r")
695        try:
696            for letter, count in ("a", 1), ("b", 2), ("c", 1), ("d", 0):
697                f.seek(0, 0)
698                self.assertEqual(countOf(f, letter + "\n"), count)
699        finally:
700            f.close()
701            try:
702                unlink(TESTFN)
703            except OSError:
704                pass
705
706    # Test iterators with operator.indexOf (PySequence_Index).
707    def test_indexOf(self):
708        from operator import indexOf
709        self.assertEqual(indexOf([1,2,2,3,2,5], 1), 0)
710        self.assertEqual(indexOf((1,2,2,3,2,5), 2), 1)
711        self.assertEqual(indexOf((1,2,2,3,2,5), 3), 3)
712        self.assertEqual(indexOf((1,2,2,3,2,5), 5), 5)
713        self.assertRaises(ValueError, indexOf, (1,2,2,3,2,5), 0)
714        self.assertRaises(ValueError, indexOf, (1,2,2,3,2,5), 6)
715
716        self.assertEqual(indexOf("122325", "2"), 1)
717        self.assertEqual(indexOf("122325", "5"), 5)
718        self.assertRaises(ValueError, indexOf, "122325", "6")
719
720        self.assertRaises(TypeError, indexOf, 42, 1)
721        self.assertRaises(TypeError, indexOf, indexOf, indexOf)
722
723        f = open(TESTFN, "w")
724        try:
725            f.write("a\n" "b\n" "c\n" "d\n" "e\n")
726        finally:
727            f.close()
728        f = open(TESTFN, "r")
729        try:
730            fiter = iter(f)
731            self.assertEqual(indexOf(fiter, "b\n"), 1)
732            self.assertEqual(indexOf(fiter, "d\n"), 1)
733            self.assertEqual(indexOf(fiter, "e\n"), 0)
734            self.assertRaises(ValueError, indexOf, fiter, "a\n")
735        finally:
736            f.close()
737            try:
738                unlink(TESTFN)
739            except OSError:
740                pass
741
742        iclass = IteratingSequenceClass(3)
743        for i in range(3):
744            self.assertEqual(indexOf(iclass, i), i)
745        self.assertRaises(ValueError, indexOf, iclass, -1)
746
747    # Test iterators with file.writelines().
748    def test_writelines(self):
749        f = open(TESTFN, "w")
750
751        try:
752            self.assertRaises(TypeError, f.writelines, None)
753            self.assertRaises(TypeError, f.writelines, 42)
754
755            f.writelines(["1\n", "2\n"])
756            f.writelines(("3\n", "4\n"))
757            f.writelines({'5\n': None})
758            f.writelines({})
759
760            # Try a big chunk too.
761            class Iterator:
762                def __init__(self, start, finish):
763                    self.start = start
764                    self.finish = finish
765                    self.i = self.start
766
767                def __next__(self):
768                    if self.i >= self.finish:
769                        raise StopIteration
770                    result = str(self.i) + '\n'
771                    self.i += 1
772                    return result
773
774                def __iter__(self):
775                    return self
776
777            class Whatever:
778                def __init__(self, start, finish):
779                    self.start = start
780                    self.finish = finish
781
782                def __iter__(self):
783                    return Iterator(self.start, self.finish)
784
785            f.writelines(Whatever(6, 6+2000))
786            f.close()
787
788            f = open(TESTFN)
789            expected = [str(i) + "\n" for i in range(1, 2006)]
790            self.assertEqual(list(f), expected)
791
792        finally:
793            f.close()
794            try:
795                unlink(TESTFN)
796            except OSError:
797                pass
798
799
800    # Test iterators on RHS of unpacking assignments.
801    def test_unpack_iter(self):
802        a, b = 1, 2
803        self.assertEqual((a, b), (1, 2))
804
805        a, b, c = IteratingSequenceClass(3)
806        self.assertEqual((a, b, c), (0, 1, 2))
807
808        try:    # too many values
809            a, b = IteratingSequenceClass(3)
810        except ValueError:
811            pass
812        else:
813            self.fail("should have raised ValueError")
814
815        try:    # not enough values
816            a, b, c = IteratingSequenceClass(2)
817        except ValueError:
818            pass
819        else:
820            self.fail("should have raised ValueError")
821
822        try:    # not iterable
823            a, b, c = len
824        except TypeError:
825            pass
826        else:
827            self.fail("should have raised TypeError")
828
829        a, b, c = {1: 42, 2: 42, 3: 42}.values()
830        self.assertEqual((a, b, c), (42, 42, 42))
831
832        f = open(TESTFN, "w")
833        lines = ("a\n", "bb\n", "ccc\n")
834        try:
835            for line in lines:
836                f.write(line)
837        finally:
838            f.close()
839        f = open(TESTFN, "r")
840        try:
841            a, b, c = f
842            self.assertEqual((a, b, c), lines)
843        finally:
844            f.close()
845            try:
846                unlink(TESTFN)
847            except OSError:
848                pass
849
850        (a, b), (c,) = IteratingSequenceClass(2), {42: 24}
851        self.assertEqual((a, b, c), (0, 1, 42))
852
853
854    @cpython_only
855    def test_ref_counting_behavior(self):
856        class C(object):
857            count = 0
858            def __new__(cls):
859                cls.count += 1
860                return object.__new__(cls)
861            def __del__(self):
862                cls = self.__class__
863                assert cls.count > 0
864                cls.count -= 1
865        x = C()
866        self.assertEqual(C.count, 1)
867        del x
868        self.assertEqual(C.count, 0)
869        l = [C(), C(), C()]
870        self.assertEqual(C.count, 3)
871        try:
872            a, b = iter(l)
873        except ValueError:
874            pass
875        del l
876        self.assertEqual(C.count, 0)
877
878
879    # Make sure StopIteration is a "sink state".
880    # This tests various things that weren't sink states in Python 2.2.1,
881    # plus various things that always were fine.
882
883    def test_sinkstate_list(self):
884        # This used to fail
885        a = list(range(5))
886        b = iter(a)
887        self.assertEqual(list(b), list(range(5)))
888        a.extend(range(5, 10))
889        self.assertEqual(list(b), [])
890
891    def test_sinkstate_tuple(self):
892        a = (0, 1, 2, 3, 4)
893        b = iter(a)
894        self.assertEqual(list(b), list(range(5)))
895        self.assertEqual(list(b), [])
896
897    def test_sinkstate_string(self):
898        a = "abcde"
899        b = iter(a)
900        self.assertEqual(list(b), ['a', 'b', 'c', 'd', 'e'])
901        self.assertEqual(list(b), [])
902
903    def test_sinkstate_sequence(self):
904        # This used to fail
905        a = SequenceClass(5)
906        b = iter(a)
907        self.assertEqual(list(b), list(range(5)))
908        a.n = 10
909        self.assertEqual(list(b), [])
910
911    def test_sinkstate_callable(self):
912        # This used to fail
913        def spam(state=[0]):
914            i = state[0]
915            state[0] = i+1
916            if i == 10:
917                raise AssertionError("shouldn't have gotten this far")
918            return i
919        b = iter(spam, 5)
920        self.assertEqual(list(b), list(range(5)))
921        self.assertEqual(list(b), [])
922
923    def test_sinkstate_dict(self):
924        # XXX For a more thorough test, see towards the end of:
925        # http://mail.python.org/pipermail/python-dev/2002-July/026512.html
926        a = {1:1, 2:2, 0:0, 4:4, 3:3}
927        for b in iter(a), a.keys(), a.items(), a.values():
928            b = iter(a)
929            self.assertEqual(len(list(b)), 5)
930            self.assertEqual(list(b), [])
931
932    def test_sinkstate_yield(self):
933        def gen():
934            for i in range(5):
935                yield i
936        b = gen()
937        self.assertEqual(list(b), list(range(5)))
938        self.assertEqual(list(b), [])
939
940    def test_sinkstate_range(self):
941        a = range(5)
942        b = iter(a)
943        self.assertEqual(list(b), list(range(5)))
944        self.assertEqual(list(b), [])
945
946    def test_sinkstate_enumerate(self):
947        a = range(5)
948        e = enumerate(a)
949        b = iter(e)
950        self.assertEqual(list(b), list(zip(range(5), range(5))))
951        self.assertEqual(list(b), [])
952
953    def test_3720(self):
954        # Avoid a crash, when an iterator deletes its next() method.
955        class BadIterator(object):
956            def __iter__(self):
957                return self
958            def __next__(self):
959                del BadIterator.__next__
960                return 1
961
962        try:
963            for i in BadIterator() :
964                pass
965        except TypeError:
966            pass
967
968    def test_extending_list_with_iterator_does_not_segfault(self):
969        # The code to extend a list with an iterator has a fair
970        # amount of nontrivial logic in terms of guessing how
971        # much memory to allocate in advance, "stealing" refs,
972        # and then shrinking at the end.  This is a basic smoke
973        # test for that scenario.
974        def gen():
975            for i in range(500):
976                yield i
977        lst = [0] * 500
978        for i in range(240):
979            lst.pop(0)
980        lst.extend(gen())
981        self.assertEqual(len(lst), 760)
982
983    @cpython_only
984    def test_iter_overflow(self):
985        # Test for the issue 22939
986        it = iter(UnlimitedSequenceClass())
987        # Manually set `it_index` to PY_SSIZE_T_MAX-2 without a loop
988        it.__setstate__(sys.maxsize - 2)
989        self.assertEqual(next(it), sys.maxsize - 2)
990        self.assertEqual(next(it), sys.maxsize - 1)
991        with self.assertRaises(OverflowError):
992            next(it)
993        # Check that Overflow error is always raised
994        with self.assertRaises(OverflowError):
995            next(it)
996
997    def test_iter_neg_setstate(self):
998        it = iter(UnlimitedSequenceClass())
999        it.__setstate__(-42)
1000        self.assertEqual(next(it), 0)
1001        self.assertEqual(next(it), 1)
1002
1003    def test_free_after_iterating(self):
1004        check_free_after_iterating(self, iter, SequenceClass, (0,))
1005
1006    def test_error_iter(self):
1007        for typ in (DefaultIterClass, NoIterClass):
1008            self.assertRaises(TypeError, iter, typ())
1009
1010
1011def test_main():
1012    run_unittest(TestCase)
1013
1014
1015if __name__ == "__main__":
1016    test_main()
1017