1import re 2import sys 3import copy 4import types 5import inspect 6import keyword 7import builtins 8import functools 9import _thread 10 11 12__all__ = ['dataclass', 13 'field', 14 'Field', 15 'FrozenInstanceError', 16 'InitVar', 17 'MISSING', 18 19 # Helper functions. 20 'fields', 21 'asdict', 22 'astuple', 23 'make_dataclass', 24 'replace', 25 'is_dataclass', 26 ] 27 28# Conditions for adding methods. The boxes indicate what action the 29# dataclass decorator takes. For all of these tables, when I talk 30# about init=, repr=, eq=, order=, unsafe_hash=, or frozen=, I'm 31# referring to the arguments to the @dataclass decorator. When 32# checking if a dunder method already exists, I mean check for an 33# entry in the class's __dict__. I never check to see if an attribute 34# is defined in a base class. 35 36# Key: 37# +=========+=========================================+ 38# + Value | Meaning | 39# +=========+=========================================+ 40# | <blank> | No action: no method is added. | 41# +---------+-----------------------------------------+ 42# | add | Generated method is added. | 43# +---------+-----------------------------------------+ 44# | raise | TypeError is raised. | 45# +---------+-----------------------------------------+ 46# | None | Attribute is set to None. | 47# +=========+=========================================+ 48 49# __init__ 50# 51# +--- init= parameter 52# | 53# v | | | 54# | no | yes | <--- class has __init__ in __dict__? 55# +=======+=======+=======+ 56# | False | | | 57# +-------+-------+-------+ 58# | True | add | | <- the default 59# +=======+=======+=======+ 60 61# __repr__ 62# 63# +--- repr= parameter 64# | 65# v | | | 66# | no | yes | <--- class has __repr__ in __dict__? 67# +=======+=======+=======+ 68# | False | | | 69# +-------+-------+-------+ 70# | True | add | | <- the default 71# +=======+=======+=======+ 72 73 74# __setattr__ 75# __delattr__ 76# 77# +--- frozen= parameter 78# | 79# v | | | 80# | no | yes | <--- class has __setattr__ or __delattr__ in __dict__? 81# +=======+=======+=======+ 82# | False | | | <- the default 83# +-------+-------+-------+ 84# | True | add | raise | 85# +=======+=======+=======+ 86# Raise because not adding these methods would break the "frozen-ness" 87# of the class. 88 89# __eq__ 90# 91# +--- eq= parameter 92# | 93# v | | | 94# | no | yes | <--- class has __eq__ in __dict__? 95# +=======+=======+=======+ 96# | False | | | 97# +-------+-------+-------+ 98# | True | add | | <- the default 99# +=======+=======+=======+ 100 101# __lt__ 102# __le__ 103# __gt__ 104# __ge__ 105# 106# +--- order= parameter 107# | 108# v | | | 109# | no | yes | <--- class has any comparison method in __dict__? 110# +=======+=======+=======+ 111# | False | | | <- the default 112# +-------+-------+-------+ 113# | True | add | raise | 114# +=======+=======+=======+ 115# Raise because to allow this case would interfere with using 116# functools.total_ordering. 117 118# __hash__ 119 120# +------------------- unsafe_hash= parameter 121# | +----------- eq= parameter 122# | | +--- frozen= parameter 123# | | | 124# v v v | | | 125# | no | yes | <--- class has explicitly defined __hash__ 126# +=======+=======+=======+========+========+ 127# | False | False | False | | | No __eq__, use the base class __hash__ 128# +-------+-------+-------+--------+--------+ 129# | False | False | True | | | No __eq__, use the base class __hash__ 130# +-------+-------+-------+--------+--------+ 131# | False | True | False | None | | <-- the default, not hashable 132# +-------+-------+-------+--------+--------+ 133# | False | True | True | add | | Frozen, so hashable, allows override 134# +-------+-------+-------+--------+--------+ 135# | True | False | False | add | raise | Has no __eq__, but hashable 136# +-------+-------+-------+--------+--------+ 137# | True | False | True | add | raise | Has no __eq__, but hashable 138# +-------+-------+-------+--------+--------+ 139# | True | True | False | add | raise | Not frozen, but hashable 140# +-------+-------+-------+--------+--------+ 141# | True | True | True | add | raise | Frozen, so hashable 142# +=======+=======+=======+========+========+ 143# For boxes that are blank, __hash__ is untouched and therefore 144# inherited from the base class. If the base is object, then 145# id-based hashing is used. 146# 147# Note that a class may already have __hash__=None if it specified an 148# __eq__ method in the class body (not one that was created by 149# @dataclass). 150# 151# See _hash_action (below) for a coded version of this table. 152 153 154# Raised when an attempt is made to modify a frozen class. 155class FrozenInstanceError(AttributeError): pass 156 157# A sentinel object for default values to signal that a default 158# factory will be used. This is given a nice repr() which will appear 159# in the function signature of dataclasses' constructors. 160class _HAS_DEFAULT_FACTORY_CLASS: 161 def __repr__(self): 162 return '<factory>' 163_HAS_DEFAULT_FACTORY = _HAS_DEFAULT_FACTORY_CLASS() 164 165# A sentinel object to detect if a parameter is supplied or not. Use 166# a class to give it a better repr. 167class _MISSING_TYPE: 168 pass 169MISSING = _MISSING_TYPE() 170 171# Since most per-field metadata will be unused, create an empty 172# read-only proxy that can be shared among all fields. 173_EMPTY_METADATA = types.MappingProxyType({}) 174 175# Markers for the various kinds of fields and pseudo-fields. 176class _FIELD_BASE: 177 def __init__(self, name): 178 self.name = name 179 def __repr__(self): 180 return self.name 181_FIELD = _FIELD_BASE('_FIELD') 182_FIELD_CLASSVAR = _FIELD_BASE('_FIELD_CLASSVAR') 183_FIELD_INITVAR = _FIELD_BASE('_FIELD_INITVAR') 184 185# The name of an attribute on the class where we store the Field 186# objects. Also used to check if a class is a Data Class. 187_FIELDS = '__dataclass_fields__' 188 189# The name of an attribute on the class that stores the parameters to 190# @dataclass. 191_PARAMS = '__dataclass_params__' 192 193# The name of the function, that if it exists, is called at the end of 194# __init__. 195_POST_INIT_NAME = '__post_init__' 196 197# String regex that string annotations for ClassVar or InitVar must match. 198# Allows "identifier.identifier[" or "identifier[". 199# https://bugs.python.org/issue33453 for details. 200_MODULE_IDENTIFIER_RE = re.compile(r'^(?:\s*(\w+)\s*\.)?\s*(\w+)') 201 202class _InitVarMeta(type): 203 def __getitem__(self, params): 204 return InitVar(params) 205 206class InitVar(metaclass=_InitVarMeta): 207 __slots__ = ('type', ) 208 209 def __init__(self, type): 210 self.type = type 211 212 def __repr__(self): 213 if isinstance(self.type, type): 214 type_name = self.type.__name__ 215 else: 216 # typing objects, e.g. List[int] 217 type_name = repr(self.type) 218 return f'dataclasses.InitVar[{type_name}]' 219 220 221# Instances of Field are only ever created from within this module, 222# and only from the field() function, although Field instances are 223# exposed externally as (conceptually) read-only objects. 224# 225# name and type are filled in after the fact, not in __init__. 226# They're not known at the time this class is instantiated, but it's 227# convenient if they're available later. 228# 229# When cls._FIELDS is filled in with a list of Field objects, the name 230# and type fields will have been populated. 231class Field: 232 __slots__ = ('name', 233 'type', 234 'default', 235 'default_factory', 236 'repr', 237 'hash', 238 'init', 239 'compare', 240 'metadata', 241 '_field_type', # Private: not to be used by user code. 242 ) 243 244 def __init__(self, default, default_factory, init, repr, hash, compare, 245 metadata): 246 self.name = None 247 self.type = None 248 self.default = default 249 self.default_factory = default_factory 250 self.init = init 251 self.repr = repr 252 self.hash = hash 253 self.compare = compare 254 self.metadata = (_EMPTY_METADATA 255 if metadata is None else 256 types.MappingProxyType(metadata)) 257 self._field_type = None 258 259 def __repr__(self): 260 return ('Field(' 261 f'name={self.name!r},' 262 f'type={self.type!r},' 263 f'default={self.default!r},' 264 f'default_factory={self.default_factory!r},' 265 f'init={self.init!r},' 266 f'repr={self.repr!r},' 267 f'hash={self.hash!r},' 268 f'compare={self.compare!r},' 269 f'metadata={self.metadata!r},' 270 f'_field_type={self._field_type}' 271 ')') 272 273 # This is used to support the PEP 487 __set_name__ protocol in the 274 # case where we're using a field that contains a descriptor as a 275 # default value. For details on __set_name__, see 276 # https://www.python.org/dev/peps/pep-0487/#implementation-details. 277 # 278 # Note that in _process_class, this Field object is overwritten 279 # with the default value, so the end result is a descriptor that 280 # had __set_name__ called on it at the right time. 281 def __set_name__(self, owner, name): 282 func = getattr(type(self.default), '__set_name__', None) 283 if func: 284 # There is a __set_name__ method on the descriptor, call 285 # it. 286 func(self.default, owner, name) 287 288 289class _DataclassParams: 290 __slots__ = ('init', 291 'repr', 292 'eq', 293 'order', 294 'unsafe_hash', 295 'frozen', 296 ) 297 298 def __init__(self, init, repr, eq, order, unsafe_hash, frozen): 299 self.init = init 300 self.repr = repr 301 self.eq = eq 302 self.order = order 303 self.unsafe_hash = unsafe_hash 304 self.frozen = frozen 305 306 def __repr__(self): 307 return ('_DataclassParams(' 308 f'init={self.init!r},' 309 f'repr={self.repr!r},' 310 f'eq={self.eq!r},' 311 f'order={self.order!r},' 312 f'unsafe_hash={self.unsafe_hash!r},' 313 f'frozen={self.frozen!r}' 314 ')') 315 316 317# This function is used instead of exposing Field creation directly, 318# so that a type checker can be told (via overloads) that this is a 319# function whose type depends on its parameters. 320def field(*, default=MISSING, default_factory=MISSING, init=True, repr=True, 321 hash=None, compare=True, metadata=None): 322 """Return an object to identify dataclass fields. 323 324 default is the default value of the field. default_factory is a 325 0-argument function called to initialize a field's value. If init 326 is True, the field will be a parameter to the class's __init__() 327 function. If repr is True, the field will be included in the 328 object's repr(). If hash is True, the field will be included in 329 the object's hash(). If compare is True, the field will be used 330 in comparison functions. metadata, if specified, must be a 331 mapping which is stored but not otherwise examined by dataclass. 332 333 It is an error to specify both default and default_factory. 334 """ 335 336 if default is not MISSING and default_factory is not MISSING: 337 raise ValueError('cannot specify both default and default_factory') 338 return Field(default, default_factory, init, repr, hash, compare, 339 metadata) 340 341 342def _tuple_str(obj_name, fields): 343 # Return a string representing each field of obj_name as a tuple 344 # member. So, if fields is ['x', 'y'] and obj_name is "self", 345 # return "(self.x,self.y)". 346 347 # Special case for the 0-tuple. 348 if not fields: 349 return '()' 350 # Note the trailing comma, needed if this turns out to be a 1-tuple. 351 return f'({",".join([f"{obj_name}.{f.name}" for f in fields])},)' 352 353 354# This function's logic is copied from "recursive_repr" function in 355# reprlib module to avoid dependency. 356def _recursive_repr(user_function): 357 # Decorator to make a repr function return "..." for a recursive 358 # call. 359 repr_running = set() 360 361 @functools.wraps(user_function) 362 def wrapper(self): 363 key = id(self), _thread.get_ident() 364 if key in repr_running: 365 return '...' 366 repr_running.add(key) 367 try: 368 result = user_function(self) 369 finally: 370 repr_running.discard(key) 371 return result 372 return wrapper 373 374 375def _create_fn(name, args, body, *, globals=None, locals=None, 376 return_type=MISSING): 377 # Note that we mutate locals when exec() is called. Caller 378 # beware! The only callers are internal to this module, so no 379 # worries about external callers. 380 if locals is None: 381 locals = {} 382 if 'BUILTINS' not in locals: 383 locals['BUILTINS'] = builtins 384 return_annotation = '' 385 if return_type is not MISSING: 386 locals['_return_type'] = return_type 387 return_annotation = '->_return_type' 388 args = ','.join(args) 389 body = '\n'.join(f' {b}' for b in body) 390 391 # Compute the text of the entire function. 392 txt = f' def {name}({args}){return_annotation}:\n{body}' 393 394 local_vars = ', '.join(locals.keys()) 395 txt = f"def __create_fn__({local_vars}):\n{txt}\n return {name}" 396 397 ns = {} 398 exec(txt, globals, ns) 399 return ns['__create_fn__'](**locals) 400 401 402def _field_assign(frozen, name, value, self_name): 403 # If we're a frozen class, then assign to our fields in __init__ 404 # via object.__setattr__. Otherwise, just use a simple 405 # assignment. 406 # 407 # self_name is what "self" is called in this function: don't 408 # hard-code "self", since that might be a field name. 409 if frozen: 410 return f'BUILTINS.object.__setattr__({self_name},{name!r},{value})' 411 return f'{self_name}.{name}={value}' 412 413 414def _field_init(f, frozen, globals, self_name): 415 # Return the text of the line in the body of __init__ that will 416 # initialize this field. 417 418 default_name = f'_dflt_{f.name}' 419 if f.default_factory is not MISSING: 420 if f.init: 421 # This field has a default factory. If a parameter is 422 # given, use it. If not, call the factory. 423 globals[default_name] = f.default_factory 424 value = (f'{default_name}() ' 425 f'if {f.name} is _HAS_DEFAULT_FACTORY ' 426 f'else {f.name}') 427 else: 428 # This is a field that's not in the __init__ params, but 429 # has a default factory function. It needs to be 430 # initialized here by calling the factory function, 431 # because there's no other way to initialize it. 432 433 # For a field initialized with a default=defaultvalue, the 434 # class dict just has the default value 435 # (cls.fieldname=defaultvalue). But that won't work for a 436 # default factory, the factory must be called in __init__ 437 # and we must assign that to self.fieldname. We can't 438 # fall back to the class dict's value, both because it's 439 # not set, and because it might be different per-class 440 # (which, after all, is why we have a factory function!). 441 442 globals[default_name] = f.default_factory 443 value = f'{default_name}()' 444 else: 445 # No default factory. 446 if f.init: 447 if f.default is MISSING: 448 # There's no default, just do an assignment. 449 value = f.name 450 elif f.default is not MISSING: 451 globals[default_name] = f.default 452 value = f.name 453 else: 454 # This field does not need initialization. Signify that 455 # to the caller by returning None. 456 return None 457 458 # Only test this now, so that we can create variables for the 459 # default. However, return None to signify that we're not going 460 # to actually do the assignment statement for InitVars. 461 if f._field_type is _FIELD_INITVAR: 462 return None 463 464 # Now, actually generate the field assignment. 465 return _field_assign(frozen, f.name, value, self_name) 466 467 468def _init_param(f): 469 # Return the __init__ parameter string for this field. For 470 # example, the equivalent of 'x:int=3' (except instead of 'int', 471 # reference a variable set to int, and instead of '3', reference a 472 # variable set to 3). 473 if f.default is MISSING and f.default_factory is MISSING: 474 # There's no default, and no default_factory, just output the 475 # variable name and type. 476 default = '' 477 elif f.default is not MISSING: 478 # There's a default, this will be the name that's used to look 479 # it up. 480 default = f'=_dflt_{f.name}' 481 elif f.default_factory is not MISSING: 482 # There's a factory function. Set a marker. 483 default = '=_HAS_DEFAULT_FACTORY' 484 return f'{f.name}:_type_{f.name}{default}' 485 486 487def _init_fn(fields, frozen, has_post_init, self_name, globals): 488 # fields contains both real fields and InitVar pseudo-fields. 489 490 # Make sure we don't have fields without defaults following fields 491 # with defaults. This actually would be caught when exec-ing the 492 # function source code, but catching it here gives a better error 493 # message, and future-proofs us in case we build up the function 494 # using ast. 495 seen_default = False 496 for f in fields: 497 # Only consider fields in the __init__ call. 498 if f.init: 499 if not (f.default is MISSING and f.default_factory is MISSING): 500 seen_default = True 501 elif seen_default: 502 raise TypeError(f'non-default argument {f.name!r} ' 503 'follows default argument') 504 505 locals = {f'_type_{f.name}': f.type for f in fields} 506 locals.update({ 507 'MISSING': MISSING, 508 '_HAS_DEFAULT_FACTORY': _HAS_DEFAULT_FACTORY, 509 }) 510 511 body_lines = [] 512 for f in fields: 513 line = _field_init(f, frozen, locals, self_name) 514 # line is None means that this field doesn't require 515 # initialization (it's a pseudo-field). Just skip it. 516 if line: 517 body_lines.append(line) 518 519 # Does this class have a post-init function? 520 if has_post_init: 521 params_str = ','.join(f.name for f in fields 522 if f._field_type is _FIELD_INITVAR) 523 body_lines.append(f'{self_name}.{_POST_INIT_NAME}({params_str})') 524 525 # If no body lines, use 'pass'. 526 if not body_lines: 527 body_lines = ['pass'] 528 529 return _create_fn('__init__', 530 [self_name] + [_init_param(f) for f in fields if f.init], 531 body_lines, 532 locals=locals, 533 globals=globals, 534 return_type=None) 535 536 537def _repr_fn(fields, globals): 538 fn = _create_fn('__repr__', 539 ('self',), 540 ['return self.__class__.__qualname__ + f"(' + 541 ', '.join([f"{f.name}={{self.{f.name}!r}}" 542 for f in fields]) + 543 ')"'], 544 globals=globals) 545 return _recursive_repr(fn) 546 547 548def _frozen_get_del_attr(cls, fields, globals): 549 locals = {'cls': cls, 550 'FrozenInstanceError': FrozenInstanceError} 551 if fields: 552 fields_str = '(' + ','.join(repr(f.name) for f in fields) + ',)' 553 else: 554 # Special case for the zero-length tuple. 555 fields_str = '()' 556 return (_create_fn('__setattr__', 557 ('self', 'name', 'value'), 558 (f'if type(self) is cls or name in {fields_str}:', 559 ' raise FrozenInstanceError(f"cannot assign to field {name!r}")', 560 f'super(cls, self).__setattr__(name, value)'), 561 locals=locals, 562 globals=globals), 563 _create_fn('__delattr__', 564 ('self', 'name'), 565 (f'if type(self) is cls or name in {fields_str}:', 566 ' raise FrozenInstanceError(f"cannot delete field {name!r}")', 567 f'super(cls, self).__delattr__(name)'), 568 locals=locals, 569 globals=globals), 570 ) 571 572 573def _cmp_fn(name, op, self_tuple, other_tuple, globals): 574 # Create a comparison function. If the fields in the object are 575 # named 'x' and 'y', then self_tuple is the string 576 # '(self.x,self.y)' and other_tuple is the string 577 # '(other.x,other.y)'. 578 579 return _create_fn(name, 580 ('self', 'other'), 581 [ 'if other.__class__ is self.__class__:', 582 f' return {self_tuple}{op}{other_tuple}', 583 'return NotImplemented'], 584 globals=globals) 585 586 587def _hash_fn(fields, globals): 588 self_tuple = _tuple_str('self', fields) 589 return _create_fn('__hash__', 590 ('self',), 591 [f'return hash({self_tuple})'], 592 globals=globals) 593 594 595def _is_classvar(a_type, typing): 596 # This test uses a typing internal class, but it's the best way to 597 # test if this is a ClassVar. 598 return (a_type is typing.ClassVar 599 or (type(a_type) is typing._GenericAlias 600 and a_type.__origin__ is typing.ClassVar)) 601 602 603def _is_initvar(a_type, dataclasses): 604 # The module we're checking against is the module we're 605 # currently in (dataclasses.py). 606 return (a_type is dataclasses.InitVar 607 or type(a_type) is dataclasses.InitVar) 608 609 610def _is_type(annotation, cls, a_module, a_type, is_type_predicate): 611 # Given a type annotation string, does it refer to a_type in 612 # a_module? For example, when checking that annotation denotes a 613 # ClassVar, then a_module is typing, and a_type is 614 # typing.ClassVar. 615 616 # It's possible to look up a_module given a_type, but it involves 617 # looking in sys.modules (again!), and seems like a waste since 618 # the caller already knows a_module. 619 620 # - annotation is a string type annotation 621 # - cls is the class that this annotation was found in 622 # - a_module is the module we want to match 623 # - a_type is the type in that module we want to match 624 # - is_type_predicate is a function called with (obj, a_module) 625 # that determines if obj is of the desired type. 626 627 # Since this test does not do a local namespace lookup (and 628 # instead only a module (global) lookup), there are some things it 629 # gets wrong. 630 631 # With string annotations, cv0 will be detected as a ClassVar: 632 # CV = ClassVar 633 # @dataclass 634 # class C0: 635 # cv0: CV 636 637 # But in this example cv1 will not be detected as a ClassVar: 638 # @dataclass 639 # class C1: 640 # CV = ClassVar 641 # cv1: CV 642 643 # In C1, the code in this function (_is_type) will look up "CV" in 644 # the module and not find it, so it will not consider cv1 as a 645 # ClassVar. This is a fairly obscure corner case, and the best 646 # way to fix it would be to eval() the string "CV" with the 647 # correct global and local namespaces. However that would involve 648 # a eval() penalty for every single field of every dataclass 649 # that's defined. It was judged not worth it. 650 651 match = _MODULE_IDENTIFIER_RE.match(annotation) 652 if match: 653 ns = None 654 module_name = match.group(1) 655 if not module_name: 656 # No module name, assume the class's module did 657 # "from dataclasses import InitVar". 658 ns = sys.modules.get(cls.__module__).__dict__ 659 else: 660 # Look up module_name in the class's module. 661 module = sys.modules.get(cls.__module__) 662 if module and module.__dict__.get(module_name) is a_module: 663 ns = sys.modules.get(a_type.__module__).__dict__ 664 if ns and is_type_predicate(ns.get(match.group(2)), a_module): 665 return True 666 return False 667 668 669def _get_field(cls, a_name, a_type): 670 # Return a Field object for this field name and type. ClassVars 671 # and InitVars are also returned, but marked as such (see 672 # f._field_type). 673 674 # If the default value isn't derived from Field, then it's only a 675 # normal default value. Convert it to a Field(). 676 default = getattr(cls, a_name, MISSING) 677 if isinstance(default, Field): 678 f = default 679 else: 680 if isinstance(default, types.MemberDescriptorType): 681 # This is a field in __slots__, so it has no default value. 682 default = MISSING 683 f = field(default=default) 684 685 # Only at this point do we know the name and the type. Set them. 686 f.name = a_name 687 f.type = a_type 688 689 # Assume it's a normal field until proven otherwise. We're next 690 # going to decide if it's a ClassVar or InitVar, everything else 691 # is just a normal field. 692 f._field_type = _FIELD 693 694 # In addition to checking for actual types here, also check for 695 # string annotations. get_type_hints() won't always work for us 696 # (see https://github.com/python/typing/issues/508 for example), 697 # plus it's expensive and would require an eval for every stirng 698 # annotation. So, make a best effort to see if this is a ClassVar 699 # or InitVar using regex's and checking that the thing referenced 700 # is actually of the correct type. 701 702 # For the complete discussion, see https://bugs.python.org/issue33453 703 704 # If typing has not been imported, then it's impossible for any 705 # annotation to be a ClassVar. So, only look for ClassVar if 706 # typing has been imported by any module (not necessarily cls's 707 # module). 708 typing = sys.modules.get('typing') 709 if typing: 710 if (_is_classvar(a_type, typing) 711 or (isinstance(f.type, str) 712 and _is_type(f.type, cls, typing, typing.ClassVar, 713 _is_classvar))): 714 f._field_type = _FIELD_CLASSVAR 715 716 # If the type is InitVar, or if it's a matching string annotation, 717 # then it's an InitVar. 718 if f._field_type is _FIELD: 719 # The module we're checking against is the module we're 720 # currently in (dataclasses.py). 721 dataclasses = sys.modules[__name__] 722 if (_is_initvar(a_type, dataclasses) 723 or (isinstance(f.type, str) 724 and _is_type(f.type, cls, dataclasses, dataclasses.InitVar, 725 _is_initvar))): 726 f._field_type = _FIELD_INITVAR 727 728 # Validations for individual fields. This is delayed until now, 729 # instead of in the Field() constructor, since only here do we 730 # know the field name, which allows for better error reporting. 731 732 # Special restrictions for ClassVar and InitVar. 733 if f._field_type in (_FIELD_CLASSVAR, _FIELD_INITVAR): 734 if f.default_factory is not MISSING: 735 raise TypeError(f'field {f.name} cannot have a ' 736 'default factory') 737 # Should I check for other field settings? default_factory 738 # seems the most serious to check for. Maybe add others. For 739 # example, how about init=False (or really, 740 # init=<not-the-default-init-value>)? It makes no sense for 741 # ClassVar and InitVar to specify init=<anything>. 742 743 # For real fields, disallow mutable defaults for known types. 744 if f._field_type is _FIELD and isinstance(f.default, (list, dict, set)): 745 raise ValueError(f'mutable default {type(f.default)} for field ' 746 f'{f.name} is not allowed: use default_factory') 747 748 return f 749 750 751def _set_new_attribute(cls, name, value): 752 # Never overwrites an existing attribute. Returns True if the 753 # attribute already exists. 754 if name in cls.__dict__: 755 return True 756 setattr(cls, name, value) 757 return False 758 759 760# Decide if/how we're going to create a hash function. Key is 761# (unsafe_hash, eq, frozen, does-hash-exist). Value is the action to 762# take. The common case is to do nothing, so instead of providing a 763# function that is a no-op, use None to signify that. 764 765def _hash_set_none(cls, fields, globals): 766 return None 767 768def _hash_add(cls, fields, globals): 769 flds = [f for f in fields if (f.compare if f.hash is None else f.hash)] 770 return _hash_fn(flds, globals) 771 772def _hash_exception(cls, fields, globals): 773 # Raise an exception. 774 raise TypeError(f'Cannot overwrite attribute __hash__ ' 775 f'in class {cls.__name__}') 776 777# 778# +-------------------------------------- unsafe_hash? 779# | +------------------------------- eq? 780# | | +------------------------ frozen? 781# | | | +---------------- has-explicit-hash? 782# | | | | 783# | | | | +------- action 784# | | | | | 785# v v v v v 786_hash_action = {(False, False, False, False): None, 787 (False, False, False, True ): None, 788 (False, False, True, False): None, 789 (False, False, True, True ): None, 790 (False, True, False, False): _hash_set_none, 791 (False, True, False, True ): None, 792 (False, True, True, False): _hash_add, 793 (False, True, True, True ): None, 794 (True, False, False, False): _hash_add, 795 (True, False, False, True ): _hash_exception, 796 (True, False, True, False): _hash_add, 797 (True, False, True, True ): _hash_exception, 798 (True, True, False, False): _hash_add, 799 (True, True, False, True ): _hash_exception, 800 (True, True, True, False): _hash_add, 801 (True, True, True, True ): _hash_exception, 802 } 803# See https://bugs.python.org/issue32929#msg312829 for an if-statement 804# version of this table. 805 806 807def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen): 808 # Now that dicts retain insertion order, there's no reason to use 809 # an ordered dict. I am leveraging that ordering here, because 810 # derived class fields overwrite base class fields, but the order 811 # is defined by the base class, which is found first. 812 fields = {} 813 814 if cls.__module__ in sys.modules: 815 globals = sys.modules[cls.__module__].__dict__ 816 else: 817 # Theoretically this can happen if someone writes 818 # a custom string to cls.__module__. In which case 819 # such dataclass won't be fully introspectable 820 # (w.r.t. typing.get_type_hints) but will still function 821 # correctly. 822 globals = {} 823 824 setattr(cls, _PARAMS, _DataclassParams(init, repr, eq, order, 825 unsafe_hash, frozen)) 826 827 # Find our base classes in reverse MRO order, and exclude 828 # ourselves. In reversed order so that more derived classes 829 # override earlier field definitions in base classes. As long as 830 # we're iterating over them, see if any are frozen. 831 any_frozen_base = False 832 has_dataclass_bases = False 833 for b in cls.__mro__[-1:0:-1]: 834 # Only process classes that have been processed by our 835 # decorator. That is, they have a _FIELDS attribute. 836 base_fields = getattr(b, _FIELDS, None) 837 if base_fields: 838 has_dataclass_bases = True 839 for f in base_fields.values(): 840 fields[f.name] = f 841 if getattr(b, _PARAMS).frozen: 842 any_frozen_base = True 843 844 # Annotations that are defined in this class (not in base 845 # classes). If __annotations__ isn't present, then this class 846 # adds no new annotations. We use this to compute fields that are 847 # added by this class. 848 # 849 # Fields are found from cls_annotations, which is guaranteed to be 850 # ordered. Default values are from class attributes, if a field 851 # has a default. If the default value is a Field(), then it 852 # contains additional info beyond (and possibly including) the 853 # actual default value. Pseudo-fields ClassVars and InitVars are 854 # included, despite the fact that they're not real fields. That's 855 # dealt with later. 856 cls_annotations = cls.__dict__.get('__annotations__', {}) 857 858 # Now find fields in our class. While doing so, validate some 859 # things, and set the default values (as class attributes) where 860 # we can. 861 cls_fields = [_get_field(cls, name, type) 862 for name, type in cls_annotations.items()] 863 for f in cls_fields: 864 fields[f.name] = f 865 866 # If the class attribute (which is the default value for this 867 # field) exists and is of type 'Field', replace it with the 868 # real default. This is so that normal class introspection 869 # sees a real default value, not a Field. 870 if isinstance(getattr(cls, f.name, None), Field): 871 if f.default is MISSING: 872 # If there's no default, delete the class attribute. 873 # This happens if we specify field(repr=False), for 874 # example (that is, we specified a field object, but 875 # no default value). Also if we're using a default 876 # factory. The class attribute should not be set at 877 # all in the post-processed class. 878 delattr(cls, f.name) 879 else: 880 setattr(cls, f.name, f.default) 881 882 # Do we have any Field members that don't also have annotations? 883 for name, value in cls.__dict__.items(): 884 if isinstance(value, Field) and not name in cls_annotations: 885 raise TypeError(f'{name!r} is a field but has no type annotation') 886 887 # Check rules that apply if we are derived from any dataclasses. 888 if has_dataclass_bases: 889 # Raise an exception if any of our bases are frozen, but we're not. 890 if any_frozen_base and not frozen: 891 raise TypeError('cannot inherit non-frozen dataclass from a ' 892 'frozen one') 893 894 # Raise an exception if we're frozen, but none of our bases are. 895 if not any_frozen_base and frozen: 896 raise TypeError('cannot inherit frozen dataclass from a ' 897 'non-frozen one') 898 899 # Remember all of the fields on our class (including bases). This 900 # also marks this class as being a dataclass. 901 setattr(cls, _FIELDS, fields) 902 903 # Was this class defined with an explicit __hash__? Note that if 904 # __eq__ is defined in this class, then python will automatically 905 # set __hash__ to None. This is a heuristic, as it's possible 906 # that such a __hash__ == None was not auto-generated, but it 907 # close enough. 908 class_hash = cls.__dict__.get('__hash__', MISSING) 909 has_explicit_hash = not (class_hash is MISSING or 910 (class_hash is None and '__eq__' in cls.__dict__)) 911 912 # If we're generating ordering methods, we must be generating the 913 # eq methods. 914 if order and not eq: 915 raise ValueError('eq must be true if order is true') 916 917 if init: 918 # Does this class have a post-init function? 919 has_post_init = hasattr(cls, _POST_INIT_NAME) 920 921 # Include InitVars and regular fields (so, not ClassVars). 922 flds = [f for f in fields.values() 923 if f._field_type in (_FIELD, _FIELD_INITVAR)] 924 _set_new_attribute(cls, '__init__', 925 _init_fn(flds, 926 frozen, 927 has_post_init, 928 # The name to use for the "self" 929 # param in __init__. Use "self" 930 # if possible. 931 '__dataclass_self__' if 'self' in fields 932 else 'self', 933 globals, 934 )) 935 936 # Get the fields as a list, and include only real fields. This is 937 # used in all of the following methods. 938 field_list = [f for f in fields.values() if f._field_type is _FIELD] 939 940 if repr: 941 flds = [f for f in field_list if f.repr] 942 _set_new_attribute(cls, '__repr__', _repr_fn(flds, globals)) 943 944 if eq: 945 # Create _eq__ method. There's no need for a __ne__ method, 946 # since python will call __eq__ and negate it. 947 flds = [f for f in field_list if f.compare] 948 self_tuple = _tuple_str('self', flds) 949 other_tuple = _tuple_str('other', flds) 950 _set_new_attribute(cls, '__eq__', 951 _cmp_fn('__eq__', '==', 952 self_tuple, other_tuple, 953 globals=globals)) 954 955 if order: 956 # Create and set the ordering methods. 957 flds = [f for f in field_list if f.compare] 958 self_tuple = _tuple_str('self', flds) 959 other_tuple = _tuple_str('other', flds) 960 for name, op in [('__lt__', '<'), 961 ('__le__', '<='), 962 ('__gt__', '>'), 963 ('__ge__', '>='), 964 ]: 965 if _set_new_attribute(cls, name, 966 _cmp_fn(name, op, self_tuple, other_tuple, 967 globals=globals)): 968 raise TypeError(f'Cannot overwrite attribute {name} ' 969 f'in class {cls.__name__}. Consider using ' 970 'functools.total_ordering') 971 972 if frozen: 973 for fn in _frozen_get_del_attr(cls, field_list, globals): 974 if _set_new_attribute(cls, fn.__name__, fn): 975 raise TypeError(f'Cannot overwrite attribute {fn.__name__} ' 976 f'in class {cls.__name__}') 977 978 # Decide if/how we're going to create a hash function. 979 hash_action = _hash_action[bool(unsafe_hash), 980 bool(eq), 981 bool(frozen), 982 has_explicit_hash] 983 if hash_action: 984 # No need to call _set_new_attribute here, since by the time 985 # we're here the overwriting is unconditional. 986 cls.__hash__ = hash_action(cls, field_list, globals) 987 988 if not getattr(cls, '__doc__'): 989 # Create a class doc-string. 990 cls.__doc__ = (cls.__name__ + 991 str(inspect.signature(cls)).replace(' -> None', '')) 992 993 return cls 994 995 996def dataclass(cls=None, /, *, init=True, repr=True, eq=True, order=False, 997 unsafe_hash=False, frozen=False): 998 """Returns the same class as was passed in, with dunder methods 999 added based on the fields defined in the class. 1000 1001 Examines PEP 526 __annotations__ to determine fields. 1002 1003 If init is true, an __init__() method is added to the class. If 1004 repr is true, a __repr__() method is added. If order is true, rich 1005 comparison dunder methods are added. If unsafe_hash is true, a 1006 __hash__() method function is added. If frozen is true, fields may 1007 not be assigned to after instance creation. 1008 """ 1009 1010 def wrap(cls): 1011 return _process_class(cls, init, repr, eq, order, unsafe_hash, frozen) 1012 1013 # See if we're being called as @dataclass or @dataclass(). 1014 if cls is None: 1015 # We're called with parens. 1016 return wrap 1017 1018 # We're called as @dataclass without parens. 1019 return wrap(cls) 1020 1021 1022def fields(class_or_instance): 1023 """Return a tuple describing the fields of this dataclass. 1024 1025 Accepts a dataclass or an instance of one. Tuple elements are of 1026 type Field. 1027 """ 1028 1029 # Might it be worth caching this, per class? 1030 try: 1031 fields = getattr(class_or_instance, _FIELDS) 1032 except AttributeError: 1033 raise TypeError('must be called with a dataclass type or instance') 1034 1035 # Exclude pseudo-fields. Note that fields is sorted by insertion 1036 # order, so the order of the tuple is as the fields were defined. 1037 return tuple(f for f in fields.values() if f._field_type is _FIELD) 1038 1039 1040def _is_dataclass_instance(obj): 1041 """Returns True if obj is an instance of a dataclass.""" 1042 return hasattr(type(obj), _FIELDS) 1043 1044 1045def is_dataclass(obj): 1046 """Returns True if obj is a dataclass or an instance of a 1047 dataclass.""" 1048 cls = obj if isinstance(obj, type) else type(obj) 1049 return hasattr(cls, _FIELDS) 1050 1051 1052def asdict(obj, *, dict_factory=dict): 1053 """Return the fields of a dataclass instance as a new dictionary mapping 1054 field names to field values. 1055 1056 Example usage: 1057 1058 @dataclass 1059 class C: 1060 x: int 1061 y: int 1062 1063 c = C(1, 2) 1064 assert asdict(c) == {'x': 1, 'y': 2} 1065 1066 If given, 'dict_factory' will be used instead of built-in dict. 1067 The function applies recursively to field values that are 1068 dataclass instances. This will also look into built-in containers: 1069 tuples, lists, and dicts. 1070 """ 1071 if not _is_dataclass_instance(obj): 1072 raise TypeError("asdict() should be called on dataclass instances") 1073 return _asdict_inner(obj, dict_factory) 1074 1075 1076def _asdict_inner(obj, dict_factory): 1077 if _is_dataclass_instance(obj): 1078 result = [] 1079 for f in fields(obj): 1080 value = _asdict_inner(getattr(obj, f.name), dict_factory) 1081 result.append((f.name, value)) 1082 return dict_factory(result) 1083 elif isinstance(obj, tuple) and hasattr(obj, '_fields'): 1084 # obj is a namedtuple. Recurse into it, but the returned 1085 # object is another namedtuple of the same type. This is 1086 # similar to how other list- or tuple-derived classes are 1087 # treated (see below), but we just need to create them 1088 # differently because a namedtuple's __init__ needs to be 1089 # called differently (see bpo-34363). 1090 1091 # I'm not using namedtuple's _asdict() 1092 # method, because: 1093 # - it does not recurse in to the namedtuple fields and 1094 # convert them to dicts (using dict_factory). 1095 # - I don't actually want to return a dict here. The the main 1096 # use case here is json.dumps, and it handles converting 1097 # namedtuples to lists. Admittedly we're losing some 1098 # information here when we produce a json list instead of a 1099 # dict. Note that if we returned dicts here instead of 1100 # namedtuples, we could no longer call asdict() on a data 1101 # structure where a namedtuple was used as a dict key. 1102 1103 return type(obj)(*[_asdict_inner(v, dict_factory) for v in obj]) 1104 elif isinstance(obj, (list, tuple)): 1105 # Assume we can create an object of this type by passing in a 1106 # generator (which is not true for namedtuples, handled 1107 # above). 1108 return type(obj)(_asdict_inner(v, dict_factory) for v in obj) 1109 elif isinstance(obj, dict): 1110 return type(obj)((_asdict_inner(k, dict_factory), 1111 _asdict_inner(v, dict_factory)) 1112 for k, v in obj.items()) 1113 else: 1114 return copy.deepcopy(obj) 1115 1116 1117def astuple(obj, *, tuple_factory=tuple): 1118 """Return the fields of a dataclass instance as a new tuple of field values. 1119 1120 Example usage:: 1121 1122 @dataclass 1123 class C: 1124 x: int 1125 y: int 1126 1127 c = C(1, 2) 1128 assert astuple(c) == (1, 2) 1129 1130 If given, 'tuple_factory' will be used instead of built-in tuple. 1131 The function applies recursively to field values that are 1132 dataclass instances. This will also look into built-in containers: 1133 tuples, lists, and dicts. 1134 """ 1135 1136 if not _is_dataclass_instance(obj): 1137 raise TypeError("astuple() should be called on dataclass instances") 1138 return _astuple_inner(obj, tuple_factory) 1139 1140 1141def _astuple_inner(obj, tuple_factory): 1142 if _is_dataclass_instance(obj): 1143 result = [] 1144 for f in fields(obj): 1145 value = _astuple_inner(getattr(obj, f.name), tuple_factory) 1146 result.append(value) 1147 return tuple_factory(result) 1148 elif isinstance(obj, tuple) and hasattr(obj, '_fields'): 1149 # obj is a namedtuple. Recurse into it, but the returned 1150 # object is another namedtuple of the same type. This is 1151 # similar to how other list- or tuple-derived classes are 1152 # treated (see below), but we just need to create them 1153 # differently because a namedtuple's __init__ needs to be 1154 # called differently (see bpo-34363). 1155 return type(obj)(*[_astuple_inner(v, tuple_factory) for v in obj]) 1156 elif isinstance(obj, (list, tuple)): 1157 # Assume we can create an object of this type by passing in a 1158 # generator (which is not true for namedtuples, handled 1159 # above). 1160 return type(obj)(_astuple_inner(v, tuple_factory) for v in obj) 1161 elif isinstance(obj, dict): 1162 return type(obj)((_astuple_inner(k, tuple_factory), _astuple_inner(v, tuple_factory)) 1163 for k, v in obj.items()) 1164 else: 1165 return copy.deepcopy(obj) 1166 1167 1168def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True, 1169 repr=True, eq=True, order=False, unsafe_hash=False, 1170 frozen=False): 1171 """Return a new dynamically created dataclass. 1172 1173 The dataclass name will be 'cls_name'. 'fields' is an iterable 1174 of either (name), (name, type) or (name, type, Field) objects. If type is 1175 omitted, use the string 'typing.Any'. Field objects are created by 1176 the equivalent of calling 'field(name, type [, Field-info])'. 1177 1178 C = make_dataclass('C', ['x', ('y', int), ('z', int, field(init=False))], bases=(Base,)) 1179 1180 is equivalent to: 1181 1182 @dataclass 1183 class C(Base): 1184 x: 'typing.Any' 1185 y: int 1186 z: int = field(init=False) 1187 1188 For the bases and namespace parameters, see the builtin type() function. 1189 1190 The parameters init, repr, eq, order, unsafe_hash, and frozen are passed to 1191 dataclass(). 1192 """ 1193 1194 if namespace is None: 1195 namespace = {} 1196 else: 1197 # Copy namespace since we're going to mutate it. 1198 namespace = namespace.copy() 1199 1200 # While we're looking through the field names, validate that they 1201 # are identifiers, are not keywords, and not duplicates. 1202 seen = set() 1203 anns = {} 1204 for item in fields: 1205 if isinstance(item, str): 1206 name = item 1207 tp = 'typing.Any' 1208 elif len(item) == 2: 1209 name, tp, = item 1210 elif len(item) == 3: 1211 name, tp, spec = item 1212 namespace[name] = spec 1213 else: 1214 raise TypeError(f'Invalid field: {item!r}') 1215 1216 if not isinstance(name, str) or not name.isidentifier(): 1217 raise TypeError(f'Field names must be valid identifiers: {name!r}') 1218 if keyword.iskeyword(name): 1219 raise TypeError(f'Field names must not be keywords: {name!r}') 1220 if name in seen: 1221 raise TypeError(f'Field name duplicated: {name!r}') 1222 1223 seen.add(name) 1224 anns[name] = tp 1225 1226 namespace['__annotations__'] = anns 1227 # We use `types.new_class()` instead of simply `type()` to allow dynamic creation 1228 # of generic dataclassses. 1229 cls = types.new_class(cls_name, bases, {}, lambda ns: ns.update(namespace)) 1230 return dataclass(cls, init=init, repr=repr, eq=eq, order=order, 1231 unsafe_hash=unsafe_hash, frozen=frozen) 1232 1233 1234def replace(*args, **changes): 1235 """Return a new object replacing specified fields with new values. 1236 1237 This is especially useful for frozen classes. Example usage: 1238 1239 @dataclass(frozen=True) 1240 class C: 1241 x: int 1242 y: int 1243 1244 c = C(1, 2) 1245 c1 = replace(c, x=3) 1246 assert c1.x == 3 and c1.y == 2 1247 """ 1248 if len(args) > 1: 1249 raise TypeError(f'replace() takes 1 positional argument but {len(args)} were given') 1250 if args: 1251 obj, = args 1252 elif 'obj' in changes: 1253 obj = changes.pop('obj') 1254 import warnings 1255 warnings.warn("Passing 'obj' as keyword argument is deprecated", 1256 DeprecationWarning, stacklevel=2) 1257 else: 1258 raise TypeError("replace() missing 1 required positional argument: 'obj'") 1259 1260 # We're going to mutate 'changes', but that's okay because it's a 1261 # new dict, even if called with 'replace(obj, **my_changes)'. 1262 1263 if not _is_dataclass_instance(obj): 1264 raise TypeError("replace() should be called on dataclass instances") 1265 1266 # It's an error to have init=False fields in 'changes'. 1267 # If a field is not in 'changes', read its value from the provided obj. 1268 1269 for f in getattr(obj, _FIELDS).values(): 1270 # Only consider normal fields or InitVars. 1271 if f._field_type is _FIELD_CLASSVAR: 1272 continue 1273 1274 if not f.init: 1275 # Error if this field is specified in changes. 1276 if f.name in changes: 1277 raise ValueError(f'field {f.name} is declared with ' 1278 'init=False, it cannot be specified with ' 1279 'replace()') 1280 continue 1281 1282 if f.name not in changes: 1283 if f._field_type is _FIELD_INITVAR: 1284 raise ValueError(f"InitVar {f.name!r} " 1285 'must be specified with replace()') 1286 changes[f.name] = getattr(obj, f.name) 1287 1288 # Create the new object, which calls __init__() and 1289 # __post_init__() (if defined), using all of the init fields we've 1290 # added and/or left in 'changes'. If there are values supplied in 1291 # changes that aren't fields, this will correctly raise a 1292 # TypeError. 1293 return obj.__class__(**changes) 1294replace.__text_signature__ = '(obj, /, **kwargs)' 1295