1'''This module implements specialized container datatypes providing 2alternatives to Python's general purpose built-in containers, dict, 3list, set, and tuple. 4 5* namedtuple factory function for creating tuple subclasses with named fields 6* deque list-like container with fast appends and pops on either end 7* ChainMap dict-like class for creating a single view of multiple mappings 8* Counter dict subclass for counting hashable objects 9* OrderedDict dict subclass that remembers the order entries were added 10* defaultdict dict subclass that calls a factory function to supply missing values 11* UserDict wrapper around dictionary objects for easier dict subclassing 12* UserList wrapper around list objects for easier list subclassing 13* UserString wrapper around string objects for easier string subclassing 14 15''' 16 17__all__ = ['deque', 'defaultdict', 'namedtuple', 'UserDict', 'UserList', 18 'UserString', 'Counter', 'OrderedDict', 'ChainMap'] 19 20import _collections_abc 21from operator import itemgetter as _itemgetter, eq as _eq 22from keyword import iskeyword as _iskeyword 23import sys as _sys 24import heapq as _heapq 25from _weakref import proxy as _proxy 26from itertools import repeat as _repeat, chain as _chain, starmap as _starmap 27from reprlib import recursive_repr as _recursive_repr 28 29try: 30 from _collections import deque 31except ImportError: 32 pass 33else: 34 _collections_abc.MutableSequence.register(deque) 35 36try: 37 from _collections import defaultdict 38except ImportError: 39 pass 40 41 42def __getattr__(name): 43 # For backwards compatibility, continue to make the collections ABCs 44 # through Python 3.6 available through the collections module. 45 # Note, no new collections ABCs were added in Python 3.7 46 if name in _collections_abc.__all__: 47 obj = getattr(_collections_abc, name) 48 import warnings 49 warnings.warn("Using or importing the ABCs from 'collections' instead " 50 "of from 'collections.abc' is deprecated, " 51 "and in 3.8 it will stop working", 52 DeprecationWarning, stacklevel=2) 53 globals()[name] = obj 54 return obj 55 raise AttributeError(f'module {__name__!r} has no attribute {name!r}') 56 57################################################################################ 58### OrderedDict 59################################################################################ 60 61class _OrderedDictKeysView(_collections_abc.KeysView): 62 63 def __reversed__(self): 64 yield from reversed(self._mapping) 65 66class _OrderedDictItemsView(_collections_abc.ItemsView): 67 68 def __reversed__(self): 69 for key in reversed(self._mapping): 70 yield (key, self._mapping[key]) 71 72class _OrderedDictValuesView(_collections_abc.ValuesView): 73 74 def __reversed__(self): 75 for key in reversed(self._mapping): 76 yield self._mapping[key] 77 78class _Link(object): 79 __slots__ = 'prev', 'next', 'key', '__weakref__' 80 81class OrderedDict(dict): 82 'Dictionary that remembers insertion order' 83 # An inherited dict maps keys to values. 84 # The inherited dict provides __getitem__, __len__, __contains__, and get. 85 # The remaining methods are order-aware. 86 # Big-O running times for all methods are the same as regular dictionaries. 87 88 # The internal self.__map dict maps keys to links in a doubly linked list. 89 # The circular doubly linked list starts and ends with a sentinel element. 90 # The sentinel element never gets deleted (this simplifies the algorithm). 91 # The sentinel is in self.__hardroot with a weakref proxy in self.__root. 92 # The prev links are weakref proxies (to prevent circular references). 93 # Individual links are kept alive by the hard reference in self.__map. 94 # Those hard references disappear when a key is deleted from an OrderedDict. 95 96 def __init__(*args, **kwds): 97 '''Initialize an ordered dictionary. The signature is the same as 98 regular dictionaries. Keyword argument order is preserved. 99 ''' 100 if not args: 101 raise TypeError("descriptor '__init__' of 'OrderedDict' object " 102 "needs an argument") 103 self, *args = args 104 if len(args) > 1: 105 raise TypeError('expected at most 1 arguments, got %d' % len(args)) 106 try: 107 self.__root 108 except AttributeError: 109 self.__hardroot = _Link() 110 self.__root = root = _proxy(self.__hardroot) 111 root.prev = root.next = root 112 self.__map = {} 113 self.__update(*args, **kwds) 114 115 def __setitem__(self, key, value, 116 dict_setitem=dict.__setitem__, proxy=_proxy, Link=_Link): 117 'od.__setitem__(i, y) <==> od[i]=y' 118 # Setting a new item creates a new link at the end of the linked list, 119 # and the inherited dictionary is updated with the new key/value pair. 120 if key not in self: 121 self.__map[key] = link = Link() 122 root = self.__root 123 last = root.prev 124 link.prev, link.next, link.key = last, root, key 125 last.next = link 126 root.prev = proxy(link) 127 dict_setitem(self, key, value) 128 129 def __delitem__(self, key, dict_delitem=dict.__delitem__): 130 'od.__delitem__(y) <==> del od[y]' 131 # Deleting an existing item uses self.__map to find the link which gets 132 # removed by updating the links in the predecessor and successor nodes. 133 dict_delitem(self, key) 134 link = self.__map.pop(key) 135 link_prev = link.prev 136 link_next = link.next 137 link_prev.next = link_next 138 link_next.prev = link_prev 139 link.prev = None 140 link.next = None 141 142 def __iter__(self): 143 'od.__iter__() <==> iter(od)' 144 # Traverse the linked list in order. 145 root = self.__root 146 curr = root.next 147 while curr is not root: 148 yield curr.key 149 curr = curr.next 150 151 def __reversed__(self): 152 'od.__reversed__() <==> reversed(od)' 153 # Traverse the linked list in reverse order. 154 root = self.__root 155 curr = root.prev 156 while curr is not root: 157 yield curr.key 158 curr = curr.prev 159 160 def clear(self): 161 'od.clear() -> None. Remove all items from od.' 162 root = self.__root 163 root.prev = root.next = root 164 self.__map.clear() 165 dict.clear(self) 166 167 def popitem(self, last=True): 168 '''Remove and return a (key, value) pair from the dictionary. 169 170 Pairs are returned in LIFO order if last is true or FIFO order if false. 171 ''' 172 if not self: 173 raise KeyError('dictionary is empty') 174 root = self.__root 175 if last: 176 link = root.prev 177 link_prev = link.prev 178 link_prev.next = root 179 root.prev = link_prev 180 else: 181 link = root.next 182 link_next = link.next 183 root.next = link_next 184 link_next.prev = root 185 key = link.key 186 del self.__map[key] 187 value = dict.pop(self, key) 188 return key, value 189 190 def move_to_end(self, key, last=True): 191 '''Move an existing element to the end (or beginning if last is false). 192 193 Raise KeyError if the element does not exist. 194 ''' 195 link = self.__map[key] 196 link_prev = link.prev 197 link_next = link.next 198 soft_link = link_next.prev 199 link_prev.next = link_next 200 link_next.prev = link_prev 201 root = self.__root 202 if last: 203 last = root.prev 204 link.prev = last 205 link.next = root 206 root.prev = soft_link 207 last.next = link 208 else: 209 first = root.next 210 link.prev = root 211 link.next = first 212 first.prev = soft_link 213 root.next = link 214 215 def __sizeof__(self): 216 sizeof = _sys.getsizeof 217 n = len(self) + 1 # number of links including root 218 size = sizeof(self.__dict__) # instance dictionary 219 size += sizeof(self.__map) * 2 # internal dict and inherited dict 220 size += sizeof(self.__hardroot) * n # link objects 221 size += sizeof(self.__root) * n # proxy objects 222 return size 223 224 update = __update = _collections_abc.MutableMapping.update 225 226 def keys(self): 227 "D.keys() -> a set-like object providing a view on D's keys" 228 return _OrderedDictKeysView(self) 229 230 def items(self): 231 "D.items() -> a set-like object providing a view on D's items" 232 return _OrderedDictItemsView(self) 233 234 def values(self): 235 "D.values() -> an object providing a view on D's values" 236 return _OrderedDictValuesView(self) 237 238 __ne__ = _collections_abc.MutableMapping.__ne__ 239 240 __marker = object() 241 242 def pop(self, key, default=__marker): 243 '''od.pop(k[,d]) -> v, remove specified key and return the corresponding 244 value. If key is not found, d is returned if given, otherwise KeyError 245 is raised. 246 247 ''' 248 if key in self: 249 result = self[key] 250 del self[key] 251 return result 252 if default is self.__marker: 253 raise KeyError(key) 254 return default 255 256 def setdefault(self, key, default=None): 257 '''Insert key with a value of default if key is not in the dictionary. 258 259 Return the value for key if key is in the dictionary, else default. 260 ''' 261 if key in self: 262 return self[key] 263 self[key] = default 264 return default 265 266 @_recursive_repr() 267 def __repr__(self): 268 'od.__repr__() <==> repr(od)' 269 if not self: 270 return '%s()' % (self.__class__.__name__,) 271 return '%s(%r)' % (self.__class__.__name__, list(self.items())) 272 273 def __reduce__(self): 274 'Return state information for pickling' 275 inst_dict = vars(self).copy() 276 for k in vars(OrderedDict()): 277 inst_dict.pop(k, None) 278 return self.__class__, (), inst_dict or None, None, iter(self.items()) 279 280 def copy(self): 281 'od.copy() -> a shallow copy of od' 282 return self.__class__(self) 283 284 @classmethod 285 def fromkeys(cls, iterable, value=None): 286 '''Create a new ordered dictionary with keys from iterable and values set to value. 287 ''' 288 self = cls() 289 for key in iterable: 290 self[key] = value 291 return self 292 293 def __eq__(self, other): 294 '''od.__eq__(y) <==> od==y. Comparison to another OD is order-sensitive 295 while comparison to a regular mapping is order-insensitive. 296 297 ''' 298 if isinstance(other, OrderedDict): 299 return dict.__eq__(self, other) and all(map(_eq, self, other)) 300 return dict.__eq__(self, other) 301 302 303try: 304 from _collections import OrderedDict 305except ImportError: 306 # Leave the pure Python version in place. 307 pass 308 309 310################################################################################ 311### namedtuple 312################################################################################ 313 314_nt_itemgetters = {} 315 316def namedtuple(typename, field_names, *, rename=False, defaults=None, module=None): 317 """Returns a new subclass of tuple with named fields. 318 319 >>> Point = namedtuple('Point', ['x', 'y']) 320 >>> Point.__doc__ # docstring for the new class 321 'Point(x, y)' 322 >>> p = Point(11, y=22) # instantiate with positional args or keywords 323 >>> p[0] + p[1] # indexable like a plain tuple 324 33 325 >>> x, y = p # unpack like a regular tuple 326 >>> x, y 327 (11, 22) 328 >>> p.x + p.y # fields also accessible by name 329 33 330 >>> d = p._asdict() # convert to a dictionary 331 >>> d['x'] 332 11 333 >>> Point(**d) # convert from a dictionary 334 Point(x=11, y=22) 335 >>> p._replace(x=100) # _replace() is like str.replace() but targets named fields 336 Point(x=100, y=22) 337 338 """ 339 340 # Validate the field names. At the user's option, either generate an error 341 # message or automatically replace the field name with a valid name. 342 if isinstance(field_names, str): 343 field_names = field_names.replace(',', ' ').split() 344 field_names = list(map(str, field_names)) 345 typename = _sys.intern(str(typename)) 346 347 if rename: 348 seen = set() 349 for index, name in enumerate(field_names): 350 if (not name.isidentifier() 351 or _iskeyword(name) 352 or name.startswith('_') 353 or name in seen): 354 field_names[index] = f'_{index}' 355 seen.add(name) 356 357 for name in [typename] + field_names: 358 if type(name) is not str: 359 raise TypeError('Type names and field names must be strings') 360 if not name.isidentifier(): 361 raise ValueError('Type names and field names must be valid ' 362 f'identifiers: {name!r}') 363 if _iskeyword(name): 364 raise ValueError('Type names and field names cannot be a ' 365 f'keyword: {name!r}') 366 367 seen = set() 368 for name in field_names: 369 if name.startswith('_') and not rename: 370 raise ValueError('Field names cannot start with an underscore: ' 371 f'{name!r}') 372 if name in seen: 373 raise ValueError(f'Encountered duplicate field name: {name!r}') 374 seen.add(name) 375 376 field_defaults = {} 377 if defaults is not None: 378 defaults = tuple(defaults) 379 if len(defaults) > len(field_names): 380 raise TypeError('Got more default values than field names') 381 field_defaults = dict(reversed(list(zip(reversed(field_names), 382 reversed(defaults))))) 383 384 # Variables used in the methods and docstrings 385 field_names = tuple(map(_sys.intern, field_names)) 386 num_fields = len(field_names) 387 arg_list = repr(field_names).replace("'", "")[1:-1] 388 repr_fmt = '(' + ', '.join(f'{name}=%r' for name in field_names) + ')' 389 tuple_new = tuple.__new__ 390 _len = len 391 392 # Create all the named tuple methods to be added to the class namespace 393 394 s = f'def __new__(_cls, {arg_list}): return _tuple_new(_cls, ({arg_list}))' 395 namespace = {'_tuple_new': tuple_new, '__name__': f'namedtuple_{typename}'} 396 # Note: exec() has the side-effect of interning the field names 397 exec(s, namespace) 398 __new__ = namespace['__new__'] 399 __new__.__doc__ = f'Create new instance of {typename}({arg_list})' 400 if defaults is not None: 401 __new__.__defaults__ = defaults 402 403 @classmethod 404 def _make(cls, iterable): 405 result = tuple_new(cls, iterable) 406 if _len(result) != num_fields: 407 raise TypeError(f'Expected {num_fields} arguments, got {len(result)}') 408 return result 409 410 _make.__func__.__doc__ = (f'Make a new {typename} object from a sequence ' 411 'or iterable') 412 413 def _replace(_self, **kwds): 414 result = _self._make(map(kwds.pop, field_names, _self)) 415 if kwds: 416 raise ValueError(f'Got unexpected field names: {list(kwds)!r}') 417 return result 418 419 _replace.__doc__ = (f'Return a new {typename} object replacing specified ' 420 'fields with new values') 421 422 def __repr__(self): 423 'Return a nicely formatted representation string' 424 return self.__class__.__name__ + repr_fmt % self 425 426 def _asdict(self): 427 'Return a new OrderedDict which maps field names to their values.' 428 return OrderedDict(zip(self._fields, self)) 429 430 def __getnewargs__(self): 431 'Return self as a plain tuple. Used by copy and pickle.' 432 return tuple(self) 433 434 # Modify function metadata to help with introspection and debugging 435 436 for method in (__new__, _make.__func__, _replace, 437 __repr__, _asdict, __getnewargs__): 438 method.__qualname__ = f'{typename}.{method.__name__}' 439 440 # Build-up the class namespace dictionary 441 # and use type() to build the result class 442 class_namespace = { 443 '__doc__': f'{typename}({arg_list})', 444 '__slots__': (), 445 '_fields': field_names, 446 '_fields_defaults': field_defaults, 447 '__new__': __new__, 448 '_make': _make, 449 '_replace': _replace, 450 '__repr__': __repr__, 451 '_asdict': _asdict, 452 '__getnewargs__': __getnewargs__, 453 } 454 cache = _nt_itemgetters 455 for index, name in enumerate(field_names): 456 try: 457 itemgetter_object, doc = cache[index] 458 except KeyError: 459 itemgetter_object = _itemgetter(index) 460 doc = f'Alias for field number {index}' 461 cache[index] = itemgetter_object, doc 462 class_namespace[name] = property(itemgetter_object, doc=doc) 463 464 result = type(typename, (tuple,), class_namespace) 465 466 # For pickling to work, the __module__ variable needs to be set to the frame 467 # where the named tuple is created. Bypass this step in environments where 468 # sys._getframe is not defined (Jython for example) or sys._getframe is not 469 # defined for arguments greater than 0 (IronPython), or where the user has 470 # specified a particular module. 471 if module is None: 472 try: 473 module = _sys._getframe(1).f_globals.get('__name__', '__main__') 474 except (AttributeError, ValueError): 475 pass 476 if module is not None: 477 result.__module__ = module 478 479 return result 480 481 482######################################################################## 483### Counter 484######################################################################## 485 486def _count_elements(mapping, iterable): 487 'Tally elements from the iterable.' 488 mapping_get = mapping.get 489 for elem in iterable: 490 mapping[elem] = mapping_get(elem, 0) + 1 491 492try: # Load C helper function if available 493 from _collections import _count_elements 494except ImportError: 495 pass 496 497class Counter(dict): 498 '''Dict subclass for counting hashable items. Sometimes called a bag 499 or multiset. Elements are stored as dictionary keys and their counts 500 are stored as dictionary values. 501 502 >>> c = Counter('abcdeabcdabcaba') # count elements from a string 503 504 >>> c.most_common(3) # three most common elements 505 [('a', 5), ('b', 4), ('c', 3)] 506 >>> sorted(c) # list all unique elements 507 ['a', 'b', 'c', 'd', 'e'] 508 >>> ''.join(sorted(c.elements())) # list elements with repetitions 509 'aaaaabbbbcccdde' 510 >>> sum(c.values()) # total of all counts 511 15 512 513 >>> c['a'] # count of letter 'a' 514 5 515 >>> for elem in 'shazam': # update counts from an iterable 516 ... c[elem] += 1 # by adding 1 to each element's count 517 >>> c['a'] # now there are seven 'a' 518 7 519 >>> del c['b'] # remove all 'b' 520 >>> c['b'] # now there are zero 'b' 521 0 522 523 >>> d = Counter('simsalabim') # make another counter 524 >>> c.update(d) # add in the second counter 525 >>> c['a'] # now there are nine 'a' 526 9 527 528 >>> c.clear() # empty the counter 529 >>> c 530 Counter() 531 532 Note: If a count is set to zero or reduced to zero, it will remain 533 in the counter until the entry is deleted or the counter is cleared: 534 535 >>> c = Counter('aaabbc') 536 >>> c['b'] -= 2 # reduce the count of 'b' by two 537 >>> c.most_common() # 'b' is still in, but its count is zero 538 [('a', 3), ('c', 1), ('b', 0)] 539 540 ''' 541 # References: 542 # http://en.wikipedia.org/wiki/Multiset 543 # http://www.gnu.org/software/smalltalk/manual-base/html_node/Bag.html 544 # http://www.demo2s.com/Tutorial/Cpp/0380__set-multiset/Catalog0380__set-multiset.htm 545 # http://code.activestate.com/recipes/259174/ 546 # Knuth, TAOCP Vol. II section 4.6.3 547 548 def __init__(*args, **kwds): 549 '''Create a new, empty Counter object. And if given, count elements 550 from an input iterable. Or, initialize the count from another mapping 551 of elements to their counts. 552 553 >>> c = Counter() # a new, empty counter 554 >>> c = Counter('gallahad') # a new counter from an iterable 555 >>> c = Counter({'a': 4, 'b': 2}) # a new counter from a mapping 556 >>> c = Counter(a=4, b=2) # a new counter from keyword args 557 558 ''' 559 if not args: 560 raise TypeError("descriptor '__init__' of 'Counter' object " 561 "needs an argument") 562 self, *args = args 563 if len(args) > 1: 564 raise TypeError('expected at most 1 arguments, got %d' % len(args)) 565 super(Counter, self).__init__() 566 self.update(*args, **kwds) 567 568 def __missing__(self, key): 569 'The count of elements not in the Counter is zero.' 570 # Needed so that self[missing_item] does not raise KeyError 571 return 0 572 573 def most_common(self, n=None): 574 '''List the n most common elements and their counts from the most 575 common to the least. If n is None, then list all element counts. 576 577 >>> Counter('abcdeabcdabcaba').most_common(3) 578 [('a', 5), ('b', 4), ('c', 3)] 579 580 ''' 581 # Emulate Bag.sortedByCount from Smalltalk 582 if n is None: 583 return sorted(self.items(), key=_itemgetter(1), reverse=True) 584 return _heapq.nlargest(n, self.items(), key=_itemgetter(1)) 585 586 def elements(self): 587 '''Iterator over elements repeating each as many times as its count. 588 589 >>> c = Counter('ABCABC') 590 >>> sorted(c.elements()) 591 ['A', 'A', 'B', 'B', 'C', 'C'] 592 593 # Knuth's example for prime factors of 1836: 2**2 * 3**3 * 17**1 594 >>> prime_factors = Counter({2: 2, 3: 3, 17: 1}) 595 >>> product = 1 596 >>> for factor in prime_factors.elements(): # loop over factors 597 ... product *= factor # and multiply them 598 >>> product 599 1836 600 601 Note, if an element's count has been set to zero or is a negative 602 number, elements() will ignore it. 603 604 ''' 605 # Emulate Bag.do from Smalltalk and Multiset.begin from C++. 606 return _chain.from_iterable(_starmap(_repeat, self.items())) 607 608 # Override dict methods where necessary 609 610 @classmethod 611 def fromkeys(cls, iterable, v=None): 612 # There is no equivalent method for counters because setting v=1 613 # means that no element can have a count greater than one. 614 raise NotImplementedError( 615 'Counter.fromkeys() is undefined. Use Counter(iterable) instead.') 616 617 def update(*args, **kwds): 618 '''Like dict.update() but add counts instead of replacing them. 619 620 Source can be an iterable, a dictionary, or another Counter instance. 621 622 >>> c = Counter('which') 623 >>> c.update('witch') # add elements from another iterable 624 >>> d = Counter('watch') 625 >>> c.update(d) # add elements from another counter 626 >>> c['h'] # four 'h' in which, witch, and watch 627 4 628 629 ''' 630 # The regular dict.update() operation makes no sense here because the 631 # replace behavior results in the some of original untouched counts 632 # being mixed-in with all of the other counts for a mismash that 633 # doesn't have a straight-forward interpretation in most counting 634 # contexts. Instead, we implement straight-addition. Both the inputs 635 # and outputs are allowed to contain zero and negative counts. 636 637 if not args: 638 raise TypeError("descriptor 'update' of 'Counter' object " 639 "needs an argument") 640 self, *args = args 641 if len(args) > 1: 642 raise TypeError('expected at most 1 arguments, got %d' % len(args)) 643 iterable = args[0] if args else None 644 if iterable is not None: 645 if isinstance(iterable, _collections_abc.Mapping): 646 if self: 647 self_get = self.get 648 for elem, count in iterable.items(): 649 self[elem] = count + self_get(elem, 0) 650 else: 651 super(Counter, self).update(iterable) # fast path when counter is empty 652 else: 653 _count_elements(self, iterable) 654 if kwds: 655 self.update(kwds) 656 657 def subtract(*args, **kwds): 658 '''Like dict.update() but subtracts counts instead of replacing them. 659 Counts can be reduced below zero. Both the inputs and outputs are 660 allowed to contain zero and negative counts. 661 662 Source can be an iterable, a dictionary, or another Counter instance. 663 664 >>> c = Counter('which') 665 >>> c.subtract('witch') # subtract elements from another iterable 666 >>> c.subtract(Counter('watch')) # subtract elements from another counter 667 >>> c['h'] # 2 in which, minus 1 in witch, minus 1 in watch 668 0 669 >>> c['w'] # 1 in which, minus 1 in witch, minus 1 in watch 670 -1 671 672 ''' 673 if not args: 674 raise TypeError("descriptor 'subtract' of 'Counter' object " 675 "needs an argument") 676 self, *args = args 677 if len(args) > 1: 678 raise TypeError('expected at most 1 arguments, got %d' % len(args)) 679 iterable = args[0] if args else None 680 if iterable is not None: 681 self_get = self.get 682 if isinstance(iterable, _collections_abc.Mapping): 683 for elem, count in iterable.items(): 684 self[elem] = self_get(elem, 0) - count 685 else: 686 for elem in iterable: 687 self[elem] = self_get(elem, 0) - 1 688 if kwds: 689 self.subtract(kwds) 690 691 def copy(self): 692 'Return a shallow copy.' 693 return self.__class__(self) 694 695 def __reduce__(self): 696 return self.__class__, (dict(self),) 697 698 def __delitem__(self, elem): 699 'Like dict.__delitem__() but does not raise KeyError for missing values.' 700 if elem in self: 701 super().__delitem__(elem) 702 703 def __repr__(self): 704 if not self: 705 return '%s()' % self.__class__.__name__ 706 try: 707 items = ', '.join(map('%r: %r'.__mod__, self.most_common())) 708 return '%s({%s})' % (self.__class__.__name__, items) 709 except TypeError: 710 # handle case where values are not orderable 711 return '{0}({1!r})'.format(self.__class__.__name__, dict(self)) 712 713 # Multiset-style mathematical operations discussed in: 714 # Knuth TAOCP Volume II section 4.6.3 exercise 19 715 # and at http://en.wikipedia.org/wiki/Multiset 716 # 717 # Outputs guaranteed to only include positive counts. 718 # 719 # To strip negative and zero counts, add-in an empty counter: 720 # c += Counter() 721 722 def __add__(self, other): 723 '''Add counts from two counters. 724 725 >>> Counter('abbb') + Counter('bcc') 726 Counter({'b': 4, 'c': 2, 'a': 1}) 727 728 ''' 729 if not isinstance(other, Counter): 730 return NotImplemented 731 result = Counter() 732 for elem, count in self.items(): 733 newcount = count + other[elem] 734 if newcount > 0: 735 result[elem] = newcount 736 for elem, count in other.items(): 737 if elem not in self and count > 0: 738 result[elem] = count 739 return result 740 741 def __sub__(self, other): 742 ''' Subtract count, but keep only results with positive counts. 743 744 >>> Counter('abbbc') - Counter('bccd') 745 Counter({'b': 2, 'a': 1}) 746 747 ''' 748 if not isinstance(other, Counter): 749 return NotImplemented 750 result = Counter() 751 for elem, count in self.items(): 752 newcount = count - other[elem] 753 if newcount > 0: 754 result[elem] = newcount 755 for elem, count in other.items(): 756 if elem not in self and count < 0: 757 result[elem] = 0 - count 758 return result 759 760 def __or__(self, other): 761 '''Union is the maximum of value in either of the input counters. 762 763 >>> Counter('abbb') | Counter('bcc') 764 Counter({'b': 3, 'c': 2, 'a': 1}) 765 766 ''' 767 if not isinstance(other, Counter): 768 return NotImplemented 769 result = Counter() 770 for elem, count in self.items(): 771 other_count = other[elem] 772 newcount = other_count if count < other_count else count 773 if newcount > 0: 774 result[elem] = newcount 775 for elem, count in other.items(): 776 if elem not in self and count > 0: 777 result[elem] = count 778 return result 779 780 def __and__(self, other): 781 ''' Intersection is the minimum of corresponding counts. 782 783 >>> Counter('abbb') & Counter('bcc') 784 Counter({'b': 1}) 785 786 ''' 787 if not isinstance(other, Counter): 788 return NotImplemented 789 result = Counter() 790 for elem, count in self.items(): 791 other_count = other[elem] 792 newcount = count if count < other_count else other_count 793 if newcount > 0: 794 result[elem] = newcount 795 return result 796 797 def __pos__(self): 798 'Adds an empty counter, effectively stripping negative and zero counts' 799 result = Counter() 800 for elem, count in self.items(): 801 if count > 0: 802 result[elem] = count 803 return result 804 805 def __neg__(self): 806 '''Subtracts from an empty counter. Strips positive and zero counts, 807 and flips the sign on negative counts. 808 809 ''' 810 result = Counter() 811 for elem, count in self.items(): 812 if count < 0: 813 result[elem] = 0 - count 814 return result 815 816 def _keep_positive(self): 817 '''Internal method to strip elements with a negative or zero count''' 818 nonpositive = [elem for elem, count in self.items() if not count > 0] 819 for elem in nonpositive: 820 del self[elem] 821 return self 822 823 def __iadd__(self, other): 824 '''Inplace add from another counter, keeping only positive counts. 825 826 >>> c = Counter('abbb') 827 >>> c += Counter('bcc') 828 >>> c 829 Counter({'b': 4, 'c': 2, 'a': 1}) 830 831 ''' 832 for elem, count in other.items(): 833 self[elem] += count 834 return self._keep_positive() 835 836 def __isub__(self, other): 837 '''Inplace subtract counter, but keep only results with positive counts. 838 839 >>> c = Counter('abbbc') 840 >>> c -= Counter('bccd') 841 >>> c 842 Counter({'b': 2, 'a': 1}) 843 844 ''' 845 for elem, count in other.items(): 846 self[elem] -= count 847 return self._keep_positive() 848 849 def __ior__(self, other): 850 '''Inplace union is the maximum of value from either counter. 851 852 >>> c = Counter('abbb') 853 >>> c |= Counter('bcc') 854 >>> c 855 Counter({'b': 3, 'c': 2, 'a': 1}) 856 857 ''' 858 for elem, other_count in other.items(): 859 count = self[elem] 860 if other_count > count: 861 self[elem] = other_count 862 return self._keep_positive() 863 864 def __iand__(self, other): 865 '''Inplace intersection is the minimum of corresponding counts. 866 867 >>> c = Counter('abbb') 868 >>> c &= Counter('bcc') 869 >>> c 870 Counter({'b': 1}) 871 872 ''' 873 for elem, count in self.items(): 874 other_count = other[elem] 875 if other_count < count: 876 self[elem] = other_count 877 return self._keep_positive() 878 879 880######################################################################## 881### ChainMap 882######################################################################## 883 884class ChainMap(_collections_abc.MutableMapping): 885 ''' A ChainMap groups multiple dicts (or other mappings) together 886 to create a single, updateable view. 887 888 The underlying mappings are stored in a list. That list is public and can 889 be accessed or updated using the *maps* attribute. There is no other 890 state. 891 892 Lookups search the underlying mappings successively until a key is found. 893 In contrast, writes, updates, and deletions only operate on the first 894 mapping. 895 896 ''' 897 898 def __init__(self, *maps): 899 '''Initialize a ChainMap by setting *maps* to the given mappings. 900 If no mappings are provided, a single empty dictionary is used. 901 902 ''' 903 self.maps = list(maps) or [{}] # always at least one map 904 905 def __missing__(self, key): 906 raise KeyError(key) 907 908 def __getitem__(self, key): 909 for mapping in self.maps: 910 try: 911 return mapping[key] # can't use 'key in mapping' with defaultdict 912 except KeyError: 913 pass 914 return self.__missing__(key) # support subclasses that define __missing__ 915 916 def get(self, key, default=None): 917 return self[key] if key in self else default 918 919 def __len__(self): 920 return len(set().union(*self.maps)) # reuses stored hash values if possible 921 922 def __iter__(self): 923 d = {} 924 for mapping in reversed(self.maps): 925 d.update(mapping) # reuses stored hash values if possible 926 return iter(d) 927 928 def __contains__(self, key): 929 return any(key in m for m in self.maps) 930 931 def __bool__(self): 932 return any(self.maps) 933 934 @_recursive_repr() 935 def __repr__(self): 936 return '{0.__class__.__name__}({1})'.format( 937 self, ', '.join(map(repr, self.maps))) 938 939 @classmethod 940 def fromkeys(cls, iterable, *args): 941 'Create a ChainMap with a single dict created from the iterable.' 942 return cls(dict.fromkeys(iterable, *args)) 943 944 def copy(self): 945 'New ChainMap or subclass with a new copy of maps[0] and refs to maps[1:]' 946 return self.__class__(self.maps[0].copy(), *self.maps[1:]) 947 948 __copy__ = copy 949 950 def new_child(self, m=None): # like Django's Context.push() 951 '''New ChainMap with a new map followed by all previous maps. 952 If no map is provided, an empty dict is used. 953 ''' 954 if m is None: 955 m = {} 956 return self.__class__(m, *self.maps) 957 958 @property 959 def parents(self): # like Django's Context.pop() 960 'New ChainMap from maps[1:].' 961 return self.__class__(*self.maps[1:]) 962 963 def __setitem__(self, key, value): 964 self.maps[0][key] = value 965 966 def __delitem__(self, key): 967 try: 968 del self.maps[0][key] 969 except KeyError: 970 raise KeyError('Key not found in the first mapping: {!r}'.format(key)) 971 972 def popitem(self): 973 'Remove and return an item pair from maps[0]. Raise KeyError is maps[0] is empty.' 974 try: 975 return self.maps[0].popitem() 976 except KeyError: 977 raise KeyError('No keys found in the first mapping.') 978 979 def pop(self, key, *args): 980 'Remove *key* from maps[0] and return its value. Raise KeyError if *key* not in maps[0].' 981 try: 982 return self.maps[0].pop(key, *args) 983 except KeyError: 984 raise KeyError('Key not found in the first mapping: {!r}'.format(key)) 985 986 def clear(self): 987 'Clear maps[0], leaving maps[1:] intact.' 988 self.maps[0].clear() 989 990 991################################################################################ 992### UserDict 993################################################################################ 994 995class UserDict(_collections_abc.MutableMapping): 996 997 # Start by filling-out the abstract methods 998 def __init__(*args, **kwargs): 999 if not args: 1000 raise TypeError("descriptor '__init__' of 'UserDict' object " 1001 "needs an argument") 1002 self, *args = args 1003 if len(args) > 1: 1004 raise TypeError('expected at most 1 arguments, got %d' % len(args)) 1005 if args: 1006 dict = args[0] 1007 elif 'dict' in kwargs: 1008 dict = kwargs.pop('dict') 1009 import warnings 1010 warnings.warn("Passing 'dict' as keyword argument is deprecated", 1011 DeprecationWarning, stacklevel=2) 1012 else: 1013 dict = None 1014 self.data = {} 1015 if dict is not None: 1016 self.update(dict) 1017 if len(kwargs): 1018 self.update(kwargs) 1019 def __len__(self): return len(self.data) 1020 def __getitem__(self, key): 1021 if key in self.data: 1022 return self.data[key] 1023 if hasattr(self.__class__, "__missing__"): 1024 return self.__class__.__missing__(self, key) 1025 raise KeyError(key) 1026 def __setitem__(self, key, item): self.data[key] = item 1027 def __delitem__(self, key): del self.data[key] 1028 def __iter__(self): 1029 return iter(self.data) 1030 1031 # Modify __contains__ to work correctly when __missing__ is present 1032 def __contains__(self, key): 1033 return key in self.data 1034 1035 # Now, add the methods in dicts but not in MutableMapping 1036 def __repr__(self): return repr(self.data) 1037 def copy(self): 1038 if self.__class__ is UserDict: 1039 return UserDict(self.data.copy()) 1040 import copy 1041 data = self.data 1042 try: 1043 self.data = {} 1044 c = copy.copy(self) 1045 finally: 1046 self.data = data 1047 c.update(self) 1048 return c 1049 @classmethod 1050 def fromkeys(cls, iterable, value=None): 1051 d = cls() 1052 for key in iterable: 1053 d[key] = value 1054 return d 1055 1056 1057 1058################################################################################ 1059### UserList 1060################################################################################ 1061 1062class UserList(_collections_abc.MutableSequence): 1063 """A more or less complete user-defined wrapper around list objects.""" 1064 def __init__(self, initlist=None): 1065 self.data = [] 1066 if initlist is not None: 1067 # XXX should this accept an arbitrary sequence? 1068 if type(initlist) == type(self.data): 1069 self.data[:] = initlist 1070 elif isinstance(initlist, UserList): 1071 self.data[:] = initlist.data[:] 1072 else: 1073 self.data = list(initlist) 1074 def __repr__(self): return repr(self.data) 1075 def __lt__(self, other): return self.data < self.__cast(other) 1076 def __le__(self, other): return self.data <= self.__cast(other) 1077 def __eq__(self, other): return self.data == self.__cast(other) 1078 def __gt__(self, other): return self.data > self.__cast(other) 1079 def __ge__(self, other): return self.data >= self.__cast(other) 1080 def __cast(self, other): 1081 return other.data if isinstance(other, UserList) else other 1082 def __contains__(self, item): return item in self.data 1083 def __len__(self): return len(self.data) 1084 def __getitem__(self, i): return self.data[i] 1085 def __setitem__(self, i, item): self.data[i] = item 1086 def __delitem__(self, i): del self.data[i] 1087 def __add__(self, other): 1088 if isinstance(other, UserList): 1089 return self.__class__(self.data + other.data) 1090 elif isinstance(other, type(self.data)): 1091 return self.__class__(self.data + other) 1092 return self.__class__(self.data + list(other)) 1093 def __radd__(self, other): 1094 if isinstance(other, UserList): 1095 return self.__class__(other.data + self.data) 1096 elif isinstance(other, type(self.data)): 1097 return self.__class__(other + self.data) 1098 return self.__class__(list(other) + self.data) 1099 def __iadd__(self, other): 1100 if isinstance(other, UserList): 1101 self.data += other.data 1102 elif isinstance(other, type(self.data)): 1103 self.data += other 1104 else: 1105 self.data += list(other) 1106 return self 1107 def __mul__(self, n): 1108 return self.__class__(self.data*n) 1109 __rmul__ = __mul__ 1110 def __imul__(self, n): 1111 self.data *= n 1112 return self 1113 def append(self, item): self.data.append(item) 1114 def insert(self, i, item): self.data.insert(i, item) 1115 def pop(self, i=-1): return self.data.pop(i) 1116 def remove(self, item): self.data.remove(item) 1117 def clear(self): self.data.clear() 1118 def copy(self): return self.__class__(self) 1119 def count(self, item): return self.data.count(item) 1120 def index(self, item, *args): return self.data.index(item, *args) 1121 def reverse(self): self.data.reverse() 1122 def sort(self, *args, **kwds): self.data.sort(*args, **kwds) 1123 def extend(self, other): 1124 if isinstance(other, UserList): 1125 self.data.extend(other.data) 1126 else: 1127 self.data.extend(other) 1128 1129 1130 1131################################################################################ 1132### UserString 1133################################################################################ 1134 1135class UserString(_collections_abc.Sequence): 1136 def __init__(self, seq): 1137 if isinstance(seq, str): 1138 self.data = seq 1139 elif isinstance(seq, UserString): 1140 self.data = seq.data[:] 1141 else: 1142 self.data = str(seq) 1143 def __str__(self): return str(self.data) 1144 def __repr__(self): return repr(self.data) 1145 def __int__(self): return int(self.data) 1146 def __float__(self): return float(self.data) 1147 def __complex__(self): return complex(self.data) 1148 def __hash__(self): return hash(self.data) 1149 def __getnewargs__(self): 1150 return (self.data[:],) 1151 1152 def __eq__(self, string): 1153 if isinstance(string, UserString): 1154 return self.data == string.data 1155 return self.data == string 1156 def __lt__(self, string): 1157 if isinstance(string, UserString): 1158 return self.data < string.data 1159 return self.data < string 1160 def __le__(self, string): 1161 if isinstance(string, UserString): 1162 return self.data <= string.data 1163 return self.data <= string 1164 def __gt__(self, string): 1165 if isinstance(string, UserString): 1166 return self.data > string.data 1167 return self.data > string 1168 def __ge__(self, string): 1169 if isinstance(string, UserString): 1170 return self.data >= string.data 1171 return self.data >= string 1172 1173 def __contains__(self, char): 1174 if isinstance(char, UserString): 1175 char = char.data 1176 return char in self.data 1177 1178 def __len__(self): return len(self.data) 1179 def __getitem__(self, index): return self.__class__(self.data[index]) 1180 def __add__(self, other): 1181 if isinstance(other, UserString): 1182 return self.__class__(self.data + other.data) 1183 elif isinstance(other, str): 1184 return self.__class__(self.data + other) 1185 return self.__class__(self.data + str(other)) 1186 def __radd__(self, other): 1187 if isinstance(other, str): 1188 return self.__class__(other + self.data) 1189 return self.__class__(str(other) + self.data) 1190 def __mul__(self, n): 1191 return self.__class__(self.data*n) 1192 __rmul__ = __mul__ 1193 def __mod__(self, args): 1194 return self.__class__(self.data % args) 1195 def __rmod__(self, format): 1196 return self.__class__(format % args) 1197 1198 # the following methods are defined in alphabetical order: 1199 def capitalize(self): return self.__class__(self.data.capitalize()) 1200 def casefold(self): 1201 return self.__class__(self.data.casefold()) 1202 def center(self, width, *args): 1203 return self.__class__(self.data.center(width, *args)) 1204 def count(self, sub, start=0, end=_sys.maxsize): 1205 if isinstance(sub, UserString): 1206 sub = sub.data 1207 return self.data.count(sub, start, end) 1208 def encode(self, encoding=None, errors=None): # XXX improve this? 1209 if encoding: 1210 if errors: 1211 return self.__class__(self.data.encode(encoding, errors)) 1212 return self.__class__(self.data.encode(encoding)) 1213 return self.__class__(self.data.encode()) 1214 def endswith(self, suffix, start=0, end=_sys.maxsize): 1215 return self.data.endswith(suffix, start, end) 1216 def expandtabs(self, tabsize=8): 1217 return self.__class__(self.data.expandtabs(tabsize)) 1218 def find(self, sub, start=0, end=_sys.maxsize): 1219 if isinstance(sub, UserString): 1220 sub = sub.data 1221 return self.data.find(sub, start, end) 1222 def format(self, *args, **kwds): 1223 return self.data.format(*args, **kwds) 1224 def format_map(self, mapping): 1225 return self.data.format_map(mapping) 1226 def index(self, sub, start=0, end=_sys.maxsize): 1227 return self.data.index(sub, start, end) 1228 def isalpha(self): return self.data.isalpha() 1229 def isalnum(self): return self.data.isalnum() 1230 def isascii(self): return self.data.isascii() 1231 def isdecimal(self): return self.data.isdecimal() 1232 def isdigit(self): return self.data.isdigit() 1233 def isidentifier(self): return self.data.isidentifier() 1234 def islower(self): return self.data.islower() 1235 def isnumeric(self): return self.data.isnumeric() 1236 def isprintable(self): return self.data.isprintable() 1237 def isspace(self): return self.data.isspace() 1238 def istitle(self): return self.data.istitle() 1239 def isupper(self): return self.data.isupper() 1240 def join(self, seq): return self.data.join(seq) 1241 def ljust(self, width, *args): 1242 return self.__class__(self.data.ljust(width, *args)) 1243 def lower(self): return self.__class__(self.data.lower()) 1244 def lstrip(self, chars=None): return self.__class__(self.data.lstrip(chars)) 1245 maketrans = str.maketrans 1246 def partition(self, sep): 1247 return self.data.partition(sep) 1248 def replace(self, old, new, maxsplit=-1): 1249 if isinstance(old, UserString): 1250 old = old.data 1251 if isinstance(new, UserString): 1252 new = new.data 1253 return self.__class__(self.data.replace(old, new, maxsplit)) 1254 def rfind(self, sub, start=0, end=_sys.maxsize): 1255 if isinstance(sub, UserString): 1256 sub = sub.data 1257 return self.data.rfind(sub, start, end) 1258 def rindex(self, sub, start=0, end=_sys.maxsize): 1259 return self.data.rindex(sub, start, end) 1260 def rjust(self, width, *args): 1261 return self.__class__(self.data.rjust(width, *args)) 1262 def rpartition(self, sep): 1263 return self.data.rpartition(sep) 1264 def rstrip(self, chars=None): 1265 return self.__class__(self.data.rstrip(chars)) 1266 def split(self, sep=None, maxsplit=-1): 1267 return self.data.split(sep, maxsplit) 1268 def rsplit(self, sep=None, maxsplit=-1): 1269 return self.data.rsplit(sep, maxsplit) 1270 def splitlines(self, keepends=False): return self.data.splitlines(keepends) 1271 def startswith(self, prefix, start=0, end=_sys.maxsize): 1272 return self.data.startswith(prefix, start, end) 1273 def strip(self, chars=None): return self.__class__(self.data.strip(chars)) 1274 def swapcase(self): return self.__class__(self.data.swapcase()) 1275 def title(self): return self.__class__(self.data.title()) 1276 def translate(self, *args): 1277 return self.__class__(self.data.translate(*args)) 1278 def upper(self): return self.__class__(self.data.upper()) 1279 def zfill(self, width): return self.__class__(self.data.zfill(width)) 1280