• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import unittest
2from test import test_support
3from itertools import *
4from weakref import proxy
5from decimal import Decimal
6from fractions import Fraction
7import sys
8import operator
9import random
10import copy
11import pickle
12from functools import reduce
13maxsize = test_support.MAX_Py_ssize_t
14minsize = -maxsize-1
15
16def onearg(x):
17    'Test function of one argument'
18    return 2*x
19
20def errfunc(*args):
21    'Test function that raises an error'
22    raise ValueError
23
24def gen3():
25    'Non-restartable source sequence'
26    for i in (0, 1, 2):
27        yield i
28
29def isEven(x):
30    'Test predicate'
31    return x%2==0
32
33def isOdd(x):
34    'Test predicate'
35    return x%2==1
36
37class StopNow:
38    'Class emulating an empty iterable.'
39    def __iter__(self):
40        return self
41    def next(self):
42        raise StopIteration
43
44def take(n, seq):
45    'Convenience function for partially consuming a long of infinite iterable'
46    return list(islice(seq, n))
47
48def prod(iterable):
49    return reduce(operator.mul, iterable, 1)
50
51def fact(n):
52    'Factorial'
53    return prod(range(1, n+1))
54
55class TestBasicOps(unittest.TestCase):
56    def test_chain(self):
57
58        def chain2(*iterables):
59            'Pure python version in the docs'
60            for it in iterables:
61                for element in it:
62                    yield element
63
64        for c in (chain, chain2):
65            self.assertEqual(list(c('abc', 'def')), list('abcdef'))
66            self.assertEqual(list(c('abc')), list('abc'))
67            self.assertEqual(list(c('')), [])
68            self.assertEqual(take(4, c('abc', 'def')), list('abcd'))
69            self.assertRaises(TypeError, list,c(2, 3))
70
71    def test_chain_from_iterable(self):
72        self.assertEqual(list(chain.from_iterable(['abc', 'def'])), list('abcdef'))
73        self.assertEqual(list(chain.from_iterable(['abc'])), list('abc'))
74        self.assertEqual(list(chain.from_iterable([''])), [])
75        self.assertEqual(take(4, chain.from_iterable(['abc', 'def'])), list('abcd'))
76        self.assertRaises(TypeError, list, chain.from_iterable([2, 3]))
77
78    def test_combinations(self):
79        self.assertRaises(TypeError, combinations, 'abc')       # missing r argument
80        self.assertRaises(TypeError, combinations, 'abc', 2, 1) # too many arguments
81        self.assertRaises(TypeError, combinations, None)        # pool is not iterable
82        self.assertRaises(ValueError, combinations, 'abc', -2)  # r is negative
83        self.assertEqual(list(combinations('abc', 32)), [])     # r > n
84        self.assertEqual(list(combinations(range(4), 3)),
85                                           [(0,1,2), (0,1,3), (0,2,3), (1,2,3)])
86
87        def combinations1(iterable, r):
88            'Pure python version shown in the docs'
89            pool = tuple(iterable)
90            n = len(pool)
91            if r > n:
92                return
93            indices = range(r)
94            yield tuple(pool[i] for i in indices)
95            while 1:
96                for i in reversed(range(r)):
97                    if indices[i] != i + n - r:
98                        break
99                else:
100                    return
101                indices[i] += 1
102                for j in range(i+1, r):
103                    indices[j] = indices[j-1] + 1
104                yield tuple(pool[i] for i in indices)
105
106        def combinations2(iterable, r):
107            'Pure python version shown in the docs'
108            pool = tuple(iterable)
109            n = len(pool)
110            for indices in permutations(range(n), r):
111                if sorted(indices) == list(indices):
112                    yield tuple(pool[i] for i in indices)
113
114        def combinations3(iterable, r):
115            'Pure python version from cwr()'
116            pool = tuple(iterable)
117            n = len(pool)
118            for indices in combinations_with_replacement(range(n), r):
119                if len(set(indices)) == r:
120                    yield tuple(pool[i] for i in indices)
121
122        for n in range(7):
123            values = [5*x-12 for x in range(n)]
124            for r in range(n+2):
125                result = list(combinations(values, r))
126                self.assertEqual(len(result), 0 if r>n else fact(n) // fact(r) // fact(n-r)) # right number of combs
127                self.assertEqual(len(result), len(set(result)))         # no repeats
128                self.assertEqual(result, sorted(result))                # lexicographic order
129                for c in result:
130                    self.assertEqual(len(c), r)                         # r-length combinations
131                    self.assertEqual(len(set(c)), r)                    # no duplicate elements
132                    self.assertEqual(list(c), sorted(c))                # keep original ordering
133                    self.assertTrue(all(e in values for e in c))           # elements taken from input iterable
134                    self.assertEqual(list(c),
135                                     [e for e in values if e in c])      # comb is a subsequence of the input iterable
136                self.assertEqual(result, list(combinations1(values, r))) # matches first pure python version
137                self.assertEqual(result, list(combinations2(values, r))) # matches second pure python version
138                self.assertEqual(result, list(combinations3(values, r))) # matches second pure python version
139
140        # Test implementation detail:  tuple re-use
141        self.assertEqual(len(set(map(id, combinations('abcde', 3)))), 1)
142        self.assertNotEqual(len(set(map(id, list(combinations('abcde', 3))))), 1)
143
144    def test_combinations_with_replacement(self):
145        cwr = combinations_with_replacement
146        self.assertRaises(TypeError, cwr, 'abc')       # missing r argument
147        self.assertRaises(TypeError, cwr, 'abc', 2, 1) # too many arguments
148        self.assertRaises(TypeError, cwr, None)        # pool is not iterable
149        self.assertRaises(ValueError, cwr, 'abc', -2)  # r is negative
150        self.assertEqual(list(cwr('ABC', 2)),
151                         [('A','A'), ('A','B'), ('A','C'), ('B','B'), ('B','C'), ('C','C')])
152
153        def cwr1(iterable, r):
154            'Pure python version shown in the docs'
155            # number items returned:  (n+r-1)! / r! / (n-1)! when n>0
156            pool = tuple(iterable)
157            n = len(pool)
158            if not n and r:
159                return
160            indices = [0] * r
161            yield tuple(pool[i] for i in indices)
162            while 1:
163                for i in reversed(range(r)):
164                    if indices[i] != n - 1:
165                        break
166                else:
167                    return
168                indices[i:] = [indices[i] + 1] * (r - i)
169                yield tuple(pool[i] for i in indices)
170
171        def cwr2(iterable, r):
172            'Pure python version shown in the docs'
173            pool = tuple(iterable)
174            n = len(pool)
175            for indices in product(range(n), repeat=r):
176                if sorted(indices) == list(indices):
177                    yield tuple(pool[i] for i in indices)
178
179        def numcombs(n, r):
180            if not n:
181                return 0 if r else 1
182            return fact(n+r-1) // fact(r) // fact(n-1)
183
184        for n in range(7):
185            values = [5*x-12 for x in range(n)]
186            for r in range(n+2):
187                result = list(cwr(values, r))
188
189                self.assertEqual(len(result), numcombs(n, r))           # right number of combs
190                self.assertEqual(len(result), len(set(result)))         # no repeats
191                self.assertEqual(result, sorted(result))                # lexicographic order
192
193                regular_combs = list(combinations(values, r))           # compare to combs without replacement
194                if n == 0 or r <= 1:
195                    self.assertEqual(result, regular_combs)            # cases that should be identical
196                else:
197                    self.assertTrue(set(result) >= set(regular_combs))     # rest should be supersets of regular combs
198
199                for c in result:
200                    self.assertEqual(len(c), r)                         # r-length combinations
201                    noruns = [k for k,v in groupby(c)]                  # combo without consecutive repeats
202                    self.assertEqual(len(noruns), len(set(noruns)))     # no repeats other than consecutive
203                    self.assertEqual(list(c), sorted(c))                # keep original ordering
204                    self.assertTrue(all(e in values for e in c))           # elements taken from input iterable
205                    self.assertEqual(noruns,
206                                     [e for e in values if e in c])     # comb is a subsequence of the input iterable
207                self.assertEqual(result, list(cwr1(values, r)))         # matches first pure python version
208                self.assertEqual(result, list(cwr2(values, r)))         # matches second pure python version
209
210        # Test implementation detail:  tuple re-use
211        self.assertEqual(len(set(map(id, cwr('abcde', 3)))), 1)
212        self.assertNotEqual(len(set(map(id, list(cwr('abcde', 3))))), 1)
213
214    def test_permutations(self):
215        self.assertRaises(TypeError, permutations)              # too few arguments
216        self.assertRaises(TypeError, permutations, 'abc', 2, 1) # too many arguments
217        self.assertRaises(TypeError, permutations, None)        # pool is not iterable
218        self.assertRaises(ValueError, permutations, 'abc', -2)  # r is negative
219        self.assertEqual(list(permutations('abc', 32)), [])     # r > n
220        self.assertRaises(TypeError, permutations, 'abc', 's')  # r is not an int or None
221        self.assertEqual(list(permutations(range(3), 2)),
222                                           [(0,1), (0,2), (1,0), (1,2), (2,0), (2,1)])
223
224        def permutations1(iterable, r=None):
225            'Pure python version shown in the docs'
226            pool = tuple(iterable)
227            n = len(pool)
228            r = n if r is None else r
229            if r > n:
230                return
231            indices = range(n)
232            cycles = range(n, n-r, -1)
233            yield tuple(pool[i] for i in indices[:r])
234            while n:
235                for i in reversed(range(r)):
236                    cycles[i] -= 1
237                    if cycles[i] == 0:
238                        indices[i:] = indices[i+1:] + indices[i:i+1]
239                        cycles[i] = n - i
240                    else:
241                        j = cycles[i]
242                        indices[i], indices[-j] = indices[-j], indices[i]
243                        yield tuple(pool[i] for i in indices[:r])
244                        break
245                else:
246                    return
247
248        def permutations2(iterable, r=None):
249            'Pure python version shown in the docs'
250            pool = tuple(iterable)
251            n = len(pool)
252            r = n if r is None else r
253            for indices in product(range(n), repeat=r):
254                if len(set(indices)) == r:
255                    yield tuple(pool[i] for i in indices)
256
257        for n in range(7):
258            values = [5*x-12 for x in range(n)]
259            for r in range(n+2):
260                result = list(permutations(values, r))
261                self.assertEqual(len(result), 0 if r>n else fact(n) // fact(n-r))      # right number of perms
262                self.assertEqual(len(result), len(set(result)))         # no repeats
263                self.assertEqual(result, sorted(result))                # lexicographic order
264                for p in result:
265                    self.assertEqual(len(p), r)                         # r-length permutations
266                    self.assertEqual(len(set(p)), r)                    # no duplicate elements
267                    self.assertTrue(all(e in values for e in p))           # elements taken from input iterable
268                self.assertEqual(result, list(permutations1(values, r))) # matches first pure python version
269                self.assertEqual(result, list(permutations2(values, r))) # matches second pure python version
270                if r == n:
271                    self.assertEqual(result, list(permutations(values, None))) # test r as None
272                    self.assertEqual(result, list(permutations(values)))       # test default r
273
274        # Test implementation detail:  tuple re-use
275        self.assertEqual(len(set(map(id, permutations('abcde', 3)))), 1)
276        self.assertNotEqual(len(set(map(id, list(permutations('abcde', 3))))), 1)
277
278    def test_combinatorics(self):
279        # Test relationships between product(), permutations(),
280        # combinations() and combinations_with_replacement().
281
282        for n in range(6):
283            s = 'ABCDEFG'[:n]
284            for r in range(8):
285                prod = list(product(s, repeat=r))
286                cwr = list(combinations_with_replacement(s, r))
287                perm = list(permutations(s, r))
288                comb = list(combinations(s, r))
289
290                # Check size
291                self.assertEqual(len(prod), n**r)
292                self.assertEqual(len(cwr), (fact(n+r-1) // fact(r) // fact(n-1)) if n else (not r))
293                self.assertEqual(len(perm), 0 if r>n else fact(n) // fact(n-r))
294                self.assertEqual(len(comb), 0 if r>n else fact(n) // fact(r) // fact(n-r))
295
296                # Check lexicographic order without repeated tuples
297                self.assertEqual(prod, sorted(set(prod)))
298                self.assertEqual(cwr, sorted(set(cwr)))
299                self.assertEqual(perm, sorted(set(perm)))
300                self.assertEqual(comb, sorted(set(comb)))
301
302                # Check interrelationships
303                self.assertEqual(cwr, [t for t in prod if sorted(t)==list(t)]) # cwr: prods which are sorted
304                self.assertEqual(perm, [t for t in prod if len(set(t))==r])    # perm: prods with no dups
305                self.assertEqual(comb, [t for t in perm if sorted(t)==list(t)]) # comb: perms that are sorted
306                self.assertEqual(comb, [t for t in cwr if len(set(t))==r])      # comb: cwrs without dups
307                self.assertEqual(comb, filter(set(cwr).__contains__, perm))     # comb: perm that is a cwr
308                self.assertEqual(comb, filter(set(perm).__contains__, cwr))     # comb: cwr that is a perm
309                self.assertEqual(comb, sorted(set(cwr) & set(perm)))            # comb: both a cwr and a perm
310
311    def test_compress(self):
312        self.assertEqual(list(compress(data='ABCDEF', selectors=[1,0,1,0,1,1])), list('ACEF'))
313        self.assertEqual(list(compress('ABCDEF', [1,0,1,0,1,1])), list('ACEF'))
314        self.assertEqual(list(compress('ABCDEF', [0,0,0,0,0,0])), list(''))
315        self.assertEqual(list(compress('ABCDEF', [1,1,1,1,1,1])), list('ABCDEF'))
316        self.assertEqual(list(compress('ABCDEF', [1,0,1])), list('AC'))
317        self.assertEqual(list(compress('ABC', [0,1,1,1,1,1])), list('BC'))
318        n = 10000
319        data = chain.from_iterable(repeat(range(6), n))
320        selectors = chain.from_iterable(repeat((0, 1)))
321        self.assertEqual(list(compress(data, selectors)), [1,3,5] * n)
322        self.assertRaises(TypeError, compress, None, range(6))      # 1st arg not iterable
323        self.assertRaises(TypeError, compress, range(6), None)      # 2nd arg not iterable
324        self.assertRaises(TypeError, compress, range(6))            # too few args
325        self.assertRaises(TypeError, compress, range(6), None)      # too many args
326
327    def test_count(self):
328        self.assertEqual(zip('abc',count()), [('a', 0), ('b', 1), ('c', 2)])
329        self.assertEqual(zip('abc',count(3)), [('a', 3), ('b', 4), ('c', 5)])
330        self.assertEqual(take(2, zip('abc',count(3))), [('a', 3), ('b', 4)])
331        self.assertEqual(take(2, zip('abc',count(-1))), [('a', -1), ('b', 0)])
332        self.assertEqual(take(2, zip('abc',count(-3))), [('a', -3), ('b', -2)])
333        self.assertRaises(TypeError, count, 2, 3, 4)
334        self.assertRaises(TypeError, count, 'a')
335        self.assertEqual(list(islice(count(maxsize-5), 10)), range(maxsize-5, maxsize+5))
336        self.assertEqual(list(islice(count(-maxsize-5), 10)), range(-maxsize-5, -maxsize+5))
337        c = count(3)
338        self.assertEqual(repr(c), 'count(3)')
339        c.next()
340        self.assertEqual(repr(c), 'count(4)')
341        c = count(-9)
342        self.assertEqual(repr(c), 'count(-9)')
343        c.next()
344        self.assertEqual(repr(count(10.25)), 'count(10.25)')
345        self.assertEqual(c.next(), -8)
346        for i in (-sys.maxint-5, -sys.maxint+5 ,-10, -1, 0, 10, sys.maxint-5, sys.maxint+5):
347            # Test repr (ignoring the L in longs)
348            r1 = repr(count(i)).replace('L', '')
349            r2 = 'count(%r)'.__mod__(i).replace('L', '')
350            self.assertEqual(r1, r2)
351
352        # check copy, deepcopy, pickle
353        for value in -3, 3, sys.maxint-5, sys.maxint+5:
354            c = count(value)
355            self.assertEqual(next(copy.copy(c)), value)
356            self.assertEqual(next(copy.deepcopy(c)), value)
357            self.assertEqual(next(pickle.loads(pickle.dumps(c))), value)
358
359    def test_count_with_stride(self):
360        self.assertEqual(zip('abc',count(2,3)), [('a', 2), ('b', 5), ('c', 8)])
361        self.assertEqual(zip('abc',count(start=2,step=3)),
362                         [('a', 2), ('b', 5), ('c', 8)])
363        self.assertEqual(zip('abc',count(step=-1)),
364                         [('a', 0), ('b', -1), ('c', -2)])
365        self.assertEqual(zip('abc',count(2,0)), [('a', 2), ('b', 2), ('c', 2)])
366        self.assertEqual(zip('abc',count(2,1)), [('a', 2), ('b', 3), ('c', 4)])
367        self.assertEqual(take(20, count(maxsize-15, 3)), take(20, range(maxsize-15, maxsize+100, 3)))
368        self.assertEqual(take(20, count(-maxsize-15, 3)), take(20, range(-maxsize-15,-maxsize+100, 3)))
369        self.assertEqual(take(3, count(2, 3.25-4j)), [2, 5.25-4j, 8.5-8j])
370        self.assertEqual(take(3, count(Decimal('1.1'), Decimal('.1'))),
371                         [Decimal('1.1'), Decimal('1.2'), Decimal('1.3')])
372        self.assertEqual(take(3, count(Fraction(2,3), Fraction(1,7))),
373                         [Fraction(2,3), Fraction(17,21), Fraction(20,21)])
374        self.assertEqual(repr(take(3, count(10, 2.5))), repr([10, 12.5, 15.0]))
375        c = count(3, 5)
376        self.assertEqual(repr(c), 'count(3, 5)')
377        c.next()
378        self.assertEqual(repr(c), 'count(8, 5)')
379        c = count(-9, 0)
380        self.assertEqual(repr(c), 'count(-9, 0)')
381        c.next()
382        self.assertEqual(repr(c), 'count(-9, 0)')
383        c = count(-9, -3)
384        self.assertEqual(repr(c), 'count(-9, -3)')
385        c.next()
386        self.assertEqual(repr(c), 'count(-12, -3)')
387        self.assertEqual(repr(c), 'count(-12, -3)')
388        self.assertEqual(repr(count(10.5, 1.25)), 'count(10.5, 1.25)')
389        self.assertEqual(repr(count(10.5, 1)), 'count(10.5)')           # suppress step=1 when it's an int
390        self.assertEqual(repr(count(10.5, 1.00)), 'count(10.5, 1.0)')   # do show float values lilke 1.0
391        for i in (-sys.maxint-5, -sys.maxint+5 ,-10, -1, 0, 10, sys.maxint-5, sys.maxint+5):
392            for j in  (-sys.maxint-5, -sys.maxint+5 ,-10, -1, 0, 1, 10, sys.maxint-5, sys.maxint+5):
393                # Test repr (ignoring the L in longs)
394                r1 = repr(count(i, j)).replace('L', '')
395                if j == 1:
396                    r2 = ('count(%r)' % i).replace('L', '')
397                else:
398                    r2 = ('count(%r, %r)' % (i, j)).replace('L', '')
399                self.assertEqual(r1, r2)
400
401    def test_cycle(self):
402        self.assertEqual(take(10, cycle('abc')), list('abcabcabca'))
403        self.assertEqual(list(cycle('')), [])
404        self.assertRaises(TypeError, cycle)
405        self.assertRaises(TypeError, cycle, 5)
406        self.assertEqual(list(islice(cycle(gen3()),10)), [0,1,2,0,1,2,0,1,2,0])
407
408    def test_groupby(self):
409        # Check whether it accepts arguments correctly
410        self.assertEqual([], list(groupby([])))
411        self.assertEqual([], list(groupby([], key=id)))
412        self.assertRaises(TypeError, list, groupby('abc', []))
413        self.assertRaises(TypeError, groupby, None)
414        self.assertRaises(TypeError, groupby, 'abc', lambda x:x, 10)
415
416        # Check normal input
417        s = [(0, 10, 20), (0, 11,21), (0,12,21), (1,13,21), (1,14,22),
418             (2,15,22), (3,16,23), (3,17,23)]
419        dup = []
420        for k, g in groupby(s, lambda r:r[0]):
421            for elem in g:
422                self.assertEqual(k, elem[0])
423                dup.append(elem)
424        self.assertEqual(s, dup)
425
426        # Check nested case
427        dup = []
428        for k, g in groupby(s, lambda r:r[0]):
429            for ik, ig in groupby(g, lambda r:r[2]):
430                for elem in ig:
431                    self.assertEqual(k, elem[0])
432                    self.assertEqual(ik, elem[2])
433                    dup.append(elem)
434        self.assertEqual(s, dup)
435
436        # Check case where inner iterator is not used
437        keys = [k for k, g in groupby(s, lambda r:r[0])]
438        expectedkeys = set([r[0] for r in s])
439        self.assertEqual(set(keys), expectedkeys)
440        self.assertEqual(len(keys), len(expectedkeys))
441
442        # Exercise pipes and filters style
443        s = 'abracadabra'
444        # sort s | uniq
445        r = [k for k, g in groupby(sorted(s))]
446        self.assertEqual(r, ['a', 'b', 'c', 'd', 'r'])
447        # sort s | uniq -d
448        r = [k for k, g in groupby(sorted(s)) if list(islice(g,1,2))]
449        self.assertEqual(r, ['a', 'b', 'r'])
450        # sort s | uniq -c
451        r = [(len(list(g)), k) for k, g in groupby(sorted(s))]
452        self.assertEqual(r, [(5, 'a'), (2, 'b'), (1, 'c'), (1, 'd'), (2, 'r')])
453        # sort s | uniq -c | sort -rn | head -3
454        r = sorted([(len(list(g)) , k) for k, g in groupby(sorted(s))], reverse=True)[:3]
455        self.assertEqual(r, [(5, 'a'), (2, 'r'), (2, 'b')])
456
457        # iter.next failure
458        class ExpectedError(Exception):
459            pass
460        def delayed_raise(n=0):
461            for i in range(n):
462                yield 'yo'
463            raise ExpectedError
464        def gulp(iterable, keyp=None, func=list):
465            return [func(g) for k, g in groupby(iterable, keyp)]
466
467        # iter.next failure on outer object
468        self.assertRaises(ExpectedError, gulp, delayed_raise(0))
469        # iter.next failure on inner object
470        self.assertRaises(ExpectedError, gulp, delayed_raise(1))
471
472        # __cmp__ failure
473        class DummyCmp:
474            def __cmp__(self, dst):
475                raise ExpectedError
476        s = [DummyCmp(), DummyCmp(), None]
477
478        # __cmp__ failure on outer object
479        self.assertRaises(ExpectedError, gulp, s, func=id)
480        # __cmp__ failure on inner object
481        self.assertRaises(ExpectedError, gulp, s)
482
483        # keyfunc failure
484        def keyfunc(obj):
485            if keyfunc.skip > 0:
486                keyfunc.skip -= 1
487                return obj
488            else:
489                raise ExpectedError
490
491        # keyfunc failure on outer object
492        keyfunc.skip = 0
493        self.assertRaises(ExpectedError, gulp, [None], keyfunc)
494        keyfunc.skip = 1
495        self.assertRaises(ExpectedError, gulp, [None, None], keyfunc)
496
497    def test_ifilter(self):
498        self.assertEqual(list(ifilter(isEven, range(6))), [0,2,4])
499        self.assertEqual(list(ifilter(None, [0,1,0,2,0])), [1,2])
500        self.assertEqual(list(ifilter(bool, [0,1,0,2,0])), [1,2])
501        self.assertEqual(take(4, ifilter(isEven, count())), [0,2,4,6])
502        self.assertRaises(TypeError, ifilter)
503        self.assertRaises(TypeError, ifilter, lambda x:x)
504        self.assertRaises(TypeError, ifilter, lambda x:x, range(6), 7)
505        self.assertRaises(TypeError, ifilter, isEven, 3)
506        self.assertRaises(TypeError, ifilter(range(6), range(6)).next)
507
508    def test_ifilterfalse(self):
509        self.assertEqual(list(ifilterfalse(isEven, range(6))), [1,3,5])
510        self.assertEqual(list(ifilterfalse(None, [0,1,0,2,0])), [0,0,0])
511        self.assertEqual(list(ifilterfalse(bool, [0,1,0,2,0])), [0,0,0])
512        self.assertEqual(take(4, ifilterfalse(isEven, count())), [1,3,5,7])
513        self.assertRaises(TypeError, ifilterfalse)
514        self.assertRaises(TypeError, ifilterfalse, lambda x:x)
515        self.assertRaises(TypeError, ifilterfalse, lambda x:x, range(6), 7)
516        self.assertRaises(TypeError, ifilterfalse, isEven, 3)
517        self.assertRaises(TypeError, ifilterfalse(range(6), range(6)).next)
518
519    def test_izip(self):
520        ans = [(x,y) for x, y in izip('abc',count())]
521        self.assertEqual(ans, [('a', 0), ('b', 1), ('c', 2)])
522        self.assertEqual(list(izip('abc', range(6))), zip('abc', range(6)))
523        self.assertEqual(list(izip('abcdef', range(3))), zip('abcdef', range(3)))
524        self.assertEqual(take(3,izip('abcdef', count())), zip('abcdef', range(3)))
525        self.assertEqual(list(izip('abcdef')), zip('abcdef'))
526        self.assertEqual(list(izip()), zip())
527        self.assertRaises(TypeError, izip, 3)
528        self.assertRaises(TypeError, izip, range(3), 3)
529        # Check tuple re-use (implementation detail)
530        self.assertEqual([tuple(list(pair)) for pair in izip('abc', 'def')],
531                         zip('abc', 'def'))
532        self.assertEqual([pair for pair in izip('abc', 'def')],
533                         zip('abc', 'def'))
534        ids = map(id, izip('abc', 'def'))
535        self.assertEqual(min(ids), max(ids))
536        ids = map(id, list(izip('abc', 'def')))
537        self.assertEqual(len(dict.fromkeys(ids)), len(ids))
538
539    def test_iziplongest(self):
540        for args in [
541                ['abc', range(6)],
542                [range(6), 'abc'],
543                [range(1000), range(2000,2100), range(3000,3050)],
544                [range(1000), range(0), range(3000,3050), range(1200), range(1500)],
545                [range(1000), range(0), range(3000,3050), range(1200), range(1500), range(0)],
546            ]:
547            # target = map(None, *args) <- this raises a py3k warning
548            # this is the replacement:
549            target = [tuple([arg[i] if i < len(arg) else None for arg in args])
550                      for i in range(max(map(len, args)))]
551            self.assertEqual(list(izip_longest(*args)), target)
552            self.assertEqual(list(izip_longest(*args, **{})), target)
553            target = [tuple((e is None and 'X' or e) for e in t) for t in target]   # Replace None fills with 'X'
554            self.assertEqual(list(izip_longest(*args, **dict(fillvalue='X'))), target)
555
556        self.assertEqual(take(3,izip_longest('abcdef', count())), zip('abcdef', range(3))) # take 3 from infinite input
557
558        self.assertEqual(list(izip_longest()), zip())
559        self.assertEqual(list(izip_longest([])), zip([]))
560        self.assertEqual(list(izip_longest('abcdef')), zip('abcdef'))
561
562        self.assertEqual(list(izip_longest('abc', 'defg', **{})),
563                         zip(list('abc') + [None], 'defg'))  # empty keyword dict
564        self.assertRaises(TypeError, izip_longest, 3)
565        self.assertRaises(TypeError, izip_longest, range(3), 3)
566
567        for stmt in [
568            "izip_longest('abc', fv=1)",
569            "izip_longest('abc', fillvalue=1, bogus_keyword=None)",
570        ]:
571            try:
572                eval(stmt, globals(), locals())
573            except TypeError:
574                pass
575            else:
576                self.fail('Did not raise Type in:  ' + stmt)
577
578        # Check tuple re-use (implementation detail)
579        self.assertEqual([tuple(list(pair)) for pair in izip_longest('abc', 'def')],
580                         zip('abc', 'def'))
581        self.assertEqual([pair for pair in izip_longest('abc', 'def')],
582                         zip('abc', 'def'))
583        ids = map(id, izip_longest('abc', 'def'))
584        self.assertEqual(min(ids), max(ids))
585        ids = map(id, list(izip_longest('abc', 'def')))
586        self.assertEqual(len(dict.fromkeys(ids)), len(ids))
587
588    def test_bug_7244(self):
589
590        class Repeater(object):
591            # this class is similar to itertools.repeat
592            def __init__(self, o, t, e):
593                self.o = o
594                self.t = int(t)
595                self.e = e
596            def __iter__(self): # its iterator is itself
597                return self
598            def next(self):
599                if self.t > 0:
600                    self.t -= 1
601                    return self.o
602                else:
603                    raise self.e
604
605        # Formerly this code in would fail in debug mode
606        # with Undetected Error and Stop Iteration
607        r1 = Repeater(1, 3, StopIteration)
608        r2 = Repeater(2, 4, StopIteration)
609        def run(r1, r2):
610            result = []
611            for i, j in izip_longest(r1, r2, fillvalue=0):
612                with test_support.captured_output('stdout'):
613                    print (i, j)
614                result.append((i, j))
615            return result
616        self.assertEqual(run(r1, r2), [(1,2), (1,2), (1,2), (0,2)])
617
618        # Formerly, the RuntimeError would be lost
619        # and StopIteration would stop as expected
620        r1 = Repeater(1, 3, RuntimeError)
621        r2 = Repeater(2, 4, StopIteration)
622        it = izip_longest(r1, r2, fillvalue=0)
623        self.assertEqual(next(it), (1, 2))
624        self.assertEqual(next(it), (1, 2))
625        self.assertEqual(next(it), (1, 2))
626        self.assertRaises(RuntimeError, next, it)
627
628    def test_product(self):
629        for args, result in [
630            ([], [()]),                     # zero iterables
631            (['ab'], [('a',), ('b',)]),     # one iterable
632            ([range(2), range(3)], [(0,0), (0,1), (0,2), (1,0), (1,1), (1,2)]),     # two iterables
633            ([range(0), range(2), range(3)], []),           # first iterable with zero length
634            ([range(2), range(0), range(3)], []),           # middle iterable with zero length
635            ([range(2), range(3), range(0)], []),           # last iterable with zero length
636            ]:
637            self.assertEqual(list(product(*args)), result)
638            for r in range(4):
639                self.assertEqual(list(product(*(args*r))),
640                                 list(product(*args, **dict(repeat=r))))
641        self.assertEqual(len(list(product(*[range(7)]*6))), 7**6)
642        self.assertRaises(TypeError, product, range(6), None)
643
644        def product1(*args, **kwds):
645            pools = map(tuple, args) * kwds.get('repeat', 1)
646            n = len(pools)
647            if n == 0:
648                yield ()
649                return
650            if any(len(pool) == 0 for pool in pools):
651                return
652            indices = [0] * n
653            yield tuple(pool[i] for pool, i in zip(pools, indices))
654            while 1:
655                for i in reversed(range(n)):  # right to left
656                    if indices[i] == len(pools[i]) - 1:
657                        continue
658                    indices[i] += 1
659                    for j in range(i+1, n):
660                        indices[j] = 0
661                    yield tuple(pool[i] for pool, i in zip(pools, indices))
662                    break
663                else:
664                    return
665
666        def product2(*args, **kwds):
667            'Pure python version used in docs'
668            pools = map(tuple, args) * kwds.get('repeat', 1)
669            result = [[]]
670            for pool in pools:
671                result = [x+[y] for x in result for y in pool]
672            for prod in result:
673                yield tuple(prod)
674
675        argtypes = ['', 'abc', '', xrange(0), xrange(4), dict(a=1, b=2, c=3),
676                    set('abcdefg'), range(11), tuple(range(13))]
677        for i in range(100):
678            args = [random.choice(argtypes) for j in range(random.randrange(5))]
679            expected_len = prod(map(len, args))
680            self.assertEqual(len(list(product(*args))), expected_len)
681            self.assertEqual(list(product(*args)), list(product1(*args)))
682            self.assertEqual(list(product(*args)), list(product2(*args)))
683            args = map(iter, args)
684            self.assertEqual(len(list(product(*args))), expected_len)
685
686        # Test implementation detail:  tuple re-use
687        self.assertEqual(len(set(map(id, product('abc', 'def')))), 1)
688        self.assertNotEqual(len(set(map(id, list(product('abc', 'def'))))), 1)
689
690    def test_repeat(self):
691        self.assertEqual(list(repeat(object='a', times=3)), ['a', 'a', 'a'])
692        self.assertEqual(zip(xrange(3),repeat('a')),
693                         [(0, 'a'), (1, 'a'), (2, 'a')])
694        self.assertEqual(list(repeat('a', 3)), ['a', 'a', 'a'])
695        self.assertEqual(take(3, repeat('a')), ['a', 'a', 'a'])
696        self.assertEqual(list(repeat('a', 0)), [])
697        self.assertEqual(list(repeat('a', -3)), [])
698        self.assertRaises(TypeError, repeat)
699        self.assertRaises(TypeError, repeat, None, 3, 4)
700        self.assertRaises(TypeError, repeat, None, 'a')
701        r = repeat(1+0j)
702        self.assertEqual(repr(r), 'repeat((1+0j))')
703        r = repeat(1+0j, 5)
704        self.assertEqual(repr(r), 'repeat((1+0j), 5)')
705        list(r)
706        self.assertEqual(repr(r), 'repeat((1+0j), 0)')
707
708    def test_imap(self):
709        self.assertEqual(list(imap(operator.pow, range(3), range(1,7))),
710                         [0**1, 1**2, 2**3])
711        self.assertEqual(list(imap(None, 'abc', range(5))),
712                         [('a',0),('b',1),('c',2)])
713        self.assertEqual(list(imap(None, 'abc', count())),
714                         [('a',0),('b',1),('c',2)])
715        self.assertEqual(take(2,imap(None, 'abc', count())),
716                         [('a',0),('b',1)])
717        self.assertEqual(list(imap(operator.pow, [])), [])
718        self.assertRaises(TypeError, imap)
719        self.assertRaises(TypeError, imap, operator.neg)
720        self.assertRaises(TypeError, imap(10, range(5)).next)
721        self.assertRaises(ValueError, imap(errfunc, [4], [5]).next)
722        self.assertRaises(TypeError, imap(onearg, [4], [5]).next)
723
724    def test_starmap(self):
725        self.assertEqual(list(starmap(operator.pow, zip(range(3), range(1,7)))),
726                         [0**1, 1**2, 2**3])
727        self.assertEqual(take(3, starmap(operator.pow, izip(count(), count(1)))),
728                         [0**1, 1**2, 2**3])
729        self.assertEqual(list(starmap(operator.pow, [])), [])
730        self.assertEqual(list(starmap(operator.pow, [iter([4,5])])), [4**5])
731        self.assertRaises(TypeError, list, starmap(operator.pow, [None]))
732        self.assertRaises(TypeError, starmap)
733        self.assertRaises(TypeError, starmap, operator.pow, [(4,5)], 'extra')
734        self.assertRaises(TypeError, starmap(10, [(4,5)]).next)
735        self.assertRaises(ValueError, starmap(errfunc, [(4,5)]).next)
736        self.assertRaises(TypeError, starmap(onearg, [(4,5)]).next)
737
738    def test_islice(self):
739        for args in [          # islice(args) should agree with range(args)
740                (10, 20, 3),
741                (10, 3, 20),
742                (10, 20),
743                (10, 3),
744                (20,)
745                ]:
746            self.assertEqual(list(islice(xrange(100), *args)), range(*args))
747
748        for args, tgtargs in [  # Stop when seqn is exhausted
749                ((10, 110, 3), ((10, 100, 3))),
750                ((10, 110), ((10, 100))),
751                ((110,), (100,))
752                ]:
753            self.assertEqual(list(islice(xrange(100), *args)), range(*tgtargs))
754
755        # Test stop=None
756        self.assertEqual(list(islice(xrange(10), None)), range(10))
757        self.assertEqual(list(islice(xrange(10), None, None)), range(10))
758        self.assertEqual(list(islice(xrange(10), None, None, None)), range(10))
759        self.assertEqual(list(islice(xrange(10), 2, None)), range(2, 10))
760        self.assertEqual(list(islice(xrange(10), 1, None, 2)), range(1, 10, 2))
761
762        # Test number of items consumed     SF #1171417
763        it = iter(range(10))
764        self.assertEqual(list(islice(it, 3)), range(3))
765        self.assertEqual(list(it), range(3, 10))
766
767        # Test invalid arguments
768        self.assertRaises(TypeError, islice, xrange(10))
769        self.assertRaises(TypeError, islice, xrange(10), 1, 2, 3, 4)
770        self.assertRaises(ValueError, islice, xrange(10), -5, 10, 1)
771        self.assertRaises(ValueError, islice, xrange(10), 1, -5, -1)
772        self.assertRaises(ValueError, islice, xrange(10), 1, 10, -1)
773        self.assertRaises(ValueError, islice, xrange(10), 1, 10, 0)
774        self.assertRaises(ValueError, islice, xrange(10), 'a')
775        self.assertRaises(ValueError, islice, xrange(10), 'a', 1)
776        self.assertRaises(ValueError, islice, xrange(10), 1, 'a')
777        self.assertRaises(ValueError, islice, xrange(10), 'a', 1, 1)
778        self.assertRaises(ValueError, islice, xrange(10), 1, 'a', 1)
779        self.assertEqual(len(list(islice(count(), 1, 10, maxsize))), 1)
780
781        # Issue #10323:  Less islice in a predictable state
782        c = count()
783        self.assertEqual(list(islice(c, 1, 3, 50)), [1])
784        self.assertEqual(next(c), 3)
785
786    def test_takewhile(self):
787        data = [1, 3, 5, 20, 2, 4, 6, 8]
788        underten = lambda x: x<10
789        self.assertEqual(list(takewhile(underten, data)), [1, 3, 5])
790        self.assertEqual(list(takewhile(underten, [])), [])
791        self.assertRaises(TypeError, takewhile)
792        self.assertRaises(TypeError, takewhile, operator.pow)
793        self.assertRaises(TypeError, takewhile, operator.pow, [(4,5)], 'extra')
794        self.assertRaises(TypeError, takewhile(10, [(4,5)]).next)
795        self.assertRaises(ValueError, takewhile(errfunc, [(4,5)]).next)
796        t = takewhile(bool, [1, 1, 1, 0, 0, 0])
797        self.assertEqual(list(t), [1, 1, 1])
798        self.assertRaises(StopIteration, t.next)
799
800    def test_dropwhile(self):
801        data = [1, 3, 5, 20, 2, 4, 6, 8]
802        underten = lambda x: x<10
803        self.assertEqual(list(dropwhile(underten, data)), [20, 2, 4, 6, 8])
804        self.assertEqual(list(dropwhile(underten, [])), [])
805        self.assertRaises(TypeError, dropwhile)
806        self.assertRaises(TypeError, dropwhile, operator.pow)
807        self.assertRaises(TypeError, dropwhile, operator.pow, [(4,5)], 'extra')
808        self.assertRaises(TypeError, dropwhile(10, [(4,5)]).next)
809        self.assertRaises(ValueError, dropwhile(errfunc, [(4,5)]).next)
810
811    def test_tee(self):
812        n = 200
813        def irange(n):
814            for i in xrange(n):
815                yield i
816
817        a, b = tee([])        # test empty iterator
818        self.assertEqual(list(a), [])
819        self.assertEqual(list(b), [])
820
821        a, b = tee(irange(n)) # test 100% interleaved
822        self.assertEqual(zip(a,b), zip(range(n),range(n)))
823
824        a, b = tee(irange(n)) # test 0% interleaved
825        self.assertEqual(list(a), range(n))
826        self.assertEqual(list(b), range(n))
827
828        a, b = tee(irange(n)) # test dealloc of leading iterator
829        for i in xrange(100):
830            self.assertEqual(a.next(), i)
831        del a
832        self.assertEqual(list(b), range(n))
833
834        a, b = tee(irange(n)) # test dealloc of trailing iterator
835        for i in xrange(100):
836            self.assertEqual(a.next(), i)
837        del b
838        self.assertEqual(list(a), range(100, n))
839
840        for j in xrange(5):   # test randomly interleaved
841            order = [0]*n + [1]*n
842            random.shuffle(order)
843            lists = ([], [])
844            its = tee(irange(n))
845            for i in order:
846                value = its[i].next()
847                lists[i].append(value)
848            self.assertEqual(lists[0], range(n))
849            self.assertEqual(lists[1], range(n))
850
851        # test argument format checking
852        self.assertRaises(TypeError, tee)
853        self.assertRaises(TypeError, tee, 3)
854        self.assertRaises(TypeError, tee, [1,2], 'x')
855        self.assertRaises(TypeError, tee, [1,2], 3, 'x')
856
857        # tee object should be instantiable
858        a, b = tee('abc')
859        c = type(a)('def')
860        self.assertEqual(list(c), list('def'))
861
862        # test long-lagged and multi-way split
863        a, b, c = tee(xrange(2000), 3)
864        for i in xrange(100):
865            self.assertEqual(a.next(), i)
866        self.assertEqual(list(b), range(2000))
867        self.assertEqual([c.next(), c.next()], range(2))
868        self.assertEqual(list(a), range(100,2000))
869        self.assertEqual(list(c), range(2,2000))
870
871        # test values of n
872        self.assertRaises(TypeError, tee, 'abc', 'invalid')
873        self.assertRaises(ValueError, tee, [], -1)
874        for n in xrange(5):
875            result = tee('abc', n)
876            self.assertEqual(type(result), tuple)
877            self.assertEqual(len(result), n)
878            self.assertEqual(map(list, result), [list('abc')]*n)
879
880        # tee pass-through to copyable iterator
881        a, b = tee('abc')
882        c, d = tee(a)
883        self.assertTrue(a is c)
884
885        # test tee_new
886        t1, t2 = tee('abc')
887        tnew = type(t1)
888        self.assertRaises(TypeError, tnew)
889        self.assertRaises(TypeError, tnew, 10)
890        t3 = tnew(t1)
891        self.assertTrue(list(t1) == list(t2) == list(t3) == list('abc'))
892
893        # test that tee objects are weak referencable
894        a, b = tee(xrange(10))
895        p = proxy(a)
896        self.assertEqual(getattr(p, '__class__'), type(b))
897        del a
898        self.assertRaises(ReferenceError, getattr, p, '__class__')
899
900    def test_StopIteration(self):
901        self.assertRaises(StopIteration, izip().next)
902
903        for f in (chain, cycle, izip, groupby):
904            self.assertRaises(StopIteration, f([]).next)
905            self.assertRaises(StopIteration, f(StopNow()).next)
906
907        self.assertRaises(StopIteration, islice([], None).next)
908        self.assertRaises(StopIteration, islice(StopNow(), None).next)
909
910        p, q = tee([])
911        self.assertRaises(StopIteration, p.next)
912        self.assertRaises(StopIteration, q.next)
913        p, q = tee(StopNow())
914        self.assertRaises(StopIteration, p.next)
915        self.assertRaises(StopIteration, q.next)
916
917        self.assertRaises(StopIteration, repeat(None, 0).next)
918
919        for f in (ifilter, ifilterfalse, imap, takewhile, dropwhile, starmap):
920            self.assertRaises(StopIteration, f(lambda x:x, []).next)
921            self.assertRaises(StopIteration, f(lambda x:x, StopNow()).next)
922
923class TestExamples(unittest.TestCase):
924
925    def test_chain(self):
926        self.assertEqual(''.join(chain('ABC', 'DEF')), 'ABCDEF')
927
928    def test_chain_from_iterable(self):
929        self.assertEqual(''.join(chain.from_iterable(['ABC', 'DEF'])), 'ABCDEF')
930
931    def test_combinations(self):
932        self.assertEqual(list(combinations('ABCD', 2)),
933                         [('A','B'), ('A','C'), ('A','D'), ('B','C'), ('B','D'), ('C','D')])
934        self.assertEqual(list(combinations(range(4), 3)),
935                         [(0,1,2), (0,1,3), (0,2,3), (1,2,3)])
936
937    def test_combinations_with_replacement(self):
938        self.assertEqual(list(combinations_with_replacement('ABC', 2)),
939                         [('A','A'), ('A','B'), ('A','C'), ('B','B'), ('B','C'), ('C','C')])
940
941    def test_compress(self):
942        self.assertEqual(list(compress('ABCDEF', [1,0,1,0,1,1])), list('ACEF'))
943
944    def test_count(self):
945        self.assertEqual(list(islice(count(10), 5)), [10, 11, 12, 13, 14])
946
947    def test_cycle(self):
948        self.assertEqual(list(islice(cycle('ABCD'), 12)), list('ABCDABCDABCD'))
949
950    def test_dropwhile(self):
951        self.assertEqual(list(dropwhile(lambda x: x<5, [1,4,6,4,1])), [6,4,1])
952
953    def test_groupby(self):
954        self.assertEqual([k for k, g in groupby('AAAABBBCCDAABBB')],
955                         list('ABCDAB'))
956        self.assertEqual([(list(g)) for k, g in groupby('AAAABBBCCD')],
957                         [list('AAAA'), list('BBB'), list('CC'), list('D')])
958
959    def test_ifilter(self):
960        self.assertEqual(list(ifilter(lambda x: x%2, range(10))), [1,3,5,7,9])
961
962    def test_ifilterfalse(self):
963        self.assertEqual(list(ifilterfalse(lambda x: x%2, range(10))), [0,2,4,6,8])
964
965    def test_imap(self):
966        self.assertEqual(list(imap(pow, (2,3,10), (5,2,3))), [32, 9, 1000])
967
968    def test_islice(self):
969        self.assertEqual(list(islice('ABCDEFG', 2)), list('AB'))
970        self.assertEqual(list(islice('ABCDEFG', 2, 4)), list('CD'))
971        self.assertEqual(list(islice('ABCDEFG', 2, None)), list('CDEFG'))
972        self.assertEqual(list(islice('ABCDEFG', 0, None, 2)), list('ACEG'))
973
974    def test_izip(self):
975        self.assertEqual(list(izip('ABCD', 'xy')), [('A', 'x'), ('B', 'y')])
976
977    def test_izip_longest(self):
978        self.assertEqual(list(izip_longest('ABCD', 'xy', fillvalue='-')),
979                         [('A', 'x'), ('B', 'y'), ('C', '-'), ('D', '-')])
980
981    def test_permutations(self):
982        self.assertEqual(list(permutations('ABCD', 2)),
983                         map(tuple, 'AB AC AD BA BC BD CA CB CD DA DB DC'.split()))
984        self.assertEqual(list(permutations(range(3))),
985                         [(0,1,2), (0,2,1), (1,0,2), (1,2,0), (2,0,1), (2,1,0)])
986
987    def test_product(self):
988        self.assertEqual(list(product('ABCD', 'xy')),
989                         map(tuple, 'Ax Ay Bx By Cx Cy Dx Dy'.split()))
990        self.assertEqual(list(product(range(2), repeat=3)),
991                        [(0,0,0), (0,0,1), (0,1,0), (0,1,1),
992                         (1,0,0), (1,0,1), (1,1,0), (1,1,1)])
993
994    def test_repeat(self):
995        self.assertEqual(list(repeat(10, 3)), [10, 10, 10])
996
997    def test_stapmap(self):
998        self.assertEqual(list(starmap(pow, [(2,5), (3,2), (10,3)])),
999                         [32, 9, 1000])
1000
1001    def test_takewhile(self):
1002        self.assertEqual(list(takewhile(lambda x: x<5, [1,4,6,4,1])), [1,4])
1003
1004
1005class TestGC(unittest.TestCase):
1006
1007    def makecycle(self, iterator, container):
1008        container.append(iterator)
1009        iterator.next()
1010        del container, iterator
1011
1012    def test_chain(self):
1013        a = []
1014        self.makecycle(chain(a), a)
1015
1016    def test_chain_from_iterable(self):
1017        a = []
1018        self.makecycle(chain.from_iterable([a]), a)
1019
1020    def test_combinations(self):
1021        a = []
1022        self.makecycle(combinations([1,2,a,3], 3), a)
1023
1024    def test_combinations_with_replacement(self):
1025        a = []
1026        self.makecycle(combinations_with_replacement([1,2,a,3], 3), a)
1027
1028    def test_compress(self):
1029        a = []
1030        self.makecycle(compress('ABCDEF', [1,0,1,0,1,0]), a)
1031
1032    def test_count(self):
1033        a = []
1034        Int = type('Int', (int,), dict(x=a))
1035        self.makecycle(count(Int(0), Int(1)), a)
1036
1037    def test_cycle(self):
1038        a = []
1039        self.makecycle(cycle([a]*2), a)
1040
1041    def test_dropwhile(self):
1042        a = []
1043        self.makecycle(dropwhile(bool, [0, a, a]), a)
1044
1045    def test_groupby(self):
1046        a = []
1047        self.makecycle(groupby([a]*2, lambda x:x), a)
1048
1049    def test_issue2246(self):
1050        # Issue 2246 -- the _grouper iterator was not included in GC
1051        n = 10
1052        keyfunc = lambda x: x
1053        for i, j in groupby(xrange(n), key=keyfunc):
1054            keyfunc.__dict__.setdefault('x',[]).append(j)
1055
1056    def test_ifilter(self):
1057        a = []
1058        self.makecycle(ifilter(lambda x:True, [a]*2), a)
1059
1060    def test_ifilterfalse(self):
1061        a = []
1062        self.makecycle(ifilterfalse(lambda x:False, a), a)
1063
1064    def test_izip(self):
1065        a = []
1066        self.makecycle(izip([a]*2, [a]*3), a)
1067
1068    def test_izip_longest(self):
1069        a = []
1070        self.makecycle(izip_longest([a]*2, [a]*3), a)
1071        b = [a, None]
1072        self.makecycle(izip_longest([a]*2, [a]*3, fillvalue=b), a)
1073
1074    def test_imap(self):
1075        a = []
1076        self.makecycle(imap(lambda x:x, [a]*2), a)
1077
1078    def test_islice(self):
1079        a = []
1080        self.makecycle(islice([a]*2, None), a)
1081
1082    def test_permutations(self):
1083        a = []
1084        self.makecycle(permutations([1,2,a,3], 3), a)
1085
1086    def test_product(self):
1087        a = []
1088        self.makecycle(product([1,2,a,3], repeat=3), a)
1089
1090    def test_repeat(self):
1091        a = []
1092        self.makecycle(repeat(a), a)
1093
1094    def test_starmap(self):
1095        a = []
1096        self.makecycle(starmap(lambda *t: t, [(a,a)]*2), a)
1097
1098    def test_takewhile(self):
1099        a = []
1100        self.makecycle(takewhile(bool, [1, 0, a, a]), a)
1101
1102def R(seqn):
1103    'Regular generator'
1104    for i in seqn:
1105        yield i
1106
1107class G:
1108    'Sequence using __getitem__'
1109    def __init__(self, seqn):
1110        self.seqn = seqn
1111    def __getitem__(self, i):
1112        return self.seqn[i]
1113
1114class I:
1115    'Sequence using iterator protocol'
1116    def __init__(self, seqn):
1117        self.seqn = seqn
1118        self.i = 0
1119    def __iter__(self):
1120        return self
1121    def next(self):
1122        if self.i >= len(self.seqn): raise StopIteration
1123        v = self.seqn[self.i]
1124        self.i += 1
1125        return v
1126
1127class Ig:
1128    'Sequence using iterator protocol defined with a generator'
1129    def __init__(self, seqn):
1130        self.seqn = seqn
1131        self.i = 0
1132    def __iter__(self):
1133        for val in self.seqn:
1134            yield val
1135
1136class X:
1137    'Missing __getitem__ and __iter__'
1138    def __init__(self, seqn):
1139        self.seqn = seqn
1140        self.i = 0
1141    def next(self):
1142        if self.i >= len(self.seqn): raise StopIteration
1143        v = self.seqn[self.i]
1144        self.i += 1
1145        return v
1146
1147class N:
1148    'Iterator missing next()'
1149    def __init__(self, seqn):
1150        self.seqn = seqn
1151        self.i = 0
1152    def __iter__(self):
1153        return self
1154
1155class E:
1156    'Test propagation of exceptions'
1157    def __init__(self, seqn):
1158        self.seqn = seqn
1159        self.i = 0
1160    def __iter__(self):
1161        return self
1162    def next(self):
1163        3 // 0
1164
1165class S:
1166    'Test immediate stop'
1167    def __init__(self, seqn):
1168        pass
1169    def __iter__(self):
1170        return self
1171    def next(self):
1172        raise StopIteration
1173
1174def L(seqn):
1175    'Test multiple tiers of iterators'
1176    return chain(imap(lambda x:x, R(Ig(G(seqn)))))
1177
1178
1179class TestVariousIteratorArgs(unittest.TestCase):
1180
1181    def test_chain(self):
1182        for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):
1183            for g in (G, I, Ig, S, L, R):
1184                self.assertEqual(list(chain(g(s))), list(g(s)))
1185                self.assertEqual(list(chain(g(s), g(s))), list(g(s))+list(g(s)))
1186            self.assertRaises(TypeError, list, chain(X(s)))
1187            self.assertRaises(TypeError, list, chain(N(s)))
1188            self.assertRaises(ZeroDivisionError, list, chain(E(s)))
1189
1190    def test_compress(self):
1191        for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):
1192            n = len(s)
1193            for g in (G, I, Ig, S, L, R):
1194                self.assertEqual(list(compress(g(s), repeat(1))), list(g(s)))
1195            self.assertRaises(TypeError, compress, X(s), repeat(1))
1196            self.assertRaises(TypeError, list, compress(N(s), repeat(1)))
1197            self.assertRaises(ZeroDivisionError, list, compress(E(s), repeat(1)))
1198
1199    def test_product(self):
1200        for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):
1201            self.assertRaises(TypeError, product, X(s))
1202            self.assertRaises(TypeError, product, N(s))
1203            self.assertRaises(ZeroDivisionError, product, E(s))
1204
1205    def test_cycle(self):
1206        for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):
1207            for g in (G, I, Ig, S, L, R):
1208                tgtlen = len(s) * 3
1209                expected = list(g(s))*3
1210                actual = list(islice(cycle(g(s)), tgtlen))
1211                self.assertEqual(actual, expected)
1212            self.assertRaises(TypeError, cycle, X(s))
1213            self.assertRaises(TypeError, list, cycle(N(s)))
1214            self.assertRaises(ZeroDivisionError, list, cycle(E(s)))
1215
1216    def test_groupby(self):
1217        for s in (range(10), range(0), range(1000), (7,11), xrange(2000,2200,5)):
1218            for g in (G, I, Ig, S, L, R):
1219                self.assertEqual([k for k, sb in groupby(g(s))], list(g(s)))
1220            self.assertRaises(TypeError, groupby, X(s))
1221            self.assertRaises(TypeError, list, groupby(N(s)))
1222            self.assertRaises(ZeroDivisionError, list, groupby(E(s)))
1223
1224    def test_ifilter(self):
1225        for s in (range(10), range(0), range(1000), (7,11), xrange(2000,2200,5)):
1226            for g in (G, I, Ig, S, L, R):
1227                self.assertEqual(list(ifilter(isEven, g(s))), filter(isEven, g(s)))
1228            self.assertRaises(TypeError, ifilter, isEven, X(s))
1229            self.assertRaises(TypeError, list, ifilter(isEven, N(s)))
1230            self.assertRaises(ZeroDivisionError, list, ifilter(isEven, E(s)))
1231
1232    def test_ifilterfalse(self):
1233        for s in (range(10), range(0), range(1000), (7,11), xrange(2000,2200,5)):
1234            for g in (G, I, Ig, S, L, R):
1235                self.assertEqual(list(ifilterfalse(isEven, g(s))), filter(isOdd, g(s)))
1236            self.assertRaises(TypeError, ifilterfalse, isEven, X(s))
1237            self.assertRaises(TypeError, list, ifilterfalse(isEven, N(s)))
1238            self.assertRaises(ZeroDivisionError, list, ifilterfalse(isEven, E(s)))
1239
1240    def test_izip(self):
1241        for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):
1242            for g in (G, I, Ig, S, L, R):
1243                self.assertEqual(list(izip(g(s))), zip(g(s)))
1244                self.assertEqual(list(izip(g(s), g(s))), zip(g(s), g(s)))
1245            self.assertRaises(TypeError, izip, X(s))
1246            self.assertRaises(TypeError, list, izip(N(s)))
1247            self.assertRaises(ZeroDivisionError, list, izip(E(s)))
1248
1249    def test_iziplongest(self):
1250        for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):
1251            for g in (G, I, Ig, S, L, R):
1252                self.assertEqual(list(izip_longest(g(s))), zip(g(s)))
1253                self.assertEqual(list(izip_longest(g(s), g(s))), zip(g(s), g(s)))
1254            self.assertRaises(TypeError, izip_longest, X(s))
1255            self.assertRaises(TypeError, list, izip_longest(N(s)))
1256            self.assertRaises(ZeroDivisionError, list, izip_longest(E(s)))
1257
1258    def test_imap(self):
1259        for s in (range(10), range(0), range(100), (7,11), xrange(20,50,5)):
1260            for g in (G, I, Ig, S, L, R):
1261                self.assertEqual(list(imap(onearg, g(s))), map(onearg, g(s)))
1262                self.assertEqual(list(imap(operator.pow, g(s), g(s))), map(operator.pow, g(s), g(s)))
1263            self.assertRaises(TypeError, imap, onearg, X(s))
1264            self.assertRaises(TypeError, list, imap(onearg, N(s)))
1265            self.assertRaises(ZeroDivisionError, list, imap(onearg, E(s)))
1266
1267    def test_islice(self):
1268        for s in ("12345", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):
1269            for g in (G, I, Ig, S, L, R):
1270                self.assertEqual(list(islice(g(s),1,None,2)), list(g(s))[1::2])
1271            self.assertRaises(TypeError, islice, X(s), 10)
1272            self.assertRaises(TypeError, list, islice(N(s), 10))
1273            self.assertRaises(ZeroDivisionError, list, islice(E(s), 10))
1274
1275    def test_starmap(self):
1276        for s in (range(10), range(0), range(100), (7,11), xrange(20,50,5)):
1277            for g in (G, I, Ig, S, L, R):
1278                ss = zip(s, s)
1279                self.assertEqual(list(starmap(operator.pow, g(ss))), map(operator.pow, g(s), g(s)))
1280            self.assertRaises(TypeError, starmap, operator.pow, X(ss))
1281            self.assertRaises(TypeError, list, starmap(operator.pow, N(ss)))
1282            self.assertRaises(ZeroDivisionError, list, starmap(operator.pow, E(ss)))
1283
1284    def test_takewhile(self):
1285        for s in (range(10), range(0), range(1000), (7,11), xrange(2000,2200,5)):
1286            for g in (G, I, Ig, S, L, R):
1287                tgt = []
1288                for elem in g(s):
1289                    if not isEven(elem): break
1290                    tgt.append(elem)
1291                self.assertEqual(list(takewhile(isEven, g(s))), tgt)
1292            self.assertRaises(TypeError, takewhile, isEven, X(s))
1293            self.assertRaises(TypeError, list, takewhile(isEven, N(s)))
1294            self.assertRaises(ZeroDivisionError, list, takewhile(isEven, E(s)))
1295
1296    def test_dropwhile(self):
1297        for s in (range(10), range(0), range(1000), (7,11), xrange(2000,2200,5)):
1298            for g in (G, I, Ig, S, L, R):
1299                tgt = []
1300                for elem in g(s):
1301                    if not tgt and isOdd(elem): continue
1302                    tgt.append(elem)
1303                self.assertEqual(list(dropwhile(isOdd, g(s))), tgt)
1304            self.assertRaises(TypeError, dropwhile, isOdd, X(s))
1305            self.assertRaises(TypeError, list, dropwhile(isOdd, N(s)))
1306            self.assertRaises(ZeroDivisionError, list, dropwhile(isOdd, E(s)))
1307
1308    def test_tee(self):
1309        for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):
1310            for g in (G, I, Ig, S, L, R):
1311                it1, it2 = tee(g(s))
1312                self.assertEqual(list(it1), list(g(s)))
1313                self.assertEqual(list(it2), list(g(s)))
1314            self.assertRaises(TypeError, tee, X(s))
1315            self.assertRaises(TypeError, list, tee(N(s))[0])
1316            self.assertRaises(ZeroDivisionError, list, tee(E(s))[0])
1317
1318class LengthTransparency(unittest.TestCase):
1319
1320    def test_repeat(self):
1321        from test.test_iterlen import len
1322        self.assertEqual(len(repeat(None, 50)), 50)
1323        self.assertRaises(TypeError, len, repeat(None))
1324
1325class RegressionTests(unittest.TestCase):
1326
1327    def test_sf_793826(self):
1328        # Fix Armin Rigo's successful efforts to wreak havoc
1329
1330        def mutatingtuple(tuple1, f, tuple2):
1331            # this builds a tuple t which is a copy of tuple1,
1332            # then calls f(t), then mutates t to be equal to tuple2
1333            # (needs len(tuple1) == len(tuple2)).
1334            def g(value, first=[1]):
1335                if first:
1336                    del first[:]
1337                    f(z.next())
1338                return value
1339            items = list(tuple2)
1340            items[1:1] = list(tuple1)
1341            gen = imap(g, items)
1342            z = izip(*[gen]*len(tuple1))
1343            z.next()
1344
1345        def f(t):
1346            global T
1347            T = t
1348            first[:] = list(T)
1349
1350        first = []
1351        mutatingtuple((1,2,3), f, (4,5,6))
1352        second = list(T)
1353        self.assertEqual(first, second)
1354
1355
1356    def test_sf_950057(self):
1357        # Make sure that chain() and cycle() catch exceptions immediately
1358        # rather than when shifting between input sources
1359
1360        def gen1():
1361            hist.append(0)
1362            yield 1
1363            hist.append(1)
1364            raise AssertionError
1365            hist.append(2)
1366
1367        def gen2(x):
1368            hist.append(3)
1369            yield 2
1370            hist.append(4)
1371            if x:
1372                raise StopIteration
1373
1374        hist = []
1375        self.assertRaises(AssertionError, list, chain(gen1(), gen2(False)))
1376        self.assertEqual(hist, [0,1])
1377
1378        hist = []
1379        self.assertRaises(AssertionError, list, chain(gen1(), gen2(True)))
1380        self.assertEqual(hist, [0,1])
1381
1382        hist = []
1383        self.assertRaises(AssertionError, list, cycle(gen1()))
1384        self.assertEqual(hist, [0,1])
1385
1386class SubclassWithKwargsTest(unittest.TestCase):
1387    def test_keywords_in_subclass(self):
1388        # count is not subclassable...
1389        for cls in (repeat, izip, ifilter, ifilterfalse, chain, imap,
1390                    starmap, islice, takewhile, dropwhile, cycle, compress):
1391            class Subclass(cls):
1392                def __init__(self, newarg=None, *args):
1393                    cls.__init__(self, *args)
1394            try:
1395                Subclass(newarg=1)
1396            except TypeError, err:
1397                # we expect type errors because of wrong argument count
1398                self.assertNotIn("does not take keyword arguments", err.args[0])
1399
1400
1401libreftest = """ Doctest for examples in the library reference: libitertools.tex
1402
1403
1404>>> amounts = [120.15, 764.05, 823.14]
1405>>> for checknum, amount in izip(count(1200), amounts):
1406...     print 'Check %d is for $%.2f' % (checknum, amount)
1407...
1408Check 1200 is for $120.15
1409Check 1201 is for $764.05
1410Check 1202 is for $823.14
1411
1412>>> import operator
1413>>> for cube in imap(operator.pow, xrange(1,4), repeat(3)):
1414...    print cube
1415...
14161
14178
141827
1419
1420>>> reportlines = ['EuroPython', 'Roster', '', 'alex', '', 'laura', '', 'martin', '', 'walter', '', 'samuele']
1421>>> for name in islice(reportlines, 3, None, 2):
1422...    print name.title()
1423...
1424Alex
1425Laura
1426Martin
1427Walter
1428Samuele
1429
1430>>> from operator import itemgetter
1431>>> d = dict(a=1, b=2, c=1, d=2, e=1, f=2, g=3)
1432>>> di = sorted(sorted(d.iteritems()), key=itemgetter(1))
1433>>> for k, g in groupby(di, itemgetter(1)):
1434...     print k, map(itemgetter(0), g)
1435...
14361 ['a', 'c', 'e']
14372 ['b', 'd', 'f']
14383 ['g']
1439
1440# Find runs of consecutive numbers using groupby.  The key to the solution
1441# is differencing with a range so that consecutive numbers all appear in
1442# same group.
1443>>> data = [ 1,  4,5,6, 10, 15,16,17,18, 22, 25,26,27,28]
1444>>> for k, g in groupby(enumerate(data), lambda t:t[0]-t[1]):
1445...     print map(operator.itemgetter(1), g)
1446...
1447[1]
1448[4, 5, 6]
1449[10]
1450[15, 16, 17, 18]
1451[22]
1452[25, 26, 27, 28]
1453
1454>>> def take(n, iterable):
1455...     "Return first n items of the iterable as a list"
1456...     return list(islice(iterable, n))
1457
1458>>> def enumerate(iterable, start=0):
1459...     return izip(count(start), iterable)
1460
1461>>> def tabulate(function, start=0):
1462...     "Return function(0), function(1), ..."
1463...     return imap(function, count(start))
1464
1465>>> def nth(iterable, n, default=None):
1466...     "Returns the nth item or a default value"
1467...     return next(islice(iterable, n, None), default)
1468
1469>>> def quantify(iterable, pred=bool):
1470...     "Count how many times the predicate is true"
1471...     return sum(imap(pred, iterable))
1472
1473>>> def padnone(iterable):
1474...     "Returns the sequence elements and then returns None indefinitely"
1475...     return chain(iterable, repeat(None))
1476
1477>>> def ncycles(iterable, n):
1478...     "Returns the sequence elements n times"
1479...     return chain(*repeat(iterable, n))
1480
1481>>> def dotproduct(vec1, vec2):
1482...     return sum(imap(operator.mul, vec1, vec2))
1483
1484>>> def flatten(listOfLists):
1485...     return list(chain.from_iterable(listOfLists))
1486
1487>>> def repeatfunc(func, times=None, *args):
1488...     "Repeat calls to func with specified arguments."
1489...     "   Example:  repeatfunc(random.random)"
1490...     if times is None:
1491...         return starmap(func, repeat(args))
1492...     else:
1493...         return starmap(func, repeat(args, times))
1494
1495>>> def pairwise(iterable):
1496...     "s -> (s0,s1), (s1,s2), (s2, s3), ..."
1497...     a, b = tee(iterable)
1498...     for elem in b:
1499...         break
1500...     return izip(a, b)
1501
1502>>> def grouper(n, iterable, fillvalue=None):
1503...     "grouper(3, 'ABCDEFG', 'x') --> ABC DEF Gxx"
1504...     args = [iter(iterable)] * n
1505...     return izip_longest(fillvalue=fillvalue, *args)
1506
1507>>> def roundrobin(*iterables):
1508...     "roundrobin('ABC', 'D', 'EF') --> A D E B F C"
1509...     # Recipe credited to George Sakkis
1510...     pending = len(iterables)
1511...     nexts = cycle(iter(it).next for it in iterables)
1512...     while pending:
1513...         try:
1514...             for next in nexts:
1515...                 yield next()
1516...         except StopIteration:
1517...             pending -= 1
1518...             nexts = cycle(islice(nexts, pending))
1519
1520>>> def powerset(iterable):
1521...     "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
1522...     s = list(iterable)
1523...     return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))
1524
1525>>> def unique_everseen(iterable, key=None):
1526...     "List unique elements, preserving order. Remember all elements ever seen."
1527...     # unique_everseen('AAAABBBCCDAABBB') --> A B C D
1528...     # unique_everseen('ABBCcAD', str.lower) --> A B C D
1529...     seen = set()
1530...     seen_add = seen.add
1531...     if key is None:
1532...         for element in iterable:
1533...             if element not in seen:
1534...                 seen_add(element)
1535...                 yield element
1536...     else:
1537...         for element in iterable:
1538...             k = key(element)
1539...             if k not in seen:
1540...                 seen_add(k)
1541...                 yield element
1542
1543>>> def unique_justseen(iterable, key=None):
1544...     "List unique elements, preserving order. Remember only the element just seen."
1545...     # unique_justseen('AAAABBBCCDAABBB') --> A B C D A B
1546...     # unique_justseen('ABBCcAD', str.lower) --> A B C A D
1547...     return imap(next, imap(itemgetter(1), groupby(iterable, key)))
1548
1549This is not part of the examples but it tests to make sure the definitions
1550perform as purported.
1551
1552>>> take(10, count())
1553[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
1554
1555>>> list(enumerate('abc'))
1556[(0, 'a'), (1, 'b'), (2, 'c')]
1557
1558>>> list(islice(tabulate(lambda x: 2*x), 4))
1559[0, 2, 4, 6]
1560
1561>>> nth('abcde', 3)
1562'd'
1563
1564>>> nth('abcde', 9) is None
1565True
1566
1567>>> quantify(xrange(99), lambda x: x%2==0)
156850
1569
1570>>> a = [[1, 2, 3], [4, 5, 6]]
1571>>> flatten(a)
1572[1, 2, 3, 4, 5, 6]
1573
1574>>> list(repeatfunc(pow, 5, 2, 3))
1575[8, 8, 8, 8, 8]
1576
1577>>> import random
1578>>> take(5, imap(int, repeatfunc(random.random)))
1579[0, 0, 0, 0, 0]
1580
1581>>> list(pairwise('abcd'))
1582[('a', 'b'), ('b', 'c'), ('c', 'd')]
1583
1584>>> list(pairwise([]))
1585[]
1586
1587>>> list(pairwise('a'))
1588[]
1589
1590>>> list(islice(padnone('abc'), 0, 6))
1591['a', 'b', 'c', None, None, None]
1592
1593>>> list(ncycles('abc', 3))
1594['a', 'b', 'c', 'a', 'b', 'c', 'a', 'b', 'c']
1595
1596>>> dotproduct([1,2,3], [4,5,6])
159732
1598
1599>>> list(grouper(3, 'abcdefg', 'x'))
1600[('a', 'b', 'c'), ('d', 'e', 'f'), ('g', 'x', 'x')]
1601
1602>>> list(roundrobin('abc', 'd', 'ef'))
1603['a', 'd', 'e', 'b', 'f', 'c']
1604
1605>>> list(powerset([1,2,3]))
1606[(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
1607
1608>>> all(len(list(powerset(range(n)))) == 2**n for n in range(18))
1609True
1610
1611>>> list(powerset('abcde')) == sorted(sorted(set(powerset('abcde'))), key=len)
1612True
1613
1614>>> list(unique_everseen('AAAABBBCCDAABBB'))
1615['A', 'B', 'C', 'D']
1616
1617>>> list(unique_everseen('ABBCcAD', str.lower))
1618['A', 'B', 'C', 'D']
1619
1620>>> list(unique_justseen('AAAABBBCCDAABBB'))
1621['A', 'B', 'C', 'D', 'A', 'B']
1622
1623>>> list(unique_justseen('ABBCcAD', str.lower))
1624['A', 'B', 'C', 'A', 'D']
1625
1626"""
1627
1628__test__ = {'libreftest' : libreftest}
1629
1630def test_main(verbose=None):
1631    test_classes = (TestBasicOps, TestVariousIteratorArgs, TestGC,
1632                    RegressionTests, LengthTransparency,
1633                    SubclassWithKwargsTest, TestExamples)
1634    test_support.run_unittest(*test_classes)
1635
1636    # verify reference counting
1637    if verbose and hasattr(sys, "gettotalrefcount"):
1638        import gc
1639        counts = [None] * 5
1640        for i in xrange(len(counts)):
1641            test_support.run_unittest(*test_classes)
1642            gc.collect()
1643            counts[i] = sys.gettotalrefcount()
1644        print counts
1645
1646    # doctest the examples in the library reference
1647    test_support.run_doctest(sys.modules[__name__], verbose)
1648
1649if __name__ == "__main__":
1650    test_main(verbose=True)
1651