1import abc 2import builtins 3import collections 4import collections.abc 5import copy 6from itertools import permutations 7import pickle 8from random import choice 9import sys 10from test import support 11import threading 12import time 13import typing 14import unittest 15import unittest.mock 16import os 17import weakref 18import gc 19from weakref import proxy 20import contextlib 21 22from test.support import import_helper 23from test.support import threading_helper 24from test.support.script_helper import assert_python_ok 25 26import functools 27 28py_functools = import_helper.import_fresh_module('functools', 29 blocked=['_functools']) 30c_functools = import_helper.import_fresh_module('functools') 31 32decimal = import_helper.import_fresh_module('decimal', fresh=['_decimal']) 33 34@contextlib.contextmanager 35def replaced_module(name, replacement): 36 original_module = sys.modules[name] 37 sys.modules[name] = replacement 38 try: 39 yield 40 finally: 41 sys.modules[name] = original_module 42 43def capture(*args, **kw): 44 """capture all positional and keyword arguments""" 45 return args, kw 46 47 48def signature(part): 49 """ return the signature of a partial object """ 50 return (part.func, part.args, part.keywords, part.__dict__) 51 52class MyTuple(tuple): 53 pass 54 55class BadTuple(tuple): 56 def __add__(self, other): 57 return list(self) + list(other) 58 59class MyDict(dict): 60 pass 61 62 63class TestPartial: 64 65 def test_basic_examples(self): 66 p = self.partial(capture, 1, 2, a=10, b=20) 67 self.assertTrue(callable(p)) 68 self.assertEqual(p(3, 4, b=30, c=40), 69 ((1, 2, 3, 4), dict(a=10, b=30, c=40))) 70 p = self.partial(map, lambda x: x*10) 71 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40]) 72 73 def test_attributes(self): 74 p = self.partial(capture, 1, 2, a=10, b=20) 75 # attributes should be readable 76 self.assertEqual(p.func, capture) 77 self.assertEqual(p.args, (1, 2)) 78 self.assertEqual(p.keywords, dict(a=10, b=20)) 79 80 def test_argument_checking(self): 81 self.assertRaises(TypeError, self.partial) # need at least a func arg 82 try: 83 self.partial(2)() 84 except TypeError: 85 pass 86 else: 87 self.fail('First arg not checked for callability') 88 89 def test_protection_of_callers_dict_argument(self): 90 # a caller's dictionary should not be altered by partial 91 def func(a=10, b=20): 92 return a 93 d = {'a':3} 94 p = self.partial(func, a=5) 95 self.assertEqual(p(**d), 3) 96 self.assertEqual(d, {'a':3}) 97 p(b=7) 98 self.assertEqual(d, {'a':3}) 99 100 def test_kwargs_copy(self): 101 # Issue #29532: Altering a kwarg dictionary passed to a constructor 102 # should not affect a partial object after creation 103 d = {'a': 3} 104 p = self.partial(capture, **d) 105 self.assertEqual(p(), ((), {'a': 3})) 106 d['a'] = 5 107 self.assertEqual(p(), ((), {'a': 3})) 108 109 def test_arg_combinations(self): 110 # exercise special code paths for zero args in either partial 111 # object or the caller 112 p = self.partial(capture) 113 self.assertEqual(p(), ((), {})) 114 self.assertEqual(p(1,2), ((1,2), {})) 115 p = self.partial(capture, 1, 2) 116 self.assertEqual(p(), ((1,2), {})) 117 self.assertEqual(p(3,4), ((1,2,3,4), {})) 118 119 def test_kw_combinations(self): 120 # exercise special code paths for no keyword args in 121 # either the partial object or the caller 122 p = self.partial(capture) 123 self.assertEqual(p.keywords, {}) 124 self.assertEqual(p(), ((), {})) 125 self.assertEqual(p(a=1), ((), {'a':1})) 126 p = self.partial(capture, a=1) 127 self.assertEqual(p.keywords, {'a':1}) 128 self.assertEqual(p(), ((), {'a':1})) 129 self.assertEqual(p(b=2), ((), {'a':1, 'b':2})) 130 # keyword args in the call override those in the partial object 131 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2})) 132 133 def test_positional(self): 134 # make sure positional arguments are captured correctly 135 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]: 136 p = self.partial(capture, *args) 137 expected = args + ('x',) 138 got, empty = p('x') 139 self.assertTrue(expected == got and empty == {}) 140 141 def test_keyword(self): 142 # make sure keyword arguments are captured correctly 143 for a in ['a', 0, None, 3.5]: 144 p = self.partial(capture, a=a) 145 expected = {'a':a,'x':None} 146 empty, got = p(x=None) 147 self.assertTrue(expected == got and empty == ()) 148 149 def test_no_side_effects(self): 150 # make sure there are no side effects that affect subsequent calls 151 p = self.partial(capture, 0, a=1) 152 args1, kw1 = p(1, b=2) 153 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2}) 154 args2, kw2 = p() 155 self.assertTrue(args2 == (0,) and kw2 == {'a':1}) 156 157 def test_error_propagation(self): 158 def f(x, y): 159 x / y 160 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0)) 161 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0) 162 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0) 163 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1) 164 165 def test_weakref(self): 166 f = self.partial(int, base=16) 167 p = proxy(f) 168 self.assertEqual(f.func, p.func) 169 f = None 170 support.gc_collect() # For PyPy or other GCs. 171 self.assertRaises(ReferenceError, getattr, p, 'func') 172 173 def test_with_bound_and_unbound_methods(self): 174 data = list(map(str, range(10))) 175 join = self.partial(str.join, '') 176 self.assertEqual(join(data), '0123456789') 177 join = self.partial(''.join) 178 self.assertEqual(join(data), '0123456789') 179 180 def test_nested_optimization(self): 181 partial = self.partial 182 inner = partial(signature, 'asdf') 183 nested = partial(inner, bar=True) 184 flat = partial(signature, 'asdf', bar=True) 185 self.assertEqual(signature(nested), signature(flat)) 186 187 def test_nested_partial_with_attribute(self): 188 # see issue 25137 189 partial = self.partial 190 191 def foo(bar): 192 return bar 193 194 p = partial(foo, 'first') 195 p2 = partial(p, 'second') 196 p2.new_attr = 'spam' 197 self.assertEqual(p2.new_attr, 'spam') 198 199 def test_repr(self): 200 args = (object(), object()) 201 args_repr = ', '.join(repr(a) for a in args) 202 kwargs = {'a': object(), 'b': object()} 203 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs), 204 'b={b!r}, a={a!r}'.format_map(kwargs)] 205 if self.partial in (c_functools.partial, py_functools.partial): 206 name = 'functools.partial' 207 else: 208 name = self.partial.__name__ 209 210 f = self.partial(capture) 211 self.assertEqual(f'{name}({capture!r})', repr(f)) 212 213 f = self.partial(capture, *args) 214 self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f)) 215 216 f = self.partial(capture, **kwargs) 217 self.assertIn(repr(f), 218 [f'{name}({capture!r}, {kwargs_repr})' 219 for kwargs_repr in kwargs_reprs]) 220 221 f = self.partial(capture, *args, **kwargs) 222 self.assertIn(repr(f), 223 [f'{name}({capture!r}, {args_repr}, {kwargs_repr})' 224 for kwargs_repr in kwargs_reprs]) 225 226 def test_recursive_repr(self): 227 if self.partial in (c_functools.partial, py_functools.partial): 228 name = 'functools.partial' 229 else: 230 name = self.partial.__name__ 231 232 f = self.partial(capture) 233 f.__setstate__((f, (), {}, {})) 234 try: 235 self.assertEqual(repr(f), '%s(...)' % (name,)) 236 finally: 237 f.__setstate__((capture, (), {}, {})) 238 239 f = self.partial(capture) 240 f.__setstate__((capture, (f,), {}, {})) 241 try: 242 self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,)) 243 finally: 244 f.__setstate__((capture, (), {}, {})) 245 246 f = self.partial(capture) 247 f.__setstate__((capture, (), {'a': f}, {})) 248 try: 249 self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,)) 250 finally: 251 f.__setstate__((capture, (), {}, {})) 252 253 def test_pickle(self): 254 with self.AllowPickle(): 255 f = self.partial(signature, ['asdf'], bar=[True]) 256 f.attr = [] 257 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 258 f_copy = pickle.loads(pickle.dumps(f, proto)) 259 self.assertEqual(signature(f_copy), signature(f)) 260 261 def test_copy(self): 262 f = self.partial(signature, ['asdf'], bar=[True]) 263 f.attr = [] 264 f_copy = copy.copy(f) 265 self.assertEqual(signature(f_copy), signature(f)) 266 self.assertIs(f_copy.attr, f.attr) 267 self.assertIs(f_copy.args, f.args) 268 self.assertIs(f_copy.keywords, f.keywords) 269 270 def test_deepcopy(self): 271 f = self.partial(signature, ['asdf'], bar=[True]) 272 f.attr = [] 273 f_copy = copy.deepcopy(f) 274 self.assertEqual(signature(f_copy), signature(f)) 275 self.assertIsNot(f_copy.attr, f.attr) 276 self.assertIsNot(f_copy.args, f.args) 277 self.assertIsNot(f_copy.args[0], f.args[0]) 278 self.assertIsNot(f_copy.keywords, f.keywords) 279 self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar']) 280 281 def test_setstate(self): 282 f = self.partial(signature) 283 f.__setstate__((capture, (1,), dict(a=10), dict(attr=[]))) 284 285 self.assertEqual(signature(f), 286 (capture, (1,), dict(a=10), dict(attr=[]))) 287 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20})) 288 289 f.__setstate__((capture, (1,), dict(a=10), None)) 290 291 self.assertEqual(signature(f), (capture, (1,), dict(a=10), {})) 292 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20})) 293 294 f.__setstate__((capture, (1,), None, None)) 295 #self.assertEqual(signature(f), (capture, (1,), {}, {})) 296 self.assertEqual(f(2, b=20), ((1, 2), {'b': 20})) 297 self.assertEqual(f(2), ((1, 2), {})) 298 self.assertEqual(f(), ((1,), {})) 299 300 f.__setstate__((capture, (), {}, None)) 301 self.assertEqual(signature(f), (capture, (), {}, {})) 302 self.assertEqual(f(2, b=20), ((2,), {'b': 20})) 303 self.assertEqual(f(2), ((2,), {})) 304 self.assertEqual(f(), ((), {})) 305 306 def test_setstate_errors(self): 307 f = self.partial(signature) 308 self.assertRaises(TypeError, f.__setstate__, (capture, (), {})) 309 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None)) 310 self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None]) 311 self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None)) 312 self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None)) 313 self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None)) 314 self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None)) 315 316 def test_setstate_subclasses(self): 317 f = self.partial(signature) 318 f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None)) 319 s = signature(f) 320 self.assertEqual(s, (capture, (1,), dict(a=10), {})) 321 self.assertIs(type(s[1]), tuple) 322 self.assertIs(type(s[2]), dict) 323 r = f() 324 self.assertEqual(r, ((1,), {'a': 10})) 325 self.assertIs(type(r[0]), tuple) 326 self.assertIs(type(r[1]), dict) 327 328 f.__setstate__((capture, BadTuple((1,)), {}, None)) 329 s = signature(f) 330 self.assertEqual(s, (capture, (1,), {}, {})) 331 self.assertIs(type(s[1]), tuple) 332 r = f(2) 333 self.assertEqual(r, ((1, 2), {})) 334 self.assertIs(type(r[0]), tuple) 335 336 def test_recursive_pickle(self): 337 with self.AllowPickle(): 338 f = self.partial(capture) 339 f.__setstate__((f, (), {}, {})) 340 try: 341 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 342 with self.assertRaises(RecursionError): 343 pickle.dumps(f, proto) 344 finally: 345 f.__setstate__((capture, (), {}, {})) 346 347 f = self.partial(capture) 348 f.__setstate__((capture, (f,), {}, {})) 349 try: 350 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 351 f_copy = pickle.loads(pickle.dumps(f, proto)) 352 try: 353 self.assertIs(f_copy.args[0], f_copy) 354 finally: 355 f_copy.__setstate__((capture, (), {}, {})) 356 finally: 357 f.__setstate__((capture, (), {}, {})) 358 359 f = self.partial(capture) 360 f.__setstate__((capture, (), {'a': f}, {})) 361 try: 362 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 363 f_copy = pickle.loads(pickle.dumps(f, proto)) 364 try: 365 self.assertIs(f_copy.keywords['a'], f_copy) 366 finally: 367 f_copy.__setstate__((capture, (), {}, {})) 368 finally: 369 f.__setstate__((capture, (), {}, {})) 370 371 # Issue 6083: Reference counting bug 372 def test_setstate_refcount(self): 373 class BadSequence: 374 def __len__(self): 375 return 4 376 def __getitem__(self, key): 377 if key == 0: 378 return max 379 elif key == 1: 380 return tuple(range(1000000)) 381 elif key in (2, 3): 382 return {} 383 raise IndexError 384 385 f = self.partial(object) 386 self.assertRaises(TypeError, f.__setstate__, BadSequence()) 387 388@unittest.skipUnless(c_functools, 'requires the C _functools module') 389class TestPartialC(TestPartial, unittest.TestCase): 390 if c_functools: 391 partial = c_functools.partial 392 393 class AllowPickle: 394 def __enter__(self): 395 return self 396 def __exit__(self, type, value, tb): 397 return False 398 399 def test_attributes_unwritable(self): 400 # attributes should not be writable 401 p = self.partial(capture, 1, 2, a=10, b=20) 402 self.assertRaises(AttributeError, setattr, p, 'func', map) 403 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2)) 404 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2)) 405 406 p = self.partial(hex) 407 try: 408 del p.__dict__ 409 except TypeError: 410 pass 411 else: 412 self.fail('partial object allowed __dict__ to be deleted') 413 414 def test_manually_adding_non_string_keyword(self): 415 p = self.partial(capture) 416 # Adding a non-string/unicode keyword to partial kwargs 417 p.keywords[1234] = 'value' 418 r = repr(p) 419 self.assertIn('1234', r) 420 self.assertIn("'value'", r) 421 with self.assertRaises(TypeError): 422 p() 423 424 def test_keystr_replaces_value(self): 425 p = self.partial(capture) 426 427 class MutatesYourDict(object): 428 def __str__(self): 429 p.keywords[self] = ['sth2'] 430 return 'astr' 431 432 # Replacing the value during key formatting should keep the original 433 # value alive (at least long enough). 434 p.keywords[MutatesYourDict()] = ['sth'] 435 r = repr(p) 436 self.assertIn('astr', r) 437 self.assertIn("['sth']", r) 438 439 440class TestPartialPy(TestPartial, unittest.TestCase): 441 partial = py_functools.partial 442 443 class AllowPickle: 444 def __init__(self): 445 self._cm = replaced_module("functools", py_functools) 446 def __enter__(self): 447 return self._cm.__enter__() 448 def __exit__(self, type, value, tb): 449 return self._cm.__exit__(type, value, tb) 450 451if c_functools: 452 class CPartialSubclass(c_functools.partial): 453 pass 454 455class PyPartialSubclass(py_functools.partial): 456 pass 457 458@unittest.skipUnless(c_functools, 'requires the C _functools module') 459class TestPartialCSubclass(TestPartialC): 460 if c_functools: 461 partial = CPartialSubclass 462 463 # partial subclasses are not optimized for nested calls 464 test_nested_optimization = None 465 466class TestPartialPySubclass(TestPartialPy): 467 partial = PyPartialSubclass 468 469class TestPartialMethod(unittest.TestCase): 470 471 class A(object): 472 nothing = functools.partialmethod(capture) 473 positional = functools.partialmethod(capture, 1) 474 keywords = functools.partialmethod(capture, a=2) 475 both = functools.partialmethod(capture, 3, b=4) 476 spec_keywords = functools.partialmethod(capture, self=1, func=2) 477 478 nested = functools.partialmethod(positional, 5) 479 480 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7) 481 482 static = functools.partialmethod(staticmethod(capture), 8) 483 cls = functools.partialmethod(classmethod(capture), d=9) 484 485 a = A() 486 487 def test_arg_combinations(self): 488 self.assertEqual(self.a.nothing(), ((self.a,), {})) 489 self.assertEqual(self.a.nothing(5), ((self.a, 5), {})) 490 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6})) 491 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6})) 492 493 self.assertEqual(self.a.positional(), ((self.a, 1), {})) 494 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {})) 495 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6})) 496 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6})) 497 498 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2})) 499 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2})) 500 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6})) 501 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6})) 502 503 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4})) 504 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4})) 505 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6})) 506 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6})) 507 508 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6})) 509 510 self.assertEqual(self.a.spec_keywords(), ((self.a,), {'self': 1, 'func': 2})) 511 512 def test_nested(self): 513 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {})) 514 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {})) 515 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7})) 516 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7})) 517 518 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7})) 519 520 def test_over_partial(self): 521 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6})) 522 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6})) 523 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8})) 524 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8})) 525 526 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8})) 527 528 def test_bound_method_introspection(self): 529 obj = self.a 530 self.assertIs(obj.both.__self__, obj) 531 self.assertIs(obj.nested.__self__, obj) 532 self.assertIs(obj.over_partial.__self__, obj) 533 self.assertIs(obj.cls.__self__, self.A) 534 self.assertIs(self.A.cls.__self__, self.A) 535 536 def test_unbound_method_retrieval(self): 537 obj = self.A 538 self.assertFalse(hasattr(obj.both, "__self__")) 539 self.assertFalse(hasattr(obj.nested, "__self__")) 540 self.assertFalse(hasattr(obj.over_partial, "__self__")) 541 self.assertFalse(hasattr(obj.static, "__self__")) 542 self.assertFalse(hasattr(self.a.static, "__self__")) 543 544 def test_descriptors(self): 545 for obj in [self.A, self.a]: 546 with self.subTest(obj=obj): 547 self.assertEqual(obj.static(), ((8,), {})) 548 self.assertEqual(obj.static(5), ((8, 5), {})) 549 self.assertEqual(obj.static(d=8), ((8,), {'d': 8})) 550 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8})) 551 552 self.assertEqual(obj.cls(), ((self.A,), {'d': 9})) 553 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9})) 554 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9})) 555 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9})) 556 557 def test_overriding_keywords(self): 558 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3})) 559 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3})) 560 561 def test_invalid_args(self): 562 with self.assertRaises(TypeError): 563 class B(object): 564 method = functools.partialmethod(None, 1) 565 with self.assertRaises(TypeError): 566 class B: 567 method = functools.partialmethod() 568 with self.assertRaises(TypeError): 569 class B: 570 method = functools.partialmethod(func=capture, a=1) 571 572 def test_repr(self): 573 self.assertEqual(repr(vars(self.A)['both']), 574 'functools.partialmethod({}, 3, b=4)'.format(capture)) 575 576 def test_abstract(self): 577 class Abstract(abc.ABCMeta): 578 579 @abc.abstractmethod 580 def add(self, x, y): 581 pass 582 583 add5 = functools.partialmethod(add, 5) 584 585 self.assertTrue(Abstract.add.__isabstractmethod__) 586 self.assertTrue(Abstract.add5.__isabstractmethod__) 587 588 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]: 589 self.assertFalse(getattr(func, '__isabstractmethod__', False)) 590 591 def test_positional_only(self): 592 def f(a, b, /): 593 return a + b 594 595 p = functools.partial(f, 1) 596 self.assertEqual(p(2), f(1, 2)) 597 598 599class TestUpdateWrapper(unittest.TestCase): 600 601 def check_wrapper(self, wrapper, wrapped, 602 assigned=functools.WRAPPER_ASSIGNMENTS, 603 updated=functools.WRAPPER_UPDATES): 604 # Check attributes were assigned 605 for name in assigned: 606 self.assertIs(getattr(wrapper, name), getattr(wrapped, name)) 607 # Check attributes were updated 608 for name in updated: 609 wrapper_attr = getattr(wrapper, name) 610 wrapped_attr = getattr(wrapped, name) 611 for key in wrapped_attr: 612 if name == "__dict__" and key == "__wrapped__": 613 # __wrapped__ is overwritten by the update code 614 continue 615 self.assertIs(wrapped_attr[key], wrapper_attr[key]) 616 # Check __wrapped__ 617 self.assertIs(wrapper.__wrapped__, wrapped) 618 619 620 def _default_update(self): 621 def f(a:'This is a new annotation'): 622 """This is a test""" 623 pass 624 f.attr = 'This is also a test' 625 f.__wrapped__ = "This is a bald faced lie" 626 def wrapper(b:'This is the prior annotation'): 627 pass 628 functools.update_wrapper(wrapper, f) 629 return wrapper, f 630 631 def test_default_update(self): 632 wrapper, f = self._default_update() 633 self.check_wrapper(wrapper, f) 634 self.assertIs(wrapper.__wrapped__, f) 635 self.assertEqual(wrapper.__name__, 'f') 636 self.assertEqual(wrapper.__qualname__, f.__qualname__) 637 self.assertEqual(wrapper.attr, 'This is also a test') 638 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation') 639 self.assertNotIn('b', wrapper.__annotations__) 640 641 @unittest.skipIf(sys.flags.optimize >= 2, 642 "Docstrings are omitted with -O2 and above") 643 def test_default_update_doc(self): 644 wrapper, f = self._default_update() 645 self.assertEqual(wrapper.__doc__, 'This is a test') 646 647 def test_no_update(self): 648 def f(): 649 """This is a test""" 650 pass 651 f.attr = 'This is also a test' 652 def wrapper(): 653 pass 654 functools.update_wrapper(wrapper, f, (), ()) 655 self.check_wrapper(wrapper, f, (), ()) 656 self.assertEqual(wrapper.__name__, 'wrapper') 657 self.assertNotEqual(wrapper.__qualname__, f.__qualname__) 658 self.assertEqual(wrapper.__doc__, None) 659 self.assertEqual(wrapper.__annotations__, {}) 660 self.assertFalse(hasattr(wrapper, 'attr')) 661 662 def test_selective_update(self): 663 def f(): 664 pass 665 f.attr = 'This is a different test' 666 f.dict_attr = dict(a=1, b=2, c=3) 667 def wrapper(): 668 pass 669 wrapper.dict_attr = {} 670 assign = ('attr',) 671 update = ('dict_attr',) 672 functools.update_wrapper(wrapper, f, assign, update) 673 self.check_wrapper(wrapper, f, assign, update) 674 self.assertEqual(wrapper.__name__, 'wrapper') 675 self.assertNotEqual(wrapper.__qualname__, f.__qualname__) 676 self.assertEqual(wrapper.__doc__, None) 677 self.assertEqual(wrapper.attr, 'This is a different test') 678 self.assertEqual(wrapper.dict_attr, f.dict_attr) 679 680 def test_missing_attributes(self): 681 def f(): 682 pass 683 def wrapper(): 684 pass 685 wrapper.dict_attr = {} 686 assign = ('attr',) 687 update = ('dict_attr',) 688 # Missing attributes on wrapped object are ignored 689 functools.update_wrapper(wrapper, f, assign, update) 690 self.assertNotIn('attr', wrapper.__dict__) 691 self.assertEqual(wrapper.dict_attr, {}) 692 # Wrapper must have expected attributes for updating 693 del wrapper.dict_attr 694 with self.assertRaises(AttributeError): 695 functools.update_wrapper(wrapper, f, assign, update) 696 wrapper.dict_attr = 1 697 with self.assertRaises(AttributeError): 698 functools.update_wrapper(wrapper, f, assign, update) 699 700 @support.requires_docstrings 701 @unittest.skipIf(sys.flags.optimize >= 2, 702 "Docstrings are omitted with -O2 and above") 703 def test_builtin_update(self): 704 # Test for bug #1576241 705 def wrapper(): 706 pass 707 functools.update_wrapper(wrapper, max) 708 self.assertEqual(wrapper.__name__, 'max') 709 self.assertTrue(wrapper.__doc__.startswith('max(')) 710 self.assertEqual(wrapper.__annotations__, {}) 711 712 713class TestWraps(TestUpdateWrapper): 714 715 def _default_update(self): 716 def f(): 717 """This is a test""" 718 pass 719 f.attr = 'This is also a test' 720 f.__wrapped__ = "This is still a bald faced lie" 721 @functools.wraps(f) 722 def wrapper(): 723 pass 724 return wrapper, f 725 726 def test_default_update(self): 727 wrapper, f = self._default_update() 728 self.check_wrapper(wrapper, f) 729 self.assertEqual(wrapper.__name__, 'f') 730 self.assertEqual(wrapper.__qualname__, f.__qualname__) 731 self.assertEqual(wrapper.attr, 'This is also a test') 732 733 @unittest.skipIf(sys.flags.optimize >= 2, 734 "Docstrings are omitted with -O2 and above") 735 def test_default_update_doc(self): 736 wrapper, _ = self._default_update() 737 self.assertEqual(wrapper.__doc__, 'This is a test') 738 739 def test_no_update(self): 740 def f(): 741 """This is a test""" 742 pass 743 f.attr = 'This is also a test' 744 @functools.wraps(f, (), ()) 745 def wrapper(): 746 pass 747 self.check_wrapper(wrapper, f, (), ()) 748 self.assertEqual(wrapper.__name__, 'wrapper') 749 self.assertNotEqual(wrapper.__qualname__, f.__qualname__) 750 self.assertEqual(wrapper.__doc__, None) 751 self.assertFalse(hasattr(wrapper, 'attr')) 752 753 def test_selective_update(self): 754 def f(): 755 pass 756 f.attr = 'This is a different test' 757 f.dict_attr = dict(a=1, b=2, c=3) 758 def add_dict_attr(f): 759 f.dict_attr = {} 760 return f 761 assign = ('attr',) 762 update = ('dict_attr',) 763 @functools.wraps(f, assign, update) 764 @add_dict_attr 765 def wrapper(): 766 pass 767 self.check_wrapper(wrapper, f, assign, update) 768 self.assertEqual(wrapper.__name__, 'wrapper') 769 self.assertNotEqual(wrapper.__qualname__, f.__qualname__) 770 self.assertEqual(wrapper.__doc__, None) 771 self.assertEqual(wrapper.attr, 'This is a different test') 772 self.assertEqual(wrapper.dict_attr, f.dict_attr) 773 774 775class TestReduce: 776 def test_reduce(self): 777 class Squares: 778 def __init__(self, max): 779 self.max = max 780 self.sofar = [] 781 782 def __len__(self): 783 return len(self.sofar) 784 785 def __getitem__(self, i): 786 if not 0 <= i < self.max: raise IndexError 787 n = len(self.sofar) 788 while n <= i: 789 self.sofar.append(n*n) 790 n += 1 791 return self.sofar[i] 792 def add(x, y): 793 return x + y 794 self.assertEqual(self.reduce(add, ['a', 'b', 'c'], ''), 'abc') 795 self.assertEqual( 796 self.reduce(add, [['a', 'c'], [], ['d', 'w']], []), 797 ['a','c','d','w'] 798 ) 799 self.assertEqual(self.reduce(lambda x, y: x*y, range(2,8), 1), 5040) 800 self.assertEqual( 801 self.reduce(lambda x, y: x*y, range(2,21), 1), 802 2432902008176640000 803 ) 804 self.assertEqual(self.reduce(add, Squares(10)), 285) 805 self.assertEqual(self.reduce(add, Squares(10), 0), 285) 806 self.assertEqual(self.reduce(add, Squares(0), 0), 0) 807 self.assertRaises(TypeError, self.reduce) 808 self.assertRaises(TypeError, self.reduce, 42, 42) 809 self.assertRaises(TypeError, self.reduce, 42, 42, 42) 810 self.assertEqual(self.reduce(42, "1"), "1") # func is never called with one item 811 self.assertEqual(self.reduce(42, "", "1"), "1") # func is never called with one item 812 self.assertRaises(TypeError, self.reduce, 42, (42, 42)) 813 self.assertRaises(TypeError, self.reduce, add, []) # arg 2 must not be empty sequence with no initial value 814 self.assertRaises(TypeError, self.reduce, add, "") 815 self.assertRaises(TypeError, self.reduce, add, ()) 816 self.assertRaises(TypeError, self.reduce, add, object()) 817 818 class TestFailingIter: 819 def __iter__(self): 820 raise RuntimeError 821 self.assertRaises(RuntimeError, self.reduce, add, TestFailingIter()) 822 823 self.assertEqual(self.reduce(add, [], None), None) 824 self.assertEqual(self.reduce(add, [], 42), 42) 825 826 class BadSeq: 827 def __getitem__(self, index): 828 raise ValueError 829 self.assertRaises(ValueError, self.reduce, 42, BadSeq()) 830 831 # Test reduce()'s use of iterators. 832 def test_iterator_usage(self): 833 class SequenceClass: 834 def __init__(self, n): 835 self.n = n 836 def __getitem__(self, i): 837 if 0 <= i < self.n: 838 return i 839 else: 840 raise IndexError 841 842 from operator import add 843 self.assertEqual(self.reduce(add, SequenceClass(5)), 10) 844 self.assertEqual(self.reduce(add, SequenceClass(5), 42), 52) 845 self.assertRaises(TypeError, self.reduce, add, SequenceClass(0)) 846 self.assertEqual(self.reduce(add, SequenceClass(0), 42), 42) 847 self.assertEqual(self.reduce(add, SequenceClass(1)), 0) 848 self.assertEqual(self.reduce(add, SequenceClass(1), 42), 42) 849 850 d = {"one": 1, "two": 2, "three": 3} 851 self.assertEqual(self.reduce(add, d), "".join(d.keys())) 852 853 854@unittest.skipUnless(c_functools, 'requires the C _functools module') 855class TestReduceC(TestReduce, unittest.TestCase): 856 if c_functools: 857 reduce = c_functools.reduce 858 859 860class TestReducePy(TestReduce, unittest.TestCase): 861 reduce = staticmethod(py_functools.reduce) 862 863 864class TestCmpToKey: 865 866 def test_cmp_to_key(self): 867 def cmp1(x, y): 868 return (x > y) - (x < y) 869 key = self.cmp_to_key(cmp1) 870 self.assertEqual(key(3), key(3)) 871 self.assertGreater(key(3), key(1)) 872 self.assertGreaterEqual(key(3), key(3)) 873 874 def cmp2(x, y): 875 return int(x) - int(y) 876 key = self.cmp_to_key(cmp2) 877 self.assertEqual(key(4.0), key('4')) 878 self.assertLess(key(2), key('35')) 879 self.assertLessEqual(key(2), key('35')) 880 self.assertNotEqual(key(2), key('35')) 881 882 def test_cmp_to_key_arguments(self): 883 def cmp1(x, y): 884 return (x > y) - (x < y) 885 key = self.cmp_to_key(mycmp=cmp1) 886 self.assertEqual(key(obj=3), key(obj=3)) 887 self.assertGreater(key(obj=3), key(obj=1)) 888 with self.assertRaises((TypeError, AttributeError)): 889 key(3) > 1 # rhs is not a K object 890 with self.assertRaises((TypeError, AttributeError)): 891 1 < key(3) # lhs is not a K object 892 with self.assertRaises(TypeError): 893 key = self.cmp_to_key() # too few args 894 with self.assertRaises(TypeError): 895 key = self.cmp_to_key(cmp1, None) # too many args 896 key = self.cmp_to_key(cmp1) 897 with self.assertRaises(TypeError): 898 key() # too few args 899 with self.assertRaises(TypeError): 900 key(None, None) # too many args 901 902 def test_bad_cmp(self): 903 def cmp1(x, y): 904 raise ZeroDivisionError 905 key = self.cmp_to_key(cmp1) 906 with self.assertRaises(ZeroDivisionError): 907 key(3) > key(1) 908 909 class BadCmp: 910 def __lt__(self, other): 911 raise ZeroDivisionError 912 def cmp1(x, y): 913 return BadCmp() 914 with self.assertRaises(ZeroDivisionError): 915 key(3) > key(1) 916 917 def test_obj_field(self): 918 def cmp1(x, y): 919 return (x > y) - (x < y) 920 key = self.cmp_to_key(mycmp=cmp1) 921 self.assertEqual(key(50).obj, 50) 922 923 def test_sort_int(self): 924 def mycmp(x, y): 925 return y - x 926 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)), 927 [4, 3, 2, 1, 0]) 928 929 def test_sort_int_str(self): 930 def mycmp(x, y): 931 x, y = int(x), int(y) 932 return (x > y) - (x < y) 933 values = [5, '3', 7, 2, '0', '1', 4, '10', 1] 934 values = sorted(values, key=self.cmp_to_key(mycmp)) 935 self.assertEqual([int(value) for value in values], 936 [0, 1, 1, 2, 3, 4, 5, 7, 10]) 937 938 def test_hash(self): 939 def mycmp(x, y): 940 return y - x 941 key = self.cmp_to_key(mycmp) 942 k = key(10) 943 self.assertRaises(TypeError, hash, k) 944 self.assertNotIsInstance(k, collections.abc.Hashable) 945 946 947@unittest.skipUnless(c_functools, 'requires the C _functools module') 948class TestCmpToKeyC(TestCmpToKey, unittest.TestCase): 949 if c_functools: 950 cmp_to_key = c_functools.cmp_to_key 951 952 @support.cpython_only 953 def test_disallow_instantiation(self): 954 # Ensure that the type disallows instantiation (bpo-43916) 955 support.check_disallow_instantiation( 956 self, type(c_functools.cmp_to_key(None)) 957 ) 958 959 960class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase): 961 cmp_to_key = staticmethod(py_functools.cmp_to_key) 962 963 964class TestTotalOrdering(unittest.TestCase): 965 966 def test_total_ordering_lt(self): 967 @functools.total_ordering 968 class A: 969 def __init__(self, value): 970 self.value = value 971 def __lt__(self, other): 972 return self.value < other.value 973 def __eq__(self, other): 974 return self.value == other.value 975 self.assertTrue(A(1) < A(2)) 976 self.assertTrue(A(2) > A(1)) 977 self.assertTrue(A(1) <= A(2)) 978 self.assertTrue(A(2) >= A(1)) 979 self.assertTrue(A(2) <= A(2)) 980 self.assertTrue(A(2) >= A(2)) 981 self.assertFalse(A(1) > A(2)) 982 983 def test_total_ordering_le(self): 984 @functools.total_ordering 985 class A: 986 def __init__(self, value): 987 self.value = value 988 def __le__(self, other): 989 return self.value <= other.value 990 def __eq__(self, other): 991 return self.value == other.value 992 self.assertTrue(A(1) < A(2)) 993 self.assertTrue(A(2) > A(1)) 994 self.assertTrue(A(1) <= A(2)) 995 self.assertTrue(A(2) >= A(1)) 996 self.assertTrue(A(2) <= A(2)) 997 self.assertTrue(A(2) >= A(2)) 998 self.assertFalse(A(1) >= A(2)) 999 1000 def test_total_ordering_gt(self): 1001 @functools.total_ordering 1002 class A: 1003 def __init__(self, value): 1004 self.value = value 1005 def __gt__(self, other): 1006 return self.value > other.value 1007 def __eq__(self, other): 1008 return self.value == other.value 1009 self.assertTrue(A(1) < A(2)) 1010 self.assertTrue(A(2) > A(1)) 1011 self.assertTrue(A(1) <= A(2)) 1012 self.assertTrue(A(2) >= A(1)) 1013 self.assertTrue(A(2) <= A(2)) 1014 self.assertTrue(A(2) >= A(2)) 1015 self.assertFalse(A(2) < A(1)) 1016 1017 def test_total_ordering_ge(self): 1018 @functools.total_ordering 1019 class A: 1020 def __init__(self, value): 1021 self.value = value 1022 def __ge__(self, other): 1023 return self.value >= other.value 1024 def __eq__(self, other): 1025 return self.value == other.value 1026 self.assertTrue(A(1) < A(2)) 1027 self.assertTrue(A(2) > A(1)) 1028 self.assertTrue(A(1) <= A(2)) 1029 self.assertTrue(A(2) >= A(1)) 1030 self.assertTrue(A(2) <= A(2)) 1031 self.assertTrue(A(2) >= A(2)) 1032 self.assertFalse(A(2) <= A(1)) 1033 1034 def test_total_ordering_no_overwrite(self): 1035 # new methods should not overwrite existing 1036 @functools.total_ordering 1037 class A(int): 1038 pass 1039 self.assertTrue(A(1) < A(2)) 1040 self.assertTrue(A(2) > A(1)) 1041 self.assertTrue(A(1) <= A(2)) 1042 self.assertTrue(A(2) >= A(1)) 1043 self.assertTrue(A(2) <= A(2)) 1044 self.assertTrue(A(2) >= A(2)) 1045 1046 def test_no_operations_defined(self): 1047 with self.assertRaises(ValueError): 1048 @functools.total_ordering 1049 class A: 1050 pass 1051 1052 def test_type_error_when_not_implemented(self): 1053 # bug 10042; ensure stack overflow does not occur 1054 # when decorated types return NotImplemented 1055 @functools.total_ordering 1056 class ImplementsLessThan: 1057 def __init__(self, value): 1058 self.value = value 1059 def __eq__(self, other): 1060 if isinstance(other, ImplementsLessThan): 1061 return self.value == other.value 1062 return False 1063 def __lt__(self, other): 1064 if isinstance(other, ImplementsLessThan): 1065 return self.value < other.value 1066 return NotImplemented 1067 1068 @functools.total_ordering 1069 class ImplementsGreaterThan: 1070 def __init__(self, value): 1071 self.value = value 1072 def __eq__(self, other): 1073 if isinstance(other, ImplementsGreaterThan): 1074 return self.value == other.value 1075 return False 1076 def __gt__(self, other): 1077 if isinstance(other, ImplementsGreaterThan): 1078 return self.value > other.value 1079 return NotImplemented 1080 1081 @functools.total_ordering 1082 class ImplementsLessThanEqualTo: 1083 def __init__(self, value): 1084 self.value = value 1085 def __eq__(self, other): 1086 if isinstance(other, ImplementsLessThanEqualTo): 1087 return self.value == other.value 1088 return False 1089 def __le__(self, other): 1090 if isinstance(other, ImplementsLessThanEqualTo): 1091 return self.value <= other.value 1092 return NotImplemented 1093 1094 @functools.total_ordering 1095 class ImplementsGreaterThanEqualTo: 1096 def __init__(self, value): 1097 self.value = value 1098 def __eq__(self, other): 1099 if isinstance(other, ImplementsGreaterThanEqualTo): 1100 return self.value == other.value 1101 return False 1102 def __ge__(self, other): 1103 if isinstance(other, ImplementsGreaterThanEqualTo): 1104 return self.value >= other.value 1105 return NotImplemented 1106 1107 @functools.total_ordering 1108 class ComparatorNotImplemented: 1109 def __init__(self, value): 1110 self.value = value 1111 def __eq__(self, other): 1112 if isinstance(other, ComparatorNotImplemented): 1113 return self.value == other.value 1114 return False 1115 def __lt__(self, other): 1116 return NotImplemented 1117 1118 with self.subTest("LT < 1"), self.assertRaises(TypeError): 1119 ImplementsLessThan(-1) < 1 1120 1121 with self.subTest("LT < LE"), self.assertRaises(TypeError): 1122 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0) 1123 1124 with self.subTest("LT < GT"), self.assertRaises(TypeError): 1125 ImplementsLessThan(1) < ImplementsGreaterThan(1) 1126 1127 with self.subTest("LE <= LT"), self.assertRaises(TypeError): 1128 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2) 1129 1130 with self.subTest("LE <= GE"), self.assertRaises(TypeError): 1131 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3) 1132 1133 with self.subTest("GT > GE"), self.assertRaises(TypeError): 1134 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4) 1135 1136 with self.subTest("GT > LT"), self.assertRaises(TypeError): 1137 ImplementsGreaterThan(5) > ImplementsLessThan(5) 1138 1139 with self.subTest("GE >= GT"), self.assertRaises(TypeError): 1140 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6) 1141 1142 with self.subTest("GE >= LE"), self.assertRaises(TypeError): 1143 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7) 1144 1145 with self.subTest("GE when equal"): 1146 a = ComparatorNotImplemented(8) 1147 b = ComparatorNotImplemented(8) 1148 self.assertEqual(a, b) 1149 with self.assertRaises(TypeError): 1150 a >= b 1151 1152 with self.subTest("LE when equal"): 1153 a = ComparatorNotImplemented(9) 1154 b = ComparatorNotImplemented(9) 1155 self.assertEqual(a, b) 1156 with self.assertRaises(TypeError): 1157 a <= b 1158 1159 def test_pickle(self): 1160 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 1161 for name in '__lt__', '__gt__', '__le__', '__ge__': 1162 with self.subTest(method=name, proto=proto): 1163 method = getattr(Orderable_LT, name) 1164 method_copy = pickle.loads(pickle.dumps(method, proto)) 1165 self.assertIs(method_copy, method) 1166 1167 1168 def test_total_ordering_for_metaclasses_issue_44605(self): 1169 1170 @functools.total_ordering 1171 class SortableMeta(type): 1172 def __new__(cls, name, bases, ns): 1173 return super().__new__(cls, name, bases, ns) 1174 1175 def __lt__(self, other): 1176 if not isinstance(other, SortableMeta): 1177 pass 1178 return self.__name__ < other.__name__ 1179 1180 def __eq__(self, other): 1181 if not isinstance(other, SortableMeta): 1182 pass 1183 return self.__name__ == other.__name__ 1184 1185 class B(metaclass=SortableMeta): 1186 pass 1187 1188 class A(metaclass=SortableMeta): 1189 pass 1190 1191 self.assertTrue(A < B) 1192 self.assertFalse(A > B) 1193 1194 1195@functools.total_ordering 1196class Orderable_LT: 1197 def __init__(self, value): 1198 self.value = value 1199 def __lt__(self, other): 1200 return self.value < other.value 1201 def __eq__(self, other): 1202 return self.value == other.value 1203 1204 1205class TestCache: 1206 # This tests that the pass-through is working as designed. 1207 # The underlying functionality is tested in TestLRU. 1208 1209 def test_cache(self): 1210 @self.module.cache 1211 def fib(n): 1212 if n < 2: 1213 return n 1214 return fib(n-1) + fib(n-2) 1215 self.assertEqual([fib(n) for n in range(16)], 1216 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]) 1217 self.assertEqual(fib.cache_info(), 1218 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16)) 1219 fib.cache_clear() 1220 self.assertEqual(fib.cache_info(), 1221 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)) 1222 1223 1224class TestLRU: 1225 1226 def test_lru(self): 1227 def orig(x, y): 1228 return 3 * x + y 1229 f = self.module.lru_cache(maxsize=20)(orig) 1230 hits, misses, maxsize, currsize = f.cache_info() 1231 self.assertEqual(maxsize, 20) 1232 self.assertEqual(currsize, 0) 1233 self.assertEqual(hits, 0) 1234 self.assertEqual(misses, 0) 1235 1236 domain = range(5) 1237 for i in range(1000): 1238 x, y = choice(domain), choice(domain) 1239 actual = f(x, y) 1240 expected = orig(x, y) 1241 self.assertEqual(actual, expected) 1242 hits, misses, maxsize, currsize = f.cache_info() 1243 self.assertTrue(hits > misses) 1244 self.assertEqual(hits + misses, 1000) 1245 self.assertEqual(currsize, 20) 1246 1247 f.cache_clear() # test clearing 1248 hits, misses, maxsize, currsize = f.cache_info() 1249 self.assertEqual(hits, 0) 1250 self.assertEqual(misses, 0) 1251 self.assertEqual(currsize, 0) 1252 f(x, y) 1253 hits, misses, maxsize, currsize = f.cache_info() 1254 self.assertEqual(hits, 0) 1255 self.assertEqual(misses, 1) 1256 self.assertEqual(currsize, 1) 1257 1258 # Test bypassing the cache 1259 self.assertIs(f.__wrapped__, orig) 1260 f.__wrapped__(x, y) 1261 hits, misses, maxsize, currsize = f.cache_info() 1262 self.assertEqual(hits, 0) 1263 self.assertEqual(misses, 1) 1264 self.assertEqual(currsize, 1) 1265 1266 # test size zero (which means "never-cache") 1267 @self.module.lru_cache(0) 1268 def f(): 1269 nonlocal f_cnt 1270 f_cnt += 1 1271 return 20 1272 self.assertEqual(f.cache_info().maxsize, 0) 1273 f_cnt = 0 1274 for i in range(5): 1275 self.assertEqual(f(), 20) 1276 self.assertEqual(f_cnt, 5) 1277 hits, misses, maxsize, currsize = f.cache_info() 1278 self.assertEqual(hits, 0) 1279 self.assertEqual(misses, 5) 1280 self.assertEqual(currsize, 0) 1281 1282 # test size one 1283 @self.module.lru_cache(1) 1284 def f(): 1285 nonlocal f_cnt 1286 f_cnt += 1 1287 return 20 1288 self.assertEqual(f.cache_info().maxsize, 1) 1289 f_cnt = 0 1290 for i in range(5): 1291 self.assertEqual(f(), 20) 1292 self.assertEqual(f_cnt, 1) 1293 hits, misses, maxsize, currsize = f.cache_info() 1294 self.assertEqual(hits, 4) 1295 self.assertEqual(misses, 1) 1296 self.assertEqual(currsize, 1) 1297 1298 # test size two 1299 @self.module.lru_cache(2) 1300 def f(x): 1301 nonlocal f_cnt 1302 f_cnt += 1 1303 return x*10 1304 self.assertEqual(f.cache_info().maxsize, 2) 1305 f_cnt = 0 1306 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7: 1307 # * * * * 1308 self.assertEqual(f(x), x*10) 1309 self.assertEqual(f_cnt, 4) 1310 hits, misses, maxsize, currsize = f.cache_info() 1311 self.assertEqual(hits, 12) 1312 self.assertEqual(misses, 4) 1313 self.assertEqual(currsize, 2) 1314 1315 def test_lru_no_args(self): 1316 @self.module.lru_cache 1317 def square(x): 1318 return x ** 2 1319 1320 self.assertEqual(list(map(square, [10, 20, 10])), 1321 [100, 400, 100]) 1322 self.assertEqual(square.cache_info().hits, 1) 1323 self.assertEqual(square.cache_info().misses, 2) 1324 self.assertEqual(square.cache_info().maxsize, 128) 1325 self.assertEqual(square.cache_info().currsize, 2) 1326 1327 def test_lru_bug_35780(self): 1328 # C version of the lru_cache was not checking to see if 1329 # the user function call has already modified the cache 1330 # (this arises in recursive calls and in multi-threading). 1331 # This cause the cache to have orphan links not referenced 1332 # by the cache dictionary. 1333 1334 once = True # Modified by f(x) below 1335 1336 @self.module.lru_cache(maxsize=10) 1337 def f(x): 1338 nonlocal once 1339 rv = f'.{x}.' 1340 if x == 20 and once: 1341 once = False 1342 rv = f(x) 1343 return rv 1344 1345 # Fill the cache 1346 for x in range(15): 1347 self.assertEqual(f(x), f'.{x}.') 1348 self.assertEqual(f.cache_info().currsize, 10) 1349 1350 # Make a recursive call and make sure the cache remains full 1351 self.assertEqual(f(20), '.20.') 1352 self.assertEqual(f.cache_info().currsize, 10) 1353 1354 def test_lru_bug_36650(self): 1355 # C version of lru_cache was treating a call with an empty **kwargs 1356 # dictionary as being distinct from a call with no keywords at all. 1357 # This did not result in an incorrect answer, but it did trigger 1358 # an unexpected cache miss. 1359 1360 @self.module.lru_cache() 1361 def f(x): 1362 pass 1363 1364 f(0) 1365 f(0, **{}) 1366 self.assertEqual(f.cache_info().hits, 1) 1367 1368 def test_lru_hash_only_once(self): 1369 # To protect against weird reentrancy bugs and to improve 1370 # efficiency when faced with slow __hash__ methods, the 1371 # LRU cache guarantees that it will only call __hash__ 1372 # only once per use as an argument to the cached function. 1373 1374 @self.module.lru_cache(maxsize=1) 1375 def f(x, y): 1376 return x * 3 + y 1377 1378 # Simulate the integer 5 1379 mock_int = unittest.mock.Mock() 1380 mock_int.__mul__ = unittest.mock.Mock(return_value=15) 1381 mock_int.__hash__ = unittest.mock.Mock(return_value=999) 1382 1383 # Add to cache: One use as an argument gives one call 1384 self.assertEqual(f(mock_int, 1), 16) 1385 self.assertEqual(mock_int.__hash__.call_count, 1) 1386 self.assertEqual(f.cache_info(), (0, 1, 1, 1)) 1387 1388 # Cache hit: One use as an argument gives one additional call 1389 self.assertEqual(f(mock_int, 1), 16) 1390 self.assertEqual(mock_int.__hash__.call_count, 2) 1391 self.assertEqual(f.cache_info(), (1, 1, 1, 1)) 1392 1393 # Cache eviction: No use as an argument gives no additional call 1394 self.assertEqual(f(6, 2), 20) 1395 self.assertEqual(mock_int.__hash__.call_count, 2) 1396 self.assertEqual(f.cache_info(), (1, 2, 1, 1)) 1397 1398 # Cache miss: One use as an argument gives one additional call 1399 self.assertEqual(f(mock_int, 1), 16) 1400 self.assertEqual(mock_int.__hash__.call_count, 3) 1401 self.assertEqual(f.cache_info(), (1, 3, 1, 1)) 1402 1403 def test_lru_reentrancy_with_len(self): 1404 # Test to make sure the LRU cache code isn't thrown-off by 1405 # caching the built-in len() function. Since len() can be 1406 # cached, we shouldn't use it inside the lru code itself. 1407 old_len = builtins.len 1408 try: 1409 builtins.len = self.module.lru_cache(4)(len) 1410 for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]: 1411 self.assertEqual(len('abcdefghijklmn'[:i]), i) 1412 finally: 1413 builtins.len = old_len 1414 1415 def test_lru_star_arg_handling(self): 1416 # Test regression that arose in ea064ff3c10f 1417 @functools.lru_cache() 1418 def f(*args): 1419 return args 1420 1421 self.assertEqual(f(1, 2), (1, 2)) 1422 self.assertEqual(f((1, 2)), ((1, 2),)) 1423 1424 def test_lru_type_error(self): 1425 # Regression test for issue #28653. 1426 # lru_cache was leaking when one of the arguments 1427 # wasn't cacheable. 1428 1429 @functools.lru_cache(maxsize=None) 1430 def infinite_cache(o): 1431 pass 1432 1433 @functools.lru_cache(maxsize=10) 1434 def limited_cache(o): 1435 pass 1436 1437 with self.assertRaises(TypeError): 1438 infinite_cache([]) 1439 1440 with self.assertRaises(TypeError): 1441 limited_cache([]) 1442 1443 def test_lru_with_maxsize_none(self): 1444 @self.module.lru_cache(maxsize=None) 1445 def fib(n): 1446 if n < 2: 1447 return n 1448 return fib(n-1) + fib(n-2) 1449 self.assertEqual([fib(n) for n in range(16)], 1450 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]) 1451 self.assertEqual(fib.cache_info(), 1452 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16)) 1453 fib.cache_clear() 1454 self.assertEqual(fib.cache_info(), 1455 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)) 1456 1457 def test_lru_with_maxsize_negative(self): 1458 @self.module.lru_cache(maxsize=-10) 1459 def eq(n): 1460 return n 1461 for i in (0, 1): 1462 self.assertEqual([eq(n) for n in range(150)], list(range(150))) 1463 self.assertEqual(eq.cache_info(), 1464 self.module._CacheInfo(hits=0, misses=300, maxsize=0, currsize=0)) 1465 1466 def test_lru_with_exceptions(self): 1467 # Verify that user_function exceptions get passed through without 1468 # creating a hard-to-read chained exception. 1469 # http://bugs.python.org/issue13177 1470 for maxsize in (None, 128): 1471 @self.module.lru_cache(maxsize) 1472 def func(i): 1473 return 'abc'[i] 1474 self.assertEqual(func(0), 'a') 1475 with self.assertRaises(IndexError) as cm: 1476 func(15) 1477 self.assertIsNone(cm.exception.__context__) 1478 # Verify that the previous exception did not result in a cached entry 1479 with self.assertRaises(IndexError): 1480 func(15) 1481 1482 def test_lru_with_types(self): 1483 for maxsize in (None, 128): 1484 @self.module.lru_cache(maxsize=maxsize, typed=True) 1485 def square(x): 1486 return x * x 1487 self.assertEqual(square(3), 9) 1488 self.assertEqual(type(square(3)), type(9)) 1489 self.assertEqual(square(3.0), 9.0) 1490 self.assertEqual(type(square(3.0)), type(9.0)) 1491 self.assertEqual(square(x=3), 9) 1492 self.assertEqual(type(square(x=3)), type(9)) 1493 self.assertEqual(square(x=3.0), 9.0) 1494 self.assertEqual(type(square(x=3.0)), type(9.0)) 1495 self.assertEqual(square.cache_info().hits, 4) 1496 self.assertEqual(square.cache_info().misses, 4) 1497 1498 def test_lru_with_keyword_args(self): 1499 @self.module.lru_cache() 1500 def fib(n): 1501 if n < 2: 1502 return n 1503 return fib(n=n-1) + fib(n=n-2) 1504 self.assertEqual( 1505 [fib(n=number) for number in range(16)], 1506 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610] 1507 ) 1508 self.assertEqual(fib.cache_info(), 1509 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16)) 1510 fib.cache_clear() 1511 self.assertEqual(fib.cache_info(), 1512 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0)) 1513 1514 def test_lru_with_keyword_args_maxsize_none(self): 1515 @self.module.lru_cache(maxsize=None) 1516 def fib(n): 1517 if n < 2: 1518 return n 1519 return fib(n=n-1) + fib(n=n-2) 1520 self.assertEqual([fib(n=number) for number in range(16)], 1521 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]) 1522 self.assertEqual(fib.cache_info(), 1523 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16)) 1524 fib.cache_clear() 1525 self.assertEqual(fib.cache_info(), 1526 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)) 1527 1528 def test_kwargs_order(self): 1529 # PEP 468: Preserving Keyword Argument Order 1530 @self.module.lru_cache(maxsize=10) 1531 def f(**kwargs): 1532 return list(kwargs.items()) 1533 self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)]) 1534 self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)]) 1535 self.assertEqual(f.cache_info(), 1536 self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2)) 1537 1538 def test_lru_cache_decoration(self): 1539 def f(zomg: 'zomg_annotation'): 1540 """f doc string""" 1541 return 42 1542 g = self.module.lru_cache()(f) 1543 for attr in self.module.WRAPPER_ASSIGNMENTS: 1544 self.assertEqual(getattr(g, attr), getattr(f, attr)) 1545 1546 def test_lru_cache_threaded(self): 1547 n, m = 5, 11 1548 def orig(x, y): 1549 return 3 * x + y 1550 f = self.module.lru_cache(maxsize=n*m)(orig) 1551 hits, misses, maxsize, currsize = f.cache_info() 1552 self.assertEqual(currsize, 0) 1553 1554 start = threading.Event() 1555 def full(k): 1556 start.wait(10) 1557 for _ in range(m): 1558 self.assertEqual(f(k, 0), orig(k, 0)) 1559 1560 def clear(): 1561 start.wait(10) 1562 for _ in range(2*m): 1563 f.cache_clear() 1564 1565 orig_si = sys.getswitchinterval() 1566 support.setswitchinterval(1e-6) 1567 try: 1568 # create n threads in order to fill cache 1569 threads = [threading.Thread(target=full, args=[k]) 1570 for k in range(n)] 1571 with threading_helper.start_threads(threads): 1572 start.set() 1573 1574 hits, misses, maxsize, currsize = f.cache_info() 1575 if self.module is py_functools: 1576 # XXX: Why can be not equal? 1577 self.assertLessEqual(misses, n) 1578 self.assertLessEqual(hits, m*n - misses) 1579 else: 1580 self.assertEqual(misses, n) 1581 self.assertEqual(hits, m*n - misses) 1582 self.assertEqual(currsize, n) 1583 1584 # create n threads in order to fill cache and 1 to clear it 1585 threads = [threading.Thread(target=clear)] 1586 threads += [threading.Thread(target=full, args=[k]) 1587 for k in range(n)] 1588 start.clear() 1589 with threading_helper.start_threads(threads): 1590 start.set() 1591 finally: 1592 sys.setswitchinterval(orig_si) 1593 1594 def test_lru_cache_threaded2(self): 1595 # Simultaneous call with the same arguments 1596 n, m = 5, 7 1597 start = threading.Barrier(n+1) 1598 pause = threading.Barrier(n+1) 1599 stop = threading.Barrier(n+1) 1600 @self.module.lru_cache(maxsize=m*n) 1601 def f(x): 1602 pause.wait(10) 1603 return 3 * x 1604 self.assertEqual(f.cache_info(), (0, 0, m*n, 0)) 1605 def test(): 1606 for i in range(m): 1607 start.wait(10) 1608 self.assertEqual(f(i), 3 * i) 1609 stop.wait(10) 1610 threads = [threading.Thread(target=test) for k in range(n)] 1611 with threading_helper.start_threads(threads): 1612 for i in range(m): 1613 start.wait(10) 1614 stop.reset() 1615 pause.wait(10) 1616 start.reset() 1617 stop.wait(10) 1618 pause.reset() 1619 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1)) 1620 1621 def test_lru_cache_threaded3(self): 1622 @self.module.lru_cache(maxsize=2) 1623 def f(x): 1624 time.sleep(.01) 1625 return 3 * x 1626 def test(i, x): 1627 with self.subTest(thread=i): 1628 self.assertEqual(f(x), 3 * x, i) 1629 threads = [threading.Thread(target=test, args=(i, v)) 1630 for i, v in enumerate([1, 2, 2, 3, 2])] 1631 with threading_helper.start_threads(threads): 1632 pass 1633 1634 def test_need_for_rlock(self): 1635 # This will deadlock on an LRU cache that uses a regular lock 1636 1637 @self.module.lru_cache(maxsize=10) 1638 def test_func(x): 1639 'Used to demonstrate a reentrant lru_cache call within a single thread' 1640 return x 1641 1642 class DoubleEq: 1643 'Demonstrate a reentrant lru_cache call within a single thread' 1644 def __init__(self, x): 1645 self.x = x 1646 def __hash__(self): 1647 return self.x 1648 def __eq__(self, other): 1649 if self.x == 2: 1650 test_func(DoubleEq(1)) 1651 return self.x == other.x 1652 1653 test_func(DoubleEq(1)) # Load the cache 1654 test_func(DoubleEq(2)) # Load the cache 1655 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call 1656 DoubleEq(2)) # Verify the correct return value 1657 1658 def test_lru_method(self): 1659 class X(int): 1660 f_cnt = 0 1661 @self.module.lru_cache(2) 1662 def f(self, x): 1663 self.f_cnt += 1 1664 return x*10+self 1665 a = X(5) 1666 b = X(5) 1667 c = X(7) 1668 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0)) 1669 1670 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3: 1671 self.assertEqual(a.f(x), x*10 + 5) 1672 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0)) 1673 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2)) 1674 1675 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2: 1676 self.assertEqual(b.f(x), x*10 + 5) 1677 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0)) 1678 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2)) 1679 1680 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1: 1681 self.assertEqual(c.f(x), x*10 + 7) 1682 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5)) 1683 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2)) 1684 1685 self.assertEqual(a.f.cache_info(), X.f.cache_info()) 1686 self.assertEqual(b.f.cache_info(), X.f.cache_info()) 1687 self.assertEqual(c.f.cache_info(), X.f.cache_info()) 1688 1689 def test_pickle(self): 1690 cls = self.__class__ 1691 for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth: 1692 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 1693 with self.subTest(proto=proto, func=f): 1694 f_copy = pickle.loads(pickle.dumps(f, proto)) 1695 self.assertIs(f_copy, f) 1696 1697 def test_copy(self): 1698 cls = self.__class__ 1699 def orig(x, y): 1700 return 3 * x + y 1701 part = self.module.partial(orig, 2) 1702 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth, 1703 self.module.lru_cache(2)(part)) 1704 for f in funcs: 1705 with self.subTest(func=f): 1706 f_copy = copy.copy(f) 1707 self.assertIs(f_copy, f) 1708 1709 def test_deepcopy(self): 1710 cls = self.__class__ 1711 def orig(x, y): 1712 return 3 * x + y 1713 part = self.module.partial(orig, 2) 1714 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth, 1715 self.module.lru_cache(2)(part)) 1716 for f in funcs: 1717 with self.subTest(func=f): 1718 f_copy = copy.deepcopy(f) 1719 self.assertIs(f_copy, f) 1720 1721 def test_lru_cache_parameters(self): 1722 @self.module.lru_cache(maxsize=2) 1723 def f(): 1724 return 1 1725 self.assertEqual(f.cache_parameters(), {'maxsize': 2, "typed": False}) 1726 1727 @self.module.lru_cache(maxsize=1000, typed=True) 1728 def f(): 1729 return 1 1730 self.assertEqual(f.cache_parameters(), {'maxsize': 1000, "typed": True}) 1731 1732 def test_lru_cache_weakrefable(self): 1733 @self.module.lru_cache 1734 def test_function(x): 1735 return x 1736 1737 class A: 1738 @self.module.lru_cache 1739 def test_method(self, x): 1740 return (self, x) 1741 1742 @staticmethod 1743 @self.module.lru_cache 1744 def test_staticmethod(x): 1745 return (self, x) 1746 1747 refs = [weakref.ref(test_function), 1748 weakref.ref(A.test_method), 1749 weakref.ref(A.test_staticmethod)] 1750 1751 for ref in refs: 1752 self.assertIsNotNone(ref()) 1753 1754 del A 1755 del test_function 1756 gc.collect() 1757 1758 for ref in refs: 1759 self.assertIsNone(ref()) 1760 1761 1762@py_functools.lru_cache() 1763def py_cached_func(x, y): 1764 return 3 * x + y 1765 1766@c_functools.lru_cache() 1767def c_cached_func(x, y): 1768 return 3 * x + y 1769 1770 1771class TestLRUPy(TestLRU, unittest.TestCase): 1772 module = py_functools 1773 cached_func = py_cached_func, 1774 1775 @module.lru_cache() 1776 def cached_meth(self, x, y): 1777 return 3 * x + y 1778 1779 @staticmethod 1780 @module.lru_cache() 1781 def cached_staticmeth(x, y): 1782 return 3 * x + y 1783 1784 1785class TestLRUC(TestLRU, unittest.TestCase): 1786 module = c_functools 1787 cached_func = c_cached_func, 1788 1789 @module.lru_cache() 1790 def cached_meth(self, x, y): 1791 return 3 * x + y 1792 1793 @staticmethod 1794 @module.lru_cache() 1795 def cached_staticmeth(x, y): 1796 return 3 * x + y 1797 1798 1799class TestSingleDispatch(unittest.TestCase): 1800 def test_simple_overloads(self): 1801 @functools.singledispatch 1802 def g(obj): 1803 return "base" 1804 def g_int(i): 1805 return "integer" 1806 g.register(int, g_int) 1807 self.assertEqual(g("str"), "base") 1808 self.assertEqual(g(1), "integer") 1809 self.assertEqual(g([1,2,3]), "base") 1810 1811 def test_mro(self): 1812 @functools.singledispatch 1813 def g(obj): 1814 return "base" 1815 class A: 1816 pass 1817 class C(A): 1818 pass 1819 class B(A): 1820 pass 1821 class D(C, B): 1822 pass 1823 def g_A(a): 1824 return "A" 1825 def g_B(b): 1826 return "B" 1827 g.register(A, g_A) 1828 g.register(B, g_B) 1829 self.assertEqual(g(A()), "A") 1830 self.assertEqual(g(B()), "B") 1831 self.assertEqual(g(C()), "A") 1832 self.assertEqual(g(D()), "B") 1833 1834 def test_register_decorator(self): 1835 @functools.singledispatch 1836 def g(obj): 1837 return "base" 1838 @g.register(int) 1839 def g_int(i): 1840 return "int %s" % (i,) 1841 self.assertEqual(g(""), "base") 1842 self.assertEqual(g(12), "int 12") 1843 self.assertIs(g.dispatch(int), g_int) 1844 self.assertIs(g.dispatch(object), g.dispatch(str)) 1845 # Note: in the assert above this is not g. 1846 # @singledispatch returns the wrapper. 1847 1848 def test_wrapping_attributes(self): 1849 @functools.singledispatch 1850 def g(obj): 1851 "Simple test" 1852 return "Test" 1853 self.assertEqual(g.__name__, "g") 1854 if sys.flags.optimize < 2: 1855 self.assertEqual(g.__doc__, "Simple test") 1856 1857 @unittest.skipUnless(decimal, 'requires _decimal') 1858 @support.cpython_only 1859 def test_c_classes(self): 1860 @functools.singledispatch 1861 def g(obj): 1862 return "base" 1863 @g.register(decimal.DecimalException) 1864 def _(obj): 1865 return obj.args 1866 subn = decimal.Subnormal("Exponent < Emin") 1867 rnd = decimal.Rounded("Number got rounded") 1868 self.assertEqual(g(subn), ("Exponent < Emin",)) 1869 self.assertEqual(g(rnd), ("Number got rounded",)) 1870 @g.register(decimal.Subnormal) 1871 def _(obj): 1872 return "Too small to care." 1873 self.assertEqual(g(subn), "Too small to care.") 1874 self.assertEqual(g(rnd), ("Number got rounded",)) 1875 1876 def test_compose_mro(self): 1877 # None of the examples in this test depend on haystack ordering. 1878 c = collections.abc 1879 mro = functools._compose_mro 1880 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set] 1881 for haystack in permutations(bases): 1882 m = mro(dict, haystack) 1883 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, 1884 c.Collection, c.Sized, c.Iterable, 1885 c.Container, object]) 1886 bases = [c.Container, c.Mapping, c.MutableMapping, collections.OrderedDict] 1887 for haystack in permutations(bases): 1888 m = mro(collections.ChainMap, haystack) 1889 self.assertEqual(m, [collections.ChainMap, c.MutableMapping, c.Mapping, 1890 c.Collection, c.Sized, c.Iterable, 1891 c.Container, object]) 1892 1893 # If there's a generic function with implementations registered for 1894 # both Sized and Container, passing a defaultdict to it results in an 1895 # ambiguous dispatch which will cause a RuntimeError (see 1896 # test_mro_conflicts). 1897 bases = [c.Container, c.Sized, str] 1898 for haystack in permutations(bases): 1899 m = mro(collections.defaultdict, [c.Sized, c.Container, str]) 1900 self.assertEqual(m, [collections.defaultdict, dict, c.Sized, 1901 c.Container, object]) 1902 1903 # MutableSequence below is registered directly on D. In other words, it 1904 # precedes MutableMapping which means single dispatch will always 1905 # choose MutableSequence here. 1906 class D(collections.defaultdict): 1907 pass 1908 c.MutableSequence.register(D) 1909 bases = [c.MutableSequence, c.MutableMapping] 1910 for haystack in permutations(bases): 1911 m = mro(D, bases) 1912 self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible, 1913 collections.defaultdict, dict, c.MutableMapping, c.Mapping, 1914 c.Collection, c.Sized, c.Iterable, c.Container, 1915 object]) 1916 1917 # Container and Callable are registered on different base classes and 1918 # a generic function supporting both should always pick the Callable 1919 # implementation if a C instance is passed. 1920 class C(collections.defaultdict): 1921 def __call__(self): 1922 pass 1923 bases = [c.Sized, c.Callable, c.Container, c.Mapping] 1924 for haystack in permutations(bases): 1925 m = mro(C, haystack) 1926 self.assertEqual(m, [C, c.Callable, collections.defaultdict, dict, c.Mapping, 1927 c.Collection, c.Sized, c.Iterable, 1928 c.Container, object]) 1929 1930 def test_register_abc(self): 1931 c = collections.abc 1932 d = {"a": "b"} 1933 l = [1, 2, 3] 1934 s = {object(), None} 1935 f = frozenset(s) 1936 t = (1, 2, 3) 1937 @functools.singledispatch 1938 def g(obj): 1939 return "base" 1940 self.assertEqual(g(d), "base") 1941 self.assertEqual(g(l), "base") 1942 self.assertEqual(g(s), "base") 1943 self.assertEqual(g(f), "base") 1944 self.assertEqual(g(t), "base") 1945 g.register(c.Sized, lambda obj: "sized") 1946 self.assertEqual(g(d), "sized") 1947 self.assertEqual(g(l), "sized") 1948 self.assertEqual(g(s), "sized") 1949 self.assertEqual(g(f), "sized") 1950 self.assertEqual(g(t), "sized") 1951 g.register(c.MutableMapping, lambda obj: "mutablemapping") 1952 self.assertEqual(g(d), "mutablemapping") 1953 self.assertEqual(g(l), "sized") 1954 self.assertEqual(g(s), "sized") 1955 self.assertEqual(g(f), "sized") 1956 self.assertEqual(g(t), "sized") 1957 g.register(collections.ChainMap, lambda obj: "chainmap") 1958 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered 1959 self.assertEqual(g(l), "sized") 1960 self.assertEqual(g(s), "sized") 1961 self.assertEqual(g(f), "sized") 1962 self.assertEqual(g(t), "sized") 1963 g.register(c.MutableSequence, lambda obj: "mutablesequence") 1964 self.assertEqual(g(d), "mutablemapping") 1965 self.assertEqual(g(l), "mutablesequence") 1966 self.assertEqual(g(s), "sized") 1967 self.assertEqual(g(f), "sized") 1968 self.assertEqual(g(t), "sized") 1969 g.register(c.MutableSet, lambda obj: "mutableset") 1970 self.assertEqual(g(d), "mutablemapping") 1971 self.assertEqual(g(l), "mutablesequence") 1972 self.assertEqual(g(s), "mutableset") 1973 self.assertEqual(g(f), "sized") 1974 self.assertEqual(g(t), "sized") 1975 g.register(c.Mapping, lambda obj: "mapping") 1976 self.assertEqual(g(d), "mutablemapping") # not specific enough 1977 self.assertEqual(g(l), "mutablesequence") 1978 self.assertEqual(g(s), "mutableset") 1979 self.assertEqual(g(f), "sized") 1980 self.assertEqual(g(t), "sized") 1981 g.register(c.Sequence, lambda obj: "sequence") 1982 self.assertEqual(g(d), "mutablemapping") 1983 self.assertEqual(g(l), "mutablesequence") 1984 self.assertEqual(g(s), "mutableset") 1985 self.assertEqual(g(f), "sized") 1986 self.assertEqual(g(t), "sequence") 1987 g.register(c.Set, lambda obj: "set") 1988 self.assertEqual(g(d), "mutablemapping") 1989 self.assertEqual(g(l), "mutablesequence") 1990 self.assertEqual(g(s), "mutableset") 1991 self.assertEqual(g(f), "set") 1992 self.assertEqual(g(t), "sequence") 1993 g.register(dict, lambda obj: "dict") 1994 self.assertEqual(g(d), "dict") 1995 self.assertEqual(g(l), "mutablesequence") 1996 self.assertEqual(g(s), "mutableset") 1997 self.assertEqual(g(f), "set") 1998 self.assertEqual(g(t), "sequence") 1999 g.register(list, lambda obj: "list") 2000 self.assertEqual(g(d), "dict") 2001 self.assertEqual(g(l), "list") 2002 self.assertEqual(g(s), "mutableset") 2003 self.assertEqual(g(f), "set") 2004 self.assertEqual(g(t), "sequence") 2005 g.register(set, lambda obj: "concrete-set") 2006 self.assertEqual(g(d), "dict") 2007 self.assertEqual(g(l), "list") 2008 self.assertEqual(g(s), "concrete-set") 2009 self.assertEqual(g(f), "set") 2010 self.assertEqual(g(t), "sequence") 2011 g.register(frozenset, lambda obj: "frozen-set") 2012 self.assertEqual(g(d), "dict") 2013 self.assertEqual(g(l), "list") 2014 self.assertEqual(g(s), "concrete-set") 2015 self.assertEqual(g(f), "frozen-set") 2016 self.assertEqual(g(t), "sequence") 2017 g.register(tuple, lambda obj: "tuple") 2018 self.assertEqual(g(d), "dict") 2019 self.assertEqual(g(l), "list") 2020 self.assertEqual(g(s), "concrete-set") 2021 self.assertEqual(g(f), "frozen-set") 2022 self.assertEqual(g(t), "tuple") 2023 2024 def test_c3_abc(self): 2025 c = collections.abc 2026 mro = functools._c3_mro 2027 class A(object): 2028 pass 2029 class B(A): 2030 def __len__(self): 2031 return 0 # implies Sized 2032 @c.Container.register 2033 class C(object): 2034 pass 2035 class D(object): 2036 pass # unrelated 2037 class X(D, C, B): 2038 def __call__(self): 2039 pass # implies Callable 2040 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object] 2041 for abcs in permutations([c.Sized, c.Callable, c.Container]): 2042 self.assertEqual(mro(X, abcs=abcs), expected) 2043 # unrelated ABCs don't appear in the resulting MRO 2044 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable] 2045 self.assertEqual(mro(X, abcs=many_abcs), expected) 2046 2047 def test_false_meta(self): 2048 # see issue23572 2049 class MetaA(type): 2050 def __len__(self): 2051 return 0 2052 class A(metaclass=MetaA): 2053 pass 2054 class AA(A): 2055 pass 2056 @functools.singledispatch 2057 def fun(a): 2058 return 'base A' 2059 @fun.register(A) 2060 def _(a): 2061 return 'fun A' 2062 aa = AA() 2063 self.assertEqual(fun(aa), 'fun A') 2064 2065 def test_mro_conflicts(self): 2066 c = collections.abc 2067 @functools.singledispatch 2068 def g(arg): 2069 return "base" 2070 class O(c.Sized): 2071 def __len__(self): 2072 return 0 2073 o = O() 2074 self.assertEqual(g(o), "base") 2075 g.register(c.Iterable, lambda arg: "iterable") 2076 g.register(c.Container, lambda arg: "container") 2077 g.register(c.Sized, lambda arg: "sized") 2078 g.register(c.Set, lambda arg: "set") 2079 self.assertEqual(g(o), "sized") 2080 c.Iterable.register(O) 2081 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__ 2082 c.Container.register(O) 2083 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__ 2084 c.Set.register(O) 2085 self.assertEqual(g(o), "set") # because c.Set is a subclass of 2086 # c.Sized and c.Container 2087 class P: 2088 pass 2089 p = P() 2090 self.assertEqual(g(p), "base") 2091 c.Iterable.register(P) 2092 self.assertEqual(g(p), "iterable") 2093 c.Container.register(P) 2094 with self.assertRaises(RuntimeError) as re_one: 2095 g(p) 2096 self.assertIn( 2097 str(re_one.exception), 2098 (("Ambiguous dispatch: <class 'collections.abc.Container'> " 2099 "or <class 'collections.abc.Iterable'>"), 2100 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> " 2101 "or <class 'collections.abc.Container'>")), 2102 ) 2103 class Q(c.Sized): 2104 def __len__(self): 2105 return 0 2106 q = Q() 2107 self.assertEqual(g(q), "sized") 2108 c.Iterable.register(Q) 2109 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__ 2110 c.Set.register(Q) 2111 self.assertEqual(g(q), "set") # because c.Set is a subclass of 2112 # c.Sized and c.Iterable 2113 @functools.singledispatch 2114 def h(arg): 2115 return "base" 2116 @h.register(c.Sized) 2117 def _(arg): 2118 return "sized" 2119 @h.register(c.Container) 2120 def _(arg): 2121 return "container" 2122 # Even though Sized and Container are explicit bases of MutableMapping, 2123 # this ABC is implicitly registered on defaultdict which makes all of 2124 # MutableMapping's bases implicit as well from defaultdict's 2125 # perspective. 2126 with self.assertRaises(RuntimeError) as re_two: 2127 h(collections.defaultdict(lambda: 0)) 2128 self.assertIn( 2129 str(re_two.exception), 2130 (("Ambiguous dispatch: <class 'collections.abc.Container'> " 2131 "or <class 'collections.abc.Sized'>"), 2132 ("Ambiguous dispatch: <class 'collections.abc.Sized'> " 2133 "or <class 'collections.abc.Container'>")), 2134 ) 2135 class R(collections.defaultdict): 2136 pass 2137 c.MutableSequence.register(R) 2138 @functools.singledispatch 2139 def i(arg): 2140 return "base" 2141 @i.register(c.MutableMapping) 2142 def _(arg): 2143 return "mapping" 2144 @i.register(c.MutableSequence) 2145 def _(arg): 2146 return "sequence" 2147 r = R() 2148 self.assertEqual(i(r), "sequence") 2149 class S: 2150 pass 2151 class T(S, c.Sized): 2152 def __len__(self): 2153 return 0 2154 t = T() 2155 self.assertEqual(h(t), "sized") 2156 c.Container.register(T) 2157 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO 2158 class U: 2159 def __len__(self): 2160 return 0 2161 u = U() 2162 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred 2163 # from the existence of __len__() 2164 c.Container.register(U) 2165 # There is no preference for registered versus inferred ABCs. 2166 with self.assertRaises(RuntimeError) as re_three: 2167 h(u) 2168 self.assertIn( 2169 str(re_three.exception), 2170 (("Ambiguous dispatch: <class 'collections.abc.Container'> " 2171 "or <class 'collections.abc.Sized'>"), 2172 ("Ambiguous dispatch: <class 'collections.abc.Sized'> " 2173 "or <class 'collections.abc.Container'>")), 2174 ) 2175 class V(c.Sized, S): 2176 def __len__(self): 2177 return 0 2178 @functools.singledispatch 2179 def j(arg): 2180 return "base" 2181 @j.register(S) 2182 def _(arg): 2183 return "s" 2184 @j.register(c.Container) 2185 def _(arg): 2186 return "container" 2187 v = V() 2188 self.assertEqual(j(v), "s") 2189 c.Container.register(V) 2190 self.assertEqual(j(v), "container") # because it ends up right after 2191 # Sized in the MRO 2192 2193 def test_cache_invalidation(self): 2194 from collections import UserDict 2195 import weakref 2196 2197 class TracingDict(UserDict): 2198 def __init__(self, *args, **kwargs): 2199 super(TracingDict, self).__init__(*args, **kwargs) 2200 self.set_ops = [] 2201 self.get_ops = [] 2202 def __getitem__(self, key): 2203 result = self.data[key] 2204 self.get_ops.append(key) 2205 return result 2206 def __setitem__(self, key, value): 2207 self.set_ops.append(key) 2208 self.data[key] = value 2209 def clear(self): 2210 self.data.clear() 2211 2212 td = TracingDict() 2213 with support.swap_attr(weakref, "WeakKeyDictionary", lambda: td): 2214 c = collections.abc 2215 @functools.singledispatch 2216 def g(arg): 2217 return "base" 2218 d = {} 2219 l = [] 2220 self.assertEqual(len(td), 0) 2221 self.assertEqual(g(d), "base") 2222 self.assertEqual(len(td), 1) 2223 self.assertEqual(td.get_ops, []) 2224 self.assertEqual(td.set_ops, [dict]) 2225 self.assertEqual(td.data[dict], g.registry[object]) 2226 self.assertEqual(g(l), "base") 2227 self.assertEqual(len(td), 2) 2228 self.assertEqual(td.get_ops, []) 2229 self.assertEqual(td.set_ops, [dict, list]) 2230 self.assertEqual(td.data[dict], g.registry[object]) 2231 self.assertEqual(td.data[list], g.registry[object]) 2232 self.assertEqual(td.data[dict], td.data[list]) 2233 self.assertEqual(g(l), "base") 2234 self.assertEqual(g(d), "base") 2235 self.assertEqual(td.get_ops, [list, dict]) 2236 self.assertEqual(td.set_ops, [dict, list]) 2237 g.register(list, lambda arg: "list") 2238 self.assertEqual(td.get_ops, [list, dict]) 2239 self.assertEqual(len(td), 0) 2240 self.assertEqual(g(d), "base") 2241 self.assertEqual(len(td), 1) 2242 self.assertEqual(td.get_ops, [list, dict]) 2243 self.assertEqual(td.set_ops, [dict, list, dict]) 2244 self.assertEqual(td.data[dict], 2245 functools._find_impl(dict, g.registry)) 2246 self.assertEqual(g(l), "list") 2247 self.assertEqual(len(td), 2) 2248 self.assertEqual(td.get_ops, [list, dict]) 2249 self.assertEqual(td.set_ops, [dict, list, dict, list]) 2250 self.assertEqual(td.data[list], 2251 functools._find_impl(list, g.registry)) 2252 class X: 2253 pass 2254 c.MutableMapping.register(X) # Will not invalidate the cache, 2255 # not using ABCs yet. 2256 self.assertEqual(g(d), "base") 2257 self.assertEqual(g(l), "list") 2258 self.assertEqual(td.get_ops, [list, dict, dict, list]) 2259 self.assertEqual(td.set_ops, [dict, list, dict, list]) 2260 g.register(c.Sized, lambda arg: "sized") 2261 self.assertEqual(len(td), 0) 2262 self.assertEqual(g(d), "sized") 2263 self.assertEqual(len(td), 1) 2264 self.assertEqual(td.get_ops, [list, dict, dict, list]) 2265 self.assertEqual(td.set_ops, [dict, list, dict, list, dict]) 2266 self.assertEqual(g(l), "list") 2267 self.assertEqual(len(td), 2) 2268 self.assertEqual(td.get_ops, [list, dict, dict, list]) 2269 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) 2270 self.assertEqual(g(l), "list") 2271 self.assertEqual(g(d), "sized") 2272 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict]) 2273 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) 2274 g.dispatch(list) 2275 g.dispatch(dict) 2276 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict, 2277 list, dict]) 2278 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) 2279 c.MutableSet.register(X) # Will invalidate the cache. 2280 self.assertEqual(len(td), 2) # Stale cache. 2281 self.assertEqual(g(l), "list") 2282 self.assertEqual(len(td), 1) 2283 g.register(c.MutableMapping, lambda arg: "mutablemapping") 2284 self.assertEqual(len(td), 0) 2285 self.assertEqual(g(d), "mutablemapping") 2286 self.assertEqual(len(td), 1) 2287 self.assertEqual(g(l), "list") 2288 self.assertEqual(len(td), 2) 2289 g.register(dict, lambda arg: "dict") 2290 self.assertEqual(g(d), "dict") 2291 self.assertEqual(g(l), "list") 2292 g._clear_cache() 2293 self.assertEqual(len(td), 0) 2294 2295 def test_annotations(self): 2296 @functools.singledispatch 2297 def i(arg): 2298 return "base" 2299 @i.register 2300 def _(arg: collections.abc.Mapping): 2301 return "mapping" 2302 @i.register 2303 def _(arg: "collections.abc.Sequence"): 2304 return "sequence" 2305 self.assertEqual(i(None), "base") 2306 self.assertEqual(i({"a": 1}), "mapping") 2307 self.assertEqual(i([1, 2, 3]), "sequence") 2308 self.assertEqual(i((1, 2, 3)), "sequence") 2309 self.assertEqual(i("str"), "sequence") 2310 2311 # Registering classes as callables doesn't work with annotations, 2312 # you need to pass the type explicitly. 2313 @i.register(str) 2314 class _: 2315 def __init__(self, arg): 2316 self.arg = arg 2317 2318 def __eq__(self, other): 2319 return self.arg == other 2320 self.assertEqual(i("str"), "str") 2321 2322 def test_method_register(self): 2323 class A: 2324 @functools.singledispatchmethod 2325 def t(self, arg): 2326 self.arg = "base" 2327 @t.register(int) 2328 def _(self, arg): 2329 self.arg = "int" 2330 @t.register(str) 2331 def _(self, arg): 2332 self.arg = "str" 2333 a = A() 2334 2335 a.t(0) 2336 self.assertEqual(a.arg, "int") 2337 aa = A() 2338 self.assertFalse(hasattr(aa, 'arg')) 2339 a.t('') 2340 self.assertEqual(a.arg, "str") 2341 aa = A() 2342 self.assertFalse(hasattr(aa, 'arg')) 2343 a.t(0.0) 2344 self.assertEqual(a.arg, "base") 2345 aa = A() 2346 self.assertFalse(hasattr(aa, 'arg')) 2347 2348 def test_staticmethod_register(self): 2349 class A: 2350 @functools.singledispatchmethod 2351 @staticmethod 2352 def t(arg): 2353 return arg 2354 @t.register(int) 2355 @staticmethod 2356 def _(arg): 2357 return isinstance(arg, int) 2358 @t.register(str) 2359 @staticmethod 2360 def _(arg): 2361 return isinstance(arg, str) 2362 a = A() 2363 2364 self.assertTrue(A.t(0)) 2365 self.assertTrue(A.t('')) 2366 self.assertEqual(A.t(0.0), 0.0) 2367 2368 def test_classmethod_register(self): 2369 class A: 2370 def __init__(self, arg): 2371 self.arg = arg 2372 2373 @functools.singledispatchmethod 2374 @classmethod 2375 def t(cls, arg): 2376 return cls("base") 2377 @t.register(int) 2378 @classmethod 2379 def _(cls, arg): 2380 return cls("int") 2381 @t.register(str) 2382 @classmethod 2383 def _(cls, arg): 2384 return cls("str") 2385 2386 self.assertEqual(A.t(0).arg, "int") 2387 self.assertEqual(A.t('').arg, "str") 2388 self.assertEqual(A.t(0.0).arg, "base") 2389 2390 def test_callable_register(self): 2391 class A: 2392 def __init__(self, arg): 2393 self.arg = arg 2394 2395 @functools.singledispatchmethod 2396 @classmethod 2397 def t(cls, arg): 2398 return cls("base") 2399 2400 @A.t.register(int) 2401 @classmethod 2402 def _(cls, arg): 2403 return cls("int") 2404 @A.t.register(str) 2405 @classmethod 2406 def _(cls, arg): 2407 return cls("str") 2408 2409 self.assertEqual(A.t(0).arg, "int") 2410 self.assertEqual(A.t('').arg, "str") 2411 self.assertEqual(A.t(0.0).arg, "base") 2412 2413 def test_abstractmethod_register(self): 2414 class Abstract(metaclass=abc.ABCMeta): 2415 2416 @functools.singledispatchmethod 2417 @abc.abstractmethod 2418 def add(self, x, y): 2419 pass 2420 2421 self.assertTrue(Abstract.add.__isabstractmethod__) 2422 self.assertTrue(Abstract.__dict__['add'].__isabstractmethod__) 2423 2424 with self.assertRaises(TypeError): 2425 Abstract() 2426 2427 def test_type_ann_register(self): 2428 class A: 2429 @functools.singledispatchmethod 2430 def t(self, arg): 2431 return "base" 2432 @t.register 2433 def _(self, arg: int): 2434 return "int" 2435 @t.register 2436 def _(self, arg: str): 2437 return "str" 2438 a = A() 2439 2440 self.assertEqual(a.t(0), "int") 2441 self.assertEqual(a.t(''), "str") 2442 self.assertEqual(a.t(0.0), "base") 2443 2444 def test_staticmethod_type_ann_register(self): 2445 class A: 2446 @functools.singledispatchmethod 2447 @staticmethod 2448 def t(arg): 2449 return arg 2450 @t.register 2451 @staticmethod 2452 def _(arg: int): 2453 return isinstance(arg, int) 2454 @t.register 2455 @staticmethod 2456 def _(arg: str): 2457 return isinstance(arg, str) 2458 a = A() 2459 2460 self.assertTrue(A.t(0)) 2461 self.assertTrue(A.t('')) 2462 self.assertEqual(A.t(0.0), 0.0) 2463 2464 def test_classmethod_type_ann_register(self): 2465 class A: 2466 def __init__(self, arg): 2467 self.arg = arg 2468 2469 @functools.singledispatchmethod 2470 @classmethod 2471 def t(cls, arg): 2472 return cls("base") 2473 @t.register 2474 @classmethod 2475 def _(cls, arg: int): 2476 return cls("int") 2477 @t.register 2478 @classmethod 2479 def _(cls, arg: str): 2480 return cls("str") 2481 2482 self.assertEqual(A.t(0).arg, "int") 2483 self.assertEqual(A.t('').arg, "str") 2484 self.assertEqual(A.t(0.0).arg, "base") 2485 2486 def test_method_wrapping_attributes(self): 2487 class A: 2488 @functools.singledispatchmethod 2489 def func(self, arg: int) -> str: 2490 """My function docstring""" 2491 return str(arg) 2492 @functools.singledispatchmethod 2493 @classmethod 2494 def cls_func(cls, arg: int) -> str: 2495 """My function docstring""" 2496 return str(arg) 2497 @functools.singledispatchmethod 2498 @staticmethod 2499 def static_func(arg: int) -> str: 2500 """My function docstring""" 2501 return str(arg) 2502 2503 for meth in ( 2504 A.func, 2505 A().func, 2506 A.cls_func, 2507 A().cls_func, 2508 A.static_func, 2509 A().static_func 2510 ): 2511 with self.subTest(meth=meth): 2512 self.assertEqual(meth.__doc__, 'My function docstring') 2513 self.assertEqual(meth.__annotations__['arg'], int) 2514 2515 self.assertEqual(A.func.__name__, 'func') 2516 self.assertEqual(A().func.__name__, 'func') 2517 self.assertEqual(A.cls_func.__name__, 'cls_func') 2518 self.assertEqual(A().cls_func.__name__, 'cls_func') 2519 self.assertEqual(A.static_func.__name__, 'static_func') 2520 self.assertEqual(A().static_func.__name__, 'static_func') 2521 2522 def test_double_wrapped_methods(self): 2523 def classmethod_friendly_decorator(func): 2524 wrapped = func.__func__ 2525 @classmethod 2526 @functools.wraps(wrapped) 2527 def wrapper(*args, **kwargs): 2528 return wrapped(*args, **kwargs) 2529 return wrapper 2530 2531 class WithoutSingleDispatch: 2532 @classmethod 2533 @contextlib.contextmanager 2534 def cls_context_manager(cls, arg: int) -> str: 2535 try: 2536 yield str(arg) 2537 finally: 2538 return 'Done' 2539 2540 @classmethod_friendly_decorator 2541 @classmethod 2542 def decorated_classmethod(cls, arg: int) -> str: 2543 return str(arg) 2544 2545 class WithSingleDispatch: 2546 @functools.singledispatchmethod 2547 @classmethod 2548 @contextlib.contextmanager 2549 def cls_context_manager(cls, arg: int) -> str: 2550 """My function docstring""" 2551 try: 2552 yield str(arg) 2553 finally: 2554 return 'Done' 2555 2556 @functools.singledispatchmethod 2557 @classmethod_friendly_decorator 2558 @classmethod 2559 def decorated_classmethod(cls, arg: int) -> str: 2560 """My function docstring""" 2561 return str(arg) 2562 2563 # These are sanity checks 2564 # to test the test itself is working as expected 2565 with WithoutSingleDispatch.cls_context_manager(5) as foo: 2566 without_single_dispatch_foo = foo 2567 2568 with WithSingleDispatch.cls_context_manager(5) as foo: 2569 single_dispatch_foo = foo 2570 2571 self.assertEqual(without_single_dispatch_foo, single_dispatch_foo) 2572 self.assertEqual(single_dispatch_foo, '5') 2573 2574 self.assertEqual( 2575 WithoutSingleDispatch.decorated_classmethod(5), 2576 WithSingleDispatch.decorated_classmethod(5) 2577 ) 2578 2579 self.assertEqual(WithSingleDispatch.decorated_classmethod(5), '5') 2580 2581 # Behavioural checks now follow 2582 for method_name in ('cls_context_manager', 'decorated_classmethod'): 2583 with self.subTest(method=method_name): 2584 self.assertEqual( 2585 getattr(WithSingleDispatch, method_name).__name__, 2586 getattr(WithoutSingleDispatch, method_name).__name__ 2587 ) 2588 2589 self.assertEqual( 2590 getattr(WithSingleDispatch(), method_name).__name__, 2591 getattr(WithoutSingleDispatch(), method_name).__name__ 2592 ) 2593 2594 for meth in ( 2595 WithSingleDispatch.cls_context_manager, 2596 WithSingleDispatch().cls_context_manager, 2597 WithSingleDispatch.decorated_classmethod, 2598 WithSingleDispatch().decorated_classmethod 2599 ): 2600 with self.subTest(meth=meth): 2601 self.assertEqual(meth.__doc__, 'My function docstring') 2602 self.assertEqual(meth.__annotations__['arg'], int) 2603 2604 self.assertEqual( 2605 WithSingleDispatch.cls_context_manager.__name__, 2606 'cls_context_manager' 2607 ) 2608 self.assertEqual( 2609 WithSingleDispatch().cls_context_manager.__name__, 2610 'cls_context_manager' 2611 ) 2612 self.assertEqual( 2613 WithSingleDispatch.decorated_classmethod.__name__, 2614 'decorated_classmethod' 2615 ) 2616 self.assertEqual( 2617 WithSingleDispatch().decorated_classmethod.__name__, 2618 'decorated_classmethod' 2619 ) 2620 2621 def test_invalid_registrations(self): 2622 msg_prefix = "Invalid first argument to `register()`: " 2623 msg_suffix = ( 2624 ". Use either `@register(some_class)` or plain `@register` on an " 2625 "annotated function." 2626 ) 2627 @functools.singledispatch 2628 def i(arg): 2629 return "base" 2630 with self.assertRaises(TypeError) as exc: 2631 @i.register(42) 2632 def _(arg): 2633 return "I annotated with a non-type" 2634 self.assertTrue(str(exc.exception).startswith(msg_prefix + "42")) 2635 self.assertTrue(str(exc.exception).endswith(msg_suffix)) 2636 with self.assertRaises(TypeError) as exc: 2637 @i.register 2638 def _(arg): 2639 return "I forgot to annotate" 2640 self.assertTrue(str(exc.exception).startswith(msg_prefix + 2641 "<function TestSingleDispatch.test_invalid_registrations.<locals>._" 2642 )) 2643 self.assertTrue(str(exc.exception).endswith(msg_suffix)) 2644 2645 with self.assertRaises(TypeError) as exc: 2646 @i.register 2647 def _(arg: typing.Iterable[str]): 2648 # At runtime, dispatching on generics is impossible. 2649 # When registering implementations with singledispatch, avoid 2650 # types from `typing`. Instead, annotate with regular types 2651 # or ABCs. 2652 return "I annotated with a generic collection" 2653 self.assertTrue(str(exc.exception).startswith( 2654 "Invalid annotation for 'arg'." 2655 )) 2656 self.assertTrue(str(exc.exception).endswith( 2657 'typing.Iterable[str] is not a class.' 2658 )) 2659 2660 def test_invalid_positional_argument(self): 2661 @functools.singledispatch 2662 def f(*args): 2663 pass 2664 msg = 'f requires at least 1 positional argument' 2665 with self.assertRaisesRegex(TypeError, msg): 2666 f() 2667 2668 def test_register_genericalias(self): 2669 @functools.singledispatch 2670 def f(arg): 2671 return "default" 2672 2673 with self.assertRaisesRegex(TypeError, "Invalid first argument to "): 2674 f.register(list[int], lambda arg: "types.GenericAlias") 2675 with self.assertRaisesRegex(TypeError, "Invalid first argument to "): 2676 f.register(typing.List[int], lambda arg: "typing.GenericAlias") 2677 with self.assertRaisesRegex(TypeError, "Invalid first argument to "): 2678 f.register(list[int] | str, lambda arg: "types.UnionTypes(types.GenericAlias)") 2679 with self.assertRaisesRegex(TypeError, "Invalid first argument to "): 2680 f.register(typing.List[float] | bytes, lambda arg: "typing.Union[typing.GenericAlias]") 2681 with self.assertRaisesRegex(TypeError, "Invalid first argument to "): 2682 f.register(typing.Any, lambda arg: "typing.Any") 2683 2684 self.assertEqual(f([1]), "default") 2685 self.assertEqual(f([1.0]), "default") 2686 self.assertEqual(f(""), "default") 2687 self.assertEqual(f(b""), "default") 2688 2689 def test_register_genericalias_decorator(self): 2690 @functools.singledispatch 2691 def f(arg): 2692 return "default" 2693 2694 with self.assertRaisesRegex(TypeError, "Invalid first argument to "): 2695 f.register(list[int]) 2696 with self.assertRaisesRegex(TypeError, "Invalid first argument to "): 2697 f.register(typing.List[int]) 2698 with self.assertRaisesRegex(TypeError, "Invalid first argument to "): 2699 f.register(list[int] | str) 2700 with self.assertRaisesRegex(TypeError, "Invalid first argument to "): 2701 f.register(typing.List[int] | str) 2702 with self.assertRaisesRegex(TypeError, "Invalid first argument to "): 2703 f.register(typing.Any) 2704 2705 def test_register_genericalias_annotation(self): 2706 @functools.singledispatch 2707 def f(arg): 2708 return "default" 2709 2710 with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"): 2711 @f.register 2712 def _(arg: list[int]): 2713 return "types.GenericAlias" 2714 with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"): 2715 @f.register 2716 def _(arg: typing.List[float]): 2717 return "typing.GenericAlias" 2718 with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"): 2719 @f.register 2720 def _(arg: list[int] | str): 2721 return "types.UnionType(types.GenericAlias)" 2722 with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"): 2723 @f.register 2724 def _(arg: typing.List[float] | bytes): 2725 return "typing.Union[typing.GenericAlias]" 2726 with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"): 2727 @f.register 2728 def _(arg: typing.Any): 2729 return "typing.Any" 2730 2731 self.assertEqual(f([1]), "default") 2732 self.assertEqual(f([1.0]), "default") 2733 self.assertEqual(f(""), "default") 2734 self.assertEqual(f(b""), "default") 2735 2736 2737class CachedCostItem: 2738 _cost = 1 2739 2740 def __init__(self): 2741 self.lock = py_functools.RLock() 2742 2743 @py_functools.cached_property 2744 def cost(self): 2745 """The cost of the item.""" 2746 with self.lock: 2747 self._cost += 1 2748 return self._cost 2749 2750 2751class OptionallyCachedCostItem: 2752 _cost = 1 2753 2754 def get_cost(self): 2755 """The cost of the item.""" 2756 self._cost += 1 2757 return self._cost 2758 2759 cached_cost = py_functools.cached_property(get_cost) 2760 2761 2762class CachedCostItemWait: 2763 2764 def __init__(self, event): 2765 self._cost = 1 2766 self.lock = py_functools.RLock() 2767 self.event = event 2768 2769 @py_functools.cached_property 2770 def cost(self): 2771 self.event.wait(1) 2772 with self.lock: 2773 self._cost += 1 2774 return self._cost 2775 2776 2777class CachedCostItemWithSlots: 2778 __slots__ = ('_cost') 2779 2780 def __init__(self): 2781 self._cost = 1 2782 2783 @py_functools.cached_property 2784 def cost(self): 2785 raise RuntimeError('never called, slots not supported') 2786 2787 2788class TestCachedProperty(unittest.TestCase): 2789 def test_cached(self): 2790 item = CachedCostItem() 2791 self.assertEqual(item.cost, 2) 2792 self.assertEqual(item.cost, 2) # not 3 2793 2794 def test_cached_attribute_name_differs_from_func_name(self): 2795 item = OptionallyCachedCostItem() 2796 self.assertEqual(item.get_cost(), 2) 2797 self.assertEqual(item.cached_cost, 3) 2798 self.assertEqual(item.get_cost(), 4) 2799 self.assertEqual(item.cached_cost, 3) 2800 2801 def test_threaded(self): 2802 go = threading.Event() 2803 item = CachedCostItemWait(go) 2804 2805 num_threads = 3 2806 2807 orig_si = sys.getswitchinterval() 2808 sys.setswitchinterval(1e-6) 2809 try: 2810 threads = [ 2811 threading.Thread(target=lambda: item.cost) 2812 for k in range(num_threads) 2813 ] 2814 with threading_helper.start_threads(threads): 2815 go.set() 2816 finally: 2817 sys.setswitchinterval(orig_si) 2818 2819 self.assertEqual(item.cost, 2) 2820 2821 def test_object_with_slots(self): 2822 item = CachedCostItemWithSlots() 2823 with self.assertRaisesRegex( 2824 TypeError, 2825 "No '__dict__' attribute on 'CachedCostItemWithSlots' instance to cache 'cost' property.", 2826 ): 2827 item.cost 2828 2829 def test_immutable_dict(self): 2830 class MyMeta(type): 2831 @py_functools.cached_property 2832 def prop(self): 2833 return True 2834 2835 class MyClass(metaclass=MyMeta): 2836 pass 2837 2838 with self.assertRaisesRegex( 2839 TypeError, 2840 "The '__dict__' attribute on 'MyMeta' instance does not support item assignment for caching 'prop' property.", 2841 ): 2842 MyClass.prop 2843 2844 def test_reuse_different_names(self): 2845 """Disallow this case because decorated function a would not be cached.""" 2846 with self.assertRaises(RuntimeError) as ctx: 2847 class ReusedCachedProperty: 2848 @py_functools.cached_property 2849 def a(self): 2850 pass 2851 2852 b = a 2853 2854 self.assertEqual( 2855 str(ctx.exception.__context__), 2856 str(TypeError("Cannot assign the same cached_property to two different names ('a' and 'b').")) 2857 ) 2858 2859 def test_reuse_same_name(self): 2860 """Reusing a cached_property on different classes under the same name is OK.""" 2861 counter = 0 2862 2863 @py_functools.cached_property 2864 def _cp(_self): 2865 nonlocal counter 2866 counter += 1 2867 return counter 2868 2869 class A: 2870 cp = _cp 2871 2872 class B: 2873 cp = _cp 2874 2875 a = A() 2876 b = B() 2877 2878 self.assertEqual(a.cp, 1) 2879 self.assertEqual(b.cp, 2) 2880 self.assertEqual(a.cp, 1) 2881 2882 def test_set_name_not_called(self): 2883 cp = py_functools.cached_property(lambda s: None) 2884 class Foo: 2885 pass 2886 2887 Foo.cp = cp 2888 2889 with self.assertRaisesRegex( 2890 TypeError, 2891 "Cannot use cached_property instance without calling __set_name__ on it.", 2892 ): 2893 Foo().cp 2894 2895 def test_access_from_class(self): 2896 self.assertIsInstance(CachedCostItem.cost, py_functools.cached_property) 2897 2898 def test_doc(self): 2899 self.assertEqual(CachedCostItem.cost.__doc__, "The cost of the item.") 2900 2901 2902if __name__ == '__main__': 2903 unittest.main() 2904