1"""functools.py - Tools for working with functions and callable objects 2""" 3# Python module wrapper for _functools C module 4# to allow utilities written in Python to be added 5# to the functools module. 6# Written by Nick Coghlan <ncoghlan at gmail.com>, 7# Raymond Hettinger <python at rcn.com>, 8# and Łukasz Langa <lukasz at langa.pl>. 9# Copyright (C) 2006-2013 Python Software Foundation. 10# See C source code for _functools credits/copyright 11 12__all__ = ['update_wrapper', 'wraps', 'WRAPPER_ASSIGNMENTS', 'WRAPPER_UPDATES', 13 'total_ordering', 'cmp_to_key', 'lru_cache', 'reduce', 'partial', 14 'partialmethod', 'singledispatch'] 15 16try: 17 from _functools import reduce 18except ImportError: 19 pass 20from abc import get_cache_token 21from collections import namedtuple 22from types import MappingProxyType 23from weakref import WeakKeyDictionary 24from reprlib import recursive_repr 25try: 26 from _thread import RLock 27except ImportError: 28 class RLock: 29 'Dummy reentrant lock for builds without threads' 30 def __enter__(self): pass 31 def __exit__(self, exctype, excinst, exctb): pass 32 33 34################################################################################ 35### update_wrapper() and wraps() decorator 36################################################################################ 37 38# update_wrapper() and wraps() are tools to help write 39# wrapper functions that can handle naive introspection 40 41WRAPPER_ASSIGNMENTS = ('__module__', '__name__', '__qualname__', '__doc__', 42 '__annotations__') 43WRAPPER_UPDATES = ('__dict__',) 44def update_wrapper(wrapper, 45 wrapped, 46 assigned = WRAPPER_ASSIGNMENTS, 47 updated = WRAPPER_UPDATES): 48 """Update a wrapper function to look like the wrapped function 49 50 wrapper is the function to be updated 51 wrapped is the original function 52 assigned is a tuple naming the attributes assigned directly 53 from the wrapped function to the wrapper function (defaults to 54 functools.WRAPPER_ASSIGNMENTS) 55 updated is a tuple naming the attributes of the wrapper that 56 are updated with the corresponding attribute from the wrapped 57 function (defaults to functools.WRAPPER_UPDATES) 58 """ 59 for attr in assigned: 60 try: 61 value = getattr(wrapped, attr) 62 except AttributeError: 63 pass 64 else: 65 setattr(wrapper, attr, value) 66 for attr in updated: 67 getattr(wrapper, attr).update(getattr(wrapped, attr, {})) 68 # Issue #17482: set __wrapped__ last so we don't inadvertently copy it 69 # from the wrapped function when updating __dict__ 70 wrapper.__wrapped__ = wrapped 71 # Return the wrapper so this can be used as a decorator via partial() 72 return wrapper 73 74def wraps(wrapped, 75 assigned = WRAPPER_ASSIGNMENTS, 76 updated = WRAPPER_UPDATES): 77 """Decorator factory to apply update_wrapper() to a wrapper function 78 79 Returns a decorator that invokes update_wrapper() with the decorated 80 function as the wrapper argument and the arguments to wraps() as the 81 remaining arguments. Default arguments are as for update_wrapper(). 82 This is a convenience function to simplify applying partial() to 83 update_wrapper(). 84 """ 85 return partial(update_wrapper, wrapped=wrapped, 86 assigned=assigned, updated=updated) 87 88 89################################################################################ 90### total_ordering class decorator 91################################################################################ 92 93# The total ordering functions all invoke the root magic method directly 94# rather than using the corresponding operator. This avoids possible 95# infinite recursion that could occur when the operator dispatch logic 96# detects a NotImplemented result and then calls a reflected method. 97 98def _gt_from_lt(self, other, NotImplemented=NotImplemented): 99 'Return a > b. Computed by @total_ordering from (not a < b) and (a != b).' 100 op_result = self.__lt__(other) 101 if op_result is NotImplemented: 102 return op_result 103 return not op_result and self != other 104 105def _le_from_lt(self, other, NotImplemented=NotImplemented): 106 'Return a <= b. Computed by @total_ordering from (a < b) or (a == b).' 107 op_result = self.__lt__(other) 108 return op_result or self == other 109 110def _ge_from_lt(self, other, NotImplemented=NotImplemented): 111 'Return a >= b. Computed by @total_ordering from (not a < b).' 112 op_result = self.__lt__(other) 113 if op_result is NotImplemented: 114 return op_result 115 return not op_result 116 117def _ge_from_le(self, other, NotImplemented=NotImplemented): 118 'Return a >= b. Computed by @total_ordering from (not a <= b) or (a == b).' 119 op_result = self.__le__(other) 120 if op_result is NotImplemented: 121 return op_result 122 return not op_result or self == other 123 124def _lt_from_le(self, other, NotImplemented=NotImplemented): 125 'Return a < b. Computed by @total_ordering from (a <= b) and (a != b).' 126 op_result = self.__le__(other) 127 if op_result is NotImplemented: 128 return op_result 129 return op_result and self != other 130 131def _gt_from_le(self, other, NotImplemented=NotImplemented): 132 'Return a > b. Computed by @total_ordering from (not a <= b).' 133 op_result = self.__le__(other) 134 if op_result is NotImplemented: 135 return op_result 136 return not op_result 137 138def _lt_from_gt(self, other, NotImplemented=NotImplemented): 139 'Return a < b. Computed by @total_ordering from (not a > b) and (a != b).' 140 op_result = self.__gt__(other) 141 if op_result is NotImplemented: 142 return op_result 143 return not op_result and self != other 144 145def _ge_from_gt(self, other, NotImplemented=NotImplemented): 146 'Return a >= b. Computed by @total_ordering from (a > b) or (a == b).' 147 op_result = self.__gt__(other) 148 return op_result or self == other 149 150def _le_from_gt(self, other, NotImplemented=NotImplemented): 151 'Return a <= b. Computed by @total_ordering from (not a > b).' 152 op_result = self.__gt__(other) 153 if op_result is NotImplemented: 154 return op_result 155 return not op_result 156 157def _le_from_ge(self, other, NotImplemented=NotImplemented): 158 'Return a <= b. Computed by @total_ordering from (not a >= b) or (a == b).' 159 op_result = self.__ge__(other) 160 if op_result is NotImplemented: 161 return op_result 162 return not op_result or self == other 163 164def _gt_from_ge(self, other, NotImplemented=NotImplemented): 165 'Return a > b. Computed by @total_ordering from (a >= b) and (a != b).' 166 op_result = self.__ge__(other) 167 if op_result is NotImplemented: 168 return op_result 169 return op_result and self != other 170 171def _lt_from_ge(self, other, NotImplemented=NotImplemented): 172 'Return a < b. Computed by @total_ordering from (not a >= b).' 173 op_result = self.__ge__(other) 174 if op_result is NotImplemented: 175 return op_result 176 return not op_result 177 178_convert = { 179 '__lt__': [('__gt__', _gt_from_lt), 180 ('__le__', _le_from_lt), 181 ('__ge__', _ge_from_lt)], 182 '__le__': [('__ge__', _ge_from_le), 183 ('__lt__', _lt_from_le), 184 ('__gt__', _gt_from_le)], 185 '__gt__': [('__lt__', _lt_from_gt), 186 ('__ge__', _ge_from_gt), 187 ('__le__', _le_from_gt)], 188 '__ge__': [('__le__', _le_from_ge), 189 ('__gt__', _gt_from_ge), 190 ('__lt__', _lt_from_ge)] 191} 192 193def total_ordering(cls): 194 """Class decorator that fills in missing ordering methods""" 195 # Find user-defined comparisons (not those inherited from object). 196 roots = [op for op in _convert if getattr(cls, op, None) is not getattr(object, op, None)] 197 if not roots: 198 raise ValueError('must define at least one ordering operation: < > <= >=') 199 root = max(roots) # prefer __lt__ to __le__ to __gt__ to __ge__ 200 for opname, opfunc in _convert[root]: 201 if opname not in roots: 202 opfunc.__name__ = opname 203 setattr(cls, opname, opfunc) 204 return cls 205 206 207################################################################################ 208### cmp_to_key() function converter 209################################################################################ 210 211def cmp_to_key(mycmp): 212 """Convert a cmp= function into a key= function""" 213 class K(object): 214 __slots__ = ['obj'] 215 def __init__(self, obj): 216 self.obj = obj 217 def __lt__(self, other): 218 return mycmp(self.obj, other.obj) < 0 219 def __gt__(self, other): 220 return mycmp(self.obj, other.obj) > 0 221 def __eq__(self, other): 222 return mycmp(self.obj, other.obj) == 0 223 def __le__(self, other): 224 return mycmp(self.obj, other.obj) <= 0 225 def __ge__(self, other): 226 return mycmp(self.obj, other.obj) >= 0 227 __hash__ = None 228 return K 229 230try: 231 from _functools import cmp_to_key 232except ImportError: 233 pass 234 235 236################################################################################ 237### partial() argument application 238################################################################################ 239 240# Purely functional, no descriptor behaviour 241class partial: 242 """New function with partial application of the given arguments 243 and keywords. 244 """ 245 246 __slots__ = "func", "args", "keywords", "__dict__", "__weakref__" 247 248 def __new__(*args, **keywords): 249 if not args: 250 raise TypeError("descriptor '__new__' of partial needs an argument") 251 if len(args) < 2: 252 raise TypeError("type 'partial' takes at least one argument") 253 cls, func, *args = args 254 if not callable(func): 255 raise TypeError("the first argument must be callable") 256 args = tuple(args) 257 258 if hasattr(func, "func"): 259 args = func.args + args 260 tmpkw = func.keywords.copy() 261 tmpkw.update(keywords) 262 keywords = tmpkw 263 del tmpkw 264 func = func.func 265 266 self = super(partial, cls).__new__(cls) 267 268 self.func = func 269 self.args = args 270 self.keywords = keywords 271 return self 272 273 def __call__(*args, **keywords): 274 if not args: 275 raise TypeError("descriptor '__call__' of partial needs an argument") 276 self, *args = args 277 newkeywords = self.keywords.copy() 278 newkeywords.update(keywords) 279 return self.func(*self.args, *args, **newkeywords) 280 281 @recursive_repr() 282 def __repr__(self): 283 qualname = type(self).__qualname__ 284 args = [repr(self.func)] 285 args.extend(repr(x) for x in self.args) 286 args.extend(f"{k}={v!r}" for (k, v) in self.keywords.items()) 287 if type(self).__module__ == "functools": 288 return f"functools.{qualname}({', '.join(args)})" 289 return f"{qualname}({', '.join(args)})" 290 291 def __reduce__(self): 292 return type(self), (self.func,), (self.func, self.args, 293 self.keywords or None, self.__dict__ or None) 294 295 def __setstate__(self, state): 296 if not isinstance(state, tuple): 297 raise TypeError("argument to __setstate__ must be a tuple") 298 if len(state) != 4: 299 raise TypeError(f"expected 4 items in state, got {len(state)}") 300 func, args, kwds, namespace = state 301 if (not callable(func) or not isinstance(args, tuple) or 302 (kwds is not None and not isinstance(kwds, dict)) or 303 (namespace is not None and not isinstance(namespace, dict))): 304 raise TypeError("invalid partial state") 305 306 args = tuple(args) # just in case it's a subclass 307 if kwds is None: 308 kwds = {} 309 elif type(kwds) is not dict: # XXX does it need to be *exactly* dict? 310 kwds = dict(kwds) 311 if namespace is None: 312 namespace = {} 313 314 self.__dict__ = namespace 315 self.func = func 316 self.args = args 317 self.keywords = kwds 318 319try: 320 from _functools import partial 321except ImportError: 322 pass 323 324# Descriptor version 325class partialmethod(object): 326 """Method descriptor with partial application of the given arguments 327 and keywords. 328 329 Supports wrapping existing descriptors and handles non-descriptor 330 callables as instance methods. 331 """ 332 333 def __init__(self, func, *args, **keywords): 334 if not callable(func) and not hasattr(func, "__get__"): 335 raise TypeError("{!r} is not callable or a descriptor" 336 .format(func)) 337 338 # func could be a descriptor like classmethod which isn't callable, 339 # so we can't inherit from partial (it verifies func is callable) 340 if isinstance(func, partialmethod): 341 # flattening is mandatory in order to place cls/self before all 342 # other arguments 343 # it's also more efficient since only one function will be called 344 self.func = func.func 345 self.args = func.args + args 346 self.keywords = func.keywords.copy() 347 self.keywords.update(keywords) 348 else: 349 self.func = func 350 self.args = args 351 self.keywords = keywords 352 353 def __repr__(self): 354 args = ", ".join(map(repr, self.args)) 355 keywords = ", ".join("{}={!r}".format(k, v) 356 for k, v in self.keywords.items()) 357 format_string = "{module}.{cls}({func}, {args}, {keywords})" 358 return format_string.format(module=self.__class__.__module__, 359 cls=self.__class__.__qualname__, 360 func=self.func, 361 args=args, 362 keywords=keywords) 363 364 def _make_unbound_method(self): 365 def _method(*args, **keywords): 366 call_keywords = self.keywords.copy() 367 call_keywords.update(keywords) 368 cls_or_self, *rest = args 369 call_args = (cls_or_self,) + self.args + tuple(rest) 370 return self.func(*call_args, **call_keywords) 371 _method.__isabstractmethod__ = self.__isabstractmethod__ 372 _method._partialmethod = self 373 return _method 374 375 def __get__(self, obj, cls): 376 get = getattr(self.func, "__get__", None) 377 result = None 378 if get is not None: 379 new_func = get(obj, cls) 380 if new_func is not self.func: 381 # Assume __get__ returning something new indicates the 382 # creation of an appropriate callable 383 result = partial(new_func, *self.args, **self.keywords) 384 try: 385 result.__self__ = new_func.__self__ 386 except AttributeError: 387 pass 388 if result is None: 389 # If the underlying descriptor didn't do anything, treat this 390 # like an instance method 391 result = self._make_unbound_method().__get__(obj, cls) 392 return result 393 394 @property 395 def __isabstractmethod__(self): 396 return getattr(self.func, "__isabstractmethod__", False) 397 398 399################################################################################ 400### LRU Cache function decorator 401################################################################################ 402 403_CacheInfo = namedtuple("CacheInfo", ["hits", "misses", "maxsize", "currsize"]) 404 405class _HashedSeq(list): 406 """ This class guarantees that hash() will be called no more than once 407 per element. This is important because the lru_cache() will hash 408 the key multiple times on a cache miss. 409 410 """ 411 412 __slots__ = 'hashvalue' 413 414 def __init__(self, tup, hash=hash): 415 self[:] = tup 416 self.hashvalue = hash(tup) 417 418 def __hash__(self): 419 return self.hashvalue 420 421def _make_key(args, kwds, typed, 422 kwd_mark = (object(),), 423 fasttypes = {int, str, frozenset, type(None)}, 424 tuple=tuple, type=type, len=len): 425 """Make a cache key from optionally typed positional and keyword arguments 426 427 The key is constructed in a way that is flat as possible rather than 428 as a nested structure that would take more memory. 429 430 If there is only a single argument and its data type is known to cache 431 its hash value, then that argument is returned without a wrapper. This 432 saves space and improves lookup speed. 433 434 """ 435 key = args 436 if kwds: 437 key += kwd_mark 438 for item in kwds.items(): 439 key += item 440 if typed: 441 key += tuple(type(v) for v in args) 442 if kwds: 443 key += tuple(type(v) for v in kwds.values()) 444 elif len(key) == 1 and type(key[0]) in fasttypes: 445 return key[0] 446 return _HashedSeq(key) 447 448def lru_cache(maxsize=128, typed=False): 449 """Least-recently-used cache decorator. 450 451 If *maxsize* is set to None, the LRU features are disabled and the cache 452 can grow without bound. 453 454 If *typed* is True, arguments of different types will be cached separately. 455 For example, f(3.0) and f(3) will be treated as distinct calls with 456 distinct results. 457 458 Arguments to the cached function must be hashable. 459 460 View the cache statistics named tuple (hits, misses, maxsize, currsize) 461 with f.cache_info(). Clear the cache and statistics with f.cache_clear(). 462 Access the underlying function with f.__wrapped__. 463 464 See: http://en.wikipedia.org/wiki/Cache_algorithms#Least_Recently_Used 465 466 """ 467 468 # Users should only access the lru_cache through its public API: 469 # cache_info, cache_clear, and f.__wrapped__ 470 # The internals of the lru_cache are encapsulated for thread safety and 471 # to allow the implementation to change (including a possible C version). 472 473 # Early detection of an erroneous call to @lru_cache without any arguments 474 # resulting in the inner function being passed to maxsize instead of an 475 # integer or None. 476 if maxsize is not None and not isinstance(maxsize, int): 477 raise TypeError('Expected maxsize to be an integer or None') 478 479 def decorating_function(user_function): 480 wrapper = _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo) 481 return update_wrapper(wrapper, user_function) 482 483 return decorating_function 484 485def _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo): 486 # Constants shared by all lru cache instances: 487 sentinel = object() # unique object used to signal cache misses 488 make_key = _make_key # build a key from the function arguments 489 PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields 490 491 cache = {} 492 hits = misses = 0 493 full = False 494 cache_get = cache.get # bound method to lookup a key or return None 495 cache_len = cache.__len__ # get cache size without calling len() 496 lock = RLock() # because linkedlist updates aren't threadsafe 497 root = [] # root of the circular doubly linked list 498 root[:] = [root, root, None, None] # initialize by pointing to self 499 500 if maxsize == 0: 501 502 def wrapper(*args, **kwds): 503 # No caching -- just a statistics update after a successful call 504 nonlocal misses 505 result = user_function(*args, **kwds) 506 misses += 1 507 return result 508 509 elif maxsize is None: 510 511 def wrapper(*args, **kwds): 512 # Simple caching without ordering or size limit 513 nonlocal hits, misses 514 key = make_key(args, kwds, typed) 515 result = cache_get(key, sentinel) 516 if result is not sentinel: 517 hits += 1 518 return result 519 result = user_function(*args, **kwds) 520 cache[key] = result 521 misses += 1 522 return result 523 524 else: 525 526 def wrapper(*args, **kwds): 527 # Size limited caching that tracks accesses by recency 528 nonlocal root, hits, misses, full 529 key = make_key(args, kwds, typed) 530 with lock: 531 link = cache_get(key) 532 if link is not None: 533 # Move the link to the front of the circular queue 534 link_prev, link_next, _key, result = link 535 link_prev[NEXT] = link_next 536 link_next[PREV] = link_prev 537 last = root[PREV] 538 last[NEXT] = root[PREV] = link 539 link[PREV] = last 540 link[NEXT] = root 541 hits += 1 542 return result 543 result = user_function(*args, **kwds) 544 with lock: 545 if key in cache: 546 # Getting here means that this same key was added to the 547 # cache while the lock was released. Since the link 548 # update is already done, we need only return the 549 # computed result and update the count of misses. 550 pass 551 elif full: 552 # Use the old root to store the new key and result. 553 oldroot = root 554 oldroot[KEY] = key 555 oldroot[RESULT] = result 556 # Empty the oldest link and make it the new root. 557 # Keep a reference to the old key and old result to 558 # prevent their ref counts from going to zero during the 559 # update. That will prevent potentially arbitrary object 560 # clean-up code (i.e. __del__) from running while we're 561 # still adjusting the links. 562 root = oldroot[NEXT] 563 oldkey = root[KEY] 564 oldresult = root[RESULT] 565 root[KEY] = root[RESULT] = None 566 # Now update the cache dictionary. 567 del cache[oldkey] 568 # Save the potentially reentrant cache[key] assignment 569 # for last, after the root and links have been put in 570 # a consistent state. 571 cache[key] = oldroot 572 else: 573 # Put result in a new link at the front of the queue. 574 last = root[PREV] 575 link = [last, root, key, result] 576 last[NEXT] = root[PREV] = cache[key] = link 577 # Use the cache_len bound method instead of the len() function 578 # which could potentially be wrapped in an lru_cache itself. 579 full = (cache_len() >= maxsize) 580 misses += 1 581 return result 582 583 def cache_info(): 584 """Report cache statistics""" 585 with lock: 586 return _CacheInfo(hits, misses, maxsize, cache_len()) 587 588 def cache_clear(): 589 """Clear the cache and cache statistics""" 590 nonlocal hits, misses, full 591 with lock: 592 cache.clear() 593 root[:] = [root, root, None, None] 594 hits = misses = 0 595 full = False 596 597 wrapper.cache_info = cache_info 598 wrapper.cache_clear = cache_clear 599 return wrapper 600 601try: 602 from _functools import _lru_cache_wrapper 603except ImportError: 604 pass 605 606 607################################################################################ 608### singledispatch() - single-dispatch generic function decorator 609################################################################################ 610 611def _c3_merge(sequences): 612 """Merges MROs in *sequences* to a single MRO using the C3 algorithm. 613 614 Adapted from http://www.python.org/download/releases/2.3/mro/. 615 616 """ 617 result = [] 618 while True: 619 sequences = [s for s in sequences if s] # purge empty sequences 620 if not sequences: 621 return result 622 for s1 in sequences: # find merge candidates among seq heads 623 candidate = s1[0] 624 for s2 in sequences: 625 if candidate in s2[1:]: 626 candidate = None 627 break # reject the current head, it appears later 628 else: 629 break 630 if candidate is None: 631 raise RuntimeError("Inconsistent hierarchy") 632 result.append(candidate) 633 # remove the chosen candidate 634 for seq in sequences: 635 if seq[0] == candidate: 636 del seq[0] 637 638def _c3_mro(cls, abcs=None): 639 """Computes the method resolution order using extended C3 linearization. 640 641 If no *abcs* are given, the algorithm works exactly like the built-in C3 642 linearization used for method resolution. 643 644 If given, *abcs* is a list of abstract base classes that should be inserted 645 into the resulting MRO. Unrelated ABCs are ignored and don't end up in the 646 result. The algorithm inserts ABCs where their functionality is introduced, 647 i.e. issubclass(cls, abc) returns True for the class itself but returns 648 False for all its direct base classes. Implicit ABCs for a given class 649 (either registered or inferred from the presence of a special method like 650 __len__) are inserted directly after the last ABC explicitly listed in the 651 MRO of said class. If two implicit ABCs end up next to each other in the 652 resulting MRO, their ordering depends on the order of types in *abcs*. 653 654 """ 655 for i, base in enumerate(reversed(cls.__bases__)): 656 if hasattr(base, '__abstractmethods__'): 657 boundary = len(cls.__bases__) - i 658 break # Bases up to the last explicit ABC are considered first. 659 else: 660 boundary = 0 661 abcs = list(abcs) if abcs else [] 662 explicit_bases = list(cls.__bases__[:boundary]) 663 abstract_bases = [] 664 other_bases = list(cls.__bases__[boundary:]) 665 for base in abcs: 666 if issubclass(cls, base) and not any( 667 issubclass(b, base) for b in cls.__bases__ 668 ): 669 # If *cls* is the class that introduces behaviour described by 670 # an ABC *base*, insert said ABC to its MRO. 671 abstract_bases.append(base) 672 for base in abstract_bases: 673 abcs.remove(base) 674 explicit_c3_mros = [_c3_mro(base, abcs=abcs) for base in explicit_bases] 675 abstract_c3_mros = [_c3_mro(base, abcs=abcs) for base in abstract_bases] 676 other_c3_mros = [_c3_mro(base, abcs=abcs) for base in other_bases] 677 return _c3_merge( 678 [[cls]] + 679 explicit_c3_mros + abstract_c3_mros + other_c3_mros + 680 [explicit_bases] + [abstract_bases] + [other_bases] 681 ) 682 683def _compose_mro(cls, types): 684 """Calculates the method resolution order for a given class *cls*. 685 686 Includes relevant abstract base classes (with their respective bases) from 687 the *types* iterable. Uses a modified C3 linearization algorithm. 688 689 """ 690 bases = set(cls.__mro__) 691 # Remove entries which are already present in the __mro__ or unrelated. 692 def is_related(typ): 693 return (typ not in bases and hasattr(typ, '__mro__') 694 and issubclass(cls, typ)) 695 types = [n for n in types if is_related(n)] 696 # Remove entries which are strict bases of other entries (they will end up 697 # in the MRO anyway. 698 def is_strict_base(typ): 699 for other in types: 700 if typ != other and typ in other.__mro__: 701 return True 702 return False 703 types = [n for n in types if not is_strict_base(n)] 704 # Subclasses of the ABCs in *types* which are also implemented by 705 # *cls* can be used to stabilize ABC ordering. 706 type_set = set(types) 707 mro = [] 708 for typ in types: 709 found = [] 710 for sub in typ.__subclasses__(): 711 if sub not in bases and issubclass(cls, sub): 712 found.append([s for s in sub.__mro__ if s in type_set]) 713 if not found: 714 mro.append(typ) 715 continue 716 # Favor subclasses with the biggest number of useful bases 717 found.sort(key=len, reverse=True) 718 for sub in found: 719 for subcls in sub: 720 if subcls not in mro: 721 mro.append(subcls) 722 return _c3_mro(cls, abcs=mro) 723 724def _find_impl(cls, registry): 725 """Returns the best matching implementation from *registry* for type *cls*. 726 727 Where there is no registered implementation for a specific type, its method 728 resolution order is used to find a more generic implementation. 729 730 Note: if *registry* does not contain an implementation for the base 731 *object* type, this function may return None. 732 733 """ 734 mro = _compose_mro(cls, registry.keys()) 735 match = None 736 for t in mro: 737 if match is not None: 738 # If *match* is an implicit ABC but there is another unrelated, 739 # equally matching implicit ABC, refuse the temptation to guess. 740 if (t in registry and t not in cls.__mro__ 741 and match not in cls.__mro__ 742 and not issubclass(match, t)): 743 raise RuntimeError("Ambiguous dispatch: {} or {}".format( 744 match, t)) 745 break 746 if t in registry: 747 match = t 748 return registry.get(match) 749 750def singledispatch(func): 751 """Single-dispatch generic function decorator. 752 753 Transforms a function into a generic function, which can have different 754 behaviours depending upon the type of its first argument. The decorated 755 function acts as the default implementation, and additional 756 implementations can be registered using the register() attribute of the 757 generic function. 758 759 """ 760 registry = {} 761 dispatch_cache = WeakKeyDictionary() 762 cache_token = None 763 764 def dispatch(cls): 765 """generic_func.dispatch(cls) -> <function implementation> 766 767 Runs the dispatch algorithm to return the best available implementation 768 for the given *cls* registered on *generic_func*. 769 770 """ 771 nonlocal cache_token 772 if cache_token is not None: 773 current_token = get_cache_token() 774 if cache_token != current_token: 775 dispatch_cache.clear() 776 cache_token = current_token 777 try: 778 impl = dispatch_cache[cls] 779 except KeyError: 780 try: 781 impl = registry[cls] 782 except KeyError: 783 impl = _find_impl(cls, registry) 784 dispatch_cache[cls] = impl 785 return impl 786 787 def register(cls, func=None): 788 """generic_func.register(cls, func) -> func 789 790 Registers a new implementation for the given *cls* on a *generic_func*. 791 792 """ 793 nonlocal cache_token 794 if func is None: 795 return lambda f: register(cls, f) 796 registry[cls] = func 797 if cache_token is None and hasattr(cls, '__abstractmethods__'): 798 cache_token = get_cache_token() 799 dispatch_cache.clear() 800 return func 801 802 def wrapper(*args, **kw): 803 return dispatch(args[0].__class__)(*args, **kw) 804 805 registry[object] = func 806 wrapper.register = register 807 wrapper.dispatch = dispatch 808 wrapper.registry = MappingProxyType(registry) 809 wrapper._clear_cache = dispatch_cache.clear 810 update_wrapper(wrapper, func) 811 return wrapper 812