1"""functools.py - Tools for working with functions and callable objects 2""" 3# Python module wrapper for _functools C module 4# to allow utilities written in Python to be added 5# to the functools module. 6# Written by Nick Coghlan <ncoghlan at gmail.com>, 7# Raymond Hettinger <python at rcn.com>, 8# and Łukasz Langa <lukasz at langa.pl>. 9# Copyright (C) 2006-2013 Python Software Foundation. 10# See C source code for _functools credits/copyright 11 12__all__ = ['update_wrapper', 'wraps', 'WRAPPER_ASSIGNMENTS', 'WRAPPER_UPDATES', 13 'total_ordering', 'cmp_to_key', 'lru_cache', 'reduce', 'partial', 14 'partialmethod', 'singledispatch'] 15 16try: 17 from _functools import reduce 18except ImportError: 19 pass 20from abc import get_cache_token 21from collections import namedtuple 22# import types, weakref # Deferred to single_dispatch() 23from reprlib import recursive_repr 24from _thread import RLock 25 26 27################################################################################ 28### update_wrapper() and wraps() decorator 29################################################################################ 30 31# update_wrapper() and wraps() are tools to help write 32# wrapper functions that can handle naive introspection 33 34WRAPPER_ASSIGNMENTS = ('__module__', '__name__', '__qualname__', '__doc__', 35 '__annotations__') 36WRAPPER_UPDATES = ('__dict__',) 37def update_wrapper(wrapper, 38 wrapped, 39 assigned = WRAPPER_ASSIGNMENTS, 40 updated = WRAPPER_UPDATES): 41 """Update a wrapper function to look like the wrapped function 42 43 wrapper is the function to be updated 44 wrapped is the original function 45 assigned is a tuple naming the attributes assigned directly 46 from the wrapped function to the wrapper function (defaults to 47 functools.WRAPPER_ASSIGNMENTS) 48 updated is a tuple naming the attributes of the wrapper that 49 are updated with the corresponding attribute from the wrapped 50 function (defaults to functools.WRAPPER_UPDATES) 51 """ 52 for attr in assigned: 53 try: 54 value = getattr(wrapped, attr) 55 except AttributeError: 56 pass 57 else: 58 setattr(wrapper, attr, value) 59 for attr in updated: 60 getattr(wrapper, attr).update(getattr(wrapped, attr, {})) 61 # Issue #17482: set __wrapped__ last so we don't inadvertently copy it 62 # from the wrapped function when updating __dict__ 63 wrapper.__wrapped__ = wrapped 64 # Return the wrapper so this can be used as a decorator via partial() 65 return wrapper 66 67def wraps(wrapped, 68 assigned = WRAPPER_ASSIGNMENTS, 69 updated = WRAPPER_UPDATES): 70 """Decorator factory to apply update_wrapper() to a wrapper function 71 72 Returns a decorator that invokes update_wrapper() with the decorated 73 function as the wrapper argument and the arguments to wraps() as the 74 remaining arguments. Default arguments are as for update_wrapper(). 75 This is a convenience function to simplify applying partial() to 76 update_wrapper(). 77 """ 78 return partial(update_wrapper, wrapped=wrapped, 79 assigned=assigned, updated=updated) 80 81 82################################################################################ 83### total_ordering class decorator 84################################################################################ 85 86# The total ordering functions all invoke the root magic method directly 87# rather than using the corresponding operator. This avoids possible 88# infinite recursion that could occur when the operator dispatch logic 89# detects a NotImplemented result and then calls a reflected method. 90 91def _gt_from_lt(self, other, NotImplemented=NotImplemented): 92 'Return a > b. Computed by @total_ordering from (not a < b) and (a != b).' 93 op_result = self.__lt__(other) 94 if op_result is NotImplemented: 95 return op_result 96 return not op_result and self != other 97 98def _le_from_lt(self, other, NotImplemented=NotImplemented): 99 'Return a <= b. Computed by @total_ordering from (a < b) or (a == b).' 100 op_result = self.__lt__(other) 101 return op_result or self == other 102 103def _ge_from_lt(self, other, NotImplemented=NotImplemented): 104 'Return a >= b. Computed by @total_ordering from (not a < b).' 105 op_result = self.__lt__(other) 106 if op_result is NotImplemented: 107 return op_result 108 return not op_result 109 110def _ge_from_le(self, other, NotImplemented=NotImplemented): 111 'Return a >= b. Computed by @total_ordering from (not a <= b) or (a == b).' 112 op_result = self.__le__(other) 113 if op_result is NotImplemented: 114 return op_result 115 return not op_result or self == other 116 117def _lt_from_le(self, other, NotImplemented=NotImplemented): 118 'Return a < b. Computed by @total_ordering from (a <= b) and (a != b).' 119 op_result = self.__le__(other) 120 if op_result is NotImplemented: 121 return op_result 122 return op_result and self != other 123 124def _gt_from_le(self, other, NotImplemented=NotImplemented): 125 'Return a > b. Computed by @total_ordering from (not a <= b).' 126 op_result = self.__le__(other) 127 if op_result is NotImplemented: 128 return op_result 129 return not op_result 130 131def _lt_from_gt(self, other, NotImplemented=NotImplemented): 132 'Return a < b. Computed by @total_ordering from (not a > b) and (a != b).' 133 op_result = self.__gt__(other) 134 if op_result is NotImplemented: 135 return op_result 136 return not op_result and self != other 137 138def _ge_from_gt(self, other, NotImplemented=NotImplemented): 139 'Return a >= b. Computed by @total_ordering from (a > b) or (a == b).' 140 op_result = self.__gt__(other) 141 return op_result or self == other 142 143def _le_from_gt(self, other, NotImplemented=NotImplemented): 144 'Return a <= b. Computed by @total_ordering from (not a > b).' 145 op_result = self.__gt__(other) 146 if op_result is NotImplemented: 147 return op_result 148 return not op_result 149 150def _le_from_ge(self, other, NotImplemented=NotImplemented): 151 'Return a <= b. Computed by @total_ordering from (not a >= b) or (a == b).' 152 op_result = self.__ge__(other) 153 if op_result is NotImplemented: 154 return op_result 155 return not op_result or self == other 156 157def _gt_from_ge(self, other, NotImplemented=NotImplemented): 158 'Return a > b. Computed by @total_ordering from (a >= b) and (a != b).' 159 op_result = self.__ge__(other) 160 if op_result is NotImplemented: 161 return op_result 162 return op_result and self != other 163 164def _lt_from_ge(self, other, NotImplemented=NotImplemented): 165 'Return a < b. Computed by @total_ordering from (not a >= b).' 166 op_result = self.__ge__(other) 167 if op_result is NotImplemented: 168 return op_result 169 return not op_result 170 171_convert = { 172 '__lt__': [('__gt__', _gt_from_lt), 173 ('__le__', _le_from_lt), 174 ('__ge__', _ge_from_lt)], 175 '__le__': [('__ge__', _ge_from_le), 176 ('__lt__', _lt_from_le), 177 ('__gt__', _gt_from_le)], 178 '__gt__': [('__lt__', _lt_from_gt), 179 ('__ge__', _ge_from_gt), 180 ('__le__', _le_from_gt)], 181 '__ge__': [('__le__', _le_from_ge), 182 ('__gt__', _gt_from_ge), 183 ('__lt__', _lt_from_ge)] 184} 185 186def total_ordering(cls): 187 """Class decorator that fills in missing ordering methods""" 188 # Find user-defined comparisons (not those inherited from object). 189 roots = {op for op in _convert if getattr(cls, op, None) is not getattr(object, op, None)} 190 if not roots: 191 raise ValueError('must define at least one ordering operation: < > <= >=') 192 root = max(roots) # prefer __lt__ to __le__ to __gt__ to __ge__ 193 for opname, opfunc in _convert[root]: 194 if opname not in roots: 195 opfunc.__name__ = opname 196 setattr(cls, opname, opfunc) 197 return cls 198 199 200################################################################################ 201### cmp_to_key() function converter 202################################################################################ 203 204def cmp_to_key(mycmp): 205 """Convert a cmp= function into a key= function""" 206 class K(object): 207 __slots__ = ['obj'] 208 def __init__(self, obj): 209 self.obj = obj 210 def __lt__(self, other): 211 return mycmp(self.obj, other.obj) < 0 212 def __gt__(self, other): 213 return mycmp(self.obj, other.obj) > 0 214 def __eq__(self, other): 215 return mycmp(self.obj, other.obj) == 0 216 def __le__(self, other): 217 return mycmp(self.obj, other.obj) <= 0 218 def __ge__(self, other): 219 return mycmp(self.obj, other.obj) >= 0 220 __hash__ = None 221 return K 222 223try: 224 from _functools import cmp_to_key 225except ImportError: 226 pass 227 228 229################################################################################ 230### partial() argument application 231################################################################################ 232 233# Purely functional, no descriptor behaviour 234class partial: 235 """New function with partial application of the given arguments 236 and keywords. 237 """ 238 239 __slots__ = "func", "args", "keywords", "__dict__", "__weakref__" 240 241 def __new__(*args, **keywords): 242 if not args: 243 raise TypeError("descriptor '__new__' of partial needs an argument") 244 if len(args) < 2: 245 raise TypeError("type 'partial' takes at least one argument") 246 cls, func, *args = args 247 if not callable(func): 248 raise TypeError("the first argument must be callable") 249 args = tuple(args) 250 251 if hasattr(func, "func"): 252 args = func.args + args 253 tmpkw = func.keywords.copy() 254 tmpkw.update(keywords) 255 keywords = tmpkw 256 del tmpkw 257 func = func.func 258 259 self = super(partial, cls).__new__(cls) 260 261 self.func = func 262 self.args = args 263 self.keywords = keywords 264 return self 265 266 def __call__(*args, **keywords): 267 if not args: 268 raise TypeError("descriptor '__call__' of partial needs an argument") 269 self, *args = args 270 newkeywords = self.keywords.copy() 271 newkeywords.update(keywords) 272 return self.func(*self.args, *args, **newkeywords) 273 274 @recursive_repr() 275 def __repr__(self): 276 qualname = type(self).__qualname__ 277 args = [repr(self.func)] 278 args.extend(repr(x) for x in self.args) 279 args.extend(f"{k}={v!r}" for (k, v) in self.keywords.items()) 280 if type(self).__module__ == "functools": 281 return f"functools.{qualname}({', '.join(args)})" 282 return f"{qualname}({', '.join(args)})" 283 284 def __reduce__(self): 285 return type(self), (self.func,), (self.func, self.args, 286 self.keywords or None, self.__dict__ or None) 287 288 def __setstate__(self, state): 289 if not isinstance(state, tuple): 290 raise TypeError("argument to __setstate__ must be a tuple") 291 if len(state) != 4: 292 raise TypeError(f"expected 4 items in state, got {len(state)}") 293 func, args, kwds, namespace = state 294 if (not callable(func) or not isinstance(args, tuple) or 295 (kwds is not None and not isinstance(kwds, dict)) or 296 (namespace is not None and not isinstance(namespace, dict))): 297 raise TypeError("invalid partial state") 298 299 args = tuple(args) # just in case it's a subclass 300 if kwds is None: 301 kwds = {} 302 elif type(kwds) is not dict: # XXX does it need to be *exactly* dict? 303 kwds = dict(kwds) 304 if namespace is None: 305 namespace = {} 306 307 self.__dict__ = namespace 308 self.func = func 309 self.args = args 310 self.keywords = kwds 311 312try: 313 from _functools import partial 314except ImportError: 315 pass 316 317# Descriptor version 318class partialmethod(object): 319 """Method descriptor with partial application of the given arguments 320 and keywords. 321 322 Supports wrapping existing descriptors and handles non-descriptor 323 callables as instance methods. 324 """ 325 326 def __init__(self, func, *args, **keywords): 327 if not callable(func) and not hasattr(func, "__get__"): 328 raise TypeError("{!r} is not callable or a descriptor" 329 .format(func)) 330 331 # func could be a descriptor like classmethod which isn't callable, 332 # so we can't inherit from partial (it verifies func is callable) 333 if isinstance(func, partialmethod): 334 # flattening is mandatory in order to place cls/self before all 335 # other arguments 336 # it's also more efficient since only one function will be called 337 self.func = func.func 338 self.args = func.args + args 339 self.keywords = func.keywords.copy() 340 self.keywords.update(keywords) 341 else: 342 self.func = func 343 self.args = args 344 self.keywords = keywords 345 346 def __repr__(self): 347 args = ", ".join(map(repr, self.args)) 348 keywords = ", ".join("{}={!r}".format(k, v) 349 for k, v in self.keywords.items()) 350 format_string = "{module}.{cls}({func}, {args}, {keywords})" 351 return format_string.format(module=self.__class__.__module__, 352 cls=self.__class__.__qualname__, 353 func=self.func, 354 args=args, 355 keywords=keywords) 356 357 def _make_unbound_method(self): 358 def _method(*args, **keywords): 359 call_keywords = self.keywords.copy() 360 call_keywords.update(keywords) 361 cls_or_self, *rest = args 362 call_args = (cls_or_self,) + self.args + tuple(rest) 363 return self.func(*call_args, **call_keywords) 364 _method.__isabstractmethod__ = self.__isabstractmethod__ 365 _method._partialmethod = self 366 return _method 367 368 def __get__(self, obj, cls): 369 get = getattr(self.func, "__get__", None) 370 result = None 371 if get is not None: 372 new_func = get(obj, cls) 373 if new_func is not self.func: 374 # Assume __get__ returning something new indicates the 375 # creation of an appropriate callable 376 result = partial(new_func, *self.args, **self.keywords) 377 try: 378 result.__self__ = new_func.__self__ 379 except AttributeError: 380 pass 381 if result is None: 382 # If the underlying descriptor didn't do anything, treat this 383 # like an instance method 384 result = self._make_unbound_method().__get__(obj, cls) 385 return result 386 387 @property 388 def __isabstractmethod__(self): 389 return getattr(self.func, "__isabstractmethod__", False) 390 391 392################################################################################ 393### LRU Cache function decorator 394################################################################################ 395 396_CacheInfo = namedtuple("CacheInfo", ["hits", "misses", "maxsize", "currsize"]) 397 398class _HashedSeq(list): 399 """ This class guarantees that hash() will be called no more than once 400 per element. This is important because the lru_cache() will hash 401 the key multiple times on a cache miss. 402 403 """ 404 405 __slots__ = 'hashvalue' 406 407 def __init__(self, tup, hash=hash): 408 self[:] = tup 409 self.hashvalue = hash(tup) 410 411 def __hash__(self): 412 return self.hashvalue 413 414def _make_key(args, kwds, typed, 415 kwd_mark = (object(),), 416 fasttypes = {int, str}, 417 tuple=tuple, type=type, len=len): 418 """Make a cache key from optionally typed positional and keyword arguments 419 420 The key is constructed in a way that is flat as possible rather than 421 as a nested structure that would take more memory. 422 423 If there is only a single argument and its data type is known to cache 424 its hash value, then that argument is returned without a wrapper. This 425 saves space and improves lookup speed. 426 427 """ 428 # All of code below relies on kwds preserving the order input by the user. 429 # Formerly, we sorted() the kwds before looping. The new way is *much* 430 # faster; however, it means that f(x=1, y=2) will now be treated as a 431 # distinct call from f(y=2, x=1) which will be cached separately. 432 key = args 433 if kwds: 434 key += kwd_mark 435 for item in kwds.items(): 436 key += item 437 if typed: 438 key += tuple(type(v) for v in args) 439 if kwds: 440 key += tuple(type(v) for v in kwds.values()) 441 elif len(key) == 1 and type(key[0]) in fasttypes: 442 return key[0] 443 return _HashedSeq(key) 444 445def lru_cache(maxsize=128, typed=False): 446 """Least-recently-used cache decorator. 447 448 If *maxsize* is set to None, the LRU features are disabled and the cache 449 can grow without bound. 450 451 If *typed* is True, arguments of different types will be cached separately. 452 For example, f(3.0) and f(3) will be treated as distinct calls with 453 distinct results. 454 455 Arguments to the cached function must be hashable. 456 457 View the cache statistics named tuple (hits, misses, maxsize, currsize) 458 with f.cache_info(). Clear the cache and statistics with f.cache_clear(). 459 Access the underlying function with f.__wrapped__. 460 461 See: http://en.wikipedia.org/wiki/Cache_algorithms#Least_Recently_Used 462 463 """ 464 465 # Users should only access the lru_cache through its public API: 466 # cache_info, cache_clear, and f.__wrapped__ 467 # The internals of the lru_cache are encapsulated for thread safety and 468 # to allow the implementation to change (including a possible C version). 469 470 # Early detection of an erroneous call to @lru_cache without any arguments 471 # resulting in the inner function being passed to maxsize instead of an 472 # integer or None. Negative maxsize is treated as 0. 473 if isinstance(maxsize, int): 474 if maxsize < 0: 475 maxsize = 0 476 elif maxsize is not None: 477 raise TypeError('Expected maxsize to be an integer or None') 478 479 def decorating_function(user_function): 480 wrapper = _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo) 481 return update_wrapper(wrapper, user_function) 482 483 return decorating_function 484 485def _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo): 486 # Constants shared by all lru cache instances: 487 sentinel = object() # unique object used to signal cache misses 488 make_key = _make_key # build a key from the function arguments 489 PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields 490 491 cache = {} 492 hits = misses = 0 493 full = False 494 cache_get = cache.get # bound method to lookup a key or return None 495 cache_len = cache.__len__ # get cache size without calling len() 496 lock = RLock() # because linkedlist updates aren't threadsafe 497 root = [] # root of the circular doubly linked list 498 root[:] = [root, root, None, None] # initialize by pointing to self 499 500 if maxsize == 0: 501 502 def wrapper(*args, **kwds): 503 # No caching -- just a statistics update 504 nonlocal misses 505 misses += 1 506 result = user_function(*args, **kwds) 507 return result 508 509 elif maxsize is None: 510 511 def wrapper(*args, **kwds): 512 # Simple caching without ordering or size limit 513 nonlocal hits, misses 514 key = make_key(args, kwds, typed) 515 result = cache_get(key, sentinel) 516 if result is not sentinel: 517 hits += 1 518 return result 519 misses += 1 520 result = user_function(*args, **kwds) 521 cache[key] = result 522 return result 523 524 else: 525 526 def wrapper(*args, **kwds): 527 # Size limited caching that tracks accesses by recency 528 nonlocal root, hits, misses, full 529 key = make_key(args, kwds, typed) 530 with lock: 531 link = cache_get(key) 532 if link is not None: 533 # Move the link to the front of the circular queue 534 link_prev, link_next, _key, result = link 535 link_prev[NEXT] = link_next 536 link_next[PREV] = link_prev 537 last = root[PREV] 538 last[NEXT] = root[PREV] = link 539 link[PREV] = last 540 link[NEXT] = root 541 hits += 1 542 return result 543 misses += 1 544 result = user_function(*args, **kwds) 545 with lock: 546 if key in cache: 547 # Getting here means that this same key was added to the 548 # cache while the lock was released. Since the link 549 # update is already done, we need only return the 550 # computed result and update the count of misses. 551 pass 552 elif full: 553 # Use the old root to store the new key and result. 554 oldroot = root 555 oldroot[KEY] = key 556 oldroot[RESULT] = result 557 # Empty the oldest link and make it the new root. 558 # Keep a reference to the old key and old result to 559 # prevent their ref counts from going to zero during the 560 # update. That will prevent potentially arbitrary object 561 # clean-up code (i.e. __del__) from running while we're 562 # still adjusting the links. 563 root = oldroot[NEXT] 564 oldkey = root[KEY] 565 oldresult = root[RESULT] 566 root[KEY] = root[RESULT] = None 567 # Now update the cache dictionary. 568 del cache[oldkey] 569 # Save the potentially reentrant cache[key] assignment 570 # for last, after the root and links have been put in 571 # a consistent state. 572 cache[key] = oldroot 573 else: 574 # Put result in a new link at the front of the queue. 575 last = root[PREV] 576 link = [last, root, key, result] 577 last[NEXT] = root[PREV] = cache[key] = link 578 # Use the cache_len bound method instead of the len() function 579 # which could potentially be wrapped in an lru_cache itself. 580 full = (cache_len() >= maxsize) 581 return result 582 583 def cache_info(): 584 """Report cache statistics""" 585 with lock: 586 return _CacheInfo(hits, misses, maxsize, cache_len()) 587 588 def cache_clear(): 589 """Clear the cache and cache statistics""" 590 nonlocal hits, misses, full 591 with lock: 592 cache.clear() 593 root[:] = [root, root, None, None] 594 hits = misses = 0 595 full = False 596 597 wrapper.cache_info = cache_info 598 wrapper.cache_clear = cache_clear 599 return wrapper 600 601try: 602 from _functools import _lru_cache_wrapper 603except ImportError: 604 pass 605 606 607################################################################################ 608### singledispatch() - single-dispatch generic function decorator 609################################################################################ 610 611def _c3_merge(sequences): 612 """Merges MROs in *sequences* to a single MRO using the C3 algorithm. 613 614 Adapted from http://www.python.org/download/releases/2.3/mro/. 615 616 """ 617 result = [] 618 while True: 619 sequences = [s for s in sequences if s] # purge empty sequences 620 if not sequences: 621 return result 622 for s1 in sequences: # find merge candidates among seq heads 623 candidate = s1[0] 624 for s2 in sequences: 625 if candidate in s2[1:]: 626 candidate = None 627 break # reject the current head, it appears later 628 else: 629 break 630 if candidate is None: 631 raise RuntimeError("Inconsistent hierarchy") 632 result.append(candidate) 633 # remove the chosen candidate 634 for seq in sequences: 635 if seq[0] == candidate: 636 del seq[0] 637 638def _c3_mro(cls, abcs=None): 639 """Computes the method resolution order using extended C3 linearization. 640 641 If no *abcs* are given, the algorithm works exactly like the built-in C3 642 linearization used for method resolution. 643 644 If given, *abcs* is a list of abstract base classes that should be inserted 645 into the resulting MRO. Unrelated ABCs are ignored and don't end up in the 646 result. The algorithm inserts ABCs where their functionality is introduced, 647 i.e. issubclass(cls, abc) returns True for the class itself but returns 648 False for all its direct base classes. Implicit ABCs for a given class 649 (either registered or inferred from the presence of a special method like 650 __len__) are inserted directly after the last ABC explicitly listed in the 651 MRO of said class. If two implicit ABCs end up next to each other in the 652 resulting MRO, their ordering depends on the order of types in *abcs*. 653 654 """ 655 for i, base in enumerate(reversed(cls.__bases__)): 656 if hasattr(base, '__abstractmethods__'): 657 boundary = len(cls.__bases__) - i 658 break # Bases up to the last explicit ABC are considered first. 659 else: 660 boundary = 0 661 abcs = list(abcs) if abcs else [] 662 explicit_bases = list(cls.__bases__[:boundary]) 663 abstract_bases = [] 664 other_bases = list(cls.__bases__[boundary:]) 665 for base in abcs: 666 if issubclass(cls, base) and not any( 667 issubclass(b, base) for b in cls.__bases__ 668 ): 669 # If *cls* is the class that introduces behaviour described by 670 # an ABC *base*, insert said ABC to its MRO. 671 abstract_bases.append(base) 672 for base in abstract_bases: 673 abcs.remove(base) 674 explicit_c3_mros = [_c3_mro(base, abcs=abcs) for base in explicit_bases] 675 abstract_c3_mros = [_c3_mro(base, abcs=abcs) for base in abstract_bases] 676 other_c3_mros = [_c3_mro(base, abcs=abcs) for base in other_bases] 677 return _c3_merge( 678 [[cls]] + 679 explicit_c3_mros + abstract_c3_mros + other_c3_mros + 680 [explicit_bases] + [abstract_bases] + [other_bases] 681 ) 682 683def _compose_mro(cls, types): 684 """Calculates the method resolution order for a given class *cls*. 685 686 Includes relevant abstract base classes (with their respective bases) from 687 the *types* iterable. Uses a modified C3 linearization algorithm. 688 689 """ 690 bases = set(cls.__mro__) 691 # Remove entries which are already present in the __mro__ or unrelated. 692 def is_related(typ): 693 return (typ not in bases and hasattr(typ, '__mro__') 694 and issubclass(cls, typ)) 695 types = [n for n in types if is_related(n)] 696 # Remove entries which are strict bases of other entries (they will end up 697 # in the MRO anyway. 698 def is_strict_base(typ): 699 for other in types: 700 if typ != other and typ in other.__mro__: 701 return True 702 return False 703 types = [n for n in types if not is_strict_base(n)] 704 # Subclasses of the ABCs in *types* which are also implemented by 705 # *cls* can be used to stabilize ABC ordering. 706 type_set = set(types) 707 mro = [] 708 for typ in types: 709 found = [] 710 for sub in typ.__subclasses__(): 711 if sub not in bases and issubclass(cls, sub): 712 found.append([s for s in sub.__mro__ if s in type_set]) 713 if not found: 714 mro.append(typ) 715 continue 716 # Favor subclasses with the biggest number of useful bases 717 found.sort(key=len, reverse=True) 718 for sub in found: 719 for subcls in sub: 720 if subcls not in mro: 721 mro.append(subcls) 722 return _c3_mro(cls, abcs=mro) 723 724def _find_impl(cls, registry): 725 """Returns the best matching implementation from *registry* for type *cls*. 726 727 Where there is no registered implementation for a specific type, its method 728 resolution order is used to find a more generic implementation. 729 730 Note: if *registry* does not contain an implementation for the base 731 *object* type, this function may return None. 732 733 """ 734 mro = _compose_mro(cls, registry.keys()) 735 match = None 736 for t in mro: 737 if match is not None: 738 # If *match* is an implicit ABC but there is another unrelated, 739 # equally matching implicit ABC, refuse the temptation to guess. 740 if (t in registry and t not in cls.__mro__ 741 and match not in cls.__mro__ 742 and not issubclass(match, t)): 743 raise RuntimeError("Ambiguous dispatch: {} or {}".format( 744 match, t)) 745 break 746 if t in registry: 747 match = t 748 return registry.get(match) 749 750def singledispatch(func): 751 """Single-dispatch generic function decorator. 752 753 Transforms a function into a generic function, which can have different 754 behaviours depending upon the type of its first argument. The decorated 755 function acts as the default implementation, and additional 756 implementations can be registered using the register() attribute of the 757 generic function. 758 """ 759 # There are many programs that use functools without singledispatch, so we 760 # trade-off making singledispatch marginally slower for the benefit of 761 # making start-up of such applications slightly faster. 762 import types, weakref 763 764 registry = {} 765 dispatch_cache = weakref.WeakKeyDictionary() 766 cache_token = None 767 768 def dispatch(cls): 769 """generic_func.dispatch(cls) -> <function implementation> 770 771 Runs the dispatch algorithm to return the best available implementation 772 for the given *cls* registered on *generic_func*. 773 774 """ 775 nonlocal cache_token 776 if cache_token is not None: 777 current_token = get_cache_token() 778 if cache_token != current_token: 779 dispatch_cache.clear() 780 cache_token = current_token 781 try: 782 impl = dispatch_cache[cls] 783 except KeyError: 784 try: 785 impl = registry[cls] 786 except KeyError: 787 impl = _find_impl(cls, registry) 788 dispatch_cache[cls] = impl 789 return impl 790 791 def register(cls, func=None): 792 """generic_func.register(cls, func) -> func 793 794 Registers a new implementation for the given *cls* on a *generic_func*. 795 796 """ 797 nonlocal cache_token 798 if func is None: 799 if isinstance(cls, type): 800 return lambda f: register(cls, f) 801 ann = getattr(cls, '__annotations__', {}) 802 if not ann: 803 raise TypeError( 804 f"Invalid first argument to `register()`: {cls!r}. " 805 f"Use either `@register(some_class)` or plain `@register` " 806 f"on an annotated function." 807 ) 808 func = cls 809 810 # only import typing if annotation parsing is necessary 811 from typing import get_type_hints 812 argname, cls = next(iter(get_type_hints(func).items())) 813 assert isinstance(cls, type), ( 814 f"Invalid annotation for {argname!r}. {cls!r} is not a class." 815 ) 816 registry[cls] = func 817 if cache_token is None and hasattr(cls, '__abstractmethods__'): 818 cache_token = get_cache_token() 819 dispatch_cache.clear() 820 return func 821 822 def wrapper(*args, **kw): 823 if not args: 824 raise TypeError(f'{funcname} requires at least ' 825 '1 positional argument') 826 827 return dispatch(args[0].__class__)(*args, **kw) 828 829 funcname = getattr(func, '__name__', 'singledispatch function') 830 registry[object] = func 831 wrapper.register = register 832 wrapper.dispatch = dispatch 833 wrapper.registry = types.MappingProxyType(registry) 834 wrapper._clear_cache = dispatch_cache.clear 835 update_wrapper(wrapper, func) 836 return wrapper 837