• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import abc
2import builtins
3import collections
4import collections.abc
5import copy
6from itertools import permutations
7import pickle
8from random import choice
9import sys
10from test import support
11import threading
12import time
13import typing
14import unittest
15import unittest.mock
16from weakref import proxy
17import contextlib
18
19import functools
20
21py_functools = support.import_fresh_module('functools', blocked=['_functools'])
22c_functools = support.import_fresh_module('functools', fresh=['_functools'])
23
24decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
25
26@contextlib.contextmanager
27def replaced_module(name, replacement):
28    original_module = sys.modules[name]
29    sys.modules[name] = replacement
30    try:
31        yield
32    finally:
33        sys.modules[name] = original_module
34
35def capture(*args, **kw):
36    """capture all positional and keyword arguments"""
37    return args, kw
38
39
40def signature(part):
41    """ return the signature of a partial object """
42    return (part.func, part.args, part.keywords, part.__dict__)
43
44class MyTuple(tuple):
45    pass
46
47class BadTuple(tuple):
48    def __add__(self, other):
49        return list(self) + list(other)
50
51class MyDict(dict):
52    pass
53
54
55class TestPartial:
56
57    def test_basic_examples(self):
58        p = self.partial(capture, 1, 2, a=10, b=20)
59        self.assertTrue(callable(p))
60        self.assertEqual(p(3, 4, b=30, c=40),
61                         ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
62        p = self.partial(map, lambda x: x*10)
63        self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
64
65    def test_attributes(self):
66        p = self.partial(capture, 1, 2, a=10, b=20)
67        # attributes should be readable
68        self.assertEqual(p.func, capture)
69        self.assertEqual(p.args, (1, 2))
70        self.assertEqual(p.keywords, dict(a=10, b=20))
71
72    def test_argument_checking(self):
73        self.assertRaises(TypeError, self.partial)     # need at least a func arg
74        try:
75            self.partial(2)()
76        except TypeError:
77            pass
78        else:
79            self.fail('First arg not checked for callability')
80
81    def test_protection_of_callers_dict_argument(self):
82        # a caller's dictionary should not be altered by partial
83        def func(a=10, b=20):
84            return a
85        d = {'a':3}
86        p = self.partial(func, a=5)
87        self.assertEqual(p(**d), 3)
88        self.assertEqual(d, {'a':3})
89        p(b=7)
90        self.assertEqual(d, {'a':3})
91
92    def test_kwargs_copy(self):
93        # Issue #29532: Altering a kwarg dictionary passed to a constructor
94        # should not affect a partial object after creation
95        d = {'a': 3}
96        p = self.partial(capture, **d)
97        self.assertEqual(p(), ((), {'a': 3}))
98        d['a'] = 5
99        self.assertEqual(p(), ((), {'a': 3}))
100
101    def test_arg_combinations(self):
102        # exercise special code paths for zero args in either partial
103        # object or the caller
104        p = self.partial(capture)
105        self.assertEqual(p(), ((), {}))
106        self.assertEqual(p(1,2), ((1,2), {}))
107        p = self.partial(capture, 1, 2)
108        self.assertEqual(p(), ((1,2), {}))
109        self.assertEqual(p(3,4), ((1,2,3,4), {}))
110
111    def test_kw_combinations(self):
112        # exercise special code paths for no keyword args in
113        # either the partial object or the caller
114        p = self.partial(capture)
115        self.assertEqual(p.keywords, {})
116        self.assertEqual(p(), ((), {}))
117        self.assertEqual(p(a=1), ((), {'a':1}))
118        p = self.partial(capture, a=1)
119        self.assertEqual(p.keywords, {'a':1})
120        self.assertEqual(p(), ((), {'a':1}))
121        self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
122        # keyword args in the call override those in the partial object
123        self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
124
125    def test_positional(self):
126        # make sure positional arguments are captured correctly
127        for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
128            p = self.partial(capture, *args)
129            expected = args + ('x',)
130            got, empty = p('x')
131            self.assertTrue(expected == got and empty == {})
132
133    def test_keyword(self):
134        # make sure keyword arguments are captured correctly
135        for a in ['a', 0, None, 3.5]:
136            p = self.partial(capture, a=a)
137            expected = {'a':a,'x':None}
138            empty, got = p(x=None)
139            self.assertTrue(expected == got and empty == ())
140
141    def test_no_side_effects(self):
142        # make sure there are no side effects that affect subsequent calls
143        p = self.partial(capture, 0, a=1)
144        args1, kw1 = p(1, b=2)
145        self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
146        args2, kw2 = p()
147        self.assertTrue(args2 == (0,) and kw2 == {'a':1})
148
149    def test_error_propagation(self):
150        def f(x, y):
151            x / y
152        self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
153        self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
154        self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
155        self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
156
157    def test_weakref(self):
158        f = self.partial(int, base=16)
159        p = proxy(f)
160        self.assertEqual(f.func, p.func)
161        f = None
162        self.assertRaises(ReferenceError, getattr, p, 'func')
163
164    def test_with_bound_and_unbound_methods(self):
165        data = list(map(str, range(10)))
166        join = self.partial(str.join, '')
167        self.assertEqual(join(data), '0123456789')
168        join = self.partial(''.join)
169        self.assertEqual(join(data), '0123456789')
170
171    def test_nested_optimization(self):
172        partial = self.partial
173        inner = partial(signature, 'asdf')
174        nested = partial(inner, bar=True)
175        flat = partial(signature, 'asdf', bar=True)
176        self.assertEqual(signature(nested), signature(flat))
177
178    def test_nested_partial_with_attribute(self):
179        # see issue 25137
180        partial = self.partial
181
182        def foo(bar):
183            return bar
184
185        p = partial(foo, 'first')
186        p2 = partial(p, 'second')
187        p2.new_attr = 'spam'
188        self.assertEqual(p2.new_attr, 'spam')
189
190    def test_repr(self):
191        args = (object(), object())
192        args_repr = ', '.join(repr(a) for a in args)
193        kwargs = {'a': object(), 'b': object()}
194        kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
195                        'b={b!r}, a={a!r}'.format_map(kwargs)]
196        if self.partial in (c_functools.partial, py_functools.partial):
197            name = 'functools.partial'
198        else:
199            name = self.partial.__name__
200
201        f = self.partial(capture)
202        self.assertEqual(f'{name}({capture!r})', repr(f))
203
204        f = self.partial(capture, *args)
205        self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
206
207        f = self.partial(capture, **kwargs)
208        self.assertIn(repr(f),
209                      [f'{name}({capture!r}, {kwargs_repr})'
210                       for kwargs_repr in kwargs_reprs])
211
212        f = self.partial(capture, *args, **kwargs)
213        self.assertIn(repr(f),
214                      [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
215                       for kwargs_repr in kwargs_reprs])
216
217    def test_recursive_repr(self):
218        if self.partial in (c_functools.partial, py_functools.partial):
219            name = 'functools.partial'
220        else:
221            name = self.partial.__name__
222
223        f = self.partial(capture)
224        f.__setstate__((f, (), {}, {}))
225        try:
226            self.assertEqual(repr(f), '%s(...)' % (name,))
227        finally:
228            f.__setstate__((capture, (), {}, {}))
229
230        f = self.partial(capture)
231        f.__setstate__((capture, (f,), {}, {}))
232        try:
233            self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
234        finally:
235            f.__setstate__((capture, (), {}, {}))
236
237        f = self.partial(capture)
238        f.__setstate__((capture, (), {'a': f}, {}))
239        try:
240            self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
241        finally:
242            f.__setstate__((capture, (), {}, {}))
243
244    def test_pickle(self):
245        with self.AllowPickle():
246            f = self.partial(signature, ['asdf'], bar=[True])
247            f.attr = []
248            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
249                f_copy = pickle.loads(pickle.dumps(f, proto))
250                self.assertEqual(signature(f_copy), signature(f))
251
252    def test_copy(self):
253        f = self.partial(signature, ['asdf'], bar=[True])
254        f.attr = []
255        f_copy = copy.copy(f)
256        self.assertEqual(signature(f_copy), signature(f))
257        self.assertIs(f_copy.attr, f.attr)
258        self.assertIs(f_copy.args, f.args)
259        self.assertIs(f_copy.keywords, f.keywords)
260
261    def test_deepcopy(self):
262        f = self.partial(signature, ['asdf'], bar=[True])
263        f.attr = []
264        f_copy = copy.deepcopy(f)
265        self.assertEqual(signature(f_copy), signature(f))
266        self.assertIsNot(f_copy.attr, f.attr)
267        self.assertIsNot(f_copy.args, f.args)
268        self.assertIsNot(f_copy.args[0], f.args[0])
269        self.assertIsNot(f_copy.keywords, f.keywords)
270        self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
271
272    def test_setstate(self):
273        f = self.partial(signature)
274        f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
275
276        self.assertEqual(signature(f),
277                         (capture, (1,), dict(a=10), dict(attr=[])))
278        self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
279
280        f.__setstate__((capture, (1,), dict(a=10), None))
281
282        self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
283        self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
284
285        f.__setstate__((capture, (1,), None, None))
286        #self.assertEqual(signature(f), (capture, (1,), {}, {}))
287        self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
288        self.assertEqual(f(2), ((1, 2), {}))
289        self.assertEqual(f(), ((1,), {}))
290
291        f.__setstate__((capture, (), {}, None))
292        self.assertEqual(signature(f), (capture, (), {}, {}))
293        self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
294        self.assertEqual(f(2), ((2,), {}))
295        self.assertEqual(f(), ((), {}))
296
297    def test_setstate_errors(self):
298        f = self.partial(signature)
299        self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
300        self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
301        self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
302        self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
303        self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
304        self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
305        self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
306
307    def test_setstate_subclasses(self):
308        f = self.partial(signature)
309        f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
310        s = signature(f)
311        self.assertEqual(s, (capture, (1,), dict(a=10), {}))
312        self.assertIs(type(s[1]), tuple)
313        self.assertIs(type(s[2]), dict)
314        r = f()
315        self.assertEqual(r, ((1,), {'a': 10}))
316        self.assertIs(type(r[0]), tuple)
317        self.assertIs(type(r[1]), dict)
318
319        f.__setstate__((capture, BadTuple((1,)), {}, None))
320        s = signature(f)
321        self.assertEqual(s, (capture, (1,), {}, {}))
322        self.assertIs(type(s[1]), tuple)
323        r = f(2)
324        self.assertEqual(r, ((1, 2), {}))
325        self.assertIs(type(r[0]), tuple)
326
327    def test_recursive_pickle(self):
328        with self.AllowPickle():
329            f = self.partial(capture)
330            f.__setstate__((f, (), {}, {}))
331            try:
332                for proto in range(pickle.HIGHEST_PROTOCOL + 1):
333                    with self.assertRaises(RecursionError):
334                        pickle.dumps(f, proto)
335            finally:
336                f.__setstate__((capture, (), {}, {}))
337
338            f = self.partial(capture)
339            f.__setstate__((capture, (f,), {}, {}))
340            try:
341                for proto in range(pickle.HIGHEST_PROTOCOL + 1):
342                    f_copy = pickle.loads(pickle.dumps(f, proto))
343                    try:
344                        self.assertIs(f_copy.args[0], f_copy)
345                    finally:
346                        f_copy.__setstate__((capture, (), {}, {}))
347            finally:
348                f.__setstate__((capture, (), {}, {}))
349
350            f = self.partial(capture)
351            f.__setstate__((capture, (), {'a': f}, {}))
352            try:
353                for proto in range(pickle.HIGHEST_PROTOCOL + 1):
354                    f_copy = pickle.loads(pickle.dumps(f, proto))
355                    try:
356                        self.assertIs(f_copy.keywords['a'], f_copy)
357                    finally:
358                        f_copy.__setstate__((capture, (), {}, {}))
359            finally:
360                f.__setstate__((capture, (), {}, {}))
361
362    # Issue 6083: Reference counting bug
363    def test_setstate_refcount(self):
364        class BadSequence:
365            def __len__(self):
366                return 4
367            def __getitem__(self, key):
368                if key == 0:
369                    return max
370                elif key == 1:
371                    return tuple(range(1000000))
372                elif key in (2, 3):
373                    return {}
374                raise IndexError
375
376        f = self.partial(object)
377        self.assertRaises(TypeError, f.__setstate__, BadSequence())
378
379@unittest.skipUnless(c_functools, 'requires the C _functools module')
380class TestPartialC(TestPartial, unittest.TestCase):
381    if c_functools:
382        partial = c_functools.partial
383
384    class AllowPickle:
385        def __enter__(self):
386            return self
387        def __exit__(self, type, value, tb):
388            return False
389
390    def test_attributes_unwritable(self):
391        # attributes should not be writable
392        p = self.partial(capture, 1, 2, a=10, b=20)
393        self.assertRaises(AttributeError, setattr, p, 'func', map)
394        self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
395        self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
396
397        p = self.partial(hex)
398        try:
399            del p.__dict__
400        except TypeError:
401            pass
402        else:
403            self.fail('partial object allowed __dict__ to be deleted')
404
405    def test_manually_adding_non_string_keyword(self):
406        p = self.partial(capture)
407        # Adding a non-string/unicode keyword to partial kwargs
408        p.keywords[1234] = 'value'
409        r = repr(p)
410        self.assertIn('1234', r)
411        self.assertIn("'value'", r)
412        with self.assertRaises(TypeError):
413            p()
414
415    def test_keystr_replaces_value(self):
416        p = self.partial(capture)
417
418        class MutatesYourDict(object):
419            def __str__(self):
420                p.keywords[self] = ['sth2']
421                return 'astr'
422
423        # Replacing the value during key formatting should keep the original
424        # value alive (at least long enough).
425        p.keywords[MutatesYourDict()] = ['sth']
426        r = repr(p)
427        self.assertIn('astr', r)
428        self.assertIn("['sth']", r)
429
430
431class TestPartialPy(TestPartial, unittest.TestCase):
432    partial = py_functools.partial
433
434    class AllowPickle:
435        def __init__(self):
436            self._cm = replaced_module("functools", py_functools)
437        def __enter__(self):
438            return self._cm.__enter__()
439        def __exit__(self, type, value, tb):
440            return self._cm.__exit__(type, value, tb)
441
442if c_functools:
443    class CPartialSubclass(c_functools.partial):
444        pass
445
446class PyPartialSubclass(py_functools.partial):
447    pass
448
449@unittest.skipUnless(c_functools, 'requires the C _functools module')
450class TestPartialCSubclass(TestPartialC):
451    if c_functools:
452        partial = CPartialSubclass
453
454    # partial subclasses are not optimized for nested calls
455    test_nested_optimization = None
456
457class TestPartialPySubclass(TestPartialPy):
458    partial = PyPartialSubclass
459
460class TestPartialMethod(unittest.TestCase):
461
462    class A(object):
463        nothing = functools.partialmethod(capture)
464        positional = functools.partialmethod(capture, 1)
465        keywords = functools.partialmethod(capture, a=2)
466        both = functools.partialmethod(capture, 3, b=4)
467
468        nested = functools.partialmethod(positional, 5)
469
470        over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
471
472        static = functools.partialmethod(staticmethod(capture), 8)
473        cls = functools.partialmethod(classmethod(capture), d=9)
474
475    a = A()
476
477    def test_arg_combinations(self):
478        self.assertEqual(self.a.nothing(), ((self.a,), {}))
479        self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
480        self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
481        self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
482
483        self.assertEqual(self.a.positional(), ((self.a, 1), {}))
484        self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
485        self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
486        self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
487
488        self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
489        self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
490        self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
491        self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
492
493        self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
494        self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
495        self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
496        self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
497
498        self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
499
500    def test_nested(self):
501        self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
502        self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
503        self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
504        self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
505
506        self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
507
508    def test_over_partial(self):
509        self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
510        self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
511        self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
512        self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
513
514        self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
515
516    def test_bound_method_introspection(self):
517        obj = self.a
518        self.assertIs(obj.both.__self__, obj)
519        self.assertIs(obj.nested.__self__, obj)
520        self.assertIs(obj.over_partial.__self__, obj)
521        self.assertIs(obj.cls.__self__, self.A)
522        self.assertIs(self.A.cls.__self__, self.A)
523
524    def test_unbound_method_retrieval(self):
525        obj = self.A
526        self.assertFalse(hasattr(obj.both, "__self__"))
527        self.assertFalse(hasattr(obj.nested, "__self__"))
528        self.assertFalse(hasattr(obj.over_partial, "__self__"))
529        self.assertFalse(hasattr(obj.static, "__self__"))
530        self.assertFalse(hasattr(self.a.static, "__self__"))
531
532    def test_descriptors(self):
533        for obj in [self.A, self.a]:
534            with self.subTest(obj=obj):
535                self.assertEqual(obj.static(), ((8,), {}))
536                self.assertEqual(obj.static(5), ((8, 5), {}))
537                self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
538                self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
539
540                self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
541                self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
542                self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
543                self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
544
545    def test_overriding_keywords(self):
546        self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
547        self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
548
549    def test_invalid_args(self):
550        with self.assertRaises(TypeError):
551            class B(object):
552                method = functools.partialmethod(None, 1)
553
554    def test_repr(self):
555        self.assertEqual(repr(vars(self.A)['both']),
556                         'functools.partialmethod({}, 3, b=4)'.format(capture))
557
558    def test_abstract(self):
559        class Abstract(abc.ABCMeta):
560
561            @abc.abstractmethod
562            def add(self, x, y):
563                pass
564
565            add5 = functools.partialmethod(add, 5)
566
567        self.assertTrue(Abstract.add.__isabstractmethod__)
568        self.assertTrue(Abstract.add5.__isabstractmethod__)
569
570        for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
571            self.assertFalse(getattr(func, '__isabstractmethod__', False))
572
573
574class TestUpdateWrapper(unittest.TestCase):
575
576    def check_wrapper(self, wrapper, wrapped,
577                      assigned=functools.WRAPPER_ASSIGNMENTS,
578                      updated=functools.WRAPPER_UPDATES):
579        # Check attributes were assigned
580        for name in assigned:
581            self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
582        # Check attributes were updated
583        for name in updated:
584            wrapper_attr = getattr(wrapper, name)
585            wrapped_attr = getattr(wrapped, name)
586            for key in wrapped_attr:
587                if name == "__dict__" and key == "__wrapped__":
588                    # __wrapped__ is overwritten by the update code
589                    continue
590                self.assertIs(wrapped_attr[key], wrapper_attr[key])
591        # Check __wrapped__
592        self.assertIs(wrapper.__wrapped__, wrapped)
593
594
595    def _default_update(self):
596        def f(a:'This is a new annotation'):
597            """This is a test"""
598            pass
599        f.attr = 'This is also a test'
600        f.__wrapped__ = "This is a bald faced lie"
601        def wrapper(b:'This is the prior annotation'):
602            pass
603        functools.update_wrapper(wrapper, f)
604        return wrapper, f
605
606    def test_default_update(self):
607        wrapper, f = self._default_update()
608        self.check_wrapper(wrapper, f)
609        self.assertIs(wrapper.__wrapped__, f)
610        self.assertEqual(wrapper.__name__, 'f')
611        self.assertEqual(wrapper.__qualname__, f.__qualname__)
612        self.assertEqual(wrapper.attr, 'This is also a test')
613        self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
614        self.assertNotIn('b', wrapper.__annotations__)
615
616    @unittest.skipIf(sys.flags.optimize >= 2,
617                     "Docstrings are omitted with -O2 and above")
618    def test_default_update_doc(self):
619        wrapper, f = self._default_update()
620        self.assertEqual(wrapper.__doc__, 'This is a test')
621
622    def test_no_update(self):
623        def f():
624            """This is a test"""
625            pass
626        f.attr = 'This is also a test'
627        def wrapper():
628            pass
629        functools.update_wrapper(wrapper, f, (), ())
630        self.check_wrapper(wrapper, f, (), ())
631        self.assertEqual(wrapper.__name__, 'wrapper')
632        self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
633        self.assertEqual(wrapper.__doc__, None)
634        self.assertEqual(wrapper.__annotations__, {})
635        self.assertFalse(hasattr(wrapper, 'attr'))
636
637    def test_selective_update(self):
638        def f():
639            pass
640        f.attr = 'This is a different test'
641        f.dict_attr = dict(a=1, b=2, c=3)
642        def wrapper():
643            pass
644        wrapper.dict_attr = {}
645        assign = ('attr',)
646        update = ('dict_attr',)
647        functools.update_wrapper(wrapper, f, assign, update)
648        self.check_wrapper(wrapper, f, assign, update)
649        self.assertEqual(wrapper.__name__, 'wrapper')
650        self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
651        self.assertEqual(wrapper.__doc__, None)
652        self.assertEqual(wrapper.attr, 'This is a different test')
653        self.assertEqual(wrapper.dict_attr, f.dict_attr)
654
655    def test_missing_attributes(self):
656        def f():
657            pass
658        def wrapper():
659            pass
660        wrapper.dict_attr = {}
661        assign = ('attr',)
662        update = ('dict_attr',)
663        # Missing attributes on wrapped object are ignored
664        functools.update_wrapper(wrapper, f, assign, update)
665        self.assertNotIn('attr', wrapper.__dict__)
666        self.assertEqual(wrapper.dict_attr, {})
667        # Wrapper must have expected attributes for updating
668        del wrapper.dict_attr
669        with self.assertRaises(AttributeError):
670            functools.update_wrapper(wrapper, f, assign, update)
671        wrapper.dict_attr = 1
672        with self.assertRaises(AttributeError):
673            functools.update_wrapper(wrapper, f, assign, update)
674
675    @support.requires_docstrings
676    @unittest.skipIf(sys.flags.optimize >= 2,
677                     "Docstrings are omitted with -O2 and above")
678    def test_builtin_update(self):
679        # Test for bug #1576241
680        def wrapper():
681            pass
682        functools.update_wrapper(wrapper, max)
683        self.assertEqual(wrapper.__name__, 'max')
684        self.assertTrue(wrapper.__doc__.startswith('max('))
685        self.assertEqual(wrapper.__annotations__, {})
686
687
688class TestWraps(TestUpdateWrapper):
689
690    def _default_update(self):
691        def f():
692            """This is a test"""
693            pass
694        f.attr = 'This is also a test'
695        f.__wrapped__ = "This is still a bald faced lie"
696        @functools.wraps(f)
697        def wrapper():
698            pass
699        return wrapper, f
700
701    def test_default_update(self):
702        wrapper, f = self._default_update()
703        self.check_wrapper(wrapper, f)
704        self.assertEqual(wrapper.__name__, 'f')
705        self.assertEqual(wrapper.__qualname__, f.__qualname__)
706        self.assertEqual(wrapper.attr, 'This is also a test')
707
708    @unittest.skipIf(sys.flags.optimize >= 2,
709                     "Docstrings are omitted with -O2 and above")
710    def test_default_update_doc(self):
711        wrapper, _ = self._default_update()
712        self.assertEqual(wrapper.__doc__, 'This is a test')
713
714    def test_no_update(self):
715        def f():
716            """This is a test"""
717            pass
718        f.attr = 'This is also a test'
719        @functools.wraps(f, (), ())
720        def wrapper():
721            pass
722        self.check_wrapper(wrapper, f, (), ())
723        self.assertEqual(wrapper.__name__, 'wrapper')
724        self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
725        self.assertEqual(wrapper.__doc__, None)
726        self.assertFalse(hasattr(wrapper, 'attr'))
727
728    def test_selective_update(self):
729        def f():
730            pass
731        f.attr = 'This is a different test'
732        f.dict_attr = dict(a=1, b=2, c=3)
733        def add_dict_attr(f):
734            f.dict_attr = {}
735            return f
736        assign = ('attr',)
737        update = ('dict_attr',)
738        @functools.wraps(f, assign, update)
739        @add_dict_attr
740        def wrapper():
741            pass
742        self.check_wrapper(wrapper, f, assign, update)
743        self.assertEqual(wrapper.__name__, 'wrapper')
744        self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
745        self.assertEqual(wrapper.__doc__, None)
746        self.assertEqual(wrapper.attr, 'This is a different test')
747        self.assertEqual(wrapper.dict_attr, f.dict_attr)
748
749@unittest.skipUnless(c_functools, 'requires the C _functools module')
750class TestReduce(unittest.TestCase):
751    if c_functools:
752        func = c_functools.reduce
753
754    def test_reduce(self):
755        class Squares:
756            def __init__(self, max):
757                self.max = max
758                self.sofar = []
759
760            def __len__(self):
761                return len(self.sofar)
762
763            def __getitem__(self, i):
764                if not 0 <= i < self.max: raise IndexError
765                n = len(self.sofar)
766                while n <= i:
767                    self.sofar.append(n*n)
768                    n += 1
769                return self.sofar[i]
770        def add(x, y):
771            return x + y
772        self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
773        self.assertEqual(
774            self.func(add, [['a', 'c'], [], ['d', 'w']], []),
775            ['a','c','d','w']
776        )
777        self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
778        self.assertEqual(
779            self.func(lambda x, y: x*y, range(2,21), 1),
780            2432902008176640000
781        )
782        self.assertEqual(self.func(add, Squares(10)), 285)
783        self.assertEqual(self.func(add, Squares(10), 0), 285)
784        self.assertEqual(self.func(add, Squares(0), 0), 0)
785        self.assertRaises(TypeError, self.func)
786        self.assertRaises(TypeError, self.func, 42, 42)
787        self.assertRaises(TypeError, self.func, 42, 42, 42)
788        self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
789        self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
790        self.assertRaises(TypeError, self.func, 42, (42, 42))
791        self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
792        self.assertRaises(TypeError, self.func, add, "")
793        self.assertRaises(TypeError, self.func, add, ())
794        self.assertRaises(TypeError, self.func, add, object())
795
796        class TestFailingIter:
797            def __iter__(self):
798                raise RuntimeError
799        self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
800
801        self.assertEqual(self.func(add, [], None), None)
802        self.assertEqual(self.func(add, [], 42), 42)
803
804        class BadSeq:
805            def __getitem__(self, index):
806                raise ValueError
807        self.assertRaises(ValueError, self.func, 42, BadSeq())
808
809    # Test reduce()'s use of iterators.
810    def test_iterator_usage(self):
811        class SequenceClass:
812            def __init__(self, n):
813                self.n = n
814            def __getitem__(self, i):
815                if 0 <= i < self.n:
816                    return i
817                else:
818                    raise IndexError
819
820        from operator import add
821        self.assertEqual(self.func(add, SequenceClass(5)), 10)
822        self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
823        self.assertRaises(TypeError, self.func, add, SequenceClass(0))
824        self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
825        self.assertEqual(self.func(add, SequenceClass(1)), 0)
826        self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
827
828        d = {"one": 1, "two": 2, "three": 3}
829        self.assertEqual(self.func(add, d), "".join(d.keys()))
830
831
832class TestCmpToKey:
833
834    def test_cmp_to_key(self):
835        def cmp1(x, y):
836            return (x > y) - (x < y)
837        key = self.cmp_to_key(cmp1)
838        self.assertEqual(key(3), key(3))
839        self.assertGreater(key(3), key(1))
840        self.assertGreaterEqual(key(3), key(3))
841
842        def cmp2(x, y):
843            return int(x) - int(y)
844        key = self.cmp_to_key(cmp2)
845        self.assertEqual(key(4.0), key('4'))
846        self.assertLess(key(2), key('35'))
847        self.assertLessEqual(key(2), key('35'))
848        self.assertNotEqual(key(2), key('35'))
849
850    def test_cmp_to_key_arguments(self):
851        def cmp1(x, y):
852            return (x > y) - (x < y)
853        key = self.cmp_to_key(mycmp=cmp1)
854        self.assertEqual(key(obj=3), key(obj=3))
855        self.assertGreater(key(obj=3), key(obj=1))
856        with self.assertRaises((TypeError, AttributeError)):
857            key(3) > 1    # rhs is not a K object
858        with self.assertRaises((TypeError, AttributeError)):
859            1 < key(3)    # lhs is not a K object
860        with self.assertRaises(TypeError):
861            key = self.cmp_to_key()             # too few args
862        with self.assertRaises(TypeError):
863            key = self.cmp_to_key(cmp1, None)   # too many args
864        key = self.cmp_to_key(cmp1)
865        with self.assertRaises(TypeError):
866            key()                                    # too few args
867        with self.assertRaises(TypeError):
868            key(None, None)                          # too many args
869
870    def test_bad_cmp(self):
871        def cmp1(x, y):
872            raise ZeroDivisionError
873        key = self.cmp_to_key(cmp1)
874        with self.assertRaises(ZeroDivisionError):
875            key(3) > key(1)
876
877        class BadCmp:
878            def __lt__(self, other):
879                raise ZeroDivisionError
880        def cmp1(x, y):
881            return BadCmp()
882        with self.assertRaises(ZeroDivisionError):
883            key(3) > key(1)
884
885    def test_obj_field(self):
886        def cmp1(x, y):
887            return (x > y) - (x < y)
888        key = self.cmp_to_key(mycmp=cmp1)
889        self.assertEqual(key(50).obj, 50)
890
891    def test_sort_int(self):
892        def mycmp(x, y):
893            return y - x
894        self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
895                         [4, 3, 2, 1, 0])
896
897    def test_sort_int_str(self):
898        def mycmp(x, y):
899            x, y = int(x), int(y)
900            return (x > y) - (x < y)
901        values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
902        values = sorted(values, key=self.cmp_to_key(mycmp))
903        self.assertEqual([int(value) for value in values],
904                         [0, 1, 1, 2, 3, 4, 5, 7, 10])
905
906    def test_hash(self):
907        def mycmp(x, y):
908            return y - x
909        key = self.cmp_to_key(mycmp)
910        k = key(10)
911        self.assertRaises(TypeError, hash, k)
912        self.assertNotIsInstance(k, collections.abc.Hashable)
913
914
915@unittest.skipUnless(c_functools, 'requires the C _functools module')
916class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
917    if c_functools:
918        cmp_to_key = c_functools.cmp_to_key
919
920
921class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
922    cmp_to_key = staticmethod(py_functools.cmp_to_key)
923
924
925class TestTotalOrdering(unittest.TestCase):
926
927    def test_total_ordering_lt(self):
928        @functools.total_ordering
929        class A:
930            def __init__(self, value):
931                self.value = value
932            def __lt__(self, other):
933                return self.value < other.value
934            def __eq__(self, other):
935                return self.value == other.value
936        self.assertTrue(A(1) < A(2))
937        self.assertTrue(A(2) > A(1))
938        self.assertTrue(A(1) <= A(2))
939        self.assertTrue(A(2) >= A(1))
940        self.assertTrue(A(2) <= A(2))
941        self.assertTrue(A(2) >= A(2))
942        self.assertFalse(A(1) > A(2))
943
944    def test_total_ordering_le(self):
945        @functools.total_ordering
946        class A:
947            def __init__(self, value):
948                self.value = value
949            def __le__(self, other):
950                return self.value <= other.value
951            def __eq__(self, other):
952                return self.value == other.value
953        self.assertTrue(A(1) < A(2))
954        self.assertTrue(A(2) > A(1))
955        self.assertTrue(A(1) <= A(2))
956        self.assertTrue(A(2) >= A(1))
957        self.assertTrue(A(2) <= A(2))
958        self.assertTrue(A(2) >= A(2))
959        self.assertFalse(A(1) >= A(2))
960
961    def test_total_ordering_gt(self):
962        @functools.total_ordering
963        class A:
964            def __init__(self, value):
965                self.value = value
966            def __gt__(self, other):
967                return self.value > other.value
968            def __eq__(self, other):
969                return self.value == other.value
970        self.assertTrue(A(1) < A(2))
971        self.assertTrue(A(2) > A(1))
972        self.assertTrue(A(1) <= A(2))
973        self.assertTrue(A(2) >= A(1))
974        self.assertTrue(A(2) <= A(2))
975        self.assertTrue(A(2) >= A(2))
976        self.assertFalse(A(2) < A(1))
977
978    def test_total_ordering_ge(self):
979        @functools.total_ordering
980        class A:
981            def __init__(self, value):
982                self.value = value
983            def __ge__(self, other):
984                return self.value >= other.value
985            def __eq__(self, other):
986                return self.value == other.value
987        self.assertTrue(A(1) < A(2))
988        self.assertTrue(A(2) > A(1))
989        self.assertTrue(A(1) <= A(2))
990        self.assertTrue(A(2) >= A(1))
991        self.assertTrue(A(2) <= A(2))
992        self.assertTrue(A(2) >= A(2))
993        self.assertFalse(A(2) <= A(1))
994
995    def test_total_ordering_no_overwrite(self):
996        # new methods should not overwrite existing
997        @functools.total_ordering
998        class A(int):
999            pass
1000        self.assertTrue(A(1) < A(2))
1001        self.assertTrue(A(2) > A(1))
1002        self.assertTrue(A(1) <= A(2))
1003        self.assertTrue(A(2) >= A(1))
1004        self.assertTrue(A(2) <= A(2))
1005        self.assertTrue(A(2) >= A(2))
1006
1007    def test_no_operations_defined(self):
1008        with self.assertRaises(ValueError):
1009            @functools.total_ordering
1010            class A:
1011                pass
1012
1013    def test_type_error_when_not_implemented(self):
1014        # bug 10042; ensure stack overflow does not occur
1015        # when decorated types return NotImplemented
1016        @functools.total_ordering
1017        class ImplementsLessThan:
1018            def __init__(self, value):
1019                self.value = value
1020            def __eq__(self, other):
1021                if isinstance(other, ImplementsLessThan):
1022                    return self.value == other.value
1023                return False
1024            def __lt__(self, other):
1025                if isinstance(other, ImplementsLessThan):
1026                    return self.value < other.value
1027                return NotImplemented
1028
1029        @functools.total_ordering
1030        class ImplementsGreaterThan:
1031            def __init__(self, value):
1032                self.value = value
1033            def __eq__(self, other):
1034                if isinstance(other, ImplementsGreaterThan):
1035                    return self.value == other.value
1036                return False
1037            def __gt__(self, other):
1038                if isinstance(other, ImplementsGreaterThan):
1039                    return self.value > other.value
1040                return NotImplemented
1041
1042        @functools.total_ordering
1043        class ImplementsLessThanEqualTo:
1044            def __init__(self, value):
1045                self.value = value
1046            def __eq__(self, other):
1047                if isinstance(other, ImplementsLessThanEqualTo):
1048                    return self.value == other.value
1049                return False
1050            def __le__(self, other):
1051                if isinstance(other, ImplementsLessThanEqualTo):
1052                    return self.value <= other.value
1053                return NotImplemented
1054
1055        @functools.total_ordering
1056        class ImplementsGreaterThanEqualTo:
1057            def __init__(self, value):
1058                self.value = value
1059            def __eq__(self, other):
1060                if isinstance(other, ImplementsGreaterThanEqualTo):
1061                    return self.value == other.value
1062                return False
1063            def __ge__(self, other):
1064                if isinstance(other, ImplementsGreaterThanEqualTo):
1065                    return self.value >= other.value
1066                return NotImplemented
1067
1068        @functools.total_ordering
1069        class ComparatorNotImplemented:
1070            def __init__(self, value):
1071                self.value = value
1072            def __eq__(self, other):
1073                if isinstance(other, ComparatorNotImplemented):
1074                    return self.value == other.value
1075                return False
1076            def __lt__(self, other):
1077                return NotImplemented
1078
1079        with self.subTest("LT < 1"), self.assertRaises(TypeError):
1080            ImplementsLessThan(-1) < 1
1081
1082        with self.subTest("LT < LE"), self.assertRaises(TypeError):
1083            ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
1084
1085        with self.subTest("LT < GT"), self.assertRaises(TypeError):
1086            ImplementsLessThan(1) < ImplementsGreaterThan(1)
1087
1088        with self.subTest("LE <= LT"), self.assertRaises(TypeError):
1089            ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
1090
1091        with self.subTest("LE <= GE"), self.assertRaises(TypeError):
1092            ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
1093
1094        with self.subTest("GT > GE"), self.assertRaises(TypeError):
1095            ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
1096
1097        with self.subTest("GT > LT"), self.assertRaises(TypeError):
1098            ImplementsGreaterThan(5) > ImplementsLessThan(5)
1099
1100        with self.subTest("GE >= GT"), self.assertRaises(TypeError):
1101            ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
1102
1103        with self.subTest("GE >= LE"), self.assertRaises(TypeError):
1104            ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
1105
1106        with self.subTest("GE when equal"):
1107            a = ComparatorNotImplemented(8)
1108            b = ComparatorNotImplemented(8)
1109            self.assertEqual(a, b)
1110            with self.assertRaises(TypeError):
1111                a >= b
1112
1113        with self.subTest("LE when equal"):
1114            a = ComparatorNotImplemented(9)
1115            b = ComparatorNotImplemented(9)
1116            self.assertEqual(a, b)
1117            with self.assertRaises(TypeError):
1118                a <= b
1119
1120    def test_pickle(self):
1121        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1122            for name in '__lt__', '__gt__', '__le__', '__ge__':
1123                with self.subTest(method=name, proto=proto):
1124                    method = getattr(Orderable_LT, name)
1125                    method_copy = pickle.loads(pickle.dumps(method, proto))
1126                    self.assertIs(method_copy, method)
1127
1128@functools.total_ordering
1129class Orderable_LT:
1130    def __init__(self, value):
1131        self.value = value
1132    def __lt__(self, other):
1133        return self.value < other.value
1134    def __eq__(self, other):
1135        return self.value == other.value
1136
1137
1138class TestLRU:
1139
1140    def test_lru(self):
1141        def orig(x, y):
1142            return 3 * x + y
1143        f = self.module.lru_cache(maxsize=20)(orig)
1144        hits, misses, maxsize, currsize = f.cache_info()
1145        self.assertEqual(maxsize, 20)
1146        self.assertEqual(currsize, 0)
1147        self.assertEqual(hits, 0)
1148        self.assertEqual(misses, 0)
1149
1150        domain = range(5)
1151        for i in range(1000):
1152            x, y = choice(domain), choice(domain)
1153            actual = f(x, y)
1154            expected = orig(x, y)
1155            self.assertEqual(actual, expected)
1156        hits, misses, maxsize, currsize = f.cache_info()
1157        self.assertTrue(hits > misses)
1158        self.assertEqual(hits + misses, 1000)
1159        self.assertEqual(currsize, 20)
1160
1161        f.cache_clear()   # test clearing
1162        hits, misses, maxsize, currsize = f.cache_info()
1163        self.assertEqual(hits, 0)
1164        self.assertEqual(misses, 0)
1165        self.assertEqual(currsize, 0)
1166        f(x, y)
1167        hits, misses, maxsize, currsize = f.cache_info()
1168        self.assertEqual(hits, 0)
1169        self.assertEqual(misses, 1)
1170        self.assertEqual(currsize, 1)
1171
1172        # Test bypassing the cache
1173        self.assertIs(f.__wrapped__, orig)
1174        f.__wrapped__(x, y)
1175        hits, misses, maxsize, currsize = f.cache_info()
1176        self.assertEqual(hits, 0)
1177        self.assertEqual(misses, 1)
1178        self.assertEqual(currsize, 1)
1179
1180        # test size zero (which means "never-cache")
1181        @self.module.lru_cache(0)
1182        def f():
1183            nonlocal f_cnt
1184            f_cnt += 1
1185            return 20
1186        self.assertEqual(f.cache_info().maxsize, 0)
1187        f_cnt = 0
1188        for i in range(5):
1189            self.assertEqual(f(), 20)
1190        self.assertEqual(f_cnt, 5)
1191        hits, misses, maxsize, currsize = f.cache_info()
1192        self.assertEqual(hits, 0)
1193        self.assertEqual(misses, 5)
1194        self.assertEqual(currsize, 0)
1195
1196        # test size one
1197        @self.module.lru_cache(1)
1198        def f():
1199            nonlocal f_cnt
1200            f_cnt += 1
1201            return 20
1202        self.assertEqual(f.cache_info().maxsize, 1)
1203        f_cnt = 0
1204        for i in range(5):
1205            self.assertEqual(f(), 20)
1206        self.assertEqual(f_cnt, 1)
1207        hits, misses, maxsize, currsize = f.cache_info()
1208        self.assertEqual(hits, 4)
1209        self.assertEqual(misses, 1)
1210        self.assertEqual(currsize, 1)
1211
1212        # test size two
1213        @self.module.lru_cache(2)
1214        def f(x):
1215            nonlocal f_cnt
1216            f_cnt += 1
1217            return x*10
1218        self.assertEqual(f.cache_info().maxsize, 2)
1219        f_cnt = 0
1220        for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1221            #    *  *              *                          *
1222            self.assertEqual(f(x), x*10)
1223        self.assertEqual(f_cnt, 4)
1224        hits, misses, maxsize, currsize = f.cache_info()
1225        self.assertEqual(hits, 12)
1226        self.assertEqual(misses, 4)
1227        self.assertEqual(currsize, 2)
1228
1229    def test_lru_bug_35780(self):
1230        # C version of the lru_cache was not checking to see if
1231        # the user function call has already modified the cache
1232        # (this arises in recursive calls and in multi-threading).
1233        # This cause the cache to have orphan links not referenced
1234        # by the cache dictionary.
1235
1236        once = True                 # Modified by f(x) below
1237
1238        @self.module.lru_cache(maxsize=10)
1239        def f(x):
1240            nonlocal once
1241            rv = f'.{x}.'
1242            if x == 20 and once:
1243                once = False
1244                rv = f(x)
1245            return rv
1246
1247        # Fill the cache
1248        for x in range(15):
1249            self.assertEqual(f(x), f'.{x}.')
1250        self.assertEqual(f.cache_info().currsize, 10)
1251
1252        # Make a recursive call and make sure the cache remains full
1253        self.assertEqual(f(20), '.20.')
1254        self.assertEqual(f.cache_info().currsize, 10)
1255
1256    def test_lru_hash_only_once(self):
1257        # To protect against weird reentrancy bugs and to improve
1258        # efficiency when faced with slow __hash__ methods, the
1259        # LRU cache guarantees that it will only call __hash__
1260        # only once per use as an argument to the cached function.
1261
1262        @self.module.lru_cache(maxsize=1)
1263        def f(x, y):
1264            return x * 3 + y
1265
1266        # Simulate the integer 5
1267        mock_int = unittest.mock.Mock()
1268        mock_int.__mul__ = unittest.mock.Mock(return_value=15)
1269        mock_int.__hash__ = unittest.mock.Mock(return_value=999)
1270
1271        # Add to cache:  One use as an argument gives one call
1272        self.assertEqual(f(mock_int, 1), 16)
1273        self.assertEqual(mock_int.__hash__.call_count, 1)
1274        self.assertEqual(f.cache_info(), (0, 1, 1, 1))
1275
1276        # Cache hit: One use as an argument gives one additional call
1277        self.assertEqual(f(mock_int, 1), 16)
1278        self.assertEqual(mock_int.__hash__.call_count, 2)
1279        self.assertEqual(f.cache_info(), (1, 1, 1, 1))
1280
1281        # Cache eviction: No use as an argument gives no additional call
1282        self.assertEqual(f(6, 2), 20)
1283        self.assertEqual(mock_int.__hash__.call_count, 2)
1284        self.assertEqual(f.cache_info(), (1, 2, 1, 1))
1285
1286        # Cache miss: One use as an argument gives one additional call
1287        self.assertEqual(f(mock_int, 1), 16)
1288        self.assertEqual(mock_int.__hash__.call_count, 3)
1289        self.assertEqual(f.cache_info(), (1, 3, 1, 1))
1290
1291    def test_lru_reentrancy_with_len(self):
1292        # Test to make sure the LRU cache code isn't thrown-off by
1293        # caching the built-in len() function.  Since len() can be
1294        # cached, we shouldn't use it inside the lru code itself.
1295        old_len = builtins.len
1296        try:
1297            builtins.len = self.module.lru_cache(4)(len)
1298            for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]:
1299                self.assertEqual(len('abcdefghijklmn'[:i]), i)
1300        finally:
1301            builtins.len = old_len
1302
1303    def test_lru_star_arg_handling(self):
1304        # Test regression that arose in ea064ff3c10f
1305        @functools.lru_cache()
1306        def f(*args):
1307            return args
1308
1309        self.assertEqual(f(1, 2), (1, 2))
1310        self.assertEqual(f((1, 2)), ((1, 2),))
1311
1312    def test_lru_type_error(self):
1313        # Regression test for issue #28653.
1314        # lru_cache was leaking when one of the arguments
1315        # wasn't cacheable.
1316
1317        @functools.lru_cache(maxsize=None)
1318        def infinite_cache(o):
1319            pass
1320
1321        @functools.lru_cache(maxsize=10)
1322        def limited_cache(o):
1323            pass
1324
1325        with self.assertRaises(TypeError):
1326            infinite_cache([])
1327
1328        with self.assertRaises(TypeError):
1329            limited_cache([])
1330
1331    def test_lru_with_maxsize_none(self):
1332        @self.module.lru_cache(maxsize=None)
1333        def fib(n):
1334            if n < 2:
1335                return n
1336            return fib(n-1) + fib(n-2)
1337        self.assertEqual([fib(n) for n in range(16)],
1338            [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1339        self.assertEqual(fib.cache_info(),
1340            self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1341        fib.cache_clear()
1342        self.assertEqual(fib.cache_info(),
1343            self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1344
1345    def test_lru_with_maxsize_negative(self):
1346        @self.module.lru_cache(maxsize=-10)
1347        def eq(n):
1348            return n
1349        for i in (0, 1):
1350            self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1351        self.assertEqual(eq.cache_info(),
1352            self.module._CacheInfo(hits=0, misses=300, maxsize=0, currsize=0))
1353
1354    def test_lru_with_exceptions(self):
1355        # Verify that user_function exceptions get passed through without
1356        # creating a hard-to-read chained exception.
1357        # http://bugs.python.org/issue13177
1358        for maxsize in (None, 128):
1359            @self.module.lru_cache(maxsize)
1360            def func(i):
1361                return 'abc'[i]
1362            self.assertEqual(func(0), 'a')
1363            with self.assertRaises(IndexError) as cm:
1364                func(15)
1365            self.assertIsNone(cm.exception.__context__)
1366            # Verify that the previous exception did not result in a cached entry
1367            with self.assertRaises(IndexError):
1368                func(15)
1369
1370    def test_lru_with_types(self):
1371        for maxsize in (None, 128):
1372            @self.module.lru_cache(maxsize=maxsize, typed=True)
1373            def square(x):
1374                return x * x
1375            self.assertEqual(square(3), 9)
1376            self.assertEqual(type(square(3)), type(9))
1377            self.assertEqual(square(3.0), 9.0)
1378            self.assertEqual(type(square(3.0)), type(9.0))
1379            self.assertEqual(square(x=3), 9)
1380            self.assertEqual(type(square(x=3)), type(9))
1381            self.assertEqual(square(x=3.0), 9.0)
1382            self.assertEqual(type(square(x=3.0)), type(9.0))
1383            self.assertEqual(square.cache_info().hits, 4)
1384            self.assertEqual(square.cache_info().misses, 4)
1385
1386    def test_lru_with_keyword_args(self):
1387        @self.module.lru_cache()
1388        def fib(n):
1389            if n < 2:
1390                return n
1391            return fib(n=n-1) + fib(n=n-2)
1392        self.assertEqual(
1393            [fib(n=number) for number in range(16)],
1394            [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1395        )
1396        self.assertEqual(fib.cache_info(),
1397            self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
1398        fib.cache_clear()
1399        self.assertEqual(fib.cache_info(),
1400            self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
1401
1402    def test_lru_with_keyword_args_maxsize_none(self):
1403        @self.module.lru_cache(maxsize=None)
1404        def fib(n):
1405            if n < 2:
1406                return n
1407            return fib(n=n-1) + fib(n=n-2)
1408        self.assertEqual([fib(n=number) for number in range(16)],
1409            [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1410        self.assertEqual(fib.cache_info(),
1411            self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1412        fib.cache_clear()
1413        self.assertEqual(fib.cache_info(),
1414            self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1415
1416    def test_kwargs_order(self):
1417        # PEP 468: Preserving Keyword Argument Order
1418        @self.module.lru_cache(maxsize=10)
1419        def f(**kwargs):
1420            return list(kwargs.items())
1421        self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)])
1422        self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)])
1423        self.assertEqual(f.cache_info(),
1424            self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2))
1425
1426    def test_lru_cache_decoration(self):
1427        def f(zomg: 'zomg_annotation'):
1428            """f doc string"""
1429            return 42
1430        g = self.module.lru_cache()(f)
1431        for attr in self.module.WRAPPER_ASSIGNMENTS:
1432            self.assertEqual(getattr(g, attr), getattr(f, attr))
1433
1434    def test_lru_cache_threaded(self):
1435        n, m = 5, 11
1436        def orig(x, y):
1437            return 3 * x + y
1438        f = self.module.lru_cache(maxsize=n*m)(orig)
1439        hits, misses, maxsize, currsize = f.cache_info()
1440        self.assertEqual(currsize, 0)
1441
1442        start = threading.Event()
1443        def full(k):
1444            start.wait(10)
1445            for _ in range(m):
1446                self.assertEqual(f(k, 0), orig(k, 0))
1447
1448        def clear():
1449            start.wait(10)
1450            for _ in range(2*m):
1451                f.cache_clear()
1452
1453        orig_si = sys.getswitchinterval()
1454        support.setswitchinterval(1e-6)
1455        try:
1456            # create n threads in order to fill cache
1457            threads = [threading.Thread(target=full, args=[k])
1458                       for k in range(n)]
1459            with support.start_threads(threads):
1460                start.set()
1461
1462            hits, misses, maxsize, currsize = f.cache_info()
1463            if self.module is py_functools:
1464                # XXX: Why can be not equal?
1465                self.assertLessEqual(misses, n)
1466                self.assertLessEqual(hits, m*n - misses)
1467            else:
1468                self.assertEqual(misses, n)
1469                self.assertEqual(hits, m*n - misses)
1470            self.assertEqual(currsize, n)
1471
1472            # create n threads in order to fill cache and 1 to clear it
1473            threads = [threading.Thread(target=clear)]
1474            threads += [threading.Thread(target=full, args=[k])
1475                        for k in range(n)]
1476            start.clear()
1477            with support.start_threads(threads):
1478                start.set()
1479        finally:
1480            sys.setswitchinterval(orig_si)
1481
1482    def test_lru_cache_threaded2(self):
1483        # Simultaneous call with the same arguments
1484        n, m = 5, 7
1485        start = threading.Barrier(n+1)
1486        pause = threading.Barrier(n+1)
1487        stop = threading.Barrier(n+1)
1488        @self.module.lru_cache(maxsize=m*n)
1489        def f(x):
1490            pause.wait(10)
1491            return 3 * x
1492        self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1493        def test():
1494            for i in range(m):
1495                start.wait(10)
1496                self.assertEqual(f(i), 3 * i)
1497                stop.wait(10)
1498        threads = [threading.Thread(target=test) for k in range(n)]
1499        with support.start_threads(threads):
1500            for i in range(m):
1501                start.wait(10)
1502                stop.reset()
1503                pause.wait(10)
1504                start.reset()
1505                stop.wait(10)
1506                pause.reset()
1507                self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1508
1509    def test_lru_cache_threaded3(self):
1510        @self.module.lru_cache(maxsize=2)
1511        def f(x):
1512            time.sleep(.01)
1513            return 3 * x
1514        def test(i, x):
1515            with self.subTest(thread=i):
1516                self.assertEqual(f(x), 3 * x, i)
1517        threads = [threading.Thread(target=test, args=(i, v))
1518                   for i, v in enumerate([1, 2, 2, 3, 2])]
1519        with support.start_threads(threads):
1520            pass
1521
1522    def test_need_for_rlock(self):
1523        # This will deadlock on an LRU cache that uses a regular lock
1524
1525        @self.module.lru_cache(maxsize=10)
1526        def test_func(x):
1527            'Used to demonstrate a reentrant lru_cache call within a single thread'
1528            return x
1529
1530        class DoubleEq:
1531            'Demonstrate a reentrant lru_cache call within a single thread'
1532            def __init__(self, x):
1533                self.x = x
1534            def __hash__(self):
1535                return self.x
1536            def __eq__(self, other):
1537                if self.x == 2:
1538                    test_func(DoubleEq(1))
1539                return self.x == other.x
1540
1541        test_func(DoubleEq(1))                      # Load the cache
1542        test_func(DoubleEq(2))                      # Load the cache
1543        self.assertEqual(test_func(DoubleEq(2)),    # Trigger a re-entrant __eq__ call
1544                         DoubleEq(2))               # Verify the correct return value
1545
1546    def test_early_detection_of_bad_call(self):
1547        # Issue #22184
1548        with self.assertRaises(TypeError):
1549            @functools.lru_cache
1550            def f():
1551                pass
1552
1553    def test_lru_method(self):
1554        class X(int):
1555            f_cnt = 0
1556            @self.module.lru_cache(2)
1557            def f(self, x):
1558                self.f_cnt += 1
1559                return x*10+self
1560        a = X(5)
1561        b = X(5)
1562        c = X(7)
1563        self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1564
1565        for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1566            self.assertEqual(a.f(x), x*10 + 5)
1567        self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1568        self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1569
1570        for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1571            self.assertEqual(b.f(x), x*10 + 5)
1572        self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1573        self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1574
1575        for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1576            self.assertEqual(c.f(x), x*10 + 7)
1577        self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1578        self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1579
1580        self.assertEqual(a.f.cache_info(), X.f.cache_info())
1581        self.assertEqual(b.f.cache_info(), X.f.cache_info())
1582        self.assertEqual(c.f.cache_info(), X.f.cache_info())
1583
1584    def test_pickle(self):
1585        cls = self.__class__
1586        for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
1587            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1588                with self.subTest(proto=proto, func=f):
1589                    f_copy = pickle.loads(pickle.dumps(f, proto))
1590                    self.assertIs(f_copy, f)
1591
1592    def test_copy(self):
1593        cls = self.__class__
1594        def orig(x, y):
1595            return 3 * x + y
1596        part = self.module.partial(orig, 2)
1597        funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1598                 self.module.lru_cache(2)(part))
1599        for f in funcs:
1600            with self.subTest(func=f):
1601                f_copy = copy.copy(f)
1602                self.assertIs(f_copy, f)
1603
1604    def test_deepcopy(self):
1605        cls = self.__class__
1606        def orig(x, y):
1607            return 3 * x + y
1608        part = self.module.partial(orig, 2)
1609        funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1610                 self.module.lru_cache(2)(part))
1611        for f in funcs:
1612            with self.subTest(func=f):
1613                f_copy = copy.deepcopy(f)
1614                self.assertIs(f_copy, f)
1615
1616
1617@py_functools.lru_cache()
1618def py_cached_func(x, y):
1619    return 3 * x + y
1620
1621@c_functools.lru_cache()
1622def c_cached_func(x, y):
1623    return 3 * x + y
1624
1625
1626class TestLRUPy(TestLRU, unittest.TestCase):
1627    module = py_functools
1628    cached_func = py_cached_func,
1629
1630    @module.lru_cache()
1631    def cached_meth(self, x, y):
1632        return 3 * x + y
1633
1634    @staticmethod
1635    @module.lru_cache()
1636    def cached_staticmeth(x, y):
1637        return 3 * x + y
1638
1639
1640class TestLRUC(TestLRU, unittest.TestCase):
1641    module = c_functools
1642    cached_func = c_cached_func,
1643
1644    @module.lru_cache()
1645    def cached_meth(self, x, y):
1646        return 3 * x + y
1647
1648    @staticmethod
1649    @module.lru_cache()
1650    def cached_staticmeth(x, y):
1651        return 3 * x + y
1652
1653
1654class TestSingleDispatch(unittest.TestCase):
1655    def test_simple_overloads(self):
1656        @functools.singledispatch
1657        def g(obj):
1658            return "base"
1659        def g_int(i):
1660            return "integer"
1661        g.register(int, g_int)
1662        self.assertEqual(g("str"), "base")
1663        self.assertEqual(g(1), "integer")
1664        self.assertEqual(g([1,2,3]), "base")
1665
1666    def test_mro(self):
1667        @functools.singledispatch
1668        def g(obj):
1669            return "base"
1670        class A:
1671            pass
1672        class C(A):
1673            pass
1674        class B(A):
1675            pass
1676        class D(C, B):
1677            pass
1678        def g_A(a):
1679            return "A"
1680        def g_B(b):
1681            return "B"
1682        g.register(A, g_A)
1683        g.register(B, g_B)
1684        self.assertEqual(g(A()), "A")
1685        self.assertEqual(g(B()), "B")
1686        self.assertEqual(g(C()), "A")
1687        self.assertEqual(g(D()), "B")
1688
1689    def test_register_decorator(self):
1690        @functools.singledispatch
1691        def g(obj):
1692            return "base"
1693        @g.register(int)
1694        def g_int(i):
1695            return "int %s" % (i,)
1696        self.assertEqual(g(""), "base")
1697        self.assertEqual(g(12), "int 12")
1698        self.assertIs(g.dispatch(int), g_int)
1699        self.assertIs(g.dispatch(object), g.dispatch(str))
1700        # Note: in the assert above this is not g.
1701        # @singledispatch returns the wrapper.
1702
1703    def test_wrapping_attributes(self):
1704        @functools.singledispatch
1705        def g(obj):
1706            "Simple test"
1707            return "Test"
1708        self.assertEqual(g.__name__, "g")
1709        if sys.flags.optimize < 2:
1710            self.assertEqual(g.__doc__, "Simple test")
1711
1712    @unittest.skipUnless(decimal, 'requires _decimal')
1713    @support.cpython_only
1714    def test_c_classes(self):
1715        @functools.singledispatch
1716        def g(obj):
1717            return "base"
1718        @g.register(decimal.DecimalException)
1719        def _(obj):
1720            return obj.args
1721        subn = decimal.Subnormal("Exponent < Emin")
1722        rnd = decimal.Rounded("Number got rounded")
1723        self.assertEqual(g(subn), ("Exponent < Emin",))
1724        self.assertEqual(g(rnd), ("Number got rounded",))
1725        @g.register(decimal.Subnormal)
1726        def _(obj):
1727            return "Too small to care."
1728        self.assertEqual(g(subn), "Too small to care.")
1729        self.assertEqual(g(rnd), ("Number got rounded",))
1730
1731    def test_compose_mro(self):
1732        # None of the examples in this test depend on haystack ordering.
1733        c = collections.abc
1734        mro = functools._compose_mro
1735        bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1736        for haystack in permutations(bases):
1737            m = mro(dict, haystack)
1738            self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
1739                                 c.Collection, c.Sized, c.Iterable,
1740                                 c.Container, object])
1741        bases = [c.Container, c.Mapping, c.MutableMapping, collections.OrderedDict]
1742        for haystack in permutations(bases):
1743            m = mro(collections.ChainMap, haystack)
1744            self.assertEqual(m, [collections.ChainMap, c.MutableMapping, c.Mapping,
1745                                 c.Collection, c.Sized, c.Iterable,
1746                                 c.Container, object])
1747
1748        # If there's a generic function with implementations registered for
1749        # both Sized and Container, passing a defaultdict to it results in an
1750        # ambiguous dispatch which will cause a RuntimeError (see
1751        # test_mro_conflicts).
1752        bases = [c.Container, c.Sized, str]
1753        for haystack in permutations(bases):
1754            m = mro(collections.defaultdict, [c.Sized, c.Container, str])
1755            self.assertEqual(m, [collections.defaultdict, dict, c.Sized,
1756                                 c.Container, object])
1757
1758        # MutableSequence below is registered directly on D. In other words, it
1759        # precedes MutableMapping which means single dispatch will always
1760        # choose MutableSequence here.
1761        class D(collections.defaultdict):
1762            pass
1763        c.MutableSequence.register(D)
1764        bases = [c.MutableSequence, c.MutableMapping]
1765        for haystack in permutations(bases):
1766            m = mro(D, bases)
1767            self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
1768                                 collections.defaultdict, dict, c.MutableMapping, c.Mapping,
1769                                 c.Collection, c.Sized, c.Iterable, c.Container,
1770                                 object])
1771
1772        # Container and Callable are registered on different base classes and
1773        # a generic function supporting both should always pick the Callable
1774        # implementation if a C instance is passed.
1775        class C(collections.defaultdict):
1776            def __call__(self):
1777                pass
1778        bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1779        for haystack in permutations(bases):
1780            m = mro(C, haystack)
1781            self.assertEqual(m, [C, c.Callable, collections.defaultdict, dict, c.Mapping,
1782                                 c.Collection, c.Sized, c.Iterable,
1783                                 c.Container, object])
1784
1785    def test_register_abc(self):
1786        c = collections.abc
1787        d = {"a": "b"}
1788        l = [1, 2, 3]
1789        s = {object(), None}
1790        f = frozenset(s)
1791        t = (1, 2, 3)
1792        @functools.singledispatch
1793        def g(obj):
1794            return "base"
1795        self.assertEqual(g(d), "base")
1796        self.assertEqual(g(l), "base")
1797        self.assertEqual(g(s), "base")
1798        self.assertEqual(g(f), "base")
1799        self.assertEqual(g(t), "base")
1800        g.register(c.Sized, lambda obj: "sized")
1801        self.assertEqual(g(d), "sized")
1802        self.assertEqual(g(l), "sized")
1803        self.assertEqual(g(s), "sized")
1804        self.assertEqual(g(f), "sized")
1805        self.assertEqual(g(t), "sized")
1806        g.register(c.MutableMapping, lambda obj: "mutablemapping")
1807        self.assertEqual(g(d), "mutablemapping")
1808        self.assertEqual(g(l), "sized")
1809        self.assertEqual(g(s), "sized")
1810        self.assertEqual(g(f), "sized")
1811        self.assertEqual(g(t), "sized")
1812        g.register(collections.ChainMap, lambda obj: "chainmap")
1813        self.assertEqual(g(d), "mutablemapping")  # irrelevant ABCs registered
1814        self.assertEqual(g(l), "sized")
1815        self.assertEqual(g(s), "sized")
1816        self.assertEqual(g(f), "sized")
1817        self.assertEqual(g(t), "sized")
1818        g.register(c.MutableSequence, lambda obj: "mutablesequence")
1819        self.assertEqual(g(d), "mutablemapping")
1820        self.assertEqual(g(l), "mutablesequence")
1821        self.assertEqual(g(s), "sized")
1822        self.assertEqual(g(f), "sized")
1823        self.assertEqual(g(t), "sized")
1824        g.register(c.MutableSet, lambda obj: "mutableset")
1825        self.assertEqual(g(d), "mutablemapping")
1826        self.assertEqual(g(l), "mutablesequence")
1827        self.assertEqual(g(s), "mutableset")
1828        self.assertEqual(g(f), "sized")
1829        self.assertEqual(g(t), "sized")
1830        g.register(c.Mapping, lambda obj: "mapping")
1831        self.assertEqual(g(d), "mutablemapping")  # not specific enough
1832        self.assertEqual(g(l), "mutablesequence")
1833        self.assertEqual(g(s), "mutableset")
1834        self.assertEqual(g(f), "sized")
1835        self.assertEqual(g(t), "sized")
1836        g.register(c.Sequence, lambda obj: "sequence")
1837        self.assertEqual(g(d), "mutablemapping")
1838        self.assertEqual(g(l), "mutablesequence")
1839        self.assertEqual(g(s), "mutableset")
1840        self.assertEqual(g(f), "sized")
1841        self.assertEqual(g(t), "sequence")
1842        g.register(c.Set, lambda obj: "set")
1843        self.assertEqual(g(d), "mutablemapping")
1844        self.assertEqual(g(l), "mutablesequence")
1845        self.assertEqual(g(s), "mutableset")
1846        self.assertEqual(g(f), "set")
1847        self.assertEqual(g(t), "sequence")
1848        g.register(dict, lambda obj: "dict")
1849        self.assertEqual(g(d), "dict")
1850        self.assertEqual(g(l), "mutablesequence")
1851        self.assertEqual(g(s), "mutableset")
1852        self.assertEqual(g(f), "set")
1853        self.assertEqual(g(t), "sequence")
1854        g.register(list, lambda obj: "list")
1855        self.assertEqual(g(d), "dict")
1856        self.assertEqual(g(l), "list")
1857        self.assertEqual(g(s), "mutableset")
1858        self.assertEqual(g(f), "set")
1859        self.assertEqual(g(t), "sequence")
1860        g.register(set, lambda obj: "concrete-set")
1861        self.assertEqual(g(d), "dict")
1862        self.assertEqual(g(l), "list")
1863        self.assertEqual(g(s), "concrete-set")
1864        self.assertEqual(g(f), "set")
1865        self.assertEqual(g(t), "sequence")
1866        g.register(frozenset, lambda obj: "frozen-set")
1867        self.assertEqual(g(d), "dict")
1868        self.assertEqual(g(l), "list")
1869        self.assertEqual(g(s), "concrete-set")
1870        self.assertEqual(g(f), "frozen-set")
1871        self.assertEqual(g(t), "sequence")
1872        g.register(tuple, lambda obj: "tuple")
1873        self.assertEqual(g(d), "dict")
1874        self.assertEqual(g(l), "list")
1875        self.assertEqual(g(s), "concrete-set")
1876        self.assertEqual(g(f), "frozen-set")
1877        self.assertEqual(g(t), "tuple")
1878
1879    def test_c3_abc(self):
1880        c = collections.abc
1881        mro = functools._c3_mro
1882        class A(object):
1883            pass
1884        class B(A):
1885            def __len__(self):
1886                return 0   # implies Sized
1887        @c.Container.register
1888        class C(object):
1889            pass
1890        class D(object):
1891            pass   # unrelated
1892        class X(D, C, B):
1893            def __call__(self):
1894                pass   # implies Callable
1895        expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
1896        for abcs in permutations([c.Sized, c.Callable, c.Container]):
1897            self.assertEqual(mro(X, abcs=abcs), expected)
1898        # unrelated ABCs don't appear in the resulting MRO
1899        many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
1900        self.assertEqual(mro(X, abcs=many_abcs), expected)
1901
1902    def test_false_meta(self):
1903        # see issue23572
1904        class MetaA(type):
1905            def __len__(self):
1906                return 0
1907        class A(metaclass=MetaA):
1908            pass
1909        class AA(A):
1910            pass
1911        @functools.singledispatch
1912        def fun(a):
1913            return 'base A'
1914        @fun.register(A)
1915        def _(a):
1916            return 'fun A'
1917        aa = AA()
1918        self.assertEqual(fun(aa), 'fun A')
1919
1920    def test_mro_conflicts(self):
1921        c = collections.abc
1922        @functools.singledispatch
1923        def g(arg):
1924            return "base"
1925        class O(c.Sized):
1926            def __len__(self):
1927                return 0
1928        o = O()
1929        self.assertEqual(g(o), "base")
1930        g.register(c.Iterable, lambda arg: "iterable")
1931        g.register(c.Container, lambda arg: "container")
1932        g.register(c.Sized, lambda arg: "sized")
1933        g.register(c.Set, lambda arg: "set")
1934        self.assertEqual(g(o), "sized")
1935        c.Iterable.register(O)
1936        self.assertEqual(g(o), "sized")   # because it's explicitly in __mro__
1937        c.Container.register(O)
1938        self.assertEqual(g(o), "sized")   # see above: Sized is in __mro__
1939        c.Set.register(O)
1940        self.assertEqual(g(o), "set")     # because c.Set is a subclass of
1941                                          # c.Sized and c.Container
1942        class P:
1943            pass
1944        p = P()
1945        self.assertEqual(g(p), "base")
1946        c.Iterable.register(P)
1947        self.assertEqual(g(p), "iterable")
1948        c.Container.register(P)
1949        with self.assertRaises(RuntimeError) as re_one:
1950            g(p)
1951        self.assertIn(
1952            str(re_one.exception),
1953            (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1954              "or <class 'collections.abc.Iterable'>"),
1955             ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
1956              "or <class 'collections.abc.Container'>")),
1957        )
1958        class Q(c.Sized):
1959            def __len__(self):
1960                return 0
1961        q = Q()
1962        self.assertEqual(g(q), "sized")
1963        c.Iterable.register(Q)
1964        self.assertEqual(g(q), "sized")   # because it's explicitly in __mro__
1965        c.Set.register(Q)
1966        self.assertEqual(g(q), "set")     # because c.Set is a subclass of
1967                                          # c.Sized and c.Iterable
1968        @functools.singledispatch
1969        def h(arg):
1970            return "base"
1971        @h.register(c.Sized)
1972        def _(arg):
1973            return "sized"
1974        @h.register(c.Container)
1975        def _(arg):
1976            return "container"
1977        # Even though Sized and Container are explicit bases of MutableMapping,
1978        # this ABC is implicitly registered on defaultdict which makes all of
1979        # MutableMapping's bases implicit as well from defaultdict's
1980        # perspective.
1981        with self.assertRaises(RuntimeError) as re_two:
1982            h(collections.defaultdict(lambda: 0))
1983        self.assertIn(
1984            str(re_two.exception),
1985            (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1986              "or <class 'collections.abc.Sized'>"),
1987             ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1988              "or <class 'collections.abc.Container'>")),
1989        )
1990        class R(collections.defaultdict):
1991            pass
1992        c.MutableSequence.register(R)
1993        @functools.singledispatch
1994        def i(arg):
1995            return "base"
1996        @i.register(c.MutableMapping)
1997        def _(arg):
1998            return "mapping"
1999        @i.register(c.MutableSequence)
2000        def _(arg):
2001            return "sequence"
2002        r = R()
2003        self.assertEqual(i(r), "sequence")
2004        class S:
2005            pass
2006        class T(S, c.Sized):
2007            def __len__(self):
2008                return 0
2009        t = T()
2010        self.assertEqual(h(t), "sized")
2011        c.Container.register(T)
2012        self.assertEqual(h(t), "sized")   # because it's explicitly in the MRO
2013        class U:
2014            def __len__(self):
2015                return 0
2016        u = U()
2017        self.assertEqual(h(u), "sized")   # implicit Sized subclass inferred
2018                                          # from the existence of __len__()
2019        c.Container.register(U)
2020        # There is no preference for registered versus inferred ABCs.
2021        with self.assertRaises(RuntimeError) as re_three:
2022            h(u)
2023        self.assertIn(
2024            str(re_three.exception),
2025            (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2026              "or <class 'collections.abc.Sized'>"),
2027             ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2028              "or <class 'collections.abc.Container'>")),
2029        )
2030        class V(c.Sized, S):
2031            def __len__(self):
2032                return 0
2033        @functools.singledispatch
2034        def j(arg):
2035            return "base"
2036        @j.register(S)
2037        def _(arg):
2038            return "s"
2039        @j.register(c.Container)
2040        def _(arg):
2041            return "container"
2042        v = V()
2043        self.assertEqual(j(v), "s")
2044        c.Container.register(V)
2045        self.assertEqual(j(v), "container")   # because it ends up right after
2046                                              # Sized in the MRO
2047
2048    def test_cache_invalidation(self):
2049        from collections import UserDict
2050        import weakref
2051
2052        class TracingDict(UserDict):
2053            def __init__(self, *args, **kwargs):
2054                super(TracingDict, self).__init__(*args, **kwargs)
2055                self.set_ops = []
2056                self.get_ops = []
2057            def __getitem__(self, key):
2058                result = self.data[key]
2059                self.get_ops.append(key)
2060                return result
2061            def __setitem__(self, key, value):
2062                self.set_ops.append(key)
2063                self.data[key] = value
2064            def clear(self):
2065                self.data.clear()
2066
2067        td = TracingDict()
2068        with support.swap_attr(weakref, "WeakKeyDictionary", lambda: td):
2069            c = collections.abc
2070            @functools.singledispatch
2071            def g(arg):
2072                return "base"
2073            d = {}
2074            l = []
2075            self.assertEqual(len(td), 0)
2076            self.assertEqual(g(d), "base")
2077            self.assertEqual(len(td), 1)
2078            self.assertEqual(td.get_ops, [])
2079            self.assertEqual(td.set_ops, [dict])
2080            self.assertEqual(td.data[dict], g.registry[object])
2081            self.assertEqual(g(l), "base")
2082            self.assertEqual(len(td), 2)
2083            self.assertEqual(td.get_ops, [])
2084            self.assertEqual(td.set_ops, [dict, list])
2085            self.assertEqual(td.data[dict], g.registry[object])
2086            self.assertEqual(td.data[list], g.registry[object])
2087            self.assertEqual(td.data[dict], td.data[list])
2088            self.assertEqual(g(l), "base")
2089            self.assertEqual(g(d), "base")
2090            self.assertEqual(td.get_ops, [list, dict])
2091            self.assertEqual(td.set_ops, [dict, list])
2092            g.register(list, lambda arg: "list")
2093            self.assertEqual(td.get_ops, [list, dict])
2094            self.assertEqual(len(td), 0)
2095            self.assertEqual(g(d), "base")
2096            self.assertEqual(len(td), 1)
2097            self.assertEqual(td.get_ops, [list, dict])
2098            self.assertEqual(td.set_ops, [dict, list, dict])
2099            self.assertEqual(td.data[dict],
2100                             functools._find_impl(dict, g.registry))
2101            self.assertEqual(g(l), "list")
2102            self.assertEqual(len(td), 2)
2103            self.assertEqual(td.get_ops, [list, dict])
2104            self.assertEqual(td.set_ops, [dict, list, dict, list])
2105            self.assertEqual(td.data[list],
2106                             functools._find_impl(list, g.registry))
2107            class X:
2108                pass
2109            c.MutableMapping.register(X)   # Will not invalidate the cache,
2110                                           # not using ABCs yet.
2111            self.assertEqual(g(d), "base")
2112            self.assertEqual(g(l), "list")
2113            self.assertEqual(td.get_ops, [list, dict, dict, list])
2114            self.assertEqual(td.set_ops, [dict, list, dict, list])
2115            g.register(c.Sized, lambda arg: "sized")
2116            self.assertEqual(len(td), 0)
2117            self.assertEqual(g(d), "sized")
2118            self.assertEqual(len(td), 1)
2119            self.assertEqual(td.get_ops, [list, dict, dict, list])
2120            self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
2121            self.assertEqual(g(l), "list")
2122            self.assertEqual(len(td), 2)
2123            self.assertEqual(td.get_ops, [list, dict, dict, list])
2124            self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2125            self.assertEqual(g(l), "list")
2126            self.assertEqual(g(d), "sized")
2127            self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
2128            self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2129            g.dispatch(list)
2130            g.dispatch(dict)
2131            self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
2132                                          list, dict])
2133            self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2134            c.MutableSet.register(X)       # Will invalidate the cache.
2135            self.assertEqual(len(td), 2)   # Stale cache.
2136            self.assertEqual(g(l), "list")
2137            self.assertEqual(len(td), 1)
2138            g.register(c.MutableMapping, lambda arg: "mutablemapping")
2139            self.assertEqual(len(td), 0)
2140            self.assertEqual(g(d), "mutablemapping")
2141            self.assertEqual(len(td), 1)
2142            self.assertEqual(g(l), "list")
2143            self.assertEqual(len(td), 2)
2144            g.register(dict, lambda arg: "dict")
2145            self.assertEqual(g(d), "dict")
2146            self.assertEqual(g(l), "list")
2147            g._clear_cache()
2148            self.assertEqual(len(td), 0)
2149
2150    def test_annotations(self):
2151        @functools.singledispatch
2152        def i(arg):
2153            return "base"
2154        @i.register
2155        def _(arg: collections.abc.Mapping):
2156            return "mapping"
2157        @i.register
2158        def _(arg: "collections.abc.Sequence"):
2159            return "sequence"
2160        self.assertEqual(i(None), "base")
2161        self.assertEqual(i({"a": 1}), "mapping")
2162        self.assertEqual(i([1, 2, 3]), "sequence")
2163        self.assertEqual(i((1, 2, 3)), "sequence")
2164        self.assertEqual(i("str"), "sequence")
2165
2166        # Registering classes as callables doesn't work with annotations,
2167        # you need to pass the type explicitly.
2168        @i.register(str)
2169        class _:
2170            def __init__(self, arg):
2171                self.arg = arg
2172
2173            def __eq__(self, other):
2174                return self.arg == other
2175        self.assertEqual(i("str"), "str")
2176
2177    def test_invalid_registrations(self):
2178        msg_prefix = "Invalid first argument to `register()`: "
2179        msg_suffix = (
2180            ". Use either `@register(some_class)` or plain `@register` on an "
2181            "annotated function."
2182        )
2183        @functools.singledispatch
2184        def i(arg):
2185            return "base"
2186        with self.assertRaises(TypeError) as exc:
2187            @i.register(42)
2188            def _(arg):
2189                return "I annotated with a non-type"
2190        self.assertTrue(str(exc.exception).startswith(msg_prefix + "42"))
2191        self.assertTrue(str(exc.exception).endswith(msg_suffix))
2192        with self.assertRaises(TypeError) as exc:
2193            @i.register
2194            def _(arg):
2195                return "I forgot to annotate"
2196        self.assertTrue(str(exc.exception).startswith(msg_prefix +
2197            "<function TestSingleDispatch.test_invalid_registrations.<locals>._"
2198        ))
2199        self.assertTrue(str(exc.exception).endswith(msg_suffix))
2200
2201        # FIXME: The following will only work after PEP 560 is implemented.
2202        return
2203
2204        with self.assertRaises(TypeError) as exc:
2205            @i.register
2206            def _(arg: typing.Iterable[str]):
2207                # At runtime, dispatching on generics is impossible.
2208                # When registering implementations with singledispatch, avoid
2209                # types from `typing`. Instead, annotate with regular types
2210                # or ABCs.
2211                return "I annotated with a generic collection"
2212        self.assertTrue(str(exc.exception).startswith(msg_prefix +
2213            "<function TestSingleDispatch.test_invalid_registrations.<locals>._"
2214        ))
2215        self.assertTrue(str(exc.exception).endswith(msg_suffix))
2216
2217    def test_invalid_positional_argument(self):
2218        @functools.singledispatch
2219        def f(*args):
2220            pass
2221        msg = 'f requires at least 1 positional argument'
2222        with self.assertRaisesRegex(TypeError, msg):
2223            f()
2224
2225if __name__ == '__main__':
2226    unittest.main()
2227