1import abc 2import builtins 3import collections 4import collections.abc 5import copy 6from itertools import permutations 7import pickle 8from random import choice 9import sys 10from test import support 11import threading 12import time 13import typing 14import unittest 15import unittest.mock 16from weakref import proxy 17import contextlib 18 19import functools 20 21py_functools = support.import_fresh_module('functools', blocked=['_functools']) 22c_functools = support.import_fresh_module('functools', fresh=['_functools']) 23 24decimal = support.import_fresh_module('decimal', fresh=['_decimal']) 25 26@contextlib.contextmanager 27def replaced_module(name, replacement): 28 original_module = sys.modules[name] 29 sys.modules[name] = replacement 30 try: 31 yield 32 finally: 33 sys.modules[name] = original_module 34 35def capture(*args, **kw): 36 """capture all positional and keyword arguments""" 37 return args, kw 38 39 40def signature(part): 41 """ return the signature of a partial object """ 42 return (part.func, part.args, part.keywords, part.__dict__) 43 44class MyTuple(tuple): 45 pass 46 47class BadTuple(tuple): 48 def __add__(self, other): 49 return list(self) + list(other) 50 51class MyDict(dict): 52 pass 53 54 55class TestPartial: 56 57 def test_basic_examples(self): 58 p = self.partial(capture, 1, 2, a=10, b=20) 59 self.assertTrue(callable(p)) 60 self.assertEqual(p(3, 4, b=30, c=40), 61 ((1, 2, 3, 4), dict(a=10, b=30, c=40))) 62 p = self.partial(map, lambda x: x*10) 63 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40]) 64 65 def test_attributes(self): 66 p = self.partial(capture, 1, 2, a=10, b=20) 67 # attributes should be readable 68 self.assertEqual(p.func, capture) 69 self.assertEqual(p.args, (1, 2)) 70 self.assertEqual(p.keywords, dict(a=10, b=20)) 71 72 def test_argument_checking(self): 73 self.assertRaises(TypeError, self.partial) # need at least a func arg 74 try: 75 self.partial(2)() 76 except TypeError: 77 pass 78 else: 79 self.fail('First arg not checked for callability') 80 81 def test_protection_of_callers_dict_argument(self): 82 # a caller's dictionary should not be altered by partial 83 def func(a=10, b=20): 84 return a 85 d = {'a':3} 86 p = self.partial(func, a=5) 87 self.assertEqual(p(**d), 3) 88 self.assertEqual(d, {'a':3}) 89 p(b=7) 90 self.assertEqual(d, {'a':3}) 91 92 def test_kwargs_copy(self): 93 # Issue #29532: Altering a kwarg dictionary passed to a constructor 94 # should not affect a partial object after creation 95 d = {'a': 3} 96 p = self.partial(capture, **d) 97 self.assertEqual(p(), ((), {'a': 3})) 98 d['a'] = 5 99 self.assertEqual(p(), ((), {'a': 3})) 100 101 def test_arg_combinations(self): 102 # exercise special code paths for zero args in either partial 103 # object or the caller 104 p = self.partial(capture) 105 self.assertEqual(p(), ((), {})) 106 self.assertEqual(p(1,2), ((1,2), {})) 107 p = self.partial(capture, 1, 2) 108 self.assertEqual(p(), ((1,2), {})) 109 self.assertEqual(p(3,4), ((1,2,3,4), {})) 110 111 def test_kw_combinations(self): 112 # exercise special code paths for no keyword args in 113 # either the partial object or the caller 114 p = self.partial(capture) 115 self.assertEqual(p.keywords, {}) 116 self.assertEqual(p(), ((), {})) 117 self.assertEqual(p(a=1), ((), {'a':1})) 118 p = self.partial(capture, a=1) 119 self.assertEqual(p.keywords, {'a':1}) 120 self.assertEqual(p(), ((), {'a':1})) 121 self.assertEqual(p(b=2), ((), {'a':1, 'b':2})) 122 # keyword args in the call override those in the partial object 123 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2})) 124 125 def test_positional(self): 126 # make sure positional arguments are captured correctly 127 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]: 128 p = self.partial(capture, *args) 129 expected = args + ('x',) 130 got, empty = p('x') 131 self.assertTrue(expected == got and empty == {}) 132 133 def test_keyword(self): 134 # make sure keyword arguments are captured correctly 135 for a in ['a', 0, None, 3.5]: 136 p = self.partial(capture, a=a) 137 expected = {'a':a,'x':None} 138 empty, got = p(x=None) 139 self.assertTrue(expected == got and empty == ()) 140 141 def test_no_side_effects(self): 142 # make sure there are no side effects that affect subsequent calls 143 p = self.partial(capture, 0, a=1) 144 args1, kw1 = p(1, b=2) 145 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2}) 146 args2, kw2 = p() 147 self.assertTrue(args2 == (0,) and kw2 == {'a':1}) 148 149 def test_error_propagation(self): 150 def f(x, y): 151 x / y 152 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0)) 153 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0) 154 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0) 155 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1) 156 157 def test_weakref(self): 158 f = self.partial(int, base=16) 159 p = proxy(f) 160 self.assertEqual(f.func, p.func) 161 f = None 162 self.assertRaises(ReferenceError, getattr, p, 'func') 163 164 def test_with_bound_and_unbound_methods(self): 165 data = list(map(str, range(10))) 166 join = self.partial(str.join, '') 167 self.assertEqual(join(data), '0123456789') 168 join = self.partial(''.join) 169 self.assertEqual(join(data), '0123456789') 170 171 def test_nested_optimization(self): 172 partial = self.partial 173 inner = partial(signature, 'asdf') 174 nested = partial(inner, bar=True) 175 flat = partial(signature, 'asdf', bar=True) 176 self.assertEqual(signature(nested), signature(flat)) 177 178 def test_nested_partial_with_attribute(self): 179 # see issue 25137 180 partial = self.partial 181 182 def foo(bar): 183 return bar 184 185 p = partial(foo, 'first') 186 p2 = partial(p, 'second') 187 p2.new_attr = 'spam' 188 self.assertEqual(p2.new_attr, 'spam') 189 190 def test_repr(self): 191 args = (object(), object()) 192 args_repr = ', '.join(repr(a) for a in args) 193 kwargs = {'a': object(), 'b': object()} 194 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs), 195 'b={b!r}, a={a!r}'.format_map(kwargs)] 196 if self.partial in (c_functools.partial, py_functools.partial): 197 name = 'functools.partial' 198 else: 199 name = self.partial.__name__ 200 201 f = self.partial(capture) 202 self.assertEqual(f'{name}({capture!r})', repr(f)) 203 204 f = self.partial(capture, *args) 205 self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f)) 206 207 f = self.partial(capture, **kwargs) 208 self.assertIn(repr(f), 209 [f'{name}({capture!r}, {kwargs_repr})' 210 for kwargs_repr in kwargs_reprs]) 211 212 f = self.partial(capture, *args, **kwargs) 213 self.assertIn(repr(f), 214 [f'{name}({capture!r}, {args_repr}, {kwargs_repr})' 215 for kwargs_repr in kwargs_reprs]) 216 217 def test_recursive_repr(self): 218 if self.partial in (c_functools.partial, py_functools.partial): 219 name = 'functools.partial' 220 else: 221 name = self.partial.__name__ 222 223 f = self.partial(capture) 224 f.__setstate__((f, (), {}, {})) 225 try: 226 self.assertEqual(repr(f), '%s(...)' % (name,)) 227 finally: 228 f.__setstate__((capture, (), {}, {})) 229 230 f = self.partial(capture) 231 f.__setstate__((capture, (f,), {}, {})) 232 try: 233 self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,)) 234 finally: 235 f.__setstate__((capture, (), {}, {})) 236 237 f = self.partial(capture) 238 f.__setstate__((capture, (), {'a': f}, {})) 239 try: 240 self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,)) 241 finally: 242 f.__setstate__((capture, (), {}, {})) 243 244 def test_pickle(self): 245 with self.AllowPickle(): 246 f = self.partial(signature, ['asdf'], bar=[True]) 247 f.attr = [] 248 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 249 f_copy = pickle.loads(pickle.dumps(f, proto)) 250 self.assertEqual(signature(f_copy), signature(f)) 251 252 def test_copy(self): 253 f = self.partial(signature, ['asdf'], bar=[True]) 254 f.attr = [] 255 f_copy = copy.copy(f) 256 self.assertEqual(signature(f_copy), signature(f)) 257 self.assertIs(f_copy.attr, f.attr) 258 self.assertIs(f_copy.args, f.args) 259 self.assertIs(f_copy.keywords, f.keywords) 260 261 def test_deepcopy(self): 262 f = self.partial(signature, ['asdf'], bar=[True]) 263 f.attr = [] 264 f_copy = copy.deepcopy(f) 265 self.assertEqual(signature(f_copy), signature(f)) 266 self.assertIsNot(f_copy.attr, f.attr) 267 self.assertIsNot(f_copy.args, f.args) 268 self.assertIsNot(f_copy.args[0], f.args[0]) 269 self.assertIsNot(f_copy.keywords, f.keywords) 270 self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar']) 271 272 def test_setstate(self): 273 f = self.partial(signature) 274 f.__setstate__((capture, (1,), dict(a=10), dict(attr=[]))) 275 276 self.assertEqual(signature(f), 277 (capture, (1,), dict(a=10), dict(attr=[]))) 278 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20})) 279 280 f.__setstate__((capture, (1,), dict(a=10), None)) 281 282 self.assertEqual(signature(f), (capture, (1,), dict(a=10), {})) 283 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20})) 284 285 f.__setstate__((capture, (1,), None, None)) 286 #self.assertEqual(signature(f), (capture, (1,), {}, {})) 287 self.assertEqual(f(2, b=20), ((1, 2), {'b': 20})) 288 self.assertEqual(f(2), ((1, 2), {})) 289 self.assertEqual(f(), ((1,), {})) 290 291 f.__setstate__((capture, (), {}, None)) 292 self.assertEqual(signature(f), (capture, (), {}, {})) 293 self.assertEqual(f(2, b=20), ((2,), {'b': 20})) 294 self.assertEqual(f(2), ((2,), {})) 295 self.assertEqual(f(), ((), {})) 296 297 def test_setstate_errors(self): 298 f = self.partial(signature) 299 self.assertRaises(TypeError, f.__setstate__, (capture, (), {})) 300 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None)) 301 self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None]) 302 self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None)) 303 self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None)) 304 self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None)) 305 self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None)) 306 307 def test_setstate_subclasses(self): 308 f = self.partial(signature) 309 f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None)) 310 s = signature(f) 311 self.assertEqual(s, (capture, (1,), dict(a=10), {})) 312 self.assertIs(type(s[1]), tuple) 313 self.assertIs(type(s[2]), dict) 314 r = f() 315 self.assertEqual(r, ((1,), {'a': 10})) 316 self.assertIs(type(r[0]), tuple) 317 self.assertIs(type(r[1]), dict) 318 319 f.__setstate__((capture, BadTuple((1,)), {}, None)) 320 s = signature(f) 321 self.assertEqual(s, (capture, (1,), {}, {})) 322 self.assertIs(type(s[1]), tuple) 323 r = f(2) 324 self.assertEqual(r, ((1, 2), {})) 325 self.assertIs(type(r[0]), tuple) 326 327 def test_recursive_pickle(self): 328 with self.AllowPickle(): 329 f = self.partial(capture) 330 f.__setstate__((f, (), {}, {})) 331 try: 332 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 333 with self.assertRaises(RecursionError): 334 pickle.dumps(f, proto) 335 finally: 336 f.__setstate__((capture, (), {}, {})) 337 338 f = self.partial(capture) 339 f.__setstate__((capture, (f,), {}, {})) 340 try: 341 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 342 f_copy = pickle.loads(pickle.dumps(f, proto)) 343 try: 344 self.assertIs(f_copy.args[0], f_copy) 345 finally: 346 f_copy.__setstate__((capture, (), {}, {})) 347 finally: 348 f.__setstate__((capture, (), {}, {})) 349 350 f = self.partial(capture) 351 f.__setstate__((capture, (), {'a': f}, {})) 352 try: 353 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 354 f_copy = pickle.loads(pickle.dumps(f, proto)) 355 try: 356 self.assertIs(f_copy.keywords['a'], f_copy) 357 finally: 358 f_copy.__setstate__((capture, (), {}, {})) 359 finally: 360 f.__setstate__((capture, (), {}, {})) 361 362 # Issue 6083: Reference counting bug 363 def test_setstate_refcount(self): 364 class BadSequence: 365 def __len__(self): 366 return 4 367 def __getitem__(self, key): 368 if key == 0: 369 return max 370 elif key == 1: 371 return tuple(range(1000000)) 372 elif key in (2, 3): 373 return {} 374 raise IndexError 375 376 f = self.partial(object) 377 self.assertRaises(TypeError, f.__setstate__, BadSequence()) 378 379@unittest.skipUnless(c_functools, 'requires the C _functools module') 380class TestPartialC(TestPartial, unittest.TestCase): 381 if c_functools: 382 partial = c_functools.partial 383 384 class AllowPickle: 385 def __enter__(self): 386 return self 387 def __exit__(self, type, value, tb): 388 return False 389 390 def test_attributes_unwritable(self): 391 # attributes should not be writable 392 p = self.partial(capture, 1, 2, a=10, b=20) 393 self.assertRaises(AttributeError, setattr, p, 'func', map) 394 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2)) 395 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2)) 396 397 p = self.partial(hex) 398 try: 399 del p.__dict__ 400 except TypeError: 401 pass 402 else: 403 self.fail('partial object allowed __dict__ to be deleted') 404 405 def test_manually_adding_non_string_keyword(self): 406 p = self.partial(capture) 407 # Adding a non-string/unicode keyword to partial kwargs 408 p.keywords[1234] = 'value' 409 r = repr(p) 410 self.assertIn('1234', r) 411 self.assertIn("'value'", r) 412 with self.assertRaises(TypeError): 413 p() 414 415 def test_keystr_replaces_value(self): 416 p = self.partial(capture) 417 418 class MutatesYourDict(object): 419 def __str__(self): 420 p.keywords[self] = ['sth2'] 421 return 'astr' 422 423 # Replacing the value during key formatting should keep the original 424 # value alive (at least long enough). 425 p.keywords[MutatesYourDict()] = ['sth'] 426 r = repr(p) 427 self.assertIn('astr', r) 428 self.assertIn("['sth']", r) 429 430 431class TestPartialPy(TestPartial, unittest.TestCase): 432 partial = py_functools.partial 433 434 class AllowPickle: 435 def __init__(self): 436 self._cm = replaced_module("functools", py_functools) 437 def __enter__(self): 438 return self._cm.__enter__() 439 def __exit__(self, type, value, tb): 440 return self._cm.__exit__(type, value, tb) 441 442if c_functools: 443 class CPartialSubclass(c_functools.partial): 444 pass 445 446class PyPartialSubclass(py_functools.partial): 447 pass 448 449@unittest.skipUnless(c_functools, 'requires the C _functools module') 450class TestPartialCSubclass(TestPartialC): 451 if c_functools: 452 partial = CPartialSubclass 453 454 # partial subclasses are not optimized for nested calls 455 test_nested_optimization = None 456 457class TestPartialPySubclass(TestPartialPy): 458 partial = PyPartialSubclass 459 460class TestPartialMethod(unittest.TestCase): 461 462 class A(object): 463 nothing = functools.partialmethod(capture) 464 positional = functools.partialmethod(capture, 1) 465 keywords = functools.partialmethod(capture, a=2) 466 both = functools.partialmethod(capture, 3, b=4) 467 468 nested = functools.partialmethod(positional, 5) 469 470 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7) 471 472 static = functools.partialmethod(staticmethod(capture), 8) 473 cls = functools.partialmethod(classmethod(capture), d=9) 474 475 a = A() 476 477 def test_arg_combinations(self): 478 self.assertEqual(self.a.nothing(), ((self.a,), {})) 479 self.assertEqual(self.a.nothing(5), ((self.a, 5), {})) 480 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6})) 481 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6})) 482 483 self.assertEqual(self.a.positional(), ((self.a, 1), {})) 484 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {})) 485 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6})) 486 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6})) 487 488 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2})) 489 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2})) 490 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6})) 491 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6})) 492 493 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4})) 494 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4})) 495 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6})) 496 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6})) 497 498 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6})) 499 500 def test_nested(self): 501 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {})) 502 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {})) 503 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7})) 504 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7})) 505 506 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7})) 507 508 def test_over_partial(self): 509 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6})) 510 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6})) 511 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8})) 512 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8})) 513 514 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8})) 515 516 def test_bound_method_introspection(self): 517 obj = self.a 518 self.assertIs(obj.both.__self__, obj) 519 self.assertIs(obj.nested.__self__, obj) 520 self.assertIs(obj.over_partial.__self__, obj) 521 self.assertIs(obj.cls.__self__, self.A) 522 self.assertIs(self.A.cls.__self__, self.A) 523 524 def test_unbound_method_retrieval(self): 525 obj = self.A 526 self.assertFalse(hasattr(obj.both, "__self__")) 527 self.assertFalse(hasattr(obj.nested, "__self__")) 528 self.assertFalse(hasattr(obj.over_partial, "__self__")) 529 self.assertFalse(hasattr(obj.static, "__self__")) 530 self.assertFalse(hasattr(self.a.static, "__self__")) 531 532 def test_descriptors(self): 533 for obj in [self.A, self.a]: 534 with self.subTest(obj=obj): 535 self.assertEqual(obj.static(), ((8,), {})) 536 self.assertEqual(obj.static(5), ((8, 5), {})) 537 self.assertEqual(obj.static(d=8), ((8,), {'d': 8})) 538 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8})) 539 540 self.assertEqual(obj.cls(), ((self.A,), {'d': 9})) 541 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9})) 542 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9})) 543 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9})) 544 545 def test_overriding_keywords(self): 546 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3})) 547 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3})) 548 549 def test_invalid_args(self): 550 with self.assertRaises(TypeError): 551 class B(object): 552 method = functools.partialmethod(None, 1) 553 554 def test_repr(self): 555 self.assertEqual(repr(vars(self.A)['both']), 556 'functools.partialmethod({}, 3, b=4)'.format(capture)) 557 558 def test_abstract(self): 559 class Abstract(abc.ABCMeta): 560 561 @abc.abstractmethod 562 def add(self, x, y): 563 pass 564 565 add5 = functools.partialmethod(add, 5) 566 567 self.assertTrue(Abstract.add.__isabstractmethod__) 568 self.assertTrue(Abstract.add5.__isabstractmethod__) 569 570 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]: 571 self.assertFalse(getattr(func, '__isabstractmethod__', False)) 572 573 574class TestUpdateWrapper(unittest.TestCase): 575 576 def check_wrapper(self, wrapper, wrapped, 577 assigned=functools.WRAPPER_ASSIGNMENTS, 578 updated=functools.WRAPPER_UPDATES): 579 # Check attributes were assigned 580 for name in assigned: 581 self.assertIs(getattr(wrapper, name), getattr(wrapped, name)) 582 # Check attributes were updated 583 for name in updated: 584 wrapper_attr = getattr(wrapper, name) 585 wrapped_attr = getattr(wrapped, name) 586 for key in wrapped_attr: 587 if name == "__dict__" and key == "__wrapped__": 588 # __wrapped__ is overwritten by the update code 589 continue 590 self.assertIs(wrapped_attr[key], wrapper_attr[key]) 591 # Check __wrapped__ 592 self.assertIs(wrapper.__wrapped__, wrapped) 593 594 595 def _default_update(self): 596 def f(a:'This is a new annotation'): 597 """This is a test""" 598 pass 599 f.attr = 'This is also a test' 600 f.__wrapped__ = "This is a bald faced lie" 601 def wrapper(b:'This is the prior annotation'): 602 pass 603 functools.update_wrapper(wrapper, f) 604 return wrapper, f 605 606 def test_default_update(self): 607 wrapper, f = self._default_update() 608 self.check_wrapper(wrapper, f) 609 self.assertIs(wrapper.__wrapped__, f) 610 self.assertEqual(wrapper.__name__, 'f') 611 self.assertEqual(wrapper.__qualname__, f.__qualname__) 612 self.assertEqual(wrapper.attr, 'This is also a test') 613 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation') 614 self.assertNotIn('b', wrapper.__annotations__) 615 616 @unittest.skipIf(sys.flags.optimize >= 2, 617 "Docstrings are omitted with -O2 and above") 618 def test_default_update_doc(self): 619 wrapper, f = self._default_update() 620 self.assertEqual(wrapper.__doc__, 'This is a test') 621 622 def test_no_update(self): 623 def f(): 624 """This is a test""" 625 pass 626 f.attr = 'This is also a test' 627 def wrapper(): 628 pass 629 functools.update_wrapper(wrapper, f, (), ()) 630 self.check_wrapper(wrapper, f, (), ()) 631 self.assertEqual(wrapper.__name__, 'wrapper') 632 self.assertNotEqual(wrapper.__qualname__, f.__qualname__) 633 self.assertEqual(wrapper.__doc__, None) 634 self.assertEqual(wrapper.__annotations__, {}) 635 self.assertFalse(hasattr(wrapper, 'attr')) 636 637 def test_selective_update(self): 638 def f(): 639 pass 640 f.attr = 'This is a different test' 641 f.dict_attr = dict(a=1, b=2, c=3) 642 def wrapper(): 643 pass 644 wrapper.dict_attr = {} 645 assign = ('attr',) 646 update = ('dict_attr',) 647 functools.update_wrapper(wrapper, f, assign, update) 648 self.check_wrapper(wrapper, f, assign, update) 649 self.assertEqual(wrapper.__name__, 'wrapper') 650 self.assertNotEqual(wrapper.__qualname__, f.__qualname__) 651 self.assertEqual(wrapper.__doc__, None) 652 self.assertEqual(wrapper.attr, 'This is a different test') 653 self.assertEqual(wrapper.dict_attr, f.dict_attr) 654 655 def test_missing_attributes(self): 656 def f(): 657 pass 658 def wrapper(): 659 pass 660 wrapper.dict_attr = {} 661 assign = ('attr',) 662 update = ('dict_attr',) 663 # Missing attributes on wrapped object are ignored 664 functools.update_wrapper(wrapper, f, assign, update) 665 self.assertNotIn('attr', wrapper.__dict__) 666 self.assertEqual(wrapper.dict_attr, {}) 667 # Wrapper must have expected attributes for updating 668 del wrapper.dict_attr 669 with self.assertRaises(AttributeError): 670 functools.update_wrapper(wrapper, f, assign, update) 671 wrapper.dict_attr = 1 672 with self.assertRaises(AttributeError): 673 functools.update_wrapper(wrapper, f, assign, update) 674 675 @support.requires_docstrings 676 @unittest.skipIf(sys.flags.optimize >= 2, 677 "Docstrings are omitted with -O2 and above") 678 def test_builtin_update(self): 679 # Test for bug #1576241 680 def wrapper(): 681 pass 682 functools.update_wrapper(wrapper, max) 683 self.assertEqual(wrapper.__name__, 'max') 684 self.assertTrue(wrapper.__doc__.startswith('max(')) 685 self.assertEqual(wrapper.__annotations__, {}) 686 687 688class TestWraps(TestUpdateWrapper): 689 690 def _default_update(self): 691 def f(): 692 """This is a test""" 693 pass 694 f.attr = 'This is also a test' 695 f.__wrapped__ = "This is still a bald faced lie" 696 @functools.wraps(f) 697 def wrapper(): 698 pass 699 return wrapper, f 700 701 def test_default_update(self): 702 wrapper, f = self._default_update() 703 self.check_wrapper(wrapper, f) 704 self.assertEqual(wrapper.__name__, 'f') 705 self.assertEqual(wrapper.__qualname__, f.__qualname__) 706 self.assertEqual(wrapper.attr, 'This is also a test') 707 708 @unittest.skipIf(sys.flags.optimize >= 2, 709 "Docstrings are omitted with -O2 and above") 710 def test_default_update_doc(self): 711 wrapper, _ = self._default_update() 712 self.assertEqual(wrapper.__doc__, 'This is a test') 713 714 def test_no_update(self): 715 def f(): 716 """This is a test""" 717 pass 718 f.attr = 'This is also a test' 719 @functools.wraps(f, (), ()) 720 def wrapper(): 721 pass 722 self.check_wrapper(wrapper, f, (), ()) 723 self.assertEqual(wrapper.__name__, 'wrapper') 724 self.assertNotEqual(wrapper.__qualname__, f.__qualname__) 725 self.assertEqual(wrapper.__doc__, None) 726 self.assertFalse(hasattr(wrapper, 'attr')) 727 728 def test_selective_update(self): 729 def f(): 730 pass 731 f.attr = 'This is a different test' 732 f.dict_attr = dict(a=1, b=2, c=3) 733 def add_dict_attr(f): 734 f.dict_attr = {} 735 return f 736 assign = ('attr',) 737 update = ('dict_attr',) 738 @functools.wraps(f, assign, update) 739 @add_dict_attr 740 def wrapper(): 741 pass 742 self.check_wrapper(wrapper, f, assign, update) 743 self.assertEqual(wrapper.__name__, 'wrapper') 744 self.assertNotEqual(wrapper.__qualname__, f.__qualname__) 745 self.assertEqual(wrapper.__doc__, None) 746 self.assertEqual(wrapper.attr, 'This is a different test') 747 self.assertEqual(wrapper.dict_attr, f.dict_attr) 748 749@unittest.skipUnless(c_functools, 'requires the C _functools module') 750class TestReduce(unittest.TestCase): 751 if c_functools: 752 func = c_functools.reduce 753 754 def test_reduce(self): 755 class Squares: 756 def __init__(self, max): 757 self.max = max 758 self.sofar = [] 759 760 def __len__(self): 761 return len(self.sofar) 762 763 def __getitem__(self, i): 764 if not 0 <= i < self.max: raise IndexError 765 n = len(self.sofar) 766 while n <= i: 767 self.sofar.append(n*n) 768 n += 1 769 return self.sofar[i] 770 def add(x, y): 771 return x + y 772 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc') 773 self.assertEqual( 774 self.func(add, [['a', 'c'], [], ['d', 'w']], []), 775 ['a','c','d','w'] 776 ) 777 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040) 778 self.assertEqual( 779 self.func(lambda x, y: x*y, range(2,21), 1), 780 2432902008176640000 781 ) 782 self.assertEqual(self.func(add, Squares(10)), 285) 783 self.assertEqual(self.func(add, Squares(10), 0), 285) 784 self.assertEqual(self.func(add, Squares(0), 0), 0) 785 self.assertRaises(TypeError, self.func) 786 self.assertRaises(TypeError, self.func, 42, 42) 787 self.assertRaises(TypeError, self.func, 42, 42, 42) 788 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item 789 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item 790 self.assertRaises(TypeError, self.func, 42, (42, 42)) 791 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value 792 self.assertRaises(TypeError, self.func, add, "") 793 self.assertRaises(TypeError, self.func, add, ()) 794 self.assertRaises(TypeError, self.func, add, object()) 795 796 class TestFailingIter: 797 def __iter__(self): 798 raise RuntimeError 799 self.assertRaises(RuntimeError, self.func, add, TestFailingIter()) 800 801 self.assertEqual(self.func(add, [], None), None) 802 self.assertEqual(self.func(add, [], 42), 42) 803 804 class BadSeq: 805 def __getitem__(self, index): 806 raise ValueError 807 self.assertRaises(ValueError, self.func, 42, BadSeq()) 808 809 # Test reduce()'s use of iterators. 810 def test_iterator_usage(self): 811 class SequenceClass: 812 def __init__(self, n): 813 self.n = n 814 def __getitem__(self, i): 815 if 0 <= i < self.n: 816 return i 817 else: 818 raise IndexError 819 820 from operator import add 821 self.assertEqual(self.func(add, SequenceClass(5)), 10) 822 self.assertEqual(self.func(add, SequenceClass(5), 42), 52) 823 self.assertRaises(TypeError, self.func, add, SequenceClass(0)) 824 self.assertEqual(self.func(add, SequenceClass(0), 42), 42) 825 self.assertEqual(self.func(add, SequenceClass(1)), 0) 826 self.assertEqual(self.func(add, SequenceClass(1), 42), 42) 827 828 d = {"one": 1, "two": 2, "three": 3} 829 self.assertEqual(self.func(add, d), "".join(d.keys())) 830 831 832class TestCmpToKey: 833 834 def test_cmp_to_key(self): 835 def cmp1(x, y): 836 return (x > y) - (x < y) 837 key = self.cmp_to_key(cmp1) 838 self.assertEqual(key(3), key(3)) 839 self.assertGreater(key(3), key(1)) 840 self.assertGreaterEqual(key(3), key(3)) 841 842 def cmp2(x, y): 843 return int(x) - int(y) 844 key = self.cmp_to_key(cmp2) 845 self.assertEqual(key(4.0), key('4')) 846 self.assertLess(key(2), key('35')) 847 self.assertLessEqual(key(2), key('35')) 848 self.assertNotEqual(key(2), key('35')) 849 850 def test_cmp_to_key_arguments(self): 851 def cmp1(x, y): 852 return (x > y) - (x < y) 853 key = self.cmp_to_key(mycmp=cmp1) 854 self.assertEqual(key(obj=3), key(obj=3)) 855 self.assertGreater(key(obj=3), key(obj=1)) 856 with self.assertRaises((TypeError, AttributeError)): 857 key(3) > 1 # rhs is not a K object 858 with self.assertRaises((TypeError, AttributeError)): 859 1 < key(3) # lhs is not a K object 860 with self.assertRaises(TypeError): 861 key = self.cmp_to_key() # too few args 862 with self.assertRaises(TypeError): 863 key = self.cmp_to_key(cmp1, None) # too many args 864 key = self.cmp_to_key(cmp1) 865 with self.assertRaises(TypeError): 866 key() # too few args 867 with self.assertRaises(TypeError): 868 key(None, None) # too many args 869 870 def test_bad_cmp(self): 871 def cmp1(x, y): 872 raise ZeroDivisionError 873 key = self.cmp_to_key(cmp1) 874 with self.assertRaises(ZeroDivisionError): 875 key(3) > key(1) 876 877 class BadCmp: 878 def __lt__(self, other): 879 raise ZeroDivisionError 880 def cmp1(x, y): 881 return BadCmp() 882 with self.assertRaises(ZeroDivisionError): 883 key(3) > key(1) 884 885 def test_obj_field(self): 886 def cmp1(x, y): 887 return (x > y) - (x < y) 888 key = self.cmp_to_key(mycmp=cmp1) 889 self.assertEqual(key(50).obj, 50) 890 891 def test_sort_int(self): 892 def mycmp(x, y): 893 return y - x 894 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)), 895 [4, 3, 2, 1, 0]) 896 897 def test_sort_int_str(self): 898 def mycmp(x, y): 899 x, y = int(x), int(y) 900 return (x > y) - (x < y) 901 values = [5, '3', 7, 2, '0', '1', 4, '10', 1] 902 values = sorted(values, key=self.cmp_to_key(mycmp)) 903 self.assertEqual([int(value) for value in values], 904 [0, 1, 1, 2, 3, 4, 5, 7, 10]) 905 906 def test_hash(self): 907 def mycmp(x, y): 908 return y - x 909 key = self.cmp_to_key(mycmp) 910 k = key(10) 911 self.assertRaises(TypeError, hash, k) 912 self.assertNotIsInstance(k, collections.abc.Hashable) 913 914 915@unittest.skipUnless(c_functools, 'requires the C _functools module') 916class TestCmpToKeyC(TestCmpToKey, unittest.TestCase): 917 if c_functools: 918 cmp_to_key = c_functools.cmp_to_key 919 920 921class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase): 922 cmp_to_key = staticmethod(py_functools.cmp_to_key) 923 924 925class TestTotalOrdering(unittest.TestCase): 926 927 def test_total_ordering_lt(self): 928 @functools.total_ordering 929 class A: 930 def __init__(self, value): 931 self.value = value 932 def __lt__(self, other): 933 return self.value < other.value 934 def __eq__(self, other): 935 return self.value == other.value 936 self.assertTrue(A(1) < A(2)) 937 self.assertTrue(A(2) > A(1)) 938 self.assertTrue(A(1) <= A(2)) 939 self.assertTrue(A(2) >= A(1)) 940 self.assertTrue(A(2) <= A(2)) 941 self.assertTrue(A(2) >= A(2)) 942 self.assertFalse(A(1) > A(2)) 943 944 def test_total_ordering_le(self): 945 @functools.total_ordering 946 class A: 947 def __init__(self, value): 948 self.value = value 949 def __le__(self, other): 950 return self.value <= other.value 951 def __eq__(self, other): 952 return self.value == other.value 953 self.assertTrue(A(1) < A(2)) 954 self.assertTrue(A(2) > A(1)) 955 self.assertTrue(A(1) <= A(2)) 956 self.assertTrue(A(2) >= A(1)) 957 self.assertTrue(A(2) <= A(2)) 958 self.assertTrue(A(2) >= A(2)) 959 self.assertFalse(A(1) >= A(2)) 960 961 def test_total_ordering_gt(self): 962 @functools.total_ordering 963 class A: 964 def __init__(self, value): 965 self.value = value 966 def __gt__(self, other): 967 return self.value > other.value 968 def __eq__(self, other): 969 return self.value == other.value 970 self.assertTrue(A(1) < A(2)) 971 self.assertTrue(A(2) > A(1)) 972 self.assertTrue(A(1) <= A(2)) 973 self.assertTrue(A(2) >= A(1)) 974 self.assertTrue(A(2) <= A(2)) 975 self.assertTrue(A(2) >= A(2)) 976 self.assertFalse(A(2) < A(1)) 977 978 def test_total_ordering_ge(self): 979 @functools.total_ordering 980 class A: 981 def __init__(self, value): 982 self.value = value 983 def __ge__(self, other): 984 return self.value >= other.value 985 def __eq__(self, other): 986 return self.value == other.value 987 self.assertTrue(A(1) < A(2)) 988 self.assertTrue(A(2) > A(1)) 989 self.assertTrue(A(1) <= A(2)) 990 self.assertTrue(A(2) >= A(1)) 991 self.assertTrue(A(2) <= A(2)) 992 self.assertTrue(A(2) >= A(2)) 993 self.assertFalse(A(2) <= A(1)) 994 995 def test_total_ordering_no_overwrite(self): 996 # new methods should not overwrite existing 997 @functools.total_ordering 998 class A(int): 999 pass 1000 self.assertTrue(A(1) < A(2)) 1001 self.assertTrue(A(2) > A(1)) 1002 self.assertTrue(A(1) <= A(2)) 1003 self.assertTrue(A(2) >= A(1)) 1004 self.assertTrue(A(2) <= A(2)) 1005 self.assertTrue(A(2) >= A(2)) 1006 1007 def test_no_operations_defined(self): 1008 with self.assertRaises(ValueError): 1009 @functools.total_ordering 1010 class A: 1011 pass 1012 1013 def test_type_error_when_not_implemented(self): 1014 # bug 10042; ensure stack overflow does not occur 1015 # when decorated types return NotImplemented 1016 @functools.total_ordering 1017 class ImplementsLessThan: 1018 def __init__(self, value): 1019 self.value = value 1020 def __eq__(self, other): 1021 if isinstance(other, ImplementsLessThan): 1022 return self.value == other.value 1023 return False 1024 def __lt__(self, other): 1025 if isinstance(other, ImplementsLessThan): 1026 return self.value < other.value 1027 return NotImplemented 1028 1029 @functools.total_ordering 1030 class ImplementsGreaterThan: 1031 def __init__(self, value): 1032 self.value = value 1033 def __eq__(self, other): 1034 if isinstance(other, ImplementsGreaterThan): 1035 return self.value == other.value 1036 return False 1037 def __gt__(self, other): 1038 if isinstance(other, ImplementsGreaterThan): 1039 return self.value > other.value 1040 return NotImplemented 1041 1042 @functools.total_ordering 1043 class ImplementsLessThanEqualTo: 1044 def __init__(self, value): 1045 self.value = value 1046 def __eq__(self, other): 1047 if isinstance(other, ImplementsLessThanEqualTo): 1048 return self.value == other.value 1049 return False 1050 def __le__(self, other): 1051 if isinstance(other, ImplementsLessThanEqualTo): 1052 return self.value <= other.value 1053 return NotImplemented 1054 1055 @functools.total_ordering 1056 class ImplementsGreaterThanEqualTo: 1057 def __init__(self, value): 1058 self.value = value 1059 def __eq__(self, other): 1060 if isinstance(other, ImplementsGreaterThanEqualTo): 1061 return self.value == other.value 1062 return False 1063 def __ge__(self, other): 1064 if isinstance(other, ImplementsGreaterThanEqualTo): 1065 return self.value >= other.value 1066 return NotImplemented 1067 1068 @functools.total_ordering 1069 class ComparatorNotImplemented: 1070 def __init__(self, value): 1071 self.value = value 1072 def __eq__(self, other): 1073 if isinstance(other, ComparatorNotImplemented): 1074 return self.value == other.value 1075 return False 1076 def __lt__(self, other): 1077 return NotImplemented 1078 1079 with self.subTest("LT < 1"), self.assertRaises(TypeError): 1080 ImplementsLessThan(-1) < 1 1081 1082 with self.subTest("LT < LE"), self.assertRaises(TypeError): 1083 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0) 1084 1085 with self.subTest("LT < GT"), self.assertRaises(TypeError): 1086 ImplementsLessThan(1) < ImplementsGreaterThan(1) 1087 1088 with self.subTest("LE <= LT"), self.assertRaises(TypeError): 1089 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2) 1090 1091 with self.subTest("LE <= GE"), self.assertRaises(TypeError): 1092 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3) 1093 1094 with self.subTest("GT > GE"), self.assertRaises(TypeError): 1095 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4) 1096 1097 with self.subTest("GT > LT"), self.assertRaises(TypeError): 1098 ImplementsGreaterThan(5) > ImplementsLessThan(5) 1099 1100 with self.subTest("GE >= GT"), self.assertRaises(TypeError): 1101 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6) 1102 1103 with self.subTest("GE >= LE"), self.assertRaises(TypeError): 1104 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7) 1105 1106 with self.subTest("GE when equal"): 1107 a = ComparatorNotImplemented(8) 1108 b = ComparatorNotImplemented(8) 1109 self.assertEqual(a, b) 1110 with self.assertRaises(TypeError): 1111 a >= b 1112 1113 with self.subTest("LE when equal"): 1114 a = ComparatorNotImplemented(9) 1115 b = ComparatorNotImplemented(9) 1116 self.assertEqual(a, b) 1117 with self.assertRaises(TypeError): 1118 a <= b 1119 1120 def test_pickle(self): 1121 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 1122 for name in '__lt__', '__gt__', '__le__', '__ge__': 1123 with self.subTest(method=name, proto=proto): 1124 method = getattr(Orderable_LT, name) 1125 method_copy = pickle.loads(pickle.dumps(method, proto)) 1126 self.assertIs(method_copy, method) 1127 1128@functools.total_ordering 1129class Orderable_LT: 1130 def __init__(self, value): 1131 self.value = value 1132 def __lt__(self, other): 1133 return self.value < other.value 1134 def __eq__(self, other): 1135 return self.value == other.value 1136 1137 1138class TestLRU: 1139 1140 def test_lru(self): 1141 def orig(x, y): 1142 return 3 * x + y 1143 f = self.module.lru_cache(maxsize=20)(orig) 1144 hits, misses, maxsize, currsize = f.cache_info() 1145 self.assertEqual(maxsize, 20) 1146 self.assertEqual(currsize, 0) 1147 self.assertEqual(hits, 0) 1148 self.assertEqual(misses, 0) 1149 1150 domain = range(5) 1151 for i in range(1000): 1152 x, y = choice(domain), choice(domain) 1153 actual = f(x, y) 1154 expected = orig(x, y) 1155 self.assertEqual(actual, expected) 1156 hits, misses, maxsize, currsize = f.cache_info() 1157 self.assertTrue(hits > misses) 1158 self.assertEqual(hits + misses, 1000) 1159 self.assertEqual(currsize, 20) 1160 1161 f.cache_clear() # test clearing 1162 hits, misses, maxsize, currsize = f.cache_info() 1163 self.assertEqual(hits, 0) 1164 self.assertEqual(misses, 0) 1165 self.assertEqual(currsize, 0) 1166 f(x, y) 1167 hits, misses, maxsize, currsize = f.cache_info() 1168 self.assertEqual(hits, 0) 1169 self.assertEqual(misses, 1) 1170 self.assertEqual(currsize, 1) 1171 1172 # Test bypassing the cache 1173 self.assertIs(f.__wrapped__, orig) 1174 f.__wrapped__(x, y) 1175 hits, misses, maxsize, currsize = f.cache_info() 1176 self.assertEqual(hits, 0) 1177 self.assertEqual(misses, 1) 1178 self.assertEqual(currsize, 1) 1179 1180 # test size zero (which means "never-cache") 1181 @self.module.lru_cache(0) 1182 def f(): 1183 nonlocal f_cnt 1184 f_cnt += 1 1185 return 20 1186 self.assertEqual(f.cache_info().maxsize, 0) 1187 f_cnt = 0 1188 for i in range(5): 1189 self.assertEqual(f(), 20) 1190 self.assertEqual(f_cnt, 5) 1191 hits, misses, maxsize, currsize = f.cache_info() 1192 self.assertEqual(hits, 0) 1193 self.assertEqual(misses, 5) 1194 self.assertEqual(currsize, 0) 1195 1196 # test size one 1197 @self.module.lru_cache(1) 1198 def f(): 1199 nonlocal f_cnt 1200 f_cnt += 1 1201 return 20 1202 self.assertEqual(f.cache_info().maxsize, 1) 1203 f_cnt = 0 1204 for i in range(5): 1205 self.assertEqual(f(), 20) 1206 self.assertEqual(f_cnt, 1) 1207 hits, misses, maxsize, currsize = f.cache_info() 1208 self.assertEqual(hits, 4) 1209 self.assertEqual(misses, 1) 1210 self.assertEqual(currsize, 1) 1211 1212 # test size two 1213 @self.module.lru_cache(2) 1214 def f(x): 1215 nonlocal f_cnt 1216 f_cnt += 1 1217 return x*10 1218 self.assertEqual(f.cache_info().maxsize, 2) 1219 f_cnt = 0 1220 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7: 1221 # * * * * 1222 self.assertEqual(f(x), x*10) 1223 self.assertEqual(f_cnt, 4) 1224 hits, misses, maxsize, currsize = f.cache_info() 1225 self.assertEqual(hits, 12) 1226 self.assertEqual(misses, 4) 1227 self.assertEqual(currsize, 2) 1228 1229 def test_lru_bug_35780(self): 1230 # C version of the lru_cache was not checking to see if 1231 # the user function call has already modified the cache 1232 # (this arises in recursive calls and in multi-threading). 1233 # This cause the cache to have orphan links not referenced 1234 # by the cache dictionary. 1235 1236 once = True # Modified by f(x) below 1237 1238 @self.module.lru_cache(maxsize=10) 1239 def f(x): 1240 nonlocal once 1241 rv = f'.{x}.' 1242 if x == 20 and once: 1243 once = False 1244 rv = f(x) 1245 return rv 1246 1247 # Fill the cache 1248 for x in range(15): 1249 self.assertEqual(f(x), f'.{x}.') 1250 self.assertEqual(f.cache_info().currsize, 10) 1251 1252 # Make a recursive call and make sure the cache remains full 1253 self.assertEqual(f(20), '.20.') 1254 self.assertEqual(f.cache_info().currsize, 10) 1255 1256 def test_lru_hash_only_once(self): 1257 # To protect against weird reentrancy bugs and to improve 1258 # efficiency when faced with slow __hash__ methods, the 1259 # LRU cache guarantees that it will only call __hash__ 1260 # only once per use as an argument to the cached function. 1261 1262 @self.module.lru_cache(maxsize=1) 1263 def f(x, y): 1264 return x * 3 + y 1265 1266 # Simulate the integer 5 1267 mock_int = unittest.mock.Mock() 1268 mock_int.__mul__ = unittest.mock.Mock(return_value=15) 1269 mock_int.__hash__ = unittest.mock.Mock(return_value=999) 1270 1271 # Add to cache: One use as an argument gives one call 1272 self.assertEqual(f(mock_int, 1), 16) 1273 self.assertEqual(mock_int.__hash__.call_count, 1) 1274 self.assertEqual(f.cache_info(), (0, 1, 1, 1)) 1275 1276 # Cache hit: One use as an argument gives one additional call 1277 self.assertEqual(f(mock_int, 1), 16) 1278 self.assertEqual(mock_int.__hash__.call_count, 2) 1279 self.assertEqual(f.cache_info(), (1, 1, 1, 1)) 1280 1281 # Cache eviction: No use as an argument gives no additional call 1282 self.assertEqual(f(6, 2), 20) 1283 self.assertEqual(mock_int.__hash__.call_count, 2) 1284 self.assertEqual(f.cache_info(), (1, 2, 1, 1)) 1285 1286 # Cache miss: One use as an argument gives one additional call 1287 self.assertEqual(f(mock_int, 1), 16) 1288 self.assertEqual(mock_int.__hash__.call_count, 3) 1289 self.assertEqual(f.cache_info(), (1, 3, 1, 1)) 1290 1291 def test_lru_reentrancy_with_len(self): 1292 # Test to make sure the LRU cache code isn't thrown-off by 1293 # caching the built-in len() function. Since len() can be 1294 # cached, we shouldn't use it inside the lru code itself. 1295 old_len = builtins.len 1296 try: 1297 builtins.len = self.module.lru_cache(4)(len) 1298 for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]: 1299 self.assertEqual(len('abcdefghijklmn'[:i]), i) 1300 finally: 1301 builtins.len = old_len 1302 1303 def test_lru_star_arg_handling(self): 1304 # Test regression that arose in ea064ff3c10f 1305 @functools.lru_cache() 1306 def f(*args): 1307 return args 1308 1309 self.assertEqual(f(1, 2), (1, 2)) 1310 self.assertEqual(f((1, 2)), ((1, 2),)) 1311 1312 def test_lru_type_error(self): 1313 # Regression test for issue #28653. 1314 # lru_cache was leaking when one of the arguments 1315 # wasn't cacheable. 1316 1317 @functools.lru_cache(maxsize=None) 1318 def infinite_cache(o): 1319 pass 1320 1321 @functools.lru_cache(maxsize=10) 1322 def limited_cache(o): 1323 pass 1324 1325 with self.assertRaises(TypeError): 1326 infinite_cache([]) 1327 1328 with self.assertRaises(TypeError): 1329 limited_cache([]) 1330 1331 def test_lru_with_maxsize_none(self): 1332 @self.module.lru_cache(maxsize=None) 1333 def fib(n): 1334 if n < 2: 1335 return n 1336 return fib(n-1) + fib(n-2) 1337 self.assertEqual([fib(n) for n in range(16)], 1338 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]) 1339 self.assertEqual(fib.cache_info(), 1340 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16)) 1341 fib.cache_clear() 1342 self.assertEqual(fib.cache_info(), 1343 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)) 1344 1345 def test_lru_with_maxsize_negative(self): 1346 @self.module.lru_cache(maxsize=-10) 1347 def eq(n): 1348 return n 1349 for i in (0, 1): 1350 self.assertEqual([eq(n) for n in range(150)], list(range(150))) 1351 self.assertEqual(eq.cache_info(), 1352 self.module._CacheInfo(hits=0, misses=300, maxsize=0, currsize=0)) 1353 1354 def test_lru_with_exceptions(self): 1355 # Verify that user_function exceptions get passed through without 1356 # creating a hard-to-read chained exception. 1357 # http://bugs.python.org/issue13177 1358 for maxsize in (None, 128): 1359 @self.module.lru_cache(maxsize) 1360 def func(i): 1361 return 'abc'[i] 1362 self.assertEqual(func(0), 'a') 1363 with self.assertRaises(IndexError) as cm: 1364 func(15) 1365 self.assertIsNone(cm.exception.__context__) 1366 # Verify that the previous exception did not result in a cached entry 1367 with self.assertRaises(IndexError): 1368 func(15) 1369 1370 def test_lru_with_types(self): 1371 for maxsize in (None, 128): 1372 @self.module.lru_cache(maxsize=maxsize, typed=True) 1373 def square(x): 1374 return x * x 1375 self.assertEqual(square(3), 9) 1376 self.assertEqual(type(square(3)), type(9)) 1377 self.assertEqual(square(3.0), 9.0) 1378 self.assertEqual(type(square(3.0)), type(9.0)) 1379 self.assertEqual(square(x=3), 9) 1380 self.assertEqual(type(square(x=3)), type(9)) 1381 self.assertEqual(square(x=3.0), 9.0) 1382 self.assertEqual(type(square(x=3.0)), type(9.0)) 1383 self.assertEqual(square.cache_info().hits, 4) 1384 self.assertEqual(square.cache_info().misses, 4) 1385 1386 def test_lru_with_keyword_args(self): 1387 @self.module.lru_cache() 1388 def fib(n): 1389 if n < 2: 1390 return n 1391 return fib(n=n-1) + fib(n=n-2) 1392 self.assertEqual( 1393 [fib(n=number) for number in range(16)], 1394 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610] 1395 ) 1396 self.assertEqual(fib.cache_info(), 1397 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16)) 1398 fib.cache_clear() 1399 self.assertEqual(fib.cache_info(), 1400 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0)) 1401 1402 def test_lru_with_keyword_args_maxsize_none(self): 1403 @self.module.lru_cache(maxsize=None) 1404 def fib(n): 1405 if n < 2: 1406 return n 1407 return fib(n=n-1) + fib(n=n-2) 1408 self.assertEqual([fib(n=number) for number in range(16)], 1409 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]) 1410 self.assertEqual(fib.cache_info(), 1411 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16)) 1412 fib.cache_clear() 1413 self.assertEqual(fib.cache_info(), 1414 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)) 1415 1416 def test_kwargs_order(self): 1417 # PEP 468: Preserving Keyword Argument Order 1418 @self.module.lru_cache(maxsize=10) 1419 def f(**kwargs): 1420 return list(kwargs.items()) 1421 self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)]) 1422 self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)]) 1423 self.assertEqual(f.cache_info(), 1424 self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2)) 1425 1426 def test_lru_cache_decoration(self): 1427 def f(zomg: 'zomg_annotation'): 1428 """f doc string""" 1429 return 42 1430 g = self.module.lru_cache()(f) 1431 for attr in self.module.WRAPPER_ASSIGNMENTS: 1432 self.assertEqual(getattr(g, attr), getattr(f, attr)) 1433 1434 def test_lru_cache_threaded(self): 1435 n, m = 5, 11 1436 def orig(x, y): 1437 return 3 * x + y 1438 f = self.module.lru_cache(maxsize=n*m)(orig) 1439 hits, misses, maxsize, currsize = f.cache_info() 1440 self.assertEqual(currsize, 0) 1441 1442 start = threading.Event() 1443 def full(k): 1444 start.wait(10) 1445 for _ in range(m): 1446 self.assertEqual(f(k, 0), orig(k, 0)) 1447 1448 def clear(): 1449 start.wait(10) 1450 for _ in range(2*m): 1451 f.cache_clear() 1452 1453 orig_si = sys.getswitchinterval() 1454 support.setswitchinterval(1e-6) 1455 try: 1456 # create n threads in order to fill cache 1457 threads = [threading.Thread(target=full, args=[k]) 1458 for k in range(n)] 1459 with support.start_threads(threads): 1460 start.set() 1461 1462 hits, misses, maxsize, currsize = f.cache_info() 1463 if self.module is py_functools: 1464 # XXX: Why can be not equal? 1465 self.assertLessEqual(misses, n) 1466 self.assertLessEqual(hits, m*n - misses) 1467 else: 1468 self.assertEqual(misses, n) 1469 self.assertEqual(hits, m*n - misses) 1470 self.assertEqual(currsize, n) 1471 1472 # create n threads in order to fill cache and 1 to clear it 1473 threads = [threading.Thread(target=clear)] 1474 threads += [threading.Thread(target=full, args=[k]) 1475 for k in range(n)] 1476 start.clear() 1477 with support.start_threads(threads): 1478 start.set() 1479 finally: 1480 sys.setswitchinterval(orig_si) 1481 1482 def test_lru_cache_threaded2(self): 1483 # Simultaneous call with the same arguments 1484 n, m = 5, 7 1485 start = threading.Barrier(n+1) 1486 pause = threading.Barrier(n+1) 1487 stop = threading.Barrier(n+1) 1488 @self.module.lru_cache(maxsize=m*n) 1489 def f(x): 1490 pause.wait(10) 1491 return 3 * x 1492 self.assertEqual(f.cache_info(), (0, 0, m*n, 0)) 1493 def test(): 1494 for i in range(m): 1495 start.wait(10) 1496 self.assertEqual(f(i), 3 * i) 1497 stop.wait(10) 1498 threads = [threading.Thread(target=test) for k in range(n)] 1499 with support.start_threads(threads): 1500 for i in range(m): 1501 start.wait(10) 1502 stop.reset() 1503 pause.wait(10) 1504 start.reset() 1505 stop.wait(10) 1506 pause.reset() 1507 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1)) 1508 1509 def test_lru_cache_threaded3(self): 1510 @self.module.lru_cache(maxsize=2) 1511 def f(x): 1512 time.sleep(.01) 1513 return 3 * x 1514 def test(i, x): 1515 with self.subTest(thread=i): 1516 self.assertEqual(f(x), 3 * x, i) 1517 threads = [threading.Thread(target=test, args=(i, v)) 1518 for i, v in enumerate([1, 2, 2, 3, 2])] 1519 with support.start_threads(threads): 1520 pass 1521 1522 def test_need_for_rlock(self): 1523 # This will deadlock on an LRU cache that uses a regular lock 1524 1525 @self.module.lru_cache(maxsize=10) 1526 def test_func(x): 1527 'Used to demonstrate a reentrant lru_cache call within a single thread' 1528 return x 1529 1530 class DoubleEq: 1531 'Demonstrate a reentrant lru_cache call within a single thread' 1532 def __init__(self, x): 1533 self.x = x 1534 def __hash__(self): 1535 return self.x 1536 def __eq__(self, other): 1537 if self.x == 2: 1538 test_func(DoubleEq(1)) 1539 return self.x == other.x 1540 1541 test_func(DoubleEq(1)) # Load the cache 1542 test_func(DoubleEq(2)) # Load the cache 1543 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call 1544 DoubleEq(2)) # Verify the correct return value 1545 1546 def test_early_detection_of_bad_call(self): 1547 # Issue #22184 1548 with self.assertRaises(TypeError): 1549 @functools.lru_cache 1550 def f(): 1551 pass 1552 1553 def test_lru_method(self): 1554 class X(int): 1555 f_cnt = 0 1556 @self.module.lru_cache(2) 1557 def f(self, x): 1558 self.f_cnt += 1 1559 return x*10+self 1560 a = X(5) 1561 b = X(5) 1562 c = X(7) 1563 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0)) 1564 1565 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3: 1566 self.assertEqual(a.f(x), x*10 + 5) 1567 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0)) 1568 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2)) 1569 1570 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2: 1571 self.assertEqual(b.f(x), x*10 + 5) 1572 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0)) 1573 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2)) 1574 1575 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1: 1576 self.assertEqual(c.f(x), x*10 + 7) 1577 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5)) 1578 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2)) 1579 1580 self.assertEqual(a.f.cache_info(), X.f.cache_info()) 1581 self.assertEqual(b.f.cache_info(), X.f.cache_info()) 1582 self.assertEqual(c.f.cache_info(), X.f.cache_info()) 1583 1584 def test_pickle(self): 1585 cls = self.__class__ 1586 for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth: 1587 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 1588 with self.subTest(proto=proto, func=f): 1589 f_copy = pickle.loads(pickle.dumps(f, proto)) 1590 self.assertIs(f_copy, f) 1591 1592 def test_copy(self): 1593 cls = self.__class__ 1594 def orig(x, y): 1595 return 3 * x + y 1596 part = self.module.partial(orig, 2) 1597 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth, 1598 self.module.lru_cache(2)(part)) 1599 for f in funcs: 1600 with self.subTest(func=f): 1601 f_copy = copy.copy(f) 1602 self.assertIs(f_copy, f) 1603 1604 def test_deepcopy(self): 1605 cls = self.__class__ 1606 def orig(x, y): 1607 return 3 * x + y 1608 part = self.module.partial(orig, 2) 1609 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth, 1610 self.module.lru_cache(2)(part)) 1611 for f in funcs: 1612 with self.subTest(func=f): 1613 f_copy = copy.deepcopy(f) 1614 self.assertIs(f_copy, f) 1615 1616 1617@py_functools.lru_cache() 1618def py_cached_func(x, y): 1619 return 3 * x + y 1620 1621@c_functools.lru_cache() 1622def c_cached_func(x, y): 1623 return 3 * x + y 1624 1625 1626class TestLRUPy(TestLRU, unittest.TestCase): 1627 module = py_functools 1628 cached_func = py_cached_func, 1629 1630 @module.lru_cache() 1631 def cached_meth(self, x, y): 1632 return 3 * x + y 1633 1634 @staticmethod 1635 @module.lru_cache() 1636 def cached_staticmeth(x, y): 1637 return 3 * x + y 1638 1639 1640class TestLRUC(TestLRU, unittest.TestCase): 1641 module = c_functools 1642 cached_func = c_cached_func, 1643 1644 @module.lru_cache() 1645 def cached_meth(self, x, y): 1646 return 3 * x + y 1647 1648 @staticmethod 1649 @module.lru_cache() 1650 def cached_staticmeth(x, y): 1651 return 3 * x + y 1652 1653 1654class TestSingleDispatch(unittest.TestCase): 1655 def test_simple_overloads(self): 1656 @functools.singledispatch 1657 def g(obj): 1658 return "base" 1659 def g_int(i): 1660 return "integer" 1661 g.register(int, g_int) 1662 self.assertEqual(g("str"), "base") 1663 self.assertEqual(g(1), "integer") 1664 self.assertEqual(g([1,2,3]), "base") 1665 1666 def test_mro(self): 1667 @functools.singledispatch 1668 def g(obj): 1669 return "base" 1670 class A: 1671 pass 1672 class C(A): 1673 pass 1674 class B(A): 1675 pass 1676 class D(C, B): 1677 pass 1678 def g_A(a): 1679 return "A" 1680 def g_B(b): 1681 return "B" 1682 g.register(A, g_A) 1683 g.register(B, g_B) 1684 self.assertEqual(g(A()), "A") 1685 self.assertEqual(g(B()), "B") 1686 self.assertEqual(g(C()), "A") 1687 self.assertEqual(g(D()), "B") 1688 1689 def test_register_decorator(self): 1690 @functools.singledispatch 1691 def g(obj): 1692 return "base" 1693 @g.register(int) 1694 def g_int(i): 1695 return "int %s" % (i,) 1696 self.assertEqual(g(""), "base") 1697 self.assertEqual(g(12), "int 12") 1698 self.assertIs(g.dispatch(int), g_int) 1699 self.assertIs(g.dispatch(object), g.dispatch(str)) 1700 # Note: in the assert above this is not g. 1701 # @singledispatch returns the wrapper. 1702 1703 def test_wrapping_attributes(self): 1704 @functools.singledispatch 1705 def g(obj): 1706 "Simple test" 1707 return "Test" 1708 self.assertEqual(g.__name__, "g") 1709 if sys.flags.optimize < 2: 1710 self.assertEqual(g.__doc__, "Simple test") 1711 1712 @unittest.skipUnless(decimal, 'requires _decimal') 1713 @support.cpython_only 1714 def test_c_classes(self): 1715 @functools.singledispatch 1716 def g(obj): 1717 return "base" 1718 @g.register(decimal.DecimalException) 1719 def _(obj): 1720 return obj.args 1721 subn = decimal.Subnormal("Exponent < Emin") 1722 rnd = decimal.Rounded("Number got rounded") 1723 self.assertEqual(g(subn), ("Exponent < Emin",)) 1724 self.assertEqual(g(rnd), ("Number got rounded",)) 1725 @g.register(decimal.Subnormal) 1726 def _(obj): 1727 return "Too small to care." 1728 self.assertEqual(g(subn), "Too small to care.") 1729 self.assertEqual(g(rnd), ("Number got rounded",)) 1730 1731 def test_compose_mro(self): 1732 # None of the examples in this test depend on haystack ordering. 1733 c = collections.abc 1734 mro = functools._compose_mro 1735 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set] 1736 for haystack in permutations(bases): 1737 m = mro(dict, haystack) 1738 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, 1739 c.Collection, c.Sized, c.Iterable, 1740 c.Container, object]) 1741 bases = [c.Container, c.Mapping, c.MutableMapping, collections.OrderedDict] 1742 for haystack in permutations(bases): 1743 m = mro(collections.ChainMap, haystack) 1744 self.assertEqual(m, [collections.ChainMap, c.MutableMapping, c.Mapping, 1745 c.Collection, c.Sized, c.Iterable, 1746 c.Container, object]) 1747 1748 # If there's a generic function with implementations registered for 1749 # both Sized and Container, passing a defaultdict to it results in an 1750 # ambiguous dispatch which will cause a RuntimeError (see 1751 # test_mro_conflicts). 1752 bases = [c.Container, c.Sized, str] 1753 for haystack in permutations(bases): 1754 m = mro(collections.defaultdict, [c.Sized, c.Container, str]) 1755 self.assertEqual(m, [collections.defaultdict, dict, c.Sized, 1756 c.Container, object]) 1757 1758 # MutableSequence below is registered directly on D. In other words, it 1759 # precedes MutableMapping which means single dispatch will always 1760 # choose MutableSequence here. 1761 class D(collections.defaultdict): 1762 pass 1763 c.MutableSequence.register(D) 1764 bases = [c.MutableSequence, c.MutableMapping] 1765 for haystack in permutations(bases): 1766 m = mro(D, bases) 1767 self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible, 1768 collections.defaultdict, dict, c.MutableMapping, c.Mapping, 1769 c.Collection, c.Sized, c.Iterable, c.Container, 1770 object]) 1771 1772 # Container and Callable are registered on different base classes and 1773 # a generic function supporting both should always pick the Callable 1774 # implementation if a C instance is passed. 1775 class C(collections.defaultdict): 1776 def __call__(self): 1777 pass 1778 bases = [c.Sized, c.Callable, c.Container, c.Mapping] 1779 for haystack in permutations(bases): 1780 m = mro(C, haystack) 1781 self.assertEqual(m, [C, c.Callable, collections.defaultdict, dict, c.Mapping, 1782 c.Collection, c.Sized, c.Iterable, 1783 c.Container, object]) 1784 1785 def test_register_abc(self): 1786 c = collections.abc 1787 d = {"a": "b"} 1788 l = [1, 2, 3] 1789 s = {object(), None} 1790 f = frozenset(s) 1791 t = (1, 2, 3) 1792 @functools.singledispatch 1793 def g(obj): 1794 return "base" 1795 self.assertEqual(g(d), "base") 1796 self.assertEqual(g(l), "base") 1797 self.assertEqual(g(s), "base") 1798 self.assertEqual(g(f), "base") 1799 self.assertEqual(g(t), "base") 1800 g.register(c.Sized, lambda obj: "sized") 1801 self.assertEqual(g(d), "sized") 1802 self.assertEqual(g(l), "sized") 1803 self.assertEqual(g(s), "sized") 1804 self.assertEqual(g(f), "sized") 1805 self.assertEqual(g(t), "sized") 1806 g.register(c.MutableMapping, lambda obj: "mutablemapping") 1807 self.assertEqual(g(d), "mutablemapping") 1808 self.assertEqual(g(l), "sized") 1809 self.assertEqual(g(s), "sized") 1810 self.assertEqual(g(f), "sized") 1811 self.assertEqual(g(t), "sized") 1812 g.register(collections.ChainMap, lambda obj: "chainmap") 1813 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered 1814 self.assertEqual(g(l), "sized") 1815 self.assertEqual(g(s), "sized") 1816 self.assertEqual(g(f), "sized") 1817 self.assertEqual(g(t), "sized") 1818 g.register(c.MutableSequence, lambda obj: "mutablesequence") 1819 self.assertEqual(g(d), "mutablemapping") 1820 self.assertEqual(g(l), "mutablesequence") 1821 self.assertEqual(g(s), "sized") 1822 self.assertEqual(g(f), "sized") 1823 self.assertEqual(g(t), "sized") 1824 g.register(c.MutableSet, lambda obj: "mutableset") 1825 self.assertEqual(g(d), "mutablemapping") 1826 self.assertEqual(g(l), "mutablesequence") 1827 self.assertEqual(g(s), "mutableset") 1828 self.assertEqual(g(f), "sized") 1829 self.assertEqual(g(t), "sized") 1830 g.register(c.Mapping, lambda obj: "mapping") 1831 self.assertEqual(g(d), "mutablemapping") # not specific enough 1832 self.assertEqual(g(l), "mutablesequence") 1833 self.assertEqual(g(s), "mutableset") 1834 self.assertEqual(g(f), "sized") 1835 self.assertEqual(g(t), "sized") 1836 g.register(c.Sequence, lambda obj: "sequence") 1837 self.assertEqual(g(d), "mutablemapping") 1838 self.assertEqual(g(l), "mutablesequence") 1839 self.assertEqual(g(s), "mutableset") 1840 self.assertEqual(g(f), "sized") 1841 self.assertEqual(g(t), "sequence") 1842 g.register(c.Set, lambda obj: "set") 1843 self.assertEqual(g(d), "mutablemapping") 1844 self.assertEqual(g(l), "mutablesequence") 1845 self.assertEqual(g(s), "mutableset") 1846 self.assertEqual(g(f), "set") 1847 self.assertEqual(g(t), "sequence") 1848 g.register(dict, lambda obj: "dict") 1849 self.assertEqual(g(d), "dict") 1850 self.assertEqual(g(l), "mutablesequence") 1851 self.assertEqual(g(s), "mutableset") 1852 self.assertEqual(g(f), "set") 1853 self.assertEqual(g(t), "sequence") 1854 g.register(list, lambda obj: "list") 1855 self.assertEqual(g(d), "dict") 1856 self.assertEqual(g(l), "list") 1857 self.assertEqual(g(s), "mutableset") 1858 self.assertEqual(g(f), "set") 1859 self.assertEqual(g(t), "sequence") 1860 g.register(set, lambda obj: "concrete-set") 1861 self.assertEqual(g(d), "dict") 1862 self.assertEqual(g(l), "list") 1863 self.assertEqual(g(s), "concrete-set") 1864 self.assertEqual(g(f), "set") 1865 self.assertEqual(g(t), "sequence") 1866 g.register(frozenset, lambda obj: "frozen-set") 1867 self.assertEqual(g(d), "dict") 1868 self.assertEqual(g(l), "list") 1869 self.assertEqual(g(s), "concrete-set") 1870 self.assertEqual(g(f), "frozen-set") 1871 self.assertEqual(g(t), "sequence") 1872 g.register(tuple, lambda obj: "tuple") 1873 self.assertEqual(g(d), "dict") 1874 self.assertEqual(g(l), "list") 1875 self.assertEqual(g(s), "concrete-set") 1876 self.assertEqual(g(f), "frozen-set") 1877 self.assertEqual(g(t), "tuple") 1878 1879 def test_c3_abc(self): 1880 c = collections.abc 1881 mro = functools._c3_mro 1882 class A(object): 1883 pass 1884 class B(A): 1885 def __len__(self): 1886 return 0 # implies Sized 1887 @c.Container.register 1888 class C(object): 1889 pass 1890 class D(object): 1891 pass # unrelated 1892 class X(D, C, B): 1893 def __call__(self): 1894 pass # implies Callable 1895 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object] 1896 for abcs in permutations([c.Sized, c.Callable, c.Container]): 1897 self.assertEqual(mro(X, abcs=abcs), expected) 1898 # unrelated ABCs don't appear in the resulting MRO 1899 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable] 1900 self.assertEqual(mro(X, abcs=many_abcs), expected) 1901 1902 def test_false_meta(self): 1903 # see issue23572 1904 class MetaA(type): 1905 def __len__(self): 1906 return 0 1907 class A(metaclass=MetaA): 1908 pass 1909 class AA(A): 1910 pass 1911 @functools.singledispatch 1912 def fun(a): 1913 return 'base A' 1914 @fun.register(A) 1915 def _(a): 1916 return 'fun A' 1917 aa = AA() 1918 self.assertEqual(fun(aa), 'fun A') 1919 1920 def test_mro_conflicts(self): 1921 c = collections.abc 1922 @functools.singledispatch 1923 def g(arg): 1924 return "base" 1925 class O(c.Sized): 1926 def __len__(self): 1927 return 0 1928 o = O() 1929 self.assertEqual(g(o), "base") 1930 g.register(c.Iterable, lambda arg: "iterable") 1931 g.register(c.Container, lambda arg: "container") 1932 g.register(c.Sized, lambda arg: "sized") 1933 g.register(c.Set, lambda arg: "set") 1934 self.assertEqual(g(o), "sized") 1935 c.Iterable.register(O) 1936 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__ 1937 c.Container.register(O) 1938 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__ 1939 c.Set.register(O) 1940 self.assertEqual(g(o), "set") # because c.Set is a subclass of 1941 # c.Sized and c.Container 1942 class P: 1943 pass 1944 p = P() 1945 self.assertEqual(g(p), "base") 1946 c.Iterable.register(P) 1947 self.assertEqual(g(p), "iterable") 1948 c.Container.register(P) 1949 with self.assertRaises(RuntimeError) as re_one: 1950 g(p) 1951 self.assertIn( 1952 str(re_one.exception), 1953 (("Ambiguous dispatch: <class 'collections.abc.Container'> " 1954 "or <class 'collections.abc.Iterable'>"), 1955 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> " 1956 "or <class 'collections.abc.Container'>")), 1957 ) 1958 class Q(c.Sized): 1959 def __len__(self): 1960 return 0 1961 q = Q() 1962 self.assertEqual(g(q), "sized") 1963 c.Iterable.register(Q) 1964 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__ 1965 c.Set.register(Q) 1966 self.assertEqual(g(q), "set") # because c.Set is a subclass of 1967 # c.Sized and c.Iterable 1968 @functools.singledispatch 1969 def h(arg): 1970 return "base" 1971 @h.register(c.Sized) 1972 def _(arg): 1973 return "sized" 1974 @h.register(c.Container) 1975 def _(arg): 1976 return "container" 1977 # Even though Sized and Container are explicit bases of MutableMapping, 1978 # this ABC is implicitly registered on defaultdict which makes all of 1979 # MutableMapping's bases implicit as well from defaultdict's 1980 # perspective. 1981 with self.assertRaises(RuntimeError) as re_two: 1982 h(collections.defaultdict(lambda: 0)) 1983 self.assertIn( 1984 str(re_two.exception), 1985 (("Ambiguous dispatch: <class 'collections.abc.Container'> " 1986 "or <class 'collections.abc.Sized'>"), 1987 ("Ambiguous dispatch: <class 'collections.abc.Sized'> " 1988 "or <class 'collections.abc.Container'>")), 1989 ) 1990 class R(collections.defaultdict): 1991 pass 1992 c.MutableSequence.register(R) 1993 @functools.singledispatch 1994 def i(arg): 1995 return "base" 1996 @i.register(c.MutableMapping) 1997 def _(arg): 1998 return "mapping" 1999 @i.register(c.MutableSequence) 2000 def _(arg): 2001 return "sequence" 2002 r = R() 2003 self.assertEqual(i(r), "sequence") 2004 class S: 2005 pass 2006 class T(S, c.Sized): 2007 def __len__(self): 2008 return 0 2009 t = T() 2010 self.assertEqual(h(t), "sized") 2011 c.Container.register(T) 2012 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO 2013 class U: 2014 def __len__(self): 2015 return 0 2016 u = U() 2017 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred 2018 # from the existence of __len__() 2019 c.Container.register(U) 2020 # There is no preference for registered versus inferred ABCs. 2021 with self.assertRaises(RuntimeError) as re_three: 2022 h(u) 2023 self.assertIn( 2024 str(re_three.exception), 2025 (("Ambiguous dispatch: <class 'collections.abc.Container'> " 2026 "or <class 'collections.abc.Sized'>"), 2027 ("Ambiguous dispatch: <class 'collections.abc.Sized'> " 2028 "or <class 'collections.abc.Container'>")), 2029 ) 2030 class V(c.Sized, S): 2031 def __len__(self): 2032 return 0 2033 @functools.singledispatch 2034 def j(arg): 2035 return "base" 2036 @j.register(S) 2037 def _(arg): 2038 return "s" 2039 @j.register(c.Container) 2040 def _(arg): 2041 return "container" 2042 v = V() 2043 self.assertEqual(j(v), "s") 2044 c.Container.register(V) 2045 self.assertEqual(j(v), "container") # because it ends up right after 2046 # Sized in the MRO 2047 2048 def test_cache_invalidation(self): 2049 from collections import UserDict 2050 import weakref 2051 2052 class TracingDict(UserDict): 2053 def __init__(self, *args, **kwargs): 2054 super(TracingDict, self).__init__(*args, **kwargs) 2055 self.set_ops = [] 2056 self.get_ops = [] 2057 def __getitem__(self, key): 2058 result = self.data[key] 2059 self.get_ops.append(key) 2060 return result 2061 def __setitem__(self, key, value): 2062 self.set_ops.append(key) 2063 self.data[key] = value 2064 def clear(self): 2065 self.data.clear() 2066 2067 td = TracingDict() 2068 with support.swap_attr(weakref, "WeakKeyDictionary", lambda: td): 2069 c = collections.abc 2070 @functools.singledispatch 2071 def g(arg): 2072 return "base" 2073 d = {} 2074 l = [] 2075 self.assertEqual(len(td), 0) 2076 self.assertEqual(g(d), "base") 2077 self.assertEqual(len(td), 1) 2078 self.assertEqual(td.get_ops, []) 2079 self.assertEqual(td.set_ops, [dict]) 2080 self.assertEqual(td.data[dict], g.registry[object]) 2081 self.assertEqual(g(l), "base") 2082 self.assertEqual(len(td), 2) 2083 self.assertEqual(td.get_ops, []) 2084 self.assertEqual(td.set_ops, [dict, list]) 2085 self.assertEqual(td.data[dict], g.registry[object]) 2086 self.assertEqual(td.data[list], g.registry[object]) 2087 self.assertEqual(td.data[dict], td.data[list]) 2088 self.assertEqual(g(l), "base") 2089 self.assertEqual(g(d), "base") 2090 self.assertEqual(td.get_ops, [list, dict]) 2091 self.assertEqual(td.set_ops, [dict, list]) 2092 g.register(list, lambda arg: "list") 2093 self.assertEqual(td.get_ops, [list, dict]) 2094 self.assertEqual(len(td), 0) 2095 self.assertEqual(g(d), "base") 2096 self.assertEqual(len(td), 1) 2097 self.assertEqual(td.get_ops, [list, dict]) 2098 self.assertEqual(td.set_ops, [dict, list, dict]) 2099 self.assertEqual(td.data[dict], 2100 functools._find_impl(dict, g.registry)) 2101 self.assertEqual(g(l), "list") 2102 self.assertEqual(len(td), 2) 2103 self.assertEqual(td.get_ops, [list, dict]) 2104 self.assertEqual(td.set_ops, [dict, list, dict, list]) 2105 self.assertEqual(td.data[list], 2106 functools._find_impl(list, g.registry)) 2107 class X: 2108 pass 2109 c.MutableMapping.register(X) # Will not invalidate the cache, 2110 # not using ABCs yet. 2111 self.assertEqual(g(d), "base") 2112 self.assertEqual(g(l), "list") 2113 self.assertEqual(td.get_ops, [list, dict, dict, list]) 2114 self.assertEqual(td.set_ops, [dict, list, dict, list]) 2115 g.register(c.Sized, lambda arg: "sized") 2116 self.assertEqual(len(td), 0) 2117 self.assertEqual(g(d), "sized") 2118 self.assertEqual(len(td), 1) 2119 self.assertEqual(td.get_ops, [list, dict, dict, list]) 2120 self.assertEqual(td.set_ops, [dict, list, dict, list, dict]) 2121 self.assertEqual(g(l), "list") 2122 self.assertEqual(len(td), 2) 2123 self.assertEqual(td.get_ops, [list, dict, dict, list]) 2124 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) 2125 self.assertEqual(g(l), "list") 2126 self.assertEqual(g(d), "sized") 2127 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict]) 2128 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) 2129 g.dispatch(list) 2130 g.dispatch(dict) 2131 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict, 2132 list, dict]) 2133 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) 2134 c.MutableSet.register(X) # Will invalidate the cache. 2135 self.assertEqual(len(td), 2) # Stale cache. 2136 self.assertEqual(g(l), "list") 2137 self.assertEqual(len(td), 1) 2138 g.register(c.MutableMapping, lambda arg: "mutablemapping") 2139 self.assertEqual(len(td), 0) 2140 self.assertEqual(g(d), "mutablemapping") 2141 self.assertEqual(len(td), 1) 2142 self.assertEqual(g(l), "list") 2143 self.assertEqual(len(td), 2) 2144 g.register(dict, lambda arg: "dict") 2145 self.assertEqual(g(d), "dict") 2146 self.assertEqual(g(l), "list") 2147 g._clear_cache() 2148 self.assertEqual(len(td), 0) 2149 2150 def test_annotations(self): 2151 @functools.singledispatch 2152 def i(arg): 2153 return "base" 2154 @i.register 2155 def _(arg: collections.abc.Mapping): 2156 return "mapping" 2157 @i.register 2158 def _(arg: "collections.abc.Sequence"): 2159 return "sequence" 2160 self.assertEqual(i(None), "base") 2161 self.assertEqual(i({"a": 1}), "mapping") 2162 self.assertEqual(i([1, 2, 3]), "sequence") 2163 self.assertEqual(i((1, 2, 3)), "sequence") 2164 self.assertEqual(i("str"), "sequence") 2165 2166 # Registering classes as callables doesn't work with annotations, 2167 # you need to pass the type explicitly. 2168 @i.register(str) 2169 class _: 2170 def __init__(self, arg): 2171 self.arg = arg 2172 2173 def __eq__(self, other): 2174 return self.arg == other 2175 self.assertEqual(i("str"), "str") 2176 2177 def test_invalid_registrations(self): 2178 msg_prefix = "Invalid first argument to `register()`: " 2179 msg_suffix = ( 2180 ". Use either `@register(some_class)` or plain `@register` on an " 2181 "annotated function." 2182 ) 2183 @functools.singledispatch 2184 def i(arg): 2185 return "base" 2186 with self.assertRaises(TypeError) as exc: 2187 @i.register(42) 2188 def _(arg): 2189 return "I annotated with a non-type" 2190 self.assertTrue(str(exc.exception).startswith(msg_prefix + "42")) 2191 self.assertTrue(str(exc.exception).endswith(msg_suffix)) 2192 with self.assertRaises(TypeError) as exc: 2193 @i.register 2194 def _(arg): 2195 return "I forgot to annotate" 2196 self.assertTrue(str(exc.exception).startswith(msg_prefix + 2197 "<function TestSingleDispatch.test_invalid_registrations.<locals>._" 2198 )) 2199 self.assertTrue(str(exc.exception).endswith(msg_suffix)) 2200 2201 # FIXME: The following will only work after PEP 560 is implemented. 2202 return 2203 2204 with self.assertRaises(TypeError) as exc: 2205 @i.register 2206 def _(arg: typing.Iterable[str]): 2207 # At runtime, dispatching on generics is impossible. 2208 # When registering implementations with singledispatch, avoid 2209 # types from `typing`. Instead, annotate with regular types 2210 # or ABCs. 2211 return "I annotated with a generic collection" 2212 self.assertTrue(str(exc.exception).startswith(msg_prefix + 2213 "<function TestSingleDispatch.test_invalid_registrations.<locals>._" 2214 )) 2215 self.assertTrue(str(exc.exception).endswith(msg_suffix)) 2216 2217 def test_invalid_positional_argument(self): 2218 @functools.singledispatch 2219 def f(*args): 2220 pass 2221 msg = 'f requires at least 1 positional argument' 2222 with self.assertRaisesRegex(TypeError, msg): 2223 f() 2224 2225if __name__ == '__main__': 2226 unittest.main() 2227