• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Test iterators.
2
3import sys
4import unittest
5from test.support import run_unittest, TESTFN, unlink, cpython_only
6from test.support import check_free_after_iterating
7import pickle
8import collections.abc
9
10# Test result of triple loop (too big to inline)
11TRIPLETS = [(0, 0, 0), (0, 0, 1), (0, 0, 2),
12            (0, 1, 0), (0, 1, 1), (0, 1, 2),
13            (0, 2, 0), (0, 2, 1), (0, 2, 2),
14
15            (1, 0, 0), (1, 0, 1), (1, 0, 2),
16            (1, 1, 0), (1, 1, 1), (1, 1, 2),
17            (1, 2, 0), (1, 2, 1), (1, 2, 2),
18
19            (2, 0, 0), (2, 0, 1), (2, 0, 2),
20            (2, 1, 0), (2, 1, 1), (2, 1, 2),
21            (2, 2, 0), (2, 2, 1), (2, 2, 2)]
22
23# Helper classes
24
25class BasicIterClass:
26    def __init__(self, n):
27        self.n = n
28        self.i = 0
29    def __next__(self):
30        res = self.i
31        if res >= self.n:
32            raise StopIteration
33        self.i = res + 1
34        return res
35    def __iter__(self):
36        return self
37
38class IteratingSequenceClass:
39    def __init__(self, n):
40        self.n = n
41    def __iter__(self):
42        return BasicIterClass(self.n)
43
44class SequenceClass:
45    def __init__(self, n):
46        self.n = n
47    def __getitem__(self, i):
48        if 0 <= i < self.n:
49            return i
50        else:
51            raise IndexError
52
53class UnlimitedSequenceClass:
54    def __getitem__(self, i):
55        return i
56
57class DefaultIterClass:
58    pass
59
60class NoIterClass:
61    def __getitem__(self, i):
62        return i
63    __iter__ = None
64
65class BadIterableClass:
66    def __iter__(self):
67        raise ZeroDivisionError
68
69# Main test suite
70
71class TestCase(unittest.TestCase):
72
73    # Helper to check that an iterator returns a given sequence
74    def check_iterator(self, it, seq, pickle=True):
75        if pickle:
76            self.check_pickle(it, seq)
77        res = []
78        while 1:
79            try:
80                val = next(it)
81            except StopIteration:
82                break
83            res.append(val)
84        self.assertEqual(res, seq)
85
86    # Helper to check that a for loop generates a given sequence
87    def check_for_loop(self, expr, seq, pickle=True):
88        if pickle:
89            self.check_pickle(iter(expr), seq)
90        res = []
91        for val in expr:
92            res.append(val)
93        self.assertEqual(res, seq)
94
95    # Helper to check picklability
96    def check_pickle(self, itorg, seq):
97        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
98            d = pickle.dumps(itorg, proto)
99            it = pickle.loads(d)
100            # Cannot assert type equality because dict iterators unpickle as list
101            # iterators.
102            # self.assertEqual(type(itorg), type(it))
103            self.assertTrue(isinstance(it, collections.abc.Iterator))
104            self.assertEqual(list(it), seq)
105
106            it = pickle.loads(d)
107            try:
108                next(it)
109            except StopIteration:
110                continue
111            d = pickle.dumps(it, proto)
112            it = pickle.loads(d)
113            self.assertEqual(list(it), seq[1:])
114
115    # Test basic use of iter() function
116    def test_iter_basic(self):
117        self.check_iterator(iter(range(10)), list(range(10)))
118
119    # Test that iter(iter(x)) is the same as iter(x)
120    def test_iter_idempotency(self):
121        seq = list(range(10))
122        it = iter(seq)
123        it2 = iter(it)
124        self.assertTrue(it is it2)
125
126    # Test that for loops over iterators work
127    def test_iter_for_loop(self):
128        self.check_for_loop(iter(range(10)), list(range(10)))
129
130    # Test several independent iterators over the same list
131    def test_iter_independence(self):
132        seq = range(3)
133        res = []
134        for i in iter(seq):
135            for j in iter(seq):
136                for k in iter(seq):
137                    res.append((i, j, k))
138        self.assertEqual(res, TRIPLETS)
139
140    # Test triple list comprehension using iterators
141    def test_nested_comprehensions_iter(self):
142        seq = range(3)
143        res = [(i, j, k)
144               for i in iter(seq) for j in iter(seq) for k in iter(seq)]
145        self.assertEqual(res, TRIPLETS)
146
147    # Test triple list comprehension without iterators
148    def test_nested_comprehensions_for(self):
149        seq = range(3)
150        res = [(i, j, k) for i in seq for j in seq for k in seq]
151        self.assertEqual(res, TRIPLETS)
152
153    # Test a class with __iter__ in a for loop
154    def test_iter_class_for(self):
155        self.check_for_loop(IteratingSequenceClass(10), list(range(10)))
156
157    # Test a class with __iter__ with explicit iter()
158    def test_iter_class_iter(self):
159        self.check_iterator(iter(IteratingSequenceClass(10)), list(range(10)))
160
161    # Test for loop on a sequence class without __iter__
162    def test_seq_class_for(self):
163        self.check_for_loop(SequenceClass(10), list(range(10)))
164
165    # Test iter() on a sequence class without __iter__
166    def test_seq_class_iter(self):
167        self.check_iterator(iter(SequenceClass(10)), list(range(10)))
168
169    def test_mutating_seq_class_iter_pickle(self):
170        orig = SequenceClass(5)
171        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
172            # initial iterator
173            itorig = iter(orig)
174            d = pickle.dumps((itorig, orig), proto)
175            it, seq = pickle.loads(d)
176            seq.n = 7
177            self.assertIs(type(it), type(itorig))
178            self.assertEqual(list(it), list(range(7)))
179
180            # running iterator
181            next(itorig)
182            d = pickle.dumps((itorig, orig), proto)
183            it, seq = pickle.loads(d)
184            seq.n = 7
185            self.assertIs(type(it), type(itorig))
186            self.assertEqual(list(it), list(range(1, 7)))
187
188            # empty iterator
189            for i in range(1, 5):
190                next(itorig)
191            d = pickle.dumps((itorig, orig), proto)
192            it, seq = pickle.loads(d)
193            seq.n = 7
194            self.assertIs(type(it), type(itorig))
195            self.assertEqual(list(it), list(range(5, 7)))
196
197            # exhausted iterator
198            self.assertRaises(StopIteration, next, itorig)
199            d = pickle.dumps((itorig, orig), proto)
200            it, seq = pickle.loads(d)
201            seq.n = 7
202            self.assertTrue(isinstance(it, collections.abc.Iterator))
203            self.assertEqual(list(it), [])
204
205    def test_mutating_seq_class_exhausted_iter(self):
206        a = SequenceClass(5)
207        exhit = iter(a)
208        empit = iter(a)
209        for x in exhit:  # exhaust the iterator
210            next(empit)  # not exhausted
211        a.n = 7
212        self.assertEqual(list(exhit), [])
213        self.assertEqual(list(empit), [5, 6])
214        self.assertEqual(list(a), [0, 1, 2, 3, 4, 5, 6])
215
216    # Test a new_style class with __iter__ but no next() method
217    def test_new_style_iter_class(self):
218        class IterClass(object):
219            def __iter__(self):
220                return self
221        self.assertRaises(TypeError, iter, IterClass())
222
223    # Test two-argument iter() with callable instance
224    def test_iter_callable(self):
225        class C:
226            def __init__(self):
227                self.i = 0
228            def __call__(self):
229                i = self.i
230                self.i = i + 1
231                if i > 100:
232                    raise IndexError # Emergency stop
233                return i
234        self.check_iterator(iter(C(), 10), list(range(10)), pickle=False)
235
236    # Test two-argument iter() with function
237    def test_iter_function(self):
238        def spam(state=[0]):
239            i = state[0]
240            state[0] = i+1
241            return i
242        self.check_iterator(iter(spam, 10), list(range(10)), pickle=False)
243
244    # Test two-argument iter() with function that raises StopIteration
245    def test_iter_function_stop(self):
246        def spam(state=[0]):
247            i = state[0]
248            if i == 10:
249                raise StopIteration
250            state[0] = i+1
251            return i
252        self.check_iterator(iter(spam, 20), list(range(10)), pickle=False)
253
254    # Test exception propagation through function iterator
255    def test_exception_function(self):
256        def spam(state=[0]):
257            i = state[0]
258            state[0] = i+1
259            if i == 10:
260                raise RuntimeError
261            return i
262        res = []
263        try:
264            for x in iter(spam, 20):
265                res.append(x)
266        except RuntimeError:
267            self.assertEqual(res, list(range(10)))
268        else:
269            self.fail("should have raised RuntimeError")
270
271    # Test exception propagation through sequence iterator
272    def test_exception_sequence(self):
273        class MySequenceClass(SequenceClass):
274            def __getitem__(self, i):
275                if i == 10:
276                    raise RuntimeError
277                return SequenceClass.__getitem__(self, i)
278        res = []
279        try:
280            for x in MySequenceClass(20):
281                res.append(x)
282        except RuntimeError:
283            self.assertEqual(res, list(range(10)))
284        else:
285            self.fail("should have raised RuntimeError")
286
287    # Test for StopIteration from __getitem__
288    def test_stop_sequence(self):
289        class MySequenceClass(SequenceClass):
290            def __getitem__(self, i):
291                if i == 10:
292                    raise StopIteration
293                return SequenceClass.__getitem__(self, i)
294        self.check_for_loop(MySequenceClass(20), list(range(10)), pickle=False)
295
296    # Test a big range
297    def test_iter_big_range(self):
298        self.check_for_loop(iter(range(10000)), list(range(10000)))
299
300    # Test an empty list
301    def test_iter_empty(self):
302        self.check_for_loop(iter([]), [])
303
304    # Test a tuple
305    def test_iter_tuple(self):
306        self.check_for_loop(iter((0,1,2,3,4,5,6,7,8,9)), list(range(10)))
307
308    # Test a range
309    def test_iter_range(self):
310        self.check_for_loop(iter(range(10)), list(range(10)))
311
312    # Test a string
313    def test_iter_string(self):
314        self.check_for_loop(iter("abcde"), ["a", "b", "c", "d", "e"])
315
316    # Test a directory
317    def test_iter_dict(self):
318        dict = {}
319        for i in range(10):
320            dict[i] = None
321        self.check_for_loop(dict, list(dict.keys()))
322
323    # Test a file
324    def test_iter_file(self):
325        f = open(TESTFN, "w")
326        try:
327            for i in range(5):
328                f.write("%d\n" % i)
329        finally:
330            f.close()
331        f = open(TESTFN, "r")
332        try:
333            self.check_for_loop(f, ["0\n", "1\n", "2\n", "3\n", "4\n"], pickle=False)
334            self.check_for_loop(f, [], pickle=False)
335        finally:
336            f.close()
337            try:
338                unlink(TESTFN)
339            except OSError:
340                pass
341
342    # Test list()'s use of iterators.
343    def test_builtin_list(self):
344        self.assertEqual(list(SequenceClass(5)), list(range(5)))
345        self.assertEqual(list(SequenceClass(0)), [])
346        self.assertEqual(list(()), [])
347
348        d = {"one": 1, "two": 2, "three": 3}
349        self.assertEqual(list(d), list(d.keys()))
350
351        self.assertRaises(TypeError, list, list)
352        self.assertRaises(TypeError, list, 42)
353
354        f = open(TESTFN, "w")
355        try:
356            for i in range(5):
357                f.write("%d\n" % i)
358        finally:
359            f.close()
360        f = open(TESTFN, "r")
361        try:
362            self.assertEqual(list(f), ["0\n", "1\n", "2\n", "3\n", "4\n"])
363            f.seek(0, 0)
364            self.assertEqual(list(f),
365                             ["0\n", "1\n", "2\n", "3\n", "4\n"])
366        finally:
367            f.close()
368            try:
369                unlink(TESTFN)
370            except OSError:
371                pass
372
373    # Test tuples()'s use of iterators.
374    def test_builtin_tuple(self):
375        self.assertEqual(tuple(SequenceClass(5)), (0, 1, 2, 3, 4))
376        self.assertEqual(tuple(SequenceClass(0)), ())
377        self.assertEqual(tuple([]), ())
378        self.assertEqual(tuple(()), ())
379        self.assertEqual(tuple("abc"), ("a", "b", "c"))
380
381        d = {"one": 1, "two": 2, "three": 3}
382        self.assertEqual(tuple(d), tuple(d.keys()))
383
384        self.assertRaises(TypeError, tuple, list)
385        self.assertRaises(TypeError, tuple, 42)
386
387        f = open(TESTFN, "w")
388        try:
389            for i in range(5):
390                f.write("%d\n" % i)
391        finally:
392            f.close()
393        f = open(TESTFN, "r")
394        try:
395            self.assertEqual(tuple(f), ("0\n", "1\n", "2\n", "3\n", "4\n"))
396            f.seek(0, 0)
397            self.assertEqual(tuple(f),
398                             ("0\n", "1\n", "2\n", "3\n", "4\n"))
399        finally:
400            f.close()
401            try:
402                unlink(TESTFN)
403            except OSError:
404                pass
405
406    # Test filter()'s use of iterators.
407    def test_builtin_filter(self):
408        self.assertEqual(list(filter(None, SequenceClass(5))),
409                         list(range(1, 5)))
410        self.assertEqual(list(filter(None, SequenceClass(0))), [])
411        self.assertEqual(list(filter(None, ())), [])
412        self.assertEqual(list(filter(None, "abc")), ["a", "b", "c"])
413
414        d = {"one": 1, "two": 2, "three": 3}
415        self.assertEqual(list(filter(None, d)), list(d.keys()))
416
417        self.assertRaises(TypeError, filter, None, list)
418        self.assertRaises(TypeError, filter, None, 42)
419
420        class Boolean:
421            def __init__(self, truth):
422                self.truth = truth
423            def __bool__(self):
424                return self.truth
425        bTrue = Boolean(True)
426        bFalse = Boolean(False)
427
428        class Seq:
429            def __init__(self, *args):
430                self.vals = args
431            def __iter__(self):
432                class SeqIter:
433                    def __init__(self, vals):
434                        self.vals = vals
435                        self.i = 0
436                    def __iter__(self):
437                        return self
438                    def __next__(self):
439                        i = self.i
440                        self.i = i + 1
441                        if i < len(self.vals):
442                            return self.vals[i]
443                        else:
444                            raise StopIteration
445                return SeqIter(self.vals)
446
447        seq = Seq(*([bTrue, bFalse] * 25))
448        self.assertEqual(list(filter(lambda x: not x, seq)), [bFalse]*25)
449        self.assertEqual(list(filter(lambda x: not x, iter(seq))), [bFalse]*25)
450
451    # Test max() and min()'s use of iterators.
452    def test_builtin_max_min(self):
453        self.assertEqual(max(SequenceClass(5)), 4)
454        self.assertEqual(min(SequenceClass(5)), 0)
455        self.assertEqual(max(8, -1), 8)
456        self.assertEqual(min(8, -1), -1)
457
458        d = {"one": 1, "two": 2, "three": 3}
459        self.assertEqual(max(d), "two")
460        self.assertEqual(min(d), "one")
461        self.assertEqual(max(d.values()), 3)
462        self.assertEqual(min(iter(d.values())), 1)
463
464        f = open(TESTFN, "w")
465        try:
466            f.write("medium line\n")
467            f.write("xtra large line\n")
468            f.write("itty-bitty line\n")
469        finally:
470            f.close()
471        f = open(TESTFN, "r")
472        try:
473            self.assertEqual(min(f), "itty-bitty line\n")
474            f.seek(0, 0)
475            self.assertEqual(max(f), "xtra large line\n")
476        finally:
477            f.close()
478            try:
479                unlink(TESTFN)
480            except OSError:
481                pass
482
483    # Test map()'s use of iterators.
484    def test_builtin_map(self):
485        self.assertEqual(list(map(lambda x: x+1, SequenceClass(5))),
486                         list(range(1, 6)))
487
488        d = {"one": 1, "two": 2, "three": 3}
489        self.assertEqual(list(map(lambda k, d=d: (k, d[k]), d)),
490                         list(d.items()))
491        dkeys = list(d.keys())
492        expected = [(i < len(d) and dkeys[i] or None,
493                     i,
494                     i < len(d) and dkeys[i] or None)
495                    for i in range(3)]
496
497        f = open(TESTFN, "w")
498        try:
499            for i in range(10):
500                f.write("xy" * i + "\n") # line i has len 2*i+1
501        finally:
502            f.close()
503        f = open(TESTFN, "r")
504        try:
505            self.assertEqual(list(map(len, f)), list(range(1, 21, 2)))
506        finally:
507            f.close()
508            try:
509                unlink(TESTFN)
510            except OSError:
511                pass
512
513    # Test zip()'s use of iterators.
514    def test_builtin_zip(self):
515        self.assertEqual(list(zip()), [])
516        self.assertEqual(list(zip(*[])), [])
517        self.assertEqual(list(zip(*[(1, 2), 'ab'])), [(1, 'a'), (2, 'b')])
518
519        self.assertRaises(TypeError, zip, None)
520        self.assertRaises(TypeError, zip, range(10), 42)
521        self.assertRaises(TypeError, zip, range(10), zip)
522
523        self.assertEqual(list(zip(IteratingSequenceClass(3))),
524                         [(0,), (1,), (2,)])
525        self.assertEqual(list(zip(SequenceClass(3))),
526                         [(0,), (1,), (2,)])
527
528        d = {"one": 1, "two": 2, "three": 3}
529        self.assertEqual(list(d.items()), list(zip(d, d.values())))
530
531        # Generate all ints starting at constructor arg.
532        class IntsFrom:
533            def __init__(self, start):
534                self.i = start
535
536            def __iter__(self):
537                return self
538
539            def __next__(self):
540                i = self.i
541                self.i = i+1
542                return i
543
544        f = open(TESTFN, "w")
545        try:
546            f.write("a\n" "bbb\n" "cc\n")
547        finally:
548            f.close()
549        f = open(TESTFN, "r")
550        try:
551            self.assertEqual(list(zip(IntsFrom(0), f, IntsFrom(-100))),
552                             [(0, "a\n", -100),
553                              (1, "bbb\n", -99),
554                              (2, "cc\n", -98)])
555        finally:
556            f.close()
557            try:
558                unlink(TESTFN)
559            except OSError:
560                pass
561
562        self.assertEqual(list(zip(range(5))), [(i,) for i in range(5)])
563
564        # Classes that lie about their lengths.
565        class NoGuessLen5:
566            def __getitem__(self, i):
567                if i >= 5:
568                    raise IndexError
569                return i
570
571        class Guess3Len5(NoGuessLen5):
572            def __len__(self):
573                return 3
574
575        class Guess30Len5(NoGuessLen5):
576            def __len__(self):
577                return 30
578
579        def lzip(*args):
580            return list(zip(*args))
581
582        self.assertEqual(len(Guess3Len5()), 3)
583        self.assertEqual(len(Guess30Len5()), 30)
584        self.assertEqual(lzip(NoGuessLen5()), lzip(range(5)))
585        self.assertEqual(lzip(Guess3Len5()), lzip(range(5)))
586        self.assertEqual(lzip(Guess30Len5()), lzip(range(5)))
587
588        expected = [(i, i) for i in range(5)]
589        for x in NoGuessLen5(), Guess3Len5(), Guess30Len5():
590            for y in NoGuessLen5(), Guess3Len5(), Guess30Len5():
591                self.assertEqual(lzip(x, y), expected)
592
593    def test_unicode_join_endcase(self):
594
595        # This class inserts a Unicode object into its argument's natural
596        # iteration, in the 3rd position.
597        class OhPhooey:
598            def __init__(self, seq):
599                self.it = iter(seq)
600                self.i = 0
601
602            def __iter__(self):
603                return self
604
605            def __next__(self):
606                i = self.i
607                self.i = i+1
608                if i == 2:
609                    return "fooled you!"
610                return next(self.it)
611
612        f = open(TESTFN, "w")
613        try:
614            f.write("a\n" + "b\n" + "c\n")
615        finally:
616            f.close()
617
618        f = open(TESTFN, "r")
619        # Nasty:  string.join(s) can't know whether unicode.join() is needed
620        # until it's seen all of s's elements.  But in this case, f's
621        # iterator cannot be restarted.  So what we're testing here is
622        # whether string.join() can manage to remember everything it's seen
623        # and pass that on to unicode.join().
624        try:
625            got = " - ".join(OhPhooey(f))
626            self.assertEqual(got, "a\n - b\n - fooled you! - c\n")
627        finally:
628            f.close()
629            try:
630                unlink(TESTFN)
631            except OSError:
632                pass
633
634    # Test iterators with 'x in y' and 'x not in y'.
635    def test_in_and_not_in(self):
636        for sc5 in IteratingSequenceClass(5), SequenceClass(5):
637            for i in range(5):
638                self.assertIn(i, sc5)
639            for i in "abc", -1, 5, 42.42, (3, 4), [], {1: 1}, 3-12j, sc5:
640                self.assertNotIn(i, sc5)
641
642        self.assertRaises(TypeError, lambda: 3 in 12)
643        self.assertRaises(TypeError, lambda: 3 not in map)
644        self.assertRaises(ZeroDivisionError, lambda: 3 in BadIterableClass())
645
646        d = {"one": 1, "two": 2, "three": 3, 1j: 2j}
647        for k in d:
648            self.assertIn(k, d)
649            self.assertNotIn(k, d.values())
650        for v in d.values():
651            self.assertIn(v, d.values())
652            self.assertNotIn(v, d)
653        for k, v in d.items():
654            self.assertIn((k, v), d.items())
655            self.assertNotIn((v, k), d.items())
656
657        f = open(TESTFN, "w")
658        try:
659            f.write("a\n" "b\n" "c\n")
660        finally:
661            f.close()
662        f = open(TESTFN, "r")
663        try:
664            for chunk in "abc":
665                f.seek(0, 0)
666                self.assertNotIn(chunk, f)
667                f.seek(0, 0)
668                self.assertIn((chunk + "\n"), f)
669        finally:
670            f.close()
671            try:
672                unlink(TESTFN)
673            except OSError:
674                pass
675
676    # Test iterators with operator.countOf (PySequence_Count).
677    def test_countOf(self):
678        from operator import countOf
679        self.assertEqual(countOf([1,2,2,3,2,5], 2), 3)
680        self.assertEqual(countOf((1,2,2,3,2,5), 2), 3)
681        self.assertEqual(countOf("122325", "2"), 3)
682        self.assertEqual(countOf("122325", "6"), 0)
683
684        self.assertRaises(TypeError, countOf, 42, 1)
685        self.assertRaises(TypeError, countOf, countOf, countOf)
686
687        d = {"one": 3, "two": 3, "three": 3, 1j: 2j}
688        for k in d:
689            self.assertEqual(countOf(d, k), 1)
690        self.assertEqual(countOf(d.values(), 3), 3)
691        self.assertEqual(countOf(d.values(), 2j), 1)
692        self.assertEqual(countOf(d.values(), 1j), 0)
693
694        f = open(TESTFN, "w")
695        try:
696            f.write("a\n" "b\n" "c\n" "b\n")
697        finally:
698            f.close()
699        f = open(TESTFN, "r")
700        try:
701            for letter, count in ("a", 1), ("b", 2), ("c", 1), ("d", 0):
702                f.seek(0, 0)
703                self.assertEqual(countOf(f, letter + "\n"), count)
704        finally:
705            f.close()
706            try:
707                unlink(TESTFN)
708            except OSError:
709                pass
710
711    # Test iterators with operator.indexOf (PySequence_Index).
712    def test_indexOf(self):
713        from operator import indexOf
714        self.assertEqual(indexOf([1,2,2,3,2,5], 1), 0)
715        self.assertEqual(indexOf((1,2,2,3,2,5), 2), 1)
716        self.assertEqual(indexOf((1,2,2,3,2,5), 3), 3)
717        self.assertEqual(indexOf((1,2,2,3,2,5), 5), 5)
718        self.assertRaises(ValueError, indexOf, (1,2,2,3,2,5), 0)
719        self.assertRaises(ValueError, indexOf, (1,2,2,3,2,5), 6)
720
721        self.assertEqual(indexOf("122325", "2"), 1)
722        self.assertEqual(indexOf("122325", "5"), 5)
723        self.assertRaises(ValueError, indexOf, "122325", "6")
724
725        self.assertRaises(TypeError, indexOf, 42, 1)
726        self.assertRaises(TypeError, indexOf, indexOf, indexOf)
727        self.assertRaises(ZeroDivisionError, indexOf, BadIterableClass(), 1)
728
729        f = open(TESTFN, "w")
730        try:
731            f.write("a\n" "b\n" "c\n" "d\n" "e\n")
732        finally:
733            f.close()
734        f = open(TESTFN, "r")
735        try:
736            fiter = iter(f)
737            self.assertEqual(indexOf(fiter, "b\n"), 1)
738            self.assertEqual(indexOf(fiter, "d\n"), 1)
739            self.assertEqual(indexOf(fiter, "e\n"), 0)
740            self.assertRaises(ValueError, indexOf, fiter, "a\n")
741        finally:
742            f.close()
743            try:
744                unlink(TESTFN)
745            except OSError:
746                pass
747
748        iclass = IteratingSequenceClass(3)
749        for i in range(3):
750            self.assertEqual(indexOf(iclass, i), i)
751        self.assertRaises(ValueError, indexOf, iclass, -1)
752
753    # Test iterators with file.writelines().
754    def test_writelines(self):
755        f = open(TESTFN, "w")
756
757        try:
758            self.assertRaises(TypeError, f.writelines, None)
759            self.assertRaises(TypeError, f.writelines, 42)
760
761            f.writelines(["1\n", "2\n"])
762            f.writelines(("3\n", "4\n"))
763            f.writelines({'5\n': None})
764            f.writelines({})
765
766            # Try a big chunk too.
767            class Iterator:
768                def __init__(self, start, finish):
769                    self.start = start
770                    self.finish = finish
771                    self.i = self.start
772
773                def __next__(self):
774                    if self.i >= self.finish:
775                        raise StopIteration
776                    result = str(self.i) + '\n'
777                    self.i += 1
778                    return result
779
780                def __iter__(self):
781                    return self
782
783            class Whatever:
784                def __init__(self, start, finish):
785                    self.start = start
786                    self.finish = finish
787
788                def __iter__(self):
789                    return Iterator(self.start, self.finish)
790
791            f.writelines(Whatever(6, 6+2000))
792            f.close()
793
794            f = open(TESTFN)
795            expected = [str(i) + "\n" for i in range(1, 2006)]
796            self.assertEqual(list(f), expected)
797
798        finally:
799            f.close()
800            try:
801                unlink(TESTFN)
802            except OSError:
803                pass
804
805
806    # Test iterators on RHS of unpacking assignments.
807    def test_unpack_iter(self):
808        a, b = 1, 2
809        self.assertEqual((a, b), (1, 2))
810
811        a, b, c = IteratingSequenceClass(3)
812        self.assertEqual((a, b, c), (0, 1, 2))
813
814        try:    # too many values
815            a, b = IteratingSequenceClass(3)
816        except ValueError:
817            pass
818        else:
819            self.fail("should have raised ValueError")
820
821        try:    # not enough values
822            a, b, c = IteratingSequenceClass(2)
823        except ValueError:
824            pass
825        else:
826            self.fail("should have raised ValueError")
827
828        try:    # not iterable
829            a, b, c = len
830        except TypeError:
831            pass
832        else:
833            self.fail("should have raised TypeError")
834
835        a, b, c = {1: 42, 2: 42, 3: 42}.values()
836        self.assertEqual((a, b, c), (42, 42, 42))
837
838        f = open(TESTFN, "w")
839        lines = ("a\n", "bb\n", "ccc\n")
840        try:
841            for line in lines:
842                f.write(line)
843        finally:
844            f.close()
845        f = open(TESTFN, "r")
846        try:
847            a, b, c = f
848            self.assertEqual((a, b, c), lines)
849        finally:
850            f.close()
851            try:
852                unlink(TESTFN)
853            except OSError:
854                pass
855
856        (a, b), (c,) = IteratingSequenceClass(2), {42: 24}
857        self.assertEqual((a, b, c), (0, 1, 42))
858
859
860    @cpython_only
861    def test_ref_counting_behavior(self):
862        class C(object):
863            count = 0
864            def __new__(cls):
865                cls.count += 1
866                return object.__new__(cls)
867            def __del__(self):
868                cls = self.__class__
869                assert cls.count > 0
870                cls.count -= 1
871        x = C()
872        self.assertEqual(C.count, 1)
873        del x
874        self.assertEqual(C.count, 0)
875        l = [C(), C(), C()]
876        self.assertEqual(C.count, 3)
877        try:
878            a, b = iter(l)
879        except ValueError:
880            pass
881        del l
882        self.assertEqual(C.count, 0)
883
884
885    # Make sure StopIteration is a "sink state".
886    # This tests various things that weren't sink states in Python 2.2.1,
887    # plus various things that always were fine.
888
889    def test_sinkstate_list(self):
890        # This used to fail
891        a = list(range(5))
892        b = iter(a)
893        self.assertEqual(list(b), list(range(5)))
894        a.extend(range(5, 10))
895        self.assertEqual(list(b), [])
896
897    def test_sinkstate_tuple(self):
898        a = (0, 1, 2, 3, 4)
899        b = iter(a)
900        self.assertEqual(list(b), list(range(5)))
901        self.assertEqual(list(b), [])
902
903    def test_sinkstate_string(self):
904        a = "abcde"
905        b = iter(a)
906        self.assertEqual(list(b), ['a', 'b', 'c', 'd', 'e'])
907        self.assertEqual(list(b), [])
908
909    def test_sinkstate_sequence(self):
910        # This used to fail
911        a = SequenceClass(5)
912        b = iter(a)
913        self.assertEqual(list(b), list(range(5)))
914        a.n = 10
915        self.assertEqual(list(b), [])
916
917    def test_sinkstate_callable(self):
918        # This used to fail
919        def spam(state=[0]):
920            i = state[0]
921            state[0] = i+1
922            if i == 10:
923                raise AssertionError("shouldn't have gotten this far")
924            return i
925        b = iter(spam, 5)
926        self.assertEqual(list(b), list(range(5)))
927        self.assertEqual(list(b), [])
928
929    def test_sinkstate_dict(self):
930        # XXX For a more thorough test, see towards the end of:
931        # http://mail.python.org/pipermail/python-dev/2002-July/026512.html
932        a = {1:1, 2:2, 0:0, 4:4, 3:3}
933        for b in iter(a), a.keys(), a.items(), a.values():
934            b = iter(a)
935            self.assertEqual(len(list(b)), 5)
936            self.assertEqual(list(b), [])
937
938    def test_sinkstate_yield(self):
939        def gen():
940            for i in range(5):
941                yield i
942        b = gen()
943        self.assertEqual(list(b), list(range(5)))
944        self.assertEqual(list(b), [])
945
946    def test_sinkstate_range(self):
947        a = range(5)
948        b = iter(a)
949        self.assertEqual(list(b), list(range(5)))
950        self.assertEqual(list(b), [])
951
952    def test_sinkstate_enumerate(self):
953        a = range(5)
954        e = enumerate(a)
955        b = iter(e)
956        self.assertEqual(list(b), list(zip(range(5), range(5))))
957        self.assertEqual(list(b), [])
958
959    def test_3720(self):
960        # Avoid a crash, when an iterator deletes its next() method.
961        class BadIterator(object):
962            def __iter__(self):
963                return self
964            def __next__(self):
965                del BadIterator.__next__
966                return 1
967
968        try:
969            for i in BadIterator() :
970                pass
971        except TypeError:
972            pass
973
974    def test_extending_list_with_iterator_does_not_segfault(self):
975        # The code to extend a list with an iterator has a fair
976        # amount of nontrivial logic in terms of guessing how
977        # much memory to allocate in advance, "stealing" refs,
978        # and then shrinking at the end.  This is a basic smoke
979        # test for that scenario.
980        def gen():
981            for i in range(500):
982                yield i
983        lst = [0] * 500
984        for i in range(240):
985            lst.pop(0)
986        lst.extend(gen())
987        self.assertEqual(len(lst), 760)
988
989    @cpython_only
990    def test_iter_overflow(self):
991        # Test for the issue 22939
992        it = iter(UnlimitedSequenceClass())
993        # Manually set `it_index` to PY_SSIZE_T_MAX-2 without a loop
994        it.__setstate__(sys.maxsize - 2)
995        self.assertEqual(next(it), sys.maxsize - 2)
996        self.assertEqual(next(it), sys.maxsize - 1)
997        with self.assertRaises(OverflowError):
998            next(it)
999        # Check that Overflow error is always raised
1000        with self.assertRaises(OverflowError):
1001            next(it)
1002
1003    def test_iter_neg_setstate(self):
1004        it = iter(UnlimitedSequenceClass())
1005        it.__setstate__(-42)
1006        self.assertEqual(next(it), 0)
1007        self.assertEqual(next(it), 1)
1008
1009    def test_free_after_iterating(self):
1010        check_free_after_iterating(self, iter, SequenceClass, (0,))
1011
1012    def test_error_iter(self):
1013        for typ in (DefaultIterClass, NoIterClass):
1014            self.assertRaises(TypeError, iter, typ())
1015        self.assertRaises(ZeroDivisionError, iter, BadIterableClass())
1016
1017
1018def test_main():
1019    run_unittest(TestCase)
1020
1021
1022if __name__ == "__main__":
1023    test_main()
1024