• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import copy
2import functools
3import sys
4import unittest
5from test import test_support
6from weakref import proxy
7import pickle
8
9@staticmethod
10def PythonPartial(func, *args, **keywords):
11    'Pure Python approximation of partial()'
12    def newfunc(*fargs, **fkeywords):
13        newkeywords = keywords.copy()
14        newkeywords.update(fkeywords)
15        return func(*(args + fargs), **newkeywords)
16    newfunc.func = func
17    newfunc.args = args
18    newfunc.keywords = keywords
19    return newfunc
20
21def capture(*args, **kw):
22    """capture all positional and keyword arguments"""
23    return args, kw
24
25def signature(part):
26    """ return the signature of a partial object """
27    return (part.func, part.args, part.keywords, part.__dict__)
28
29class MyTuple(tuple):
30    pass
31
32class BadTuple(tuple):
33    def __add__(self, other):
34        return list(self) + list(other)
35
36class MyDict(dict):
37    pass
38
39class TestPartial(unittest.TestCase):
40
41    partial = functools.partial
42
43    def test_basic_examples(self):
44        p = self.partial(capture, 1, 2, a=10, b=20)
45        self.assertEqual(p(3, 4, b=30, c=40),
46                         ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
47        p = self.partial(map, lambda x: x*10)
48        self.assertEqual(p([1,2,3,4]), [10, 20, 30, 40])
49
50    def test_attributes(self):
51        p = self.partial(capture, 1, 2, a=10, b=20)
52        # attributes should be readable
53        self.assertEqual(p.func, capture)
54        self.assertEqual(p.args, (1, 2))
55        self.assertEqual(p.keywords, dict(a=10, b=20))
56        # attributes should not be writable
57        self.assertRaises(TypeError, setattr, p, 'func', map)
58        self.assertRaises(TypeError, setattr, p, 'args', (1, 2))
59        self.assertRaises(TypeError, setattr, p, 'keywords', dict(a=1, b=2))
60
61        p = self.partial(hex)
62        try:
63            del p.__dict__
64        except TypeError:
65            pass
66        else:
67            self.fail('partial object allowed __dict__ to be deleted')
68
69    def test_argument_checking(self):
70        self.assertRaises(TypeError, self.partial)     # need at least a func arg
71        try:
72            self.partial(2)()
73        except TypeError:
74            pass
75        else:
76            self.fail('First arg not checked for callability')
77
78    def test_protection_of_callers_dict_argument(self):
79        # a caller's dictionary should not be altered by partial
80        def func(a=10, b=20):
81            return a
82        d = {'a':3}
83        p = self.partial(func, a=5)
84        self.assertEqual(p(**d), 3)
85        self.assertEqual(d, {'a':3})
86        p(b=7)
87        self.assertEqual(d, {'a':3})
88
89    def test_arg_combinations(self):
90        # exercise special code paths for zero args in either partial
91        # object or the caller
92        p = self.partial(capture)
93        self.assertEqual(p(), ((), {}))
94        self.assertEqual(p(1,2), ((1,2), {}))
95        p = self.partial(capture, 1, 2)
96        self.assertEqual(p(), ((1,2), {}))
97        self.assertEqual(p(3,4), ((1,2,3,4), {}))
98
99    def test_kw_combinations(self):
100        # exercise special code paths for no keyword args in
101        # either the partial object or the caller
102        p = self.partial(capture)
103        self.assertEqual(p.keywords, {})
104        self.assertEqual(p(), ((), {}))
105        self.assertEqual(p(a=1), ((), {'a':1}))
106        p = self.partial(capture, a=1)
107        self.assertEqual(p.keywords, {'a':1})
108        self.assertEqual(p(), ((), {'a':1}))
109        self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
110        # keyword args in the call override those in the partial object
111        self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
112
113    def test_positional(self):
114        # make sure positional arguments are captured correctly
115        for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
116            p = self.partial(capture, *args)
117            expected = args + ('x',)
118            got, empty = p('x')
119            self.assertTrue(expected == got and empty == {})
120
121    def test_keyword(self):
122        # make sure keyword arguments are captured correctly
123        for a in ['a', 0, None, 3.5]:
124            p = self.partial(capture, a=a)
125            expected = {'a':a,'x':None}
126            empty, got = p(x=None)
127            self.assertTrue(expected == got and empty == ())
128
129    def test_no_side_effects(self):
130        # make sure there are no side effects that affect subsequent calls
131        p = self.partial(capture, 0, a=1)
132        args1, kw1 = p(1, b=2)
133        self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
134        args2, kw2 = p()
135        self.assertTrue(args2 == (0,) and kw2 == {'a':1})
136
137    def test_error_propagation(self):
138        def f(x, y):
139            x // y
140        self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
141        self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
142        self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
143        self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
144
145    def test_weakref(self):
146        f = self.partial(int, base=16)
147        p = proxy(f)
148        self.assertEqual(f.func, p.func)
149        f = None
150        self.assertRaises(ReferenceError, getattr, p, 'func')
151
152    def test_with_bound_and_unbound_methods(self):
153        data = map(str, range(10))
154        join = self.partial(str.join, '')
155        self.assertEqual(join(data), '0123456789')
156        join = self.partial(''.join)
157        self.assertEqual(join(data), '0123456789')
158
159    def test_pickle(self):
160        f = self.partial(signature, ['asdf'], bar=[True])
161        f.attr = []
162        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
163            f_copy = pickle.loads(pickle.dumps(f, proto))
164            self.assertEqual(signature(f_copy), signature(f))
165
166    def test_copy(self):
167        f = self.partial(signature, ['asdf'], bar=[True])
168        f.attr = []
169        f_copy = copy.copy(f)
170        self.assertEqual(signature(f_copy), signature(f))
171        self.assertIs(f_copy.attr, f.attr)
172        self.assertIs(f_copy.args, f.args)
173        self.assertIs(f_copy.keywords, f.keywords)
174
175    def test_deepcopy(self):
176        f = self.partial(signature, ['asdf'], bar=[True])
177        f.attr = []
178        f_copy = copy.deepcopy(f)
179        self.assertEqual(signature(f_copy), signature(f))
180        self.assertIsNot(f_copy.attr, f.attr)
181        self.assertIsNot(f_copy.args, f.args)
182        self.assertIsNot(f_copy.args[0], f.args[0])
183        self.assertIsNot(f_copy.keywords, f.keywords)
184        self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
185
186    def test_setstate(self):
187        f = self.partial(signature)
188        f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
189        self.assertEqual(signature(f),
190                         (capture, (1,), dict(a=10), dict(attr=[])))
191        self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
192
193        f.__setstate__((capture, (1,), dict(a=10), None))
194        self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
195        self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
196
197        f.__setstate__((capture, (1,), None, None))
198        #self.assertEqual(signature(f), (capture, (1,), {}, {}))
199        self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
200        self.assertEqual(f(2), ((1, 2), {}))
201        self.assertEqual(f(), ((1,), {}))
202
203        f.__setstate__((capture, (), {}, None))
204        self.assertEqual(signature(f), (capture, (), {}, {}))
205        self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
206        self.assertEqual(f(2), ((2,), {}))
207        self.assertEqual(f(), ((), {}))
208
209    def test_setstate_errors(self):
210        f = self.partial(signature)
211        self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
212        self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
213        self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
214        self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
215        self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
216        self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
217        self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
218
219    def test_setstate_subclasses(self):
220        f = self.partial(signature)
221        f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
222        s = signature(f)
223        self.assertEqual(s, (capture, (1,), dict(a=10), {}))
224        self.assertIs(type(s[1]), tuple)
225        self.assertIs(type(s[2]), dict)
226        r = f()
227        self.assertEqual(r, ((1,), {'a': 10}))
228        self.assertIs(type(r[0]), tuple)
229        self.assertIs(type(r[1]), dict)
230
231        f.__setstate__((capture, BadTuple((1,)), {}, None))
232        s = signature(f)
233        self.assertEqual(s, (capture, (1,), {}, {}))
234        self.assertIs(type(s[1]), tuple)
235        r = f(2)
236        self.assertEqual(r, ((1, 2), {}))
237        self.assertIs(type(r[0]), tuple)
238
239    def test_recursive_pickle(self):
240        f = self.partial(capture)
241        f.__setstate__((f, (), {}, {}))
242        try:
243            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
244                with self.assertRaises(RuntimeError):
245                    pickle.dumps(f, proto)
246        finally:
247            f.__setstate__((capture, (), {}, {}))
248
249        f = self.partial(capture)
250        f.__setstate__((capture, (f,), {}, {}))
251        try:
252            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
253                f_copy = pickle.loads(pickle.dumps(f, proto))
254                try:
255                    self.assertIs(f_copy.args[0], f_copy)
256                finally:
257                    f_copy.__setstate__((capture, (), {}, {}))
258        finally:
259            f.__setstate__((capture, (), {}, {}))
260
261        f = self.partial(capture)
262        f.__setstate__((capture, (), {'a': f}, {}))
263        try:
264            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
265                f_copy = pickle.loads(pickle.dumps(f, proto))
266                try:
267                    self.assertIs(f_copy.keywords['a'], f_copy)
268                finally:
269                    f_copy.__setstate__((capture, (), {}, {}))
270        finally:
271            f.__setstate__((capture, (), {}, {}))
272
273    # Issue 6083: Reference counting bug
274    def test_setstate_refcount(self):
275        class BadSequence:
276            def __len__(self):
277                return 4
278            def __getitem__(self, key):
279                if key == 0:
280                    return max
281                elif key == 1:
282                    return tuple(range(1000000))
283                elif key in (2, 3):
284                    return {}
285                raise IndexError
286
287        f = self.partial(object)
288        self.assertRaises(TypeError, f.__setstate__, BadSequence())
289
290class PartialSubclass(functools.partial):
291    pass
292
293class TestPartialSubclass(TestPartial):
294
295    partial = PartialSubclass
296
297class TestPythonPartial(TestPartial):
298
299    partial = PythonPartial
300
301    # the python version isn't picklable
302    test_pickle = None
303    test_setstate = None
304    test_setstate_errors = None
305    test_setstate_subclasses = None
306    test_setstate_refcount = None
307    test_recursive_pickle = None
308
309    # the python version isn't deepcopyable
310    test_deepcopy = None
311
312    # the python version isn't a type
313    test_attributes = None
314
315class TestUpdateWrapper(unittest.TestCase):
316
317    def check_wrapper(self, wrapper, wrapped,
318                      assigned=functools.WRAPPER_ASSIGNMENTS,
319                      updated=functools.WRAPPER_UPDATES):
320        # Check attributes were assigned
321        for name in assigned:
322            self.assertTrue(getattr(wrapper, name) is getattr(wrapped, name))
323        # Check attributes were updated
324        for name in updated:
325            wrapper_attr = getattr(wrapper, name)
326            wrapped_attr = getattr(wrapped, name)
327            for key in wrapped_attr:
328                self.assertTrue(wrapped_attr[key] is wrapper_attr[key])
329
330    def _default_update(self):
331        def f():
332            """This is a test"""
333            pass
334        f.attr = 'This is also a test'
335        def wrapper():
336            pass
337        functools.update_wrapper(wrapper, f)
338        return wrapper, f
339
340    def test_default_update(self):
341        wrapper, f = self._default_update()
342        self.check_wrapper(wrapper, f)
343        self.assertEqual(wrapper.__name__, 'f')
344        self.assertEqual(wrapper.attr, 'This is also a test')
345
346    @unittest.skipIf(sys.flags.optimize >= 2,
347                     "Docstrings are omitted with -O2 and above")
348    def test_default_update_doc(self):
349        wrapper, f = self._default_update()
350        self.assertEqual(wrapper.__doc__, 'This is a test')
351
352    def test_no_update(self):
353        def f():
354            """This is a test"""
355            pass
356        f.attr = 'This is also a test'
357        def wrapper():
358            pass
359        functools.update_wrapper(wrapper, f, (), ())
360        self.check_wrapper(wrapper, f, (), ())
361        self.assertEqual(wrapper.__name__, 'wrapper')
362        self.assertEqual(wrapper.__doc__, None)
363        self.assertFalse(hasattr(wrapper, 'attr'))
364
365    def test_selective_update(self):
366        def f():
367            pass
368        f.attr = 'This is a different test'
369        f.dict_attr = dict(a=1, b=2, c=3)
370        def wrapper():
371            pass
372        wrapper.dict_attr = {}
373        assign = ('attr',)
374        update = ('dict_attr',)
375        functools.update_wrapper(wrapper, f, assign, update)
376        self.check_wrapper(wrapper, f, assign, update)
377        self.assertEqual(wrapper.__name__, 'wrapper')
378        self.assertEqual(wrapper.__doc__, None)
379        self.assertEqual(wrapper.attr, 'This is a different test')
380        self.assertEqual(wrapper.dict_attr, f.dict_attr)
381
382    @test_support.requires_docstrings
383    def test_builtin_update(self):
384        # Test for bug #1576241
385        def wrapper():
386            pass
387        functools.update_wrapper(wrapper, max)
388        self.assertEqual(wrapper.__name__, 'max')
389        self.assertTrue(wrapper.__doc__.startswith('max('))
390
391class TestWraps(TestUpdateWrapper):
392
393    def _default_update(self):
394        def f():
395            """This is a test"""
396            pass
397        f.attr = 'This is also a test'
398        @functools.wraps(f)
399        def wrapper():
400            pass
401        self.check_wrapper(wrapper, f)
402        return wrapper
403
404    def test_default_update(self):
405        wrapper = self._default_update()
406        self.assertEqual(wrapper.__name__, 'f')
407        self.assertEqual(wrapper.attr, 'This is also a test')
408
409    @unittest.skipIf(sys.flags.optimize >= 2,
410                     "Docstrings are omitted with -O2 and above")
411    def test_default_update_doc(self):
412        wrapper = self._default_update()
413        self.assertEqual(wrapper.__doc__, 'This is a test')
414
415    def test_no_update(self):
416        def f():
417            """This is a test"""
418            pass
419        f.attr = 'This is also a test'
420        @functools.wraps(f, (), ())
421        def wrapper():
422            pass
423        self.check_wrapper(wrapper, f, (), ())
424        self.assertEqual(wrapper.__name__, 'wrapper')
425        self.assertEqual(wrapper.__doc__, None)
426        self.assertFalse(hasattr(wrapper, 'attr'))
427
428    def test_selective_update(self):
429        def f():
430            pass
431        f.attr = 'This is a different test'
432        f.dict_attr = dict(a=1, b=2, c=3)
433        def add_dict_attr(f):
434            f.dict_attr = {}
435            return f
436        assign = ('attr',)
437        update = ('dict_attr',)
438        @functools.wraps(f, assign, update)
439        @add_dict_attr
440        def wrapper():
441            pass
442        self.check_wrapper(wrapper, f, assign, update)
443        self.assertEqual(wrapper.__name__, 'wrapper')
444        self.assertEqual(wrapper.__doc__, None)
445        self.assertEqual(wrapper.attr, 'This is a different test')
446        self.assertEqual(wrapper.dict_attr, f.dict_attr)
447
448
449class TestReduce(unittest.TestCase):
450
451    def test_reduce(self):
452        class Squares:
453
454            def __init__(self, max):
455                self.max = max
456                self.sofar = []
457
458            def __len__(self): return len(self.sofar)
459
460            def __getitem__(self, i):
461                if not 0 <= i < self.max: raise IndexError
462                n = len(self.sofar)
463                while n <= i:
464                    self.sofar.append(n*n)
465                    n += 1
466                return self.sofar[i]
467
468        reduce = functools.reduce
469        self.assertEqual(reduce(lambda x, y: x+y, ['a', 'b', 'c'], ''), 'abc')
470        self.assertEqual(
471            reduce(lambda x, y: x+y, [['a', 'c'], [], ['d', 'w']], []),
472            ['a','c','d','w']
473        )
474        self.assertEqual(reduce(lambda x, y: x*y, range(2,8), 1), 5040)
475        self.assertEqual(
476            reduce(lambda x, y: x*y, range(2,21), 1L),
477            2432902008176640000L
478        )
479        self.assertEqual(reduce(lambda x, y: x+y, Squares(10)), 285)
480        self.assertEqual(reduce(lambda x, y: x+y, Squares(10), 0), 285)
481        self.assertEqual(reduce(lambda x, y: x+y, Squares(0), 0), 0)
482        self.assertRaises(TypeError, reduce)
483        self.assertRaises(TypeError, reduce, 42, 42)
484        self.assertRaises(TypeError, reduce, 42, 42, 42)
485        self.assertEqual(reduce(42, "1"), "1") # func is never called with one item
486        self.assertEqual(reduce(42, "", "1"), "1") # func is never called with one item
487        self.assertRaises(TypeError, reduce, 42, (42, 42))
488
489class TestCmpToKey(unittest.TestCase):
490    def test_cmp_to_key(self):
491        def mycmp(x, y):
492            return y - x
493        self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)),
494                         [4, 3, 2, 1, 0])
495
496    def test_hash(self):
497        def mycmp(x, y):
498            return y - x
499        key = functools.cmp_to_key(mycmp)
500        k = key(10)
501        self.assertRaises(TypeError, hash(k))
502
503class TestTotalOrdering(unittest.TestCase):
504
505    def test_total_ordering_lt(self):
506        @functools.total_ordering
507        class A:
508            def __init__(self, value):
509                self.value = value
510            def __lt__(self, other):
511                return self.value < other.value
512            def __eq__(self, other):
513                return self.value == other.value
514        self.assertTrue(A(1) < A(2))
515        self.assertTrue(A(2) > A(1))
516        self.assertTrue(A(1) <= A(2))
517        self.assertTrue(A(2) >= A(1))
518        self.assertTrue(A(2) <= A(2))
519        self.assertTrue(A(2) >= A(2))
520
521    def test_total_ordering_le(self):
522        @functools.total_ordering
523        class A:
524            def __init__(self, value):
525                self.value = value
526            def __le__(self, other):
527                return self.value <= other.value
528            def __eq__(self, other):
529                return self.value == other.value
530        self.assertTrue(A(1) < A(2))
531        self.assertTrue(A(2) > A(1))
532        self.assertTrue(A(1) <= A(2))
533        self.assertTrue(A(2) >= A(1))
534        self.assertTrue(A(2) <= A(2))
535        self.assertTrue(A(2) >= A(2))
536
537    def test_total_ordering_gt(self):
538        @functools.total_ordering
539        class A:
540            def __init__(self, value):
541                self.value = value
542            def __gt__(self, other):
543                return self.value > other.value
544            def __eq__(self, other):
545                return self.value == other.value
546        self.assertTrue(A(1) < A(2))
547        self.assertTrue(A(2) > A(1))
548        self.assertTrue(A(1) <= A(2))
549        self.assertTrue(A(2) >= A(1))
550        self.assertTrue(A(2) <= A(2))
551        self.assertTrue(A(2) >= A(2))
552
553    def test_total_ordering_ge(self):
554        @functools.total_ordering
555        class A:
556            def __init__(self, value):
557                self.value = value
558            def __ge__(self, other):
559                return self.value >= other.value
560            def __eq__(self, other):
561                return self.value == other.value
562        self.assertTrue(A(1) < A(2))
563        self.assertTrue(A(2) > A(1))
564        self.assertTrue(A(1) <= A(2))
565        self.assertTrue(A(2) >= A(1))
566        self.assertTrue(A(2) <= A(2))
567        self.assertTrue(A(2) >= A(2))
568
569    def test_total_ordering_no_overwrite(self):
570        # new methods should not overwrite existing
571        @functools.total_ordering
572        class A(str):
573            pass
574        self.assertTrue(A("a") < A("b"))
575        self.assertTrue(A("b") > A("a"))
576        self.assertTrue(A("a") <= A("b"))
577        self.assertTrue(A("b") >= A("a"))
578        self.assertTrue(A("b") <= A("b"))
579        self.assertTrue(A("b") >= A("b"))
580
581    def test_no_operations_defined(self):
582        with self.assertRaises(ValueError):
583            @functools.total_ordering
584            class A:
585                pass
586
587    def test_bug_10042(self):
588        @functools.total_ordering
589        class TestTO:
590            def __init__(self, value):
591                self.value = value
592            def __eq__(self, other):
593                if isinstance(other, TestTO):
594                    return self.value == other.value
595                return False
596            def __lt__(self, other):
597                if isinstance(other, TestTO):
598                    return self.value < other.value
599                raise TypeError
600        with self.assertRaises(TypeError):
601            TestTO(8) <= ()
602
603def test_main(verbose=None):
604    test_classes = (
605        TestPartial,
606        TestPartialSubclass,
607        TestPythonPartial,
608        TestUpdateWrapper,
609        TestTotalOrdering,
610        TestWraps,
611        TestReduce,
612    )
613    test_support.run_unittest(*test_classes)
614
615    # verify reference counting
616    if verbose and hasattr(sys, "gettotalrefcount"):
617        import gc
618        counts = [None] * 5
619        for i in xrange(len(counts)):
620            test_support.run_unittest(*test_classes)
621            gc.collect()
622            counts[i] = sys.gettotalrefcount()
623        print counts
624
625if __name__ == '__main__':
626    test_main(verbose=True)
627