• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Test iterators.
2
3import sys
4import unittest
5from test.support import cpython_only
6from test.support.os_helper import TESTFN, unlink
7from test.support import check_free_after_iterating, ALWAYS_EQ, NEVER_EQ
8from test.support import BrokenIter
9import pickle
10import collections.abc
11import functools
12import contextlib
13import builtins
14import traceback
15
16# Test result of triple loop (too big to inline)
17TRIPLETS = [(0, 0, 0), (0, 0, 1), (0, 0, 2),
18            (0, 1, 0), (0, 1, 1), (0, 1, 2),
19            (0, 2, 0), (0, 2, 1), (0, 2, 2),
20
21            (1, 0, 0), (1, 0, 1), (1, 0, 2),
22            (1, 1, 0), (1, 1, 1), (1, 1, 2),
23            (1, 2, 0), (1, 2, 1), (1, 2, 2),
24
25            (2, 0, 0), (2, 0, 1), (2, 0, 2),
26            (2, 1, 0), (2, 1, 1), (2, 1, 2),
27            (2, 2, 0), (2, 2, 1), (2, 2, 2)]
28
29# Helper classes
30
31class BasicIterClass:
32    def __init__(self, n):
33        self.n = n
34        self.i = 0
35    def __next__(self):
36        res = self.i
37        if res >= self.n:
38            raise StopIteration
39        self.i = res + 1
40        return res
41    def __iter__(self):
42        return self
43
44class IteratingSequenceClass:
45    def __init__(self, n):
46        self.n = n
47    def __iter__(self):
48        return BasicIterClass(self.n)
49
50class IteratorProxyClass:
51    def __init__(self, i):
52        self.i = i
53    def __next__(self):
54        return next(self.i)
55    def __iter__(self):
56        return self
57
58class SequenceClass:
59    def __init__(self, n):
60        self.n = n
61    def __getitem__(self, i):
62        if 0 <= i < self.n:
63            return i
64        else:
65            raise IndexError
66
67class SequenceProxyClass:
68    def __init__(self, s):
69        self.s = s
70    def __getitem__(self, i):
71        return self.s[i]
72
73class UnlimitedSequenceClass:
74    def __getitem__(self, i):
75        return i
76
77class DefaultIterClass:
78    pass
79
80class NoIterClass:
81    def __getitem__(self, i):
82        return i
83    __iter__ = None
84
85class BadIterableClass:
86    def __iter__(self):
87        raise ZeroDivisionError
88
89class CallableIterClass:
90    def __init__(self):
91        self.i = 0
92    def __call__(self):
93        i = self.i
94        self.i = i + 1
95        if i > 100:
96            raise IndexError # Emergency stop
97        return i
98
99class EmptyIterClass:
100    def __len__(self):
101        return 0
102    def __getitem__(self, i):
103        raise StopIteration
104
105# Main test suite
106
107class TestCase(unittest.TestCase):
108
109    # Helper to check that an iterator returns a given sequence
110    def check_iterator(self, it, seq, pickle=True):
111        if pickle:
112            self.check_pickle(it, seq)
113        res = []
114        while 1:
115            try:
116                val = next(it)
117            except StopIteration:
118                break
119            res.append(val)
120        self.assertEqual(res, seq)
121
122    # Helper to check that a for loop generates a given sequence
123    def check_for_loop(self, expr, seq, pickle=True):
124        if pickle:
125            self.check_pickle(iter(expr), seq)
126        res = []
127        for val in expr:
128            res.append(val)
129        self.assertEqual(res, seq)
130
131    # Helper to check picklability
132    def check_pickle(self, itorg, seq):
133        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
134            d = pickle.dumps(itorg, proto)
135            it = pickle.loads(d)
136            # Cannot assert type equality because dict iterators unpickle as list
137            # iterators.
138            # self.assertEqual(type(itorg), type(it))
139            self.assertTrue(isinstance(it, collections.abc.Iterator))
140            self.assertEqual(list(it), seq)
141
142            it = pickle.loads(d)
143            try:
144                next(it)
145            except StopIteration:
146                continue
147            d = pickle.dumps(it, proto)
148            it = pickle.loads(d)
149            self.assertEqual(list(it), seq[1:])
150
151    # Test basic use of iter() function
152    def test_iter_basic(self):
153        self.check_iterator(iter(range(10)), list(range(10)))
154
155    # Test that iter(iter(x)) is the same as iter(x)
156    def test_iter_idempotency(self):
157        seq = list(range(10))
158        it = iter(seq)
159        it2 = iter(it)
160        self.assertTrue(it is it2)
161
162    # Test that for loops over iterators work
163    def test_iter_for_loop(self):
164        self.check_for_loop(iter(range(10)), list(range(10)))
165
166    # Test several independent iterators over the same list
167    def test_iter_independence(self):
168        seq = range(3)
169        res = []
170        for i in iter(seq):
171            for j in iter(seq):
172                for k in iter(seq):
173                    res.append((i, j, k))
174        self.assertEqual(res, TRIPLETS)
175
176    # Test triple list comprehension using iterators
177    def test_nested_comprehensions_iter(self):
178        seq = range(3)
179        res = [(i, j, k)
180               for i in iter(seq) for j in iter(seq) for k in iter(seq)]
181        self.assertEqual(res, TRIPLETS)
182
183    # Test triple list comprehension without iterators
184    def test_nested_comprehensions_for(self):
185        seq = range(3)
186        res = [(i, j, k) for i in seq for j in seq for k in seq]
187        self.assertEqual(res, TRIPLETS)
188
189    # Test a class with __iter__ in a for loop
190    def test_iter_class_for(self):
191        self.check_for_loop(IteratingSequenceClass(10), list(range(10)))
192
193    # Test a class with __iter__ with explicit iter()
194    def test_iter_class_iter(self):
195        self.check_iterator(iter(IteratingSequenceClass(10)), list(range(10)))
196
197    # Test for loop on a sequence class without __iter__
198    def test_seq_class_for(self):
199        self.check_for_loop(SequenceClass(10), list(range(10)))
200
201    # Test iter() on a sequence class without __iter__
202    def test_seq_class_iter(self):
203        self.check_iterator(iter(SequenceClass(10)), list(range(10)))
204
205    def test_mutating_seq_class_iter_pickle(self):
206        orig = SequenceClass(5)
207        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
208            # initial iterator
209            itorig = iter(orig)
210            d = pickle.dumps((itorig, orig), proto)
211            it, seq = pickle.loads(d)
212            seq.n = 7
213            self.assertIs(type(it), type(itorig))
214            self.assertEqual(list(it), list(range(7)))
215
216            # running iterator
217            next(itorig)
218            d = pickle.dumps((itorig, orig), proto)
219            it, seq = pickle.loads(d)
220            seq.n = 7
221            self.assertIs(type(it), type(itorig))
222            self.assertEqual(list(it), list(range(1, 7)))
223
224            # empty iterator
225            for i in range(1, 5):
226                next(itorig)
227            d = pickle.dumps((itorig, orig), proto)
228            it, seq = pickle.loads(d)
229            seq.n = 7
230            self.assertIs(type(it), type(itorig))
231            self.assertEqual(list(it), list(range(5, 7)))
232
233            # exhausted iterator
234            self.assertRaises(StopIteration, next, itorig)
235            d = pickle.dumps((itorig, orig), proto)
236            it, seq = pickle.loads(d)
237            seq.n = 7
238            self.assertTrue(isinstance(it, collections.abc.Iterator))
239            self.assertEqual(list(it), [])
240
241    def test_mutating_seq_class_exhausted_iter(self):
242        a = SequenceClass(5)
243        exhit = iter(a)
244        empit = iter(a)
245        for x in exhit:  # exhaust the iterator
246            next(empit)  # not exhausted
247        a.n = 7
248        self.assertEqual(list(exhit), [])
249        self.assertEqual(list(empit), [5, 6])
250        self.assertEqual(list(a), [0, 1, 2, 3, 4, 5, 6])
251
252    def test_reduce_mutating_builtins_iter(self):
253        # This is a reproducer of issue #101765
254        # where iter `__reduce__` calls could lead to a segfault or SystemError
255        # depending on the order of C argument evaluation, which is undefined
256
257        # Backup builtins
258        builtins_dict = builtins.__dict__
259        orig = {"iter": iter, "reversed": reversed}
260
261        def run(builtin_name, item, sentinel=None):
262            it = iter(item) if sentinel is None else iter(item, sentinel)
263
264            class CustomStr:
265                def __init__(self, name, iterator):
266                    self.name = name
267                    self.iterator = iterator
268                def __hash__(self):
269                    return hash(self.name)
270                def __eq__(self, other):
271                    # Here we exhaust our iterator, possibly changing
272                    # its `it_seq` pointer to NULL
273                    # The `__reduce__` call should correctly get
274                    # the pointers after this call
275                    list(self.iterator)
276                    return other == self.name
277
278            # del is required here
279            # to not prematurely call __eq__ from
280            # the hash collision with the old key
281            del builtins_dict[builtin_name]
282            builtins_dict[CustomStr(builtin_name, it)] = orig[builtin_name]
283
284            return it.__reduce__()
285
286        types = [
287            (EmptyIterClass(),),
288            (bytes(8),),
289            (bytearray(8),),
290            ((1, 2, 3),),
291            (lambda: 0, 0),
292            (tuple[int],)  # GenericAlias
293        ]
294
295        try:
296            run_iter = functools.partial(run, "iter")
297            # The returned value of `__reduce__` should not only be valid
298            # but also *empty*, as `it` was exhausted during `__eq__`
299            # i.e "xyz" returns (iter, ("",))
300            self.assertEqual(run_iter("xyz"), (orig["iter"], ("",)))
301            self.assertEqual(run_iter([1, 2, 3]), (orig["iter"], ([],)))
302
303            # _PyEval_GetBuiltin is also called for `reversed` in a branch of
304            # listiter_reduce_general
305            self.assertEqual(
306                run("reversed", orig["reversed"](list(range(8)))),
307                (reversed, ([],))
308            )
309
310            for case in types:
311                self.assertEqual(run_iter(*case), (orig["iter"], ((),)))
312        finally:
313            # Restore original builtins
314            for key, func in orig.items():
315                # need to suppress KeyErrors in case
316                # a failed test deletes the key without setting anything
317                with contextlib.suppress(KeyError):
318                    # del is required here
319                    # to not invoke our custom __eq__ from
320                    # the hash collision with the old key
321                    del builtins_dict[key]
322                builtins_dict[key] = func
323
324    # Test a new_style class with __iter__ but no next() method
325    def test_new_style_iter_class(self):
326        class IterClass(object):
327            def __iter__(self):
328                return self
329        self.assertRaises(TypeError, iter, IterClass())
330
331    # Test two-argument iter() with callable instance
332    def test_iter_callable(self):
333        self.check_iterator(iter(CallableIterClass(), 10), list(range(10)), pickle=True)
334
335    # Test two-argument iter() with function
336    def test_iter_function(self):
337        def spam(state=[0]):
338            i = state[0]
339            state[0] = i+1
340            return i
341        self.check_iterator(iter(spam, 10), list(range(10)), pickle=False)
342
343    # Test two-argument iter() with function that raises StopIteration
344    def test_iter_function_stop(self):
345        def spam(state=[0]):
346            i = state[0]
347            if i == 10:
348                raise StopIteration
349            state[0] = i+1
350            return i
351        self.check_iterator(iter(spam, 20), list(range(10)), pickle=False)
352
353    def test_iter_function_concealing_reentrant_exhaustion(self):
354        # gh-101892: Test two-argument iter() with a function that
355        # exhausts its associated iterator but forgets to either return
356        # a sentinel value or raise StopIteration.
357        HAS_MORE = 1
358        NO_MORE = 2
359
360        def exhaust(iterator):
361            """Exhaust an iterator without raising StopIteration."""
362            list(iterator)
363
364        def spam():
365            # Touching the iterator with exhaust() below will call
366            # spam() once again so protect against recursion.
367            if spam.is_recursive_call:
368                return NO_MORE
369            spam.is_recursive_call = True
370            exhaust(spam.iterator)
371            return HAS_MORE
372
373        spam.is_recursive_call = False
374        spam.iterator = iter(spam, NO_MORE)
375        with self.assertRaises(StopIteration):
376            next(spam.iterator)
377
378    # Test exception propagation through function iterator
379    def test_exception_function(self):
380        def spam(state=[0]):
381            i = state[0]
382            state[0] = i+1
383            if i == 10:
384                raise RuntimeError
385            return i
386        res = []
387        try:
388            for x in iter(spam, 20):
389                res.append(x)
390        except RuntimeError:
391            self.assertEqual(res, list(range(10)))
392        else:
393            self.fail("should have raised RuntimeError")
394
395    # Test exception propagation through sequence iterator
396    def test_exception_sequence(self):
397        class MySequenceClass(SequenceClass):
398            def __getitem__(self, i):
399                if i == 10:
400                    raise RuntimeError
401                return SequenceClass.__getitem__(self, i)
402        res = []
403        try:
404            for x in MySequenceClass(20):
405                res.append(x)
406        except RuntimeError:
407            self.assertEqual(res, list(range(10)))
408        else:
409            self.fail("should have raised RuntimeError")
410
411    # Test for StopIteration from __getitem__
412    def test_stop_sequence(self):
413        class MySequenceClass(SequenceClass):
414            def __getitem__(self, i):
415                if i == 10:
416                    raise StopIteration
417                return SequenceClass.__getitem__(self, i)
418        self.check_for_loop(MySequenceClass(20), list(range(10)), pickle=False)
419
420    # Test a big range
421    def test_iter_big_range(self):
422        self.check_for_loop(iter(range(10000)), list(range(10000)))
423
424    # Test an empty list
425    def test_iter_empty(self):
426        self.check_for_loop(iter([]), [])
427
428    # Test a tuple
429    def test_iter_tuple(self):
430        self.check_for_loop(iter((0,1,2,3,4,5,6,7,8,9)), list(range(10)))
431
432    # Test a range
433    def test_iter_range(self):
434        self.check_for_loop(iter(range(10)), list(range(10)))
435
436    # Test a string
437    def test_iter_string(self):
438        self.check_for_loop(iter("abcde"), ["a", "b", "c", "d", "e"])
439
440    # Test a directory
441    def test_iter_dict(self):
442        dict = {}
443        for i in range(10):
444            dict[i] = None
445        self.check_for_loop(dict, list(dict.keys()))
446
447    # Test a file
448    def test_iter_file(self):
449        f = open(TESTFN, "w", encoding="utf-8")
450        try:
451            for i in range(5):
452                f.write("%d\n" % i)
453        finally:
454            f.close()
455        f = open(TESTFN, "r", encoding="utf-8")
456        try:
457            self.check_for_loop(f, ["0\n", "1\n", "2\n", "3\n", "4\n"], pickle=False)
458            self.check_for_loop(f, [], pickle=False)
459        finally:
460            f.close()
461            try:
462                unlink(TESTFN)
463            except OSError:
464                pass
465
466    # Test list()'s use of iterators.
467    def test_builtin_list(self):
468        self.assertEqual(list(SequenceClass(5)), list(range(5)))
469        self.assertEqual(list(SequenceClass(0)), [])
470        self.assertEqual(list(()), [])
471
472        d = {"one": 1, "two": 2, "three": 3}
473        self.assertEqual(list(d), list(d.keys()))
474
475        self.assertRaises(TypeError, list, list)
476        self.assertRaises(TypeError, list, 42)
477
478        f = open(TESTFN, "w", encoding="utf-8")
479        try:
480            for i in range(5):
481                f.write("%d\n" % i)
482        finally:
483            f.close()
484        f = open(TESTFN, "r", encoding="utf-8")
485        try:
486            self.assertEqual(list(f), ["0\n", "1\n", "2\n", "3\n", "4\n"])
487            f.seek(0, 0)
488            self.assertEqual(list(f),
489                             ["0\n", "1\n", "2\n", "3\n", "4\n"])
490        finally:
491            f.close()
492            try:
493                unlink(TESTFN)
494            except OSError:
495                pass
496
497    # Test tuples()'s use of iterators.
498    def test_builtin_tuple(self):
499        self.assertEqual(tuple(SequenceClass(5)), (0, 1, 2, 3, 4))
500        self.assertEqual(tuple(SequenceClass(0)), ())
501        self.assertEqual(tuple([]), ())
502        self.assertEqual(tuple(()), ())
503        self.assertEqual(tuple("abc"), ("a", "b", "c"))
504
505        d = {"one": 1, "two": 2, "three": 3}
506        self.assertEqual(tuple(d), tuple(d.keys()))
507
508        self.assertRaises(TypeError, tuple, list)
509        self.assertRaises(TypeError, tuple, 42)
510
511        f = open(TESTFN, "w", encoding="utf-8")
512        try:
513            for i in range(5):
514                f.write("%d\n" % i)
515        finally:
516            f.close()
517        f = open(TESTFN, "r", encoding="utf-8")
518        try:
519            self.assertEqual(tuple(f), ("0\n", "1\n", "2\n", "3\n", "4\n"))
520            f.seek(0, 0)
521            self.assertEqual(tuple(f),
522                             ("0\n", "1\n", "2\n", "3\n", "4\n"))
523        finally:
524            f.close()
525            try:
526                unlink(TESTFN)
527            except OSError:
528                pass
529
530    # Test filter()'s use of iterators.
531    def test_builtin_filter(self):
532        self.assertEqual(list(filter(None, SequenceClass(5))),
533                         list(range(1, 5)))
534        self.assertEqual(list(filter(None, SequenceClass(0))), [])
535        self.assertEqual(list(filter(None, ())), [])
536        self.assertEqual(list(filter(None, "abc")), ["a", "b", "c"])
537
538        d = {"one": 1, "two": 2, "three": 3}
539        self.assertEqual(list(filter(None, d)), list(d.keys()))
540
541        self.assertRaises(TypeError, filter, None, list)
542        self.assertRaises(TypeError, filter, None, 42)
543
544        class Boolean:
545            def __init__(self, truth):
546                self.truth = truth
547            def __bool__(self):
548                return self.truth
549        bTrue = Boolean(True)
550        bFalse = Boolean(False)
551
552        class Seq:
553            def __init__(self, *args):
554                self.vals = args
555            def __iter__(self):
556                class SeqIter:
557                    def __init__(self, vals):
558                        self.vals = vals
559                        self.i = 0
560                    def __iter__(self):
561                        return self
562                    def __next__(self):
563                        i = self.i
564                        self.i = i + 1
565                        if i < len(self.vals):
566                            return self.vals[i]
567                        else:
568                            raise StopIteration
569                return SeqIter(self.vals)
570
571        seq = Seq(*([bTrue, bFalse] * 25))
572        self.assertEqual(list(filter(lambda x: not x, seq)), [bFalse]*25)
573        self.assertEqual(list(filter(lambda x: not x, iter(seq))), [bFalse]*25)
574
575    # Test max() and min()'s use of iterators.
576    def test_builtin_max_min(self):
577        self.assertEqual(max(SequenceClass(5)), 4)
578        self.assertEqual(min(SequenceClass(5)), 0)
579        self.assertEqual(max(8, -1), 8)
580        self.assertEqual(min(8, -1), -1)
581
582        d = {"one": 1, "two": 2, "three": 3}
583        self.assertEqual(max(d), "two")
584        self.assertEqual(min(d), "one")
585        self.assertEqual(max(d.values()), 3)
586        self.assertEqual(min(iter(d.values())), 1)
587
588        f = open(TESTFN, "w", encoding="utf-8")
589        try:
590            f.write("medium line\n")
591            f.write("xtra large line\n")
592            f.write("itty-bitty line\n")
593        finally:
594            f.close()
595        f = open(TESTFN, "r", encoding="utf-8")
596        try:
597            self.assertEqual(min(f), "itty-bitty line\n")
598            f.seek(0, 0)
599            self.assertEqual(max(f), "xtra large line\n")
600        finally:
601            f.close()
602            try:
603                unlink(TESTFN)
604            except OSError:
605                pass
606
607    # Test map()'s use of iterators.
608    def test_builtin_map(self):
609        self.assertEqual(list(map(lambda x: x+1, SequenceClass(5))),
610                         list(range(1, 6)))
611
612        d = {"one": 1, "two": 2, "three": 3}
613        self.assertEqual(list(map(lambda k, d=d: (k, d[k]), d)),
614                         list(d.items()))
615        dkeys = list(d.keys())
616        expected = [(i < len(d) and dkeys[i] or None,
617                     i,
618                     i < len(d) and dkeys[i] or None)
619                    for i in range(3)]
620
621        f = open(TESTFN, "w", encoding="utf-8")
622        try:
623            for i in range(10):
624                f.write("xy" * i + "\n") # line i has len 2*i+1
625        finally:
626            f.close()
627        f = open(TESTFN, "r", encoding="utf-8")
628        try:
629            self.assertEqual(list(map(len, f)), list(range(1, 21, 2)))
630        finally:
631            f.close()
632            try:
633                unlink(TESTFN)
634            except OSError:
635                pass
636
637    # Test zip()'s use of iterators.
638    def test_builtin_zip(self):
639        self.assertEqual(list(zip()), [])
640        self.assertEqual(list(zip(*[])), [])
641        self.assertEqual(list(zip(*[(1, 2), 'ab'])), [(1, 'a'), (2, 'b')])
642
643        self.assertRaises(TypeError, zip, None)
644        self.assertRaises(TypeError, zip, range(10), 42)
645        self.assertRaises(TypeError, zip, range(10), zip)
646
647        self.assertEqual(list(zip(IteratingSequenceClass(3))),
648                         [(0,), (1,), (2,)])
649        self.assertEqual(list(zip(SequenceClass(3))),
650                         [(0,), (1,), (2,)])
651
652        d = {"one": 1, "two": 2, "three": 3}
653        self.assertEqual(list(d.items()), list(zip(d, d.values())))
654
655        # Generate all ints starting at constructor arg.
656        class IntsFrom:
657            def __init__(self, start):
658                self.i = start
659
660            def __iter__(self):
661                return self
662
663            def __next__(self):
664                i = self.i
665                self.i = i+1
666                return i
667
668        f = open(TESTFN, "w", encoding="utf-8")
669        try:
670            f.write("a\n" "bbb\n" "cc\n")
671        finally:
672            f.close()
673        f = open(TESTFN, "r", encoding="utf-8")
674        try:
675            self.assertEqual(list(zip(IntsFrom(0), f, IntsFrom(-100))),
676                             [(0, "a\n", -100),
677                              (1, "bbb\n", -99),
678                              (2, "cc\n", -98)])
679        finally:
680            f.close()
681            try:
682                unlink(TESTFN)
683            except OSError:
684                pass
685
686        self.assertEqual(list(zip(range(5))), [(i,) for i in range(5)])
687
688        # Classes that lie about their lengths.
689        class NoGuessLen5:
690            def __getitem__(self, i):
691                if i >= 5:
692                    raise IndexError
693                return i
694
695        class Guess3Len5(NoGuessLen5):
696            def __len__(self):
697                return 3
698
699        class Guess30Len5(NoGuessLen5):
700            def __len__(self):
701                return 30
702
703        def lzip(*args):
704            return list(zip(*args))
705
706        self.assertEqual(len(Guess3Len5()), 3)
707        self.assertEqual(len(Guess30Len5()), 30)
708        self.assertEqual(lzip(NoGuessLen5()), lzip(range(5)))
709        self.assertEqual(lzip(Guess3Len5()), lzip(range(5)))
710        self.assertEqual(lzip(Guess30Len5()), lzip(range(5)))
711
712        expected = [(i, i) for i in range(5)]
713        for x in NoGuessLen5(), Guess3Len5(), Guess30Len5():
714            for y in NoGuessLen5(), Guess3Len5(), Guess30Len5():
715                self.assertEqual(lzip(x, y), expected)
716
717    def test_unicode_join_endcase(self):
718
719        # This class inserts a Unicode object into its argument's natural
720        # iteration, in the 3rd position.
721        class OhPhooey:
722            def __init__(self, seq):
723                self.it = iter(seq)
724                self.i = 0
725
726            def __iter__(self):
727                return self
728
729            def __next__(self):
730                i = self.i
731                self.i = i+1
732                if i == 2:
733                    return "fooled you!"
734                return next(self.it)
735
736        f = open(TESTFN, "w", encoding="utf-8")
737        try:
738            f.write("a\n" + "b\n" + "c\n")
739        finally:
740            f.close()
741
742        f = open(TESTFN, "r", encoding="utf-8")
743        # Nasty:  string.join(s) can't know whether unicode.join() is needed
744        # until it's seen all of s's elements.  But in this case, f's
745        # iterator cannot be restarted.  So what we're testing here is
746        # whether string.join() can manage to remember everything it's seen
747        # and pass that on to unicode.join().
748        try:
749            got = " - ".join(OhPhooey(f))
750            self.assertEqual(got, "a\n - b\n - fooled you! - c\n")
751        finally:
752            f.close()
753            try:
754                unlink(TESTFN)
755            except OSError:
756                pass
757
758    # Test iterators with 'x in y' and 'x not in y'.
759    def test_in_and_not_in(self):
760        for sc5 in IteratingSequenceClass(5), SequenceClass(5):
761            for i in range(5):
762                self.assertIn(i, sc5)
763            for i in "abc", -1, 5, 42.42, (3, 4), [], {1: 1}, 3-12j, sc5:
764                self.assertNotIn(i, sc5)
765
766        self.assertIn(ALWAYS_EQ, IteratorProxyClass(iter([1])))
767        self.assertIn(ALWAYS_EQ, SequenceProxyClass([1]))
768        self.assertNotIn(ALWAYS_EQ, IteratorProxyClass(iter([NEVER_EQ])))
769        self.assertNotIn(ALWAYS_EQ, SequenceProxyClass([NEVER_EQ]))
770        self.assertIn(NEVER_EQ, IteratorProxyClass(iter([ALWAYS_EQ])))
771        self.assertIn(NEVER_EQ, SequenceProxyClass([ALWAYS_EQ]))
772
773        self.assertRaises(TypeError, lambda: 3 in 12)
774        self.assertRaises(TypeError, lambda: 3 not in map)
775        self.assertRaises(ZeroDivisionError, lambda: 3 in BadIterableClass())
776
777        d = {"one": 1, "two": 2, "three": 3, 1j: 2j}
778        for k in d:
779            self.assertIn(k, d)
780            self.assertNotIn(k, d.values())
781        for v in d.values():
782            self.assertIn(v, d.values())
783            self.assertNotIn(v, d)
784        for k, v in d.items():
785            self.assertIn((k, v), d.items())
786            self.assertNotIn((v, k), d.items())
787
788        f = open(TESTFN, "w", encoding="utf-8")
789        try:
790            f.write("a\n" "b\n" "c\n")
791        finally:
792            f.close()
793        f = open(TESTFN, "r", encoding="utf-8")
794        try:
795            for chunk in "abc":
796                f.seek(0, 0)
797                self.assertNotIn(chunk, f)
798                f.seek(0, 0)
799                self.assertIn((chunk + "\n"), f)
800        finally:
801            f.close()
802            try:
803                unlink(TESTFN)
804            except OSError:
805                pass
806
807    # Test iterators with operator.countOf (PySequence_Count).
808    def test_countOf(self):
809        from operator import countOf
810        self.assertEqual(countOf([1,2,2,3,2,5], 2), 3)
811        self.assertEqual(countOf((1,2,2,3,2,5), 2), 3)
812        self.assertEqual(countOf("122325", "2"), 3)
813        self.assertEqual(countOf("122325", "6"), 0)
814
815        self.assertRaises(TypeError, countOf, 42, 1)
816        self.assertRaises(TypeError, countOf, countOf, countOf)
817
818        d = {"one": 3, "two": 3, "three": 3, 1j: 2j}
819        for k in d:
820            self.assertEqual(countOf(d, k), 1)
821        self.assertEqual(countOf(d.values(), 3), 3)
822        self.assertEqual(countOf(d.values(), 2j), 1)
823        self.assertEqual(countOf(d.values(), 1j), 0)
824
825        f = open(TESTFN, "w", encoding="utf-8")
826        try:
827            f.write("a\n" "b\n" "c\n" "b\n")
828        finally:
829            f.close()
830        f = open(TESTFN, "r", encoding="utf-8")
831        try:
832            for letter, count in ("a", 1), ("b", 2), ("c", 1), ("d", 0):
833                f.seek(0, 0)
834                self.assertEqual(countOf(f, letter + "\n"), count)
835        finally:
836            f.close()
837            try:
838                unlink(TESTFN)
839            except OSError:
840                pass
841
842    # Test iterators with operator.indexOf (PySequence_Index).
843    def test_indexOf(self):
844        from operator import indexOf
845        self.assertEqual(indexOf([1,2,2,3,2,5], 1), 0)
846        self.assertEqual(indexOf((1,2,2,3,2,5), 2), 1)
847        self.assertEqual(indexOf((1,2,2,3,2,5), 3), 3)
848        self.assertEqual(indexOf((1,2,2,3,2,5), 5), 5)
849        self.assertRaises(ValueError, indexOf, (1,2,2,3,2,5), 0)
850        self.assertRaises(ValueError, indexOf, (1,2,2,3,2,5), 6)
851
852        self.assertEqual(indexOf("122325", "2"), 1)
853        self.assertEqual(indexOf("122325", "5"), 5)
854        self.assertRaises(ValueError, indexOf, "122325", "6")
855
856        self.assertRaises(TypeError, indexOf, 42, 1)
857        self.assertRaises(TypeError, indexOf, indexOf, indexOf)
858        self.assertRaises(ZeroDivisionError, indexOf, BadIterableClass(), 1)
859
860        f = open(TESTFN, "w", encoding="utf-8")
861        try:
862            f.write("a\n" "b\n" "c\n" "d\n" "e\n")
863        finally:
864            f.close()
865        f = open(TESTFN, "r", encoding="utf-8")
866        try:
867            fiter = iter(f)
868            self.assertEqual(indexOf(fiter, "b\n"), 1)
869            self.assertEqual(indexOf(fiter, "d\n"), 1)
870            self.assertEqual(indexOf(fiter, "e\n"), 0)
871            self.assertRaises(ValueError, indexOf, fiter, "a\n")
872        finally:
873            f.close()
874            try:
875                unlink(TESTFN)
876            except OSError:
877                pass
878
879        iclass = IteratingSequenceClass(3)
880        for i in range(3):
881            self.assertEqual(indexOf(iclass, i), i)
882        self.assertRaises(ValueError, indexOf, iclass, -1)
883
884    # Test iterators with file.writelines().
885    def test_writelines(self):
886        f = open(TESTFN, "w", encoding="utf-8")
887
888        try:
889            self.assertRaises(TypeError, f.writelines, None)
890            self.assertRaises(TypeError, f.writelines, 42)
891
892            f.writelines(["1\n", "2\n"])
893            f.writelines(("3\n", "4\n"))
894            f.writelines({'5\n': None})
895            f.writelines({})
896
897            # Try a big chunk too.
898            class Iterator:
899                def __init__(self, start, finish):
900                    self.start = start
901                    self.finish = finish
902                    self.i = self.start
903
904                def __next__(self):
905                    if self.i >= self.finish:
906                        raise StopIteration
907                    result = str(self.i) + '\n'
908                    self.i += 1
909                    return result
910
911                def __iter__(self):
912                    return self
913
914            class Whatever:
915                def __init__(self, start, finish):
916                    self.start = start
917                    self.finish = finish
918
919                def __iter__(self):
920                    return Iterator(self.start, self.finish)
921
922            f.writelines(Whatever(6, 6+2000))
923            f.close()
924
925            f = open(TESTFN, encoding="utf-8")
926            expected = [str(i) + "\n" for i in range(1, 2006)]
927            self.assertEqual(list(f), expected)
928
929        finally:
930            f.close()
931            try:
932                unlink(TESTFN)
933            except OSError:
934                pass
935
936
937    # Test iterators on RHS of unpacking assignments.
938    def test_unpack_iter(self):
939        a, b = 1, 2
940        self.assertEqual((a, b), (1, 2))
941
942        a, b, c = IteratingSequenceClass(3)
943        self.assertEqual((a, b, c), (0, 1, 2))
944
945        try:    # too many values
946            a, b = IteratingSequenceClass(3)
947        except ValueError:
948            pass
949        else:
950            self.fail("should have raised ValueError")
951
952        try:    # not enough values
953            a, b, c = IteratingSequenceClass(2)
954        except ValueError:
955            pass
956        else:
957            self.fail("should have raised ValueError")
958
959        try:    # not iterable
960            a, b, c = len
961        except TypeError:
962            pass
963        else:
964            self.fail("should have raised TypeError")
965
966        a, b, c = {1: 42, 2: 42, 3: 42}.values()
967        self.assertEqual((a, b, c), (42, 42, 42))
968
969        f = open(TESTFN, "w", encoding="utf-8")
970        lines = ("a\n", "bb\n", "ccc\n")
971        try:
972            for line in lines:
973                f.write(line)
974        finally:
975            f.close()
976        f = open(TESTFN, "r", encoding="utf-8")
977        try:
978            a, b, c = f
979            self.assertEqual((a, b, c), lines)
980        finally:
981            f.close()
982            try:
983                unlink(TESTFN)
984            except OSError:
985                pass
986
987        (a, b), (c,) = IteratingSequenceClass(2), {42: 24}
988        self.assertEqual((a, b, c), (0, 1, 42))
989
990
991    @cpython_only
992    def test_ref_counting_behavior(self):
993        class C(object):
994            count = 0
995            def __new__(cls):
996                cls.count += 1
997                return object.__new__(cls)
998            def __del__(self):
999                cls = self.__class__
1000                assert cls.count > 0
1001                cls.count -= 1
1002        x = C()
1003        self.assertEqual(C.count, 1)
1004        del x
1005        self.assertEqual(C.count, 0)
1006        l = [C(), C(), C()]
1007        self.assertEqual(C.count, 3)
1008        try:
1009            a, b = iter(l)
1010        except ValueError:
1011            pass
1012        del l
1013        self.assertEqual(C.count, 0)
1014
1015
1016    # Make sure StopIteration is a "sink state".
1017    # This tests various things that weren't sink states in Python 2.2.1,
1018    # plus various things that always were fine.
1019
1020    def test_sinkstate_list(self):
1021        # This used to fail
1022        a = list(range(5))
1023        b = iter(a)
1024        self.assertEqual(list(b), list(range(5)))
1025        a.extend(range(5, 10))
1026        self.assertEqual(list(b), [])
1027
1028    def test_sinkstate_tuple(self):
1029        a = (0, 1, 2, 3, 4)
1030        b = iter(a)
1031        self.assertEqual(list(b), list(range(5)))
1032        self.assertEqual(list(b), [])
1033
1034    def test_sinkstate_string(self):
1035        a = "abcde"
1036        b = iter(a)
1037        self.assertEqual(list(b), ['a', 'b', 'c', 'd', 'e'])
1038        self.assertEqual(list(b), [])
1039
1040    def test_sinkstate_sequence(self):
1041        # This used to fail
1042        a = SequenceClass(5)
1043        b = iter(a)
1044        self.assertEqual(list(b), list(range(5)))
1045        a.n = 10
1046        self.assertEqual(list(b), [])
1047
1048    def test_sinkstate_callable(self):
1049        # This used to fail
1050        def spam(state=[0]):
1051            i = state[0]
1052            state[0] = i+1
1053            if i == 10:
1054                raise AssertionError("shouldn't have gotten this far")
1055            return i
1056        b = iter(spam, 5)
1057        self.assertEqual(list(b), list(range(5)))
1058        self.assertEqual(list(b), [])
1059
1060    def test_sinkstate_dict(self):
1061        # XXX For a more thorough test, see towards the end of:
1062        # http://mail.python.org/pipermail/python-dev/2002-July/026512.html
1063        a = {1:1, 2:2, 0:0, 4:4, 3:3}
1064        for b in iter(a), a.keys(), a.items(), a.values():
1065            b = iter(a)
1066            self.assertEqual(len(list(b)), 5)
1067            self.assertEqual(list(b), [])
1068
1069    def test_sinkstate_yield(self):
1070        def gen():
1071            for i in range(5):
1072                yield i
1073        b = gen()
1074        self.assertEqual(list(b), list(range(5)))
1075        self.assertEqual(list(b), [])
1076
1077    def test_sinkstate_range(self):
1078        a = range(5)
1079        b = iter(a)
1080        self.assertEqual(list(b), list(range(5)))
1081        self.assertEqual(list(b), [])
1082
1083    def test_sinkstate_enumerate(self):
1084        a = range(5)
1085        e = enumerate(a)
1086        b = iter(e)
1087        self.assertEqual(list(b), list(zip(range(5), range(5))))
1088        self.assertEqual(list(b), [])
1089
1090    def test_3720(self):
1091        # Avoid a crash, when an iterator deletes its next() method.
1092        class BadIterator(object):
1093            def __iter__(self):
1094                return self
1095            def __next__(self):
1096                del BadIterator.__next__
1097                return 1
1098
1099        try:
1100            for i in BadIterator() :
1101                pass
1102        except TypeError:
1103            pass
1104
1105    def test_extending_list_with_iterator_does_not_segfault(self):
1106        # The code to extend a list with an iterator has a fair
1107        # amount of nontrivial logic in terms of guessing how
1108        # much memory to allocate in advance, "stealing" refs,
1109        # and then shrinking at the end.  This is a basic smoke
1110        # test for that scenario.
1111        def gen():
1112            for i in range(500):
1113                yield i
1114        lst = [0] * 500
1115        for i in range(240):
1116            lst.pop(0)
1117        lst.extend(gen())
1118        self.assertEqual(len(lst), 760)
1119
1120    @cpython_only
1121    def test_iter_overflow(self):
1122        # Test for the issue 22939
1123        it = iter(UnlimitedSequenceClass())
1124        # Manually set `it_index` to PY_SSIZE_T_MAX-2 without a loop
1125        it.__setstate__(sys.maxsize - 2)
1126        self.assertEqual(next(it), sys.maxsize - 2)
1127        self.assertEqual(next(it), sys.maxsize - 1)
1128        with self.assertRaises(OverflowError):
1129            next(it)
1130        # Check that Overflow error is always raised
1131        with self.assertRaises(OverflowError):
1132            next(it)
1133
1134    def test_iter_neg_setstate(self):
1135        it = iter(UnlimitedSequenceClass())
1136        it.__setstate__(-42)
1137        self.assertEqual(next(it), 0)
1138        self.assertEqual(next(it), 1)
1139
1140    def test_free_after_iterating(self):
1141        check_free_after_iterating(self, iter, SequenceClass, (0,))
1142
1143    def test_error_iter(self):
1144        for typ in (DefaultIterClass, NoIterClass):
1145            self.assertRaises(TypeError, iter, typ())
1146        self.assertRaises(ZeroDivisionError, iter, BadIterableClass())
1147
1148    def test_exception_locations(self):
1149        # The location of an exception raised from __init__ or
1150        # __next__ should should be the iterator expression
1151
1152        def init_raises():
1153            try:
1154                for x in BrokenIter(init_raises=True):
1155                    pass
1156            except Exception as e:
1157                return e
1158
1159        def next_raises():
1160            try:
1161                for x in BrokenIter(next_raises=True):
1162                    pass
1163            except Exception as e:
1164                return e
1165
1166        def iter_raises():
1167            try:
1168                for x in BrokenIter(iter_raises=True):
1169                    pass
1170            except Exception as e:
1171                return e
1172
1173        for func, expected in [(init_raises, "BrokenIter(init_raises=True)"),
1174                               (next_raises, "BrokenIter(next_raises=True)"),
1175                               (iter_raises, "BrokenIter(iter_raises=True)"),
1176                              ]:
1177            with self.subTest(func):
1178                exc = func()
1179                f = traceback.extract_tb(exc.__traceback__)[0]
1180                indent = 16
1181                co = func.__code__
1182                self.assertEqual(f.lineno, co.co_firstlineno + 2)
1183                self.assertEqual(f.end_lineno, co.co_firstlineno + 2)
1184                self.assertEqual(f.line[f.colno - indent : f.end_colno - indent],
1185                                 expected)
1186
1187
1188
1189if __name__ == "__main__":
1190    unittest.main()
1191