• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Test iterators.
2
3import unittest
4from test.test_support import run_unittest, TESTFN, unlink, have_unicode, \
5                              check_py3k_warnings, cpython_only
6
7# Test result of triple loop (too big to inline)
8TRIPLETS = [(0, 0, 0), (0, 0, 1), (0, 0, 2),
9            (0, 1, 0), (0, 1, 1), (0, 1, 2),
10            (0, 2, 0), (0, 2, 1), (0, 2, 2),
11
12            (1, 0, 0), (1, 0, 1), (1, 0, 2),
13            (1, 1, 0), (1, 1, 1), (1, 1, 2),
14            (1, 2, 0), (1, 2, 1), (1, 2, 2),
15
16            (2, 0, 0), (2, 0, 1), (2, 0, 2),
17            (2, 1, 0), (2, 1, 1), (2, 1, 2),
18            (2, 2, 0), (2, 2, 1), (2, 2, 2)]
19
20# Helper classes
21
22class BasicIterClass:
23    def __init__(self, n):
24        self.n = n
25        self.i = 0
26    def next(self):
27        res = self.i
28        if res >= self.n:
29            raise StopIteration
30        self.i = res + 1
31        return res
32
33class IteratingSequenceClass:
34    def __init__(self, n):
35        self.n = n
36    def __iter__(self):
37        return BasicIterClass(self.n)
38
39class SequenceClass:
40    def __init__(self, n):
41        self.n = n
42    def __getitem__(self, i):
43        if 0 <= i < self.n:
44            return i
45        else:
46            raise IndexError
47
48# Main test suite
49
50class TestCase(unittest.TestCase):
51
52    # Helper to check that an iterator returns a given sequence
53    def check_iterator(self, it, seq):
54        res = []
55        while 1:
56            try:
57                val = it.next()
58            except StopIteration:
59                break
60            res.append(val)
61        self.assertEqual(res, seq)
62
63    # Helper to check that a for loop generates a given sequence
64    def check_for_loop(self, expr, seq):
65        res = []
66        for val in expr:
67            res.append(val)
68        self.assertEqual(res, seq)
69
70    # Test basic use of iter() function
71    def test_iter_basic(self):
72        self.check_iterator(iter(range(10)), range(10))
73
74    # Test that iter(iter(x)) is the same as iter(x)
75    def test_iter_idempotency(self):
76        seq = range(10)
77        it = iter(seq)
78        it2 = iter(it)
79        self.assertTrue(it is it2)
80
81    # Test that for loops over iterators work
82    def test_iter_for_loop(self):
83        self.check_for_loop(iter(range(10)), range(10))
84
85    # Test several independent iterators over the same list
86    def test_iter_independence(self):
87        seq = range(3)
88        res = []
89        for i in iter(seq):
90            for j in iter(seq):
91                for k in iter(seq):
92                    res.append((i, j, k))
93        self.assertEqual(res, TRIPLETS)
94
95    # Test triple list comprehension using iterators
96    def test_nested_comprehensions_iter(self):
97        seq = range(3)
98        res = [(i, j, k)
99               for i in iter(seq) for j in iter(seq) for k in iter(seq)]
100        self.assertEqual(res, TRIPLETS)
101
102    # Test triple list comprehension without iterators
103    def test_nested_comprehensions_for(self):
104        seq = range(3)
105        res = [(i, j, k) for i in seq for j in seq for k in seq]
106        self.assertEqual(res, TRIPLETS)
107
108    # Test a class with __iter__ in a for loop
109    def test_iter_class_for(self):
110        self.check_for_loop(IteratingSequenceClass(10), range(10))
111
112    # Test a class with __iter__ with explicit iter()
113    def test_iter_class_iter(self):
114        self.check_iterator(iter(IteratingSequenceClass(10)), range(10))
115
116    # Test for loop on a sequence class without __iter__
117    def test_seq_class_for(self):
118        self.check_for_loop(SequenceClass(10), range(10))
119
120    # Test iter() on a sequence class without __iter__
121    def test_seq_class_iter(self):
122        self.check_iterator(iter(SequenceClass(10)), range(10))
123
124    # Test a new_style class with __iter__ but no next() method
125    def test_new_style_iter_class(self):
126        class IterClass(object):
127            def __iter__(self):
128                return self
129        self.assertRaises(TypeError, iter, IterClass())
130
131    # Test two-argument iter() with callable instance
132    def test_iter_callable(self):
133        class C:
134            def __init__(self):
135                self.i = 0
136            def __call__(self):
137                i = self.i
138                self.i = i + 1
139                if i > 100:
140                    raise IndexError # Emergency stop
141                return i
142        self.check_iterator(iter(C(), 10), range(10))
143
144    # Test two-argument iter() with function
145    def test_iter_function(self):
146        def spam(state=[0]):
147            i = state[0]
148            state[0] = i+1
149            return i
150        self.check_iterator(iter(spam, 10), range(10))
151
152    # Test two-argument iter() with function that raises StopIteration
153    def test_iter_function_stop(self):
154        def spam(state=[0]):
155            i = state[0]
156            if i == 10:
157                raise StopIteration
158            state[0] = i+1
159            return i
160        self.check_iterator(iter(spam, 20), range(10))
161
162    # Test exception propagation through function iterator
163    def test_exception_function(self):
164        def spam(state=[0]):
165            i = state[0]
166            state[0] = i+1
167            if i == 10:
168                raise RuntimeError
169            return i
170        res = []
171        try:
172            for x in iter(spam, 20):
173                res.append(x)
174        except RuntimeError:
175            self.assertEqual(res, range(10))
176        else:
177            self.fail("should have raised RuntimeError")
178
179    # Test exception propagation through sequence iterator
180    def test_exception_sequence(self):
181        class MySequenceClass(SequenceClass):
182            def __getitem__(self, i):
183                if i == 10:
184                    raise RuntimeError
185                return SequenceClass.__getitem__(self, i)
186        res = []
187        try:
188            for x in MySequenceClass(20):
189                res.append(x)
190        except RuntimeError:
191            self.assertEqual(res, range(10))
192        else:
193            self.fail("should have raised RuntimeError")
194
195    # Test for StopIteration from __getitem__
196    def test_stop_sequence(self):
197        class MySequenceClass(SequenceClass):
198            def __getitem__(self, i):
199                if i == 10:
200                    raise StopIteration
201                return SequenceClass.__getitem__(self, i)
202        self.check_for_loop(MySequenceClass(20), range(10))
203
204    # Test a big range
205    def test_iter_big_range(self):
206        self.check_for_loop(iter(range(10000)), range(10000))
207
208    # Test an empty list
209    def test_iter_empty(self):
210        self.check_for_loop(iter([]), [])
211
212    # Test a tuple
213    def test_iter_tuple(self):
214        self.check_for_loop(iter((0,1,2,3,4,5,6,7,8,9)), range(10))
215
216    # Test an xrange
217    def test_iter_xrange(self):
218        self.check_for_loop(iter(xrange(10)), range(10))
219
220    # Test a string
221    def test_iter_string(self):
222        self.check_for_loop(iter("abcde"), ["a", "b", "c", "d", "e"])
223
224    # Test a Unicode string
225    if have_unicode:
226        def test_iter_unicode(self):
227            self.check_for_loop(iter(unicode("abcde")),
228                                [unicode("a"), unicode("b"), unicode("c"),
229                                 unicode("d"), unicode("e")])
230
231    # Test a directory
232    def test_iter_dict(self):
233        dict = {}
234        for i in range(10):
235            dict[i] = None
236        self.check_for_loop(dict, dict.keys())
237
238    # Test a file
239    def test_iter_file(self):
240        f = open(TESTFN, "w")
241        try:
242            for i in range(5):
243                f.write("%d\n" % i)
244        finally:
245            f.close()
246        f = open(TESTFN, "r")
247        try:
248            self.check_for_loop(f, ["0\n", "1\n", "2\n", "3\n", "4\n"])
249            self.check_for_loop(f, [])
250        finally:
251            f.close()
252            try:
253                unlink(TESTFN)
254            except OSError:
255                pass
256
257    # Test list()'s use of iterators.
258    def test_builtin_list(self):
259        self.assertEqual(list(SequenceClass(5)), range(5))
260        self.assertEqual(list(SequenceClass(0)), [])
261        self.assertEqual(list(()), [])
262        self.assertEqual(list(range(10, -1, -1)), range(10, -1, -1))
263
264        d = {"one": 1, "two": 2, "three": 3}
265        self.assertEqual(list(d), d.keys())
266
267        self.assertRaises(TypeError, list, list)
268        self.assertRaises(TypeError, list, 42)
269
270        f = open(TESTFN, "w")
271        try:
272            for i in range(5):
273                f.write("%d\n" % i)
274        finally:
275            f.close()
276        f = open(TESTFN, "r")
277        try:
278            self.assertEqual(list(f), ["0\n", "1\n", "2\n", "3\n", "4\n"])
279            f.seek(0, 0)
280            self.assertEqual(list(f),
281                             ["0\n", "1\n", "2\n", "3\n", "4\n"])
282        finally:
283            f.close()
284            try:
285                unlink(TESTFN)
286            except OSError:
287                pass
288
289    # Test tuples()'s use of iterators.
290    def test_builtin_tuple(self):
291        self.assertEqual(tuple(SequenceClass(5)), (0, 1, 2, 3, 4))
292        self.assertEqual(tuple(SequenceClass(0)), ())
293        self.assertEqual(tuple([]), ())
294        self.assertEqual(tuple(()), ())
295        self.assertEqual(tuple("abc"), ("a", "b", "c"))
296
297        d = {"one": 1, "two": 2, "three": 3}
298        self.assertEqual(tuple(d), tuple(d.keys()))
299
300        self.assertRaises(TypeError, tuple, list)
301        self.assertRaises(TypeError, tuple, 42)
302
303        f = open(TESTFN, "w")
304        try:
305            for i in range(5):
306                f.write("%d\n" % i)
307        finally:
308            f.close()
309        f = open(TESTFN, "r")
310        try:
311            self.assertEqual(tuple(f), ("0\n", "1\n", "2\n", "3\n", "4\n"))
312            f.seek(0, 0)
313            self.assertEqual(tuple(f),
314                             ("0\n", "1\n", "2\n", "3\n", "4\n"))
315        finally:
316            f.close()
317            try:
318                unlink(TESTFN)
319            except OSError:
320                pass
321
322    # Test filter()'s use of iterators.
323    def test_builtin_filter(self):
324        self.assertEqual(filter(None, SequenceClass(5)), range(1, 5))
325        self.assertEqual(filter(None, SequenceClass(0)), [])
326        self.assertEqual(filter(None, ()), ())
327        self.assertEqual(filter(None, "abc"), "abc")
328
329        d = {"one": 1, "two": 2, "three": 3}
330        self.assertEqual(filter(None, d), d.keys())
331
332        self.assertRaises(TypeError, filter, None, list)
333        self.assertRaises(TypeError, filter, None, 42)
334
335        class Boolean:
336            def __init__(self, truth):
337                self.truth = truth
338            def __nonzero__(self):
339                return self.truth
340        bTrue = Boolean(1)
341        bFalse = Boolean(0)
342
343        class Seq:
344            def __init__(self, *args):
345                self.vals = args
346            def __iter__(self):
347                class SeqIter:
348                    def __init__(self, vals):
349                        self.vals = vals
350                        self.i = 0
351                    def __iter__(self):
352                        return self
353                    def next(self):
354                        i = self.i
355                        self.i = i + 1
356                        if i < len(self.vals):
357                            return self.vals[i]
358                        else:
359                            raise StopIteration
360                return SeqIter(self.vals)
361
362        seq = Seq(*([bTrue, bFalse] * 25))
363        self.assertEqual(filter(lambda x: not x, seq), [bFalse]*25)
364        self.assertEqual(filter(lambda x: not x, iter(seq)), [bFalse]*25)
365
366    # Test max() and min()'s use of iterators.
367    def test_builtin_max_min(self):
368        self.assertEqual(max(SequenceClass(5)), 4)
369        self.assertEqual(min(SequenceClass(5)), 0)
370        self.assertEqual(max(8, -1), 8)
371        self.assertEqual(min(8, -1), -1)
372
373        d = {"one": 1, "two": 2, "three": 3}
374        self.assertEqual(max(d), "two")
375        self.assertEqual(min(d), "one")
376        self.assertEqual(max(d.itervalues()), 3)
377        self.assertEqual(min(iter(d.itervalues())), 1)
378
379        f = open(TESTFN, "w")
380        try:
381            f.write("medium line\n")
382            f.write("xtra large line\n")
383            f.write("itty-bitty line\n")
384        finally:
385            f.close()
386        f = open(TESTFN, "r")
387        try:
388            self.assertEqual(min(f), "itty-bitty line\n")
389            f.seek(0, 0)
390            self.assertEqual(max(f), "xtra large line\n")
391        finally:
392            f.close()
393            try:
394                unlink(TESTFN)
395            except OSError:
396                pass
397
398    # Test map()'s use of iterators.
399    def test_builtin_map(self):
400        self.assertEqual(map(lambda x: x+1, SequenceClass(5)), range(1, 6))
401
402        d = {"one": 1, "two": 2, "three": 3}
403        self.assertEqual(map(lambda k, d=d: (k, d[k]), d), d.items())
404        dkeys = d.keys()
405        expected = [(i < len(d) and dkeys[i] or None,
406                     i,
407                     i < len(d) and dkeys[i] or None)
408                    for i in range(5)]
409
410        # Deprecated map(None, ...)
411        with check_py3k_warnings():
412            self.assertEqual(map(None, SequenceClass(5)), range(5))
413            self.assertEqual(map(None, d), d.keys())
414            self.assertEqual(map(None, d,
415                                       SequenceClass(5),
416                                       iter(d.iterkeys())),
417                             expected)
418
419        f = open(TESTFN, "w")
420        try:
421            for i in range(10):
422                f.write("xy" * i + "\n") # line i has len 2*i+1
423        finally:
424            f.close()
425        f = open(TESTFN, "r")
426        try:
427            self.assertEqual(map(len, f), range(1, 21, 2))
428        finally:
429            f.close()
430            try:
431                unlink(TESTFN)
432            except OSError:
433                pass
434
435    # Test zip()'s use of iterators.
436    def test_builtin_zip(self):
437        self.assertEqual(zip(), [])
438        self.assertEqual(zip(*[]), [])
439        self.assertEqual(zip(*[(1, 2), 'ab']), [(1, 'a'), (2, 'b')])
440
441        self.assertRaises(TypeError, zip, None)
442        self.assertRaises(TypeError, zip, range(10), 42)
443        self.assertRaises(TypeError, zip, range(10), zip)
444
445        self.assertEqual(zip(IteratingSequenceClass(3)),
446                         [(0,), (1,), (2,)])
447        self.assertEqual(zip(SequenceClass(3)),
448                         [(0,), (1,), (2,)])
449
450        d = {"one": 1, "two": 2, "three": 3}
451        self.assertEqual(d.items(), zip(d, d.itervalues()))
452
453        # Generate all ints starting at constructor arg.
454        class IntsFrom:
455            def __init__(self, start):
456                self.i = start
457
458            def __iter__(self):
459                return self
460
461            def next(self):
462                i = self.i
463                self.i = i+1
464                return i
465
466        f = open(TESTFN, "w")
467        try:
468            f.write("a\n" "bbb\n" "cc\n")
469        finally:
470            f.close()
471        f = open(TESTFN, "r")
472        try:
473            self.assertEqual(zip(IntsFrom(0), f, IntsFrom(-100)),
474                             [(0, "a\n", -100),
475                              (1, "bbb\n", -99),
476                              (2, "cc\n", -98)])
477        finally:
478            f.close()
479            try:
480                unlink(TESTFN)
481            except OSError:
482                pass
483
484        self.assertEqual(zip(xrange(5)), [(i,) for i in range(5)])
485
486        # Classes that lie about their lengths.
487        class NoGuessLen5:
488            def __getitem__(self, i):
489                if i >= 5:
490                    raise IndexError
491                return i
492
493        class Guess3Len5(NoGuessLen5):
494            def __len__(self):
495                return 3
496
497        class Guess30Len5(NoGuessLen5):
498            def __len__(self):
499                return 30
500
501        self.assertEqual(len(Guess3Len5()), 3)
502        self.assertEqual(len(Guess30Len5()), 30)
503        self.assertEqual(zip(NoGuessLen5()), zip(range(5)))
504        self.assertEqual(zip(Guess3Len5()), zip(range(5)))
505        self.assertEqual(zip(Guess30Len5()), zip(range(5)))
506
507        expected = [(i, i) for i in range(5)]
508        for x in NoGuessLen5(), Guess3Len5(), Guess30Len5():
509            for y in NoGuessLen5(), Guess3Len5(), Guess30Len5():
510                self.assertEqual(zip(x, y), expected)
511
512    # Test reduces()'s use of iterators.
513    def test_deprecated_builtin_reduce(self):
514        with check_py3k_warnings():
515            self._test_builtin_reduce()
516
517    def _test_builtin_reduce(self):
518        from operator import add
519        self.assertEqual(reduce(add, SequenceClass(5)), 10)
520        self.assertEqual(reduce(add, SequenceClass(5), 42), 52)
521        self.assertRaises(TypeError, reduce, add, SequenceClass(0))
522        self.assertEqual(reduce(add, SequenceClass(0), 42), 42)
523        self.assertEqual(reduce(add, SequenceClass(1)), 0)
524        self.assertEqual(reduce(add, SequenceClass(1), 42), 42)
525
526        d = {"one": 1, "two": 2, "three": 3}
527        self.assertEqual(reduce(add, d), "".join(d.keys()))
528
529    # This test case will be removed if we don't have Unicode
530    def test_unicode_join_endcase(self):
531
532        # This class inserts a Unicode object into its argument's natural
533        # iteration, in the 3rd position.
534        class OhPhooey:
535            def __init__(self, seq):
536                self.it = iter(seq)
537                self.i = 0
538
539            def __iter__(self):
540                return self
541
542            def next(self):
543                i = self.i
544                self.i = i+1
545                if i == 2:
546                    return unicode("fooled you!")
547                return self.it.next()
548
549        f = open(TESTFN, "w")
550        try:
551            f.write("a\n" + "b\n" + "c\n")
552        finally:
553            f.close()
554
555        f = open(TESTFN, "r")
556        # Nasty:  string.join(s) can't know whether unicode.join() is needed
557        # until it's seen all of s's elements.  But in this case, f's
558        # iterator cannot be restarted.  So what we're testing here is
559        # whether string.join() can manage to remember everything it's seen
560        # and pass that on to unicode.join().
561        try:
562            got = " - ".join(OhPhooey(f))
563            self.assertEqual(got, unicode("a\n - b\n - fooled you! - c\n"))
564        finally:
565            f.close()
566            try:
567                unlink(TESTFN)
568            except OSError:
569                pass
570    if not have_unicode:
571        def test_unicode_join_endcase(self): pass
572
573    # Test iterators with 'x in y' and 'x not in y'.
574    def test_in_and_not_in(self):
575        for sc5 in IteratingSequenceClass(5), SequenceClass(5):
576            for i in range(5):
577                self.assertIn(i, sc5)
578            for i in "abc", -1, 5, 42.42, (3, 4), [], {1: 1}, 3-12j, sc5:
579                self.assertNotIn(i, sc5)
580
581        self.assertRaises(TypeError, lambda: 3 in 12)
582        self.assertRaises(TypeError, lambda: 3 not in map)
583
584        d = {"one": 1, "two": 2, "three": 3, 1j: 2j}
585        for k in d:
586            self.assertIn(k, d)
587            self.assertNotIn(k, d.itervalues())
588        for v in d.values():
589            self.assertIn(v, d.itervalues())
590            self.assertNotIn(v, d)
591        for k, v in d.iteritems():
592            self.assertIn((k, v), d.iteritems())
593            self.assertNotIn((v, k), d.iteritems())
594
595        f = open(TESTFN, "w")
596        try:
597            f.write("a\n" "b\n" "c\n")
598        finally:
599            f.close()
600        f = open(TESTFN, "r")
601        try:
602            for chunk in "abc":
603                f.seek(0, 0)
604                self.assertNotIn(chunk, f)
605                f.seek(0, 0)
606                self.assertIn((chunk + "\n"), f)
607        finally:
608            f.close()
609            try:
610                unlink(TESTFN)
611            except OSError:
612                pass
613
614    # Test iterators with operator.countOf (PySequence_Count).
615    def test_countOf(self):
616        from operator import countOf
617        self.assertEqual(countOf([1,2,2,3,2,5], 2), 3)
618        self.assertEqual(countOf((1,2,2,3,2,5), 2), 3)
619        self.assertEqual(countOf("122325", "2"), 3)
620        self.assertEqual(countOf("122325", "6"), 0)
621
622        self.assertRaises(TypeError, countOf, 42, 1)
623        self.assertRaises(TypeError, countOf, countOf, countOf)
624
625        d = {"one": 3, "two": 3, "three": 3, 1j: 2j}
626        for k in d:
627            self.assertEqual(countOf(d, k), 1)
628        self.assertEqual(countOf(d.itervalues(), 3), 3)
629        self.assertEqual(countOf(d.itervalues(), 2j), 1)
630        self.assertEqual(countOf(d.itervalues(), 1j), 0)
631
632        f = open(TESTFN, "w")
633        try:
634            f.write("a\n" "b\n" "c\n" "b\n")
635        finally:
636            f.close()
637        f = open(TESTFN, "r")
638        try:
639            for letter, count in ("a", 1), ("b", 2), ("c", 1), ("d", 0):
640                f.seek(0, 0)
641                self.assertEqual(countOf(f, letter + "\n"), count)
642        finally:
643            f.close()
644            try:
645                unlink(TESTFN)
646            except OSError:
647                pass
648
649    # Test iterators with operator.indexOf (PySequence_Index).
650    def test_indexOf(self):
651        from operator import indexOf
652        self.assertEqual(indexOf([1,2,2,3,2,5], 1), 0)
653        self.assertEqual(indexOf((1,2,2,3,2,5), 2), 1)
654        self.assertEqual(indexOf((1,2,2,3,2,5), 3), 3)
655        self.assertEqual(indexOf((1,2,2,3,2,5), 5), 5)
656        self.assertRaises(ValueError, indexOf, (1,2,2,3,2,5), 0)
657        self.assertRaises(ValueError, indexOf, (1,2,2,3,2,5), 6)
658
659        self.assertEqual(indexOf("122325", "2"), 1)
660        self.assertEqual(indexOf("122325", "5"), 5)
661        self.assertRaises(ValueError, indexOf, "122325", "6")
662
663        self.assertRaises(TypeError, indexOf, 42, 1)
664        self.assertRaises(TypeError, indexOf, indexOf, indexOf)
665
666        f = open(TESTFN, "w")
667        try:
668            f.write("a\n" "b\n" "c\n" "d\n" "e\n")
669        finally:
670            f.close()
671        f = open(TESTFN, "r")
672        try:
673            fiter = iter(f)
674            self.assertEqual(indexOf(fiter, "b\n"), 1)
675            self.assertEqual(indexOf(fiter, "d\n"), 1)
676            self.assertEqual(indexOf(fiter, "e\n"), 0)
677            self.assertRaises(ValueError, indexOf, fiter, "a\n")
678        finally:
679            f.close()
680            try:
681                unlink(TESTFN)
682            except OSError:
683                pass
684
685        iclass = IteratingSequenceClass(3)
686        for i in range(3):
687            self.assertEqual(indexOf(iclass, i), i)
688        self.assertRaises(ValueError, indexOf, iclass, -1)
689
690    # Test iterators with file.writelines().
691    def test_writelines(self):
692        f = file(TESTFN, "w")
693
694        try:
695            self.assertRaises(TypeError, f.writelines, None)
696            self.assertRaises(TypeError, f.writelines, 42)
697
698            f.writelines(["1\n", "2\n"])
699            f.writelines(("3\n", "4\n"))
700            f.writelines({'5\n': None})
701            f.writelines({})
702
703            # Try a big chunk too.
704            class Iterator:
705                def __init__(self, start, finish):
706                    self.start = start
707                    self.finish = finish
708                    self.i = self.start
709
710                def next(self):
711                    if self.i >= self.finish:
712                        raise StopIteration
713                    result = str(self.i) + '\n'
714                    self.i += 1
715                    return result
716
717                def __iter__(self):
718                    return self
719
720            class Whatever:
721                def __init__(self, start, finish):
722                    self.start = start
723                    self.finish = finish
724
725                def __iter__(self):
726                    return Iterator(self.start, self.finish)
727
728            f.writelines(Whatever(6, 6+2000))
729            f.close()
730
731            f = file(TESTFN)
732            expected = [str(i) + "\n" for i in range(1, 2006)]
733            self.assertEqual(list(f), expected)
734
735        finally:
736            f.close()
737            try:
738                unlink(TESTFN)
739            except OSError:
740                pass
741
742
743    # Test iterators on RHS of unpacking assignments.
744    def test_unpack_iter(self):
745        a, b = 1, 2
746        self.assertEqual((a, b), (1, 2))
747
748        a, b, c = IteratingSequenceClass(3)
749        self.assertEqual((a, b, c), (0, 1, 2))
750
751        try:    # too many values
752            a, b = IteratingSequenceClass(3)
753        except ValueError:
754            pass
755        else:
756            self.fail("should have raised ValueError")
757
758        try:    # not enough values
759            a, b, c = IteratingSequenceClass(2)
760        except ValueError:
761            pass
762        else:
763            self.fail("should have raised ValueError")
764
765        try:    # not iterable
766            a, b, c = len
767        except TypeError:
768            pass
769        else:
770            self.fail("should have raised TypeError")
771
772        a, b, c = {1: 42, 2: 42, 3: 42}.itervalues()
773        self.assertEqual((a, b, c), (42, 42, 42))
774
775        f = open(TESTFN, "w")
776        lines = ("a\n", "bb\n", "ccc\n")
777        try:
778            for line in lines:
779                f.write(line)
780        finally:
781            f.close()
782        f = open(TESTFN, "r")
783        try:
784            a, b, c = f
785            self.assertEqual((a, b, c), lines)
786        finally:
787            f.close()
788            try:
789                unlink(TESTFN)
790            except OSError:
791                pass
792
793        (a, b), (c,) = IteratingSequenceClass(2), {42: 24}
794        self.assertEqual((a, b, c), (0, 1, 42))
795
796
797    @cpython_only
798    def test_ref_counting_behavior(self):
799        class C(object):
800            count = 0
801            def __new__(cls):
802                cls.count += 1
803                return object.__new__(cls)
804            def __del__(self):
805                cls = self.__class__
806                assert cls.count > 0
807                cls.count -= 1
808        x = C()
809        self.assertEqual(C.count, 1)
810        del x
811        self.assertEqual(C.count, 0)
812        l = [C(), C(), C()]
813        self.assertEqual(C.count, 3)
814        try:
815            a, b = iter(l)
816        except ValueError:
817            pass
818        del l
819        self.assertEqual(C.count, 0)
820
821
822    # Make sure StopIteration is a "sink state".
823    # This tests various things that weren't sink states in Python 2.2.1,
824    # plus various things that always were fine.
825
826    def test_sinkstate_list(self):
827        # This used to fail
828        a = range(5)
829        b = iter(a)
830        self.assertEqual(list(b), range(5))
831        a.extend(range(5, 10))
832        self.assertEqual(list(b), [])
833
834    def test_sinkstate_tuple(self):
835        a = (0, 1, 2, 3, 4)
836        b = iter(a)
837        self.assertEqual(list(b), range(5))
838        self.assertEqual(list(b), [])
839
840    def test_sinkstate_string(self):
841        a = "abcde"
842        b = iter(a)
843        self.assertEqual(list(b), ['a', 'b', 'c', 'd', 'e'])
844        self.assertEqual(list(b), [])
845
846    def test_sinkstate_sequence(self):
847        # This used to fail
848        a = SequenceClass(5)
849        b = iter(a)
850        self.assertEqual(list(b), range(5))
851        a.n = 10
852        self.assertEqual(list(b), [])
853
854    def test_sinkstate_callable(self):
855        # This used to fail
856        def spam(state=[0]):
857            i = state[0]
858            state[0] = i+1
859            if i == 10:
860                raise AssertionError, "shouldn't have gotten this far"
861            return i
862        b = iter(spam, 5)
863        self.assertEqual(list(b), range(5))
864        self.assertEqual(list(b), [])
865
866    def test_sinkstate_dict(self):
867        # XXX For a more thorough test, see towards the end of:
868        # http://mail.python.org/pipermail/python-dev/2002-July/026512.html
869        a = {1:1, 2:2, 0:0, 4:4, 3:3}
870        for b in iter(a), a.iterkeys(), a.iteritems(), a.itervalues():
871            b = iter(a)
872            self.assertEqual(len(list(b)), 5)
873            self.assertEqual(list(b), [])
874
875    def test_sinkstate_yield(self):
876        def gen():
877            for i in range(5):
878                yield i
879        b = gen()
880        self.assertEqual(list(b), range(5))
881        self.assertEqual(list(b), [])
882
883    def test_sinkstate_range(self):
884        a = xrange(5)
885        b = iter(a)
886        self.assertEqual(list(b), range(5))
887        self.assertEqual(list(b), [])
888
889    def test_sinkstate_enumerate(self):
890        a = range(5)
891        e = enumerate(a)
892        b = iter(e)
893        self.assertEqual(list(b), zip(range(5), range(5)))
894        self.assertEqual(list(b), [])
895
896    def test_3720(self):
897        # Avoid a crash, when an iterator deletes its next() method.
898        class BadIterator(object):
899            def __iter__(self):
900                return self
901            def next(self):
902                del BadIterator.next
903                return 1
904
905        try:
906            for i in BadIterator() :
907                pass
908        except TypeError:
909            pass
910
911
912def test_main():
913    run_unittest(TestCase)
914
915
916if __name__ == "__main__":
917    test_main()
918