• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import re
2import sys
3import copy
4import types
5import inspect
6import keyword
7import builtins
8import functools
9import _thread
10
11
12__all__ = ['dataclass',
13           'field',
14           'Field',
15           'FrozenInstanceError',
16           'InitVar',
17           'MISSING',
18
19           # Helper functions.
20           'fields',
21           'asdict',
22           'astuple',
23           'make_dataclass',
24           'replace',
25           'is_dataclass',
26           ]
27
28# Conditions for adding methods.  The boxes indicate what action the
29# dataclass decorator takes.  For all of these tables, when I talk
30# about init=, repr=, eq=, order=, unsafe_hash=, or frozen=, I'm
31# referring to the arguments to the @dataclass decorator.  When
32# checking if a dunder method already exists, I mean check for an
33# entry in the class's __dict__.  I never check to see if an attribute
34# is defined in a base class.
35
36# Key:
37# +=========+=========================================+
38# + Value   | Meaning                                 |
39# +=========+=========================================+
40# | <blank> | No action: no method is added.          |
41# +---------+-----------------------------------------+
42# | add     | Generated method is added.              |
43# +---------+-----------------------------------------+
44# | raise   | TypeError is raised.                    |
45# +---------+-----------------------------------------+
46# | None    | Attribute is set to None.               |
47# +=========+=========================================+
48
49# __init__
50#
51#   +--- init= parameter
52#   |
53#   v     |       |       |
54#         |  no   |  yes  |  <--- class has __init__ in __dict__?
55# +=======+=======+=======+
56# | False |       |       |
57# +-------+-------+-------+
58# | True  | add   |       |  <- the default
59# +=======+=======+=======+
60
61# __repr__
62#
63#    +--- repr= parameter
64#    |
65#    v    |       |       |
66#         |  no   |  yes  |  <--- class has __repr__ in __dict__?
67# +=======+=======+=======+
68# | False |       |       |
69# +-------+-------+-------+
70# | True  | add   |       |  <- the default
71# +=======+=======+=======+
72
73
74# __setattr__
75# __delattr__
76#
77#    +--- frozen= parameter
78#    |
79#    v    |       |       |
80#         |  no   |  yes  |  <--- class has __setattr__ or __delattr__ in __dict__?
81# +=======+=======+=======+
82# | False |       |       |  <- the default
83# +-------+-------+-------+
84# | True  | add   | raise |
85# +=======+=======+=======+
86# Raise because not adding these methods would break the "frozen-ness"
87# of the class.
88
89# __eq__
90#
91#    +--- eq= parameter
92#    |
93#    v    |       |       |
94#         |  no   |  yes  |  <--- class has __eq__ in __dict__?
95# +=======+=======+=======+
96# | False |       |       |
97# +-------+-------+-------+
98# | True  | add   |       |  <- the default
99# +=======+=======+=======+
100
101# __lt__
102# __le__
103# __gt__
104# __ge__
105#
106#    +--- order= parameter
107#    |
108#    v    |       |       |
109#         |  no   |  yes  |  <--- class has any comparison method in __dict__?
110# +=======+=======+=======+
111# | False |       |       |  <- the default
112# +-------+-------+-------+
113# | True  | add   | raise |
114# +=======+=======+=======+
115# Raise because to allow this case would interfere with using
116# functools.total_ordering.
117
118# __hash__
119
120#    +------------------- unsafe_hash= parameter
121#    |       +----------- eq= parameter
122#    |       |       +--- frozen= parameter
123#    |       |       |
124#    v       v       v    |        |        |
125#                         |   no   |  yes   |  <--- class has explicitly defined __hash__
126# +=======+=======+=======+========+========+
127# | False | False | False |        |        | No __eq__, use the base class __hash__
128# +-------+-------+-------+--------+--------+
129# | False | False | True  |        |        | No __eq__, use the base class __hash__
130# +-------+-------+-------+--------+--------+
131# | False | True  | False | None   |        | <-- the default, not hashable
132# +-------+-------+-------+--------+--------+
133# | False | True  | True  | add    |        | Frozen, so hashable, allows override
134# +-------+-------+-------+--------+--------+
135# | True  | False | False | add    | raise  | Has no __eq__, but hashable
136# +-------+-------+-------+--------+--------+
137# | True  | False | True  | add    | raise  | Has no __eq__, but hashable
138# +-------+-------+-------+--------+--------+
139# | True  | True  | False | add    | raise  | Not frozen, but hashable
140# +-------+-------+-------+--------+--------+
141# | True  | True  | True  | add    | raise  | Frozen, so hashable
142# +=======+=======+=======+========+========+
143# For boxes that are blank, __hash__ is untouched and therefore
144# inherited from the base class.  If the base is object, then
145# id-based hashing is used.
146#
147# Note that a class may already have __hash__=None if it specified an
148# __eq__ method in the class body (not one that was created by
149# @dataclass).
150#
151# See _hash_action (below) for a coded version of this table.
152
153
154# Raised when an attempt is made to modify a frozen class.
155class FrozenInstanceError(AttributeError): pass
156
157# A sentinel object for default values to signal that a default
158# factory will be used.  This is given a nice repr() which will appear
159# in the function signature of dataclasses' constructors.
160class _HAS_DEFAULT_FACTORY_CLASS:
161    def __repr__(self):
162        return '<factory>'
163_HAS_DEFAULT_FACTORY = _HAS_DEFAULT_FACTORY_CLASS()
164
165# A sentinel object to detect if a parameter is supplied or not.  Use
166# a class to give it a better repr.
167class _MISSING_TYPE:
168    pass
169MISSING = _MISSING_TYPE()
170
171# Since most per-field metadata will be unused, create an empty
172# read-only proxy that can be shared among all fields.
173_EMPTY_METADATA = types.MappingProxyType({})
174
175# Markers for the various kinds of fields and pseudo-fields.
176class _FIELD_BASE:
177    def __init__(self, name):
178        self.name = name
179    def __repr__(self):
180        return self.name
181_FIELD = _FIELD_BASE('_FIELD')
182_FIELD_CLASSVAR = _FIELD_BASE('_FIELD_CLASSVAR')
183_FIELD_INITVAR = _FIELD_BASE('_FIELD_INITVAR')
184
185# The name of an attribute on the class where we store the Field
186# objects.  Also used to check if a class is a Data Class.
187_FIELDS = '__dataclass_fields__'
188
189# The name of an attribute on the class that stores the parameters to
190# @dataclass.
191_PARAMS = '__dataclass_params__'
192
193# The name of the function, that if it exists, is called at the end of
194# __init__.
195_POST_INIT_NAME = '__post_init__'
196
197# String regex that string annotations for ClassVar or InitVar must match.
198# Allows "identifier.identifier[" or "identifier[".
199# https://bugs.python.org/issue33453 for details.
200_MODULE_IDENTIFIER_RE = re.compile(r'^(?:\s*(\w+)\s*\.)?\s*(\w+)')
201
202class _InitVarMeta(type):
203    def __getitem__(self, params):
204        return InitVar(params)
205
206class InitVar(metaclass=_InitVarMeta):
207    __slots__ = ('type', )
208
209    def __init__(self, type):
210        self.type = type
211
212    def __repr__(self):
213        if isinstance(self.type, type):
214            type_name = self.type.__name__
215        else:
216            # typing objects, e.g. List[int]
217            type_name = repr(self.type)
218        return f'dataclasses.InitVar[{type_name}]'
219
220
221# Instances of Field are only ever created from within this module,
222# and only from the field() function, although Field instances are
223# exposed externally as (conceptually) read-only objects.
224#
225# name and type are filled in after the fact, not in __init__.
226# They're not known at the time this class is instantiated, but it's
227# convenient if they're available later.
228#
229# When cls._FIELDS is filled in with a list of Field objects, the name
230# and type fields will have been populated.
231class Field:
232    __slots__ = ('name',
233                 'type',
234                 'default',
235                 'default_factory',
236                 'repr',
237                 'hash',
238                 'init',
239                 'compare',
240                 'metadata',
241                 '_field_type',  # Private: not to be used by user code.
242                 )
243
244    def __init__(self, default, default_factory, init, repr, hash, compare,
245                 metadata):
246        self.name = None
247        self.type = None
248        self.default = default
249        self.default_factory = default_factory
250        self.init = init
251        self.repr = repr
252        self.hash = hash
253        self.compare = compare
254        self.metadata = (_EMPTY_METADATA
255                         if metadata is None else
256                         types.MappingProxyType(metadata))
257        self._field_type = None
258
259    def __repr__(self):
260        return ('Field('
261                f'name={self.name!r},'
262                f'type={self.type!r},'
263                f'default={self.default!r},'
264                f'default_factory={self.default_factory!r},'
265                f'init={self.init!r},'
266                f'repr={self.repr!r},'
267                f'hash={self.hash!r},'
268                f'compare={self.compare!r},'
269                f'metadata={self.metadata!r},'
270                f'_field_type={self._field_type}'
271                ')')
272
273    # This is used to support the PEP 487 __set_name__ protocol in the
274    # case where we're using a field that contains a descriptor as a
275    # default value.  For details on __set_name__, see
276    # https://www.python.org/dev/peps/pep-0487/#implementation-details.
277    #
278    # Note that in _process_class, this Field object is overwritten
279    # with the default value, so the end result is a descriptor that
280    # had __set_name__ called on it at the right time.
281    def __set_name__(self, owner, name):
282        func = getattr(type(self.default), '__set_name__', None)
283        if func:
284            # There is a __set_name__ method on the descriptor, call
285            # it.
286            func(self.default, owner, name)
287
288
289class _DataclassParams:
290    __slots__ = ('init',
291                 'repr',
292                 'eq',
293                 'order',
294                 'unsafe_hash',
295                 'frozen',
296                 )
297
298    def __init__(self, init, repr, eq, order, unsafe_hash, frozen):
299        self.init = init
300        self.repr = repr
301        self.eq = eq
302        self.order = order
303        self.unsafe_hash = unsafe_hash
304        self.frozen = frozen
305
306    def __repr__(self):
307        return ('_DataclassParams('
308                f'init={self.init!r},'
309                f'repr={self.repr!r},'
310                f'eq={self.eq!r},'
311                f'order={self.order!r},'
312                f'unsafe_hash={self.unsafe_hash!r},'
313                f'frozen={self.frozen!r}'
314                ')')
315
316
317# This function is used instead of exposing Field creation directly,
318# so that a type checker can be told (via overloads) that this is a
319# function whose type depends on its parameters.
320def field(*, default=MISSING, default_factory=MISSING, init=True, repr=True,
321          hash=None, compare=True, metadata=None):
322    """Return an object to identify dataclass fields.
323
324    default is the default value of the field.  default_factory is a
325    0-argument function called to initialize a field's value.  If init
326    is True, the field will be a parameter to the class's __init__()
327    function.  If repr is True, the field will be included in the
328    object's repr().  If hash is True, the field will be included in
329    the object's hash().  If compare is True, the field will be used
330    in comparison functions.  metadata, if specified, must be a
331    mapping which is stored but not otherwise examined by dataclass.
332
333    It is an error to specify both default and default_factory.
334    """
335
336    if default is not MISSING and default_factory is not MISSING:
337        raise ValueError('cannot specify both default and default_factory')
338    return Field(default, default_factory, init, repr, hash, compare,
339                 metadata)
340
341
342def _tuple_str(obj_name, fields):
343    # Return a string representing each field of obj_name as a tuple
344    # member.  So, if fields is ['x', 'y'] and obj_name is "self",
345    # return "(self.x,self.y)".
346
347    # Special case for the 0-tuple.
348    if not fields:
349        return '()'
350    # Note the trailing comma, needed if this turns out to be a 1-tuple.
351    return f'({",".join([f"{obj_name}.{f.name}" for f in fields])},)'
352
353
354# This function's logic is copied from "recursive_repr" function in
355# reprlib module to avoid dependency.
356def _recursive_repr(user_function):
357    # Decorator to make a repr function return "..." for a recursive
358    # call.
359    repr_running = set()
360
361    @functools.wraps(user_function)
362    def wrapper(self):
363        key = id(self), _thread.get_ident()
364        if key in repr_running:
365            return '...'
366        repr_running.add(key)
367        try:
368            result = user_function(self)
369        finally:
370            repr_running.discard(key)
371        return result
372    return wrapper
373
374
375def _create_fn(name, args, body, *, globals=None, locals=None,
376               return_type=MISSING):
377    # Note that we mutate locals when exec() is called.  Caller
378    # beware!  The only callers are internal to this module, so no
379    # worries about external callers.
380    if locals is None:
381        locals = {}
382    if 'BUILTINS' not in locals:
383        locals['BUILTINS'] = builtins
384    return_annotation = ''
385    if return_type is not MISSING:
386        locals['_return_type'] = return_type
387        return_annotation = '->_return_type'
388    args = ','.join(args)
389    body = '\n'.join(f'  {b}' for b in body)
390
391    # Compute the text of the entire function.
392    txt = f' def {name}({args}){return_annotation}:\n{body}'
393
394    local_vars = ', '.join(locals.keys())
395    txt = f"def __create_fn__({local_vars}):\n{txt}\n return {name}"
396
397    ns = {}
398    exec(txt, globals, ns)
399    return ns['__create_fn__'](**locals)
400
401
402def _field_assign(frozen, name, value, self_name):
403    # If we're a frozen class, then assign to our fields in __init__
404    # via object.__setattr__.  Otherwise, just use a simple
405    # assignment.
406    #
407    # self_name is what "self" is called in this function: don't
408    # hard-code "self", since that might be a field name.
409    if frozen:
410        return f'BUILTINS.object.__setattr__({self_name},{name!r},{value})'
411    return f'{self_name}.{name}={value}'
412
413
414def _field_init(f, frozen, globals, self_name):
415    # Return the text of the line in the body of __init__ that will
416    # initialize this field.
417
418    default_name = f'_dflt_{f.name}'
419    if f.default_factory is not MISSING:
420        if f.init:
421            # This field has a default factory.  If a parameter is
422            # given, use it.  If not, call the factory.
423            globals[default_name] = f.default_factory
424            value = (f'{default_name}() '
425                     f'if {f.name} is _HAS_DEFAULT_FACTORY '
426                     f'else {f.name}')
427        else:
428            # This is a field that's not in the __init__ params, but
429            # has a default factory function.  It needs to be
430            # initialized here by calling the factory function,
431            # because there's no other way to initialize it.
432
433            # For a field initialized with a default=defaultvalue, the
434            # class dict just has the default value
435            # (cls.fieldname=defaultvalue).  But that won't work for a
436            # default factory, the factory must be called in __init__
437            # and we must assign that to self.fieldname.  We can't
438            # fall back to the class dict's value, both because it's
439            # not set, and because it might be different per-class
440            # (which, after all, is why we have a factory function!).
441
442            globals[default_name] = f.default_factory
443            value = f'{default_name}()'
444    else:
445        # No default factory.
446        if f.init:
447            if f.default is MISSING:
448                # There's no default, just do an assignment.
449                value = f.name
450            elif f.default is not MISSING:
451                globals[default_name] = f.default
452                value = f.name
453        else:
454            # This field does not need initialization.  Signify that
455            # to the caller by returning None.
456            return None
457
458    # Only test this now, so that we can create variables for the
459    # default.  However, return None to signify that we're not going
460    # to actually do the assignment statement for InitVars.
461    if f._field_type is _FIELD_INITVAR:
462        return None
463
464    # Now, actually generate the field assignment.
465    return _field_assign(frozen, f.name, value, self_name)
466
467
468def _init_param(f):
469    # Return the __init__ parameter string for this field.  For
470    # example, the equivalent of 'x:int=3' (except instead of 'int',
471    # reference a variable set to int, and instead of '3', reference a
472    # variable set to 3).
473    if f.default is MISSING and f.default_factory is MISSING:
474        # There's no default, and no default_factory, just output the
475        # variable name and type.
476        default = ''
477    elif f.default is not MISSING:
478        # There's a default, this will be the name that's used to look
479        # it up.
480        default = f'=_dflt_{f.name}'
481    elif f.default_factory is not MISSING:
482        # There's a factory function.  Set a marker.
483        default = '=_HAS_DEFAULT_FACTORY'
484    return f'{f.name}:_type_{f.name}{default}'
485
486
487def _init_fn(fields, frozen, has_post_init, self_name, globals):
488    # fields contains both real fields and InitVar pseudo-fields.
489
490    # Make sure we don't have fields without defaults following fields
491    # with defaults.  This actually would be caught when exec-ing the
492    # function source code, but catching it here gives a better error
493    # message, and future-proofs us in case we build up the function
494    # using ast.
495    seen_default = False
496    for f in fields:
497        # Only consider fields in the __init__ call.
498        if f.init:
499            if not (f.default is MISSING and f.default_factory is MISSING):
500                seen_default = True
501            elif seen_default:
502                raise TypeError(f'non-default argument {f.name!r} '
503                                'follows default argument')
504
505    locals = {f'_type_{f.name}': f.type for f in fields}
506    locals.update({
507        'MISSING': MISSING,
508        '_HAS_DEFAULT_FACTORY': _HAS_DEFAULT_FACTORY,
509    })
510
511    body_lines = []
512    for f in fields:
513        line = _field_init(f, frozen, locals, self_name)
514        # line is None means that this field doesn't require
515        # initialization (it's a pseudo-field).  Just skip it.
516        if line:
517            body_lines.append(line)
518
519    # Does this class have a post-init function?
520    if has_post_init:
521        params_str = ','.join(f.name for f in fields
522                              if f._field_type is _FIELD_INITVAR)
523        body_lines.append(f'{self_name}.{_POST_INIT_NAME}({params_str})')
524
525    # If no body lines, use 'pass'.
526    if not body_lines:
527        body_lines = ['pass']
528
529    return _create_fn('__init__',
530                      [self_name] + [_init_param(f) for f in fields if f.init],
531                      body_lines,
532                      locals=locals,
533                      globals=globals,
534                      return_type=None)
535
536
537def _repr_fn(fields, globals):
538    fn = _create_fn('__repr__',
539                    ('self',),
540                    ['return self.__class__.__qualname__ + f"(' +
541                     ', '.join([f"{f.name}={{self.{f.name}!r}}"
542                                for f in fields]) +
543                     ')"'],
544                     globals=globals)
545    return _recursive_repr(fn)
546
547
548def _frozen_get_del_attr(cls, fields, globals):
549    locals = {'cls': cls,
550              'FrozenInstanceError': FrozenInstanceError}
551    if fields:
552        fields_str = '(' + ','.join(repr(f.name) for f in fields) + ',)'
553    else:
554        # Special case for the zero-length tuple.
555        fields_str = '()'
556    return (_create_fn('__setattr__',
557                      ('self', 'name', 'value'),
558                      (f'if type(self) is cls or name in {fields_str}:',
559                        ' raise FrozenInstanceError(f"cannot assign to field {name!r}")',
560                       f'super(cls, self).__setattr__(name, value)'),
561                       locals=locals,
562                       globals=globals),
563            _create_fn('__delattr__',
564                      ('self', 'name'),
565                      (f'if type(self) is cls or name in {fields_str}:',
566                        ' raise FrozenInstanceError(f"cannot delete field {name!r}")',
567                       f'super(cls, self).__delattr__(name)'),
568                       locals=locals,
569                       globals=globals),
570            )
571
572
573def _cmp_fn(name, op, self_tuple, other_tuple, globals):
574    # Create a comparison function.  If the fields in the object are
575    # named 'x' and 'y', then self_tuple is the string
576    # '(self.x,self.y)' and other_tuple is the string
577    # '(other.x,other.y)'.
578
579    return _create_fn(name,
580                      ('self', 'other'),
581                      [ 'if other.__class__ is self.__class__:',
582                       f' return {self_tuple}{op}{other_tuple}',
583                        'return NotImplemented'],
584                      globals=globals)
585
586
587def _hash_fn(fields, globals):
588    self_tuple = _tuple_str('self', fields)
589    return _create_fn('__hash__',
590                      ('self',),
591                      [f'return hash({self_tuple})'],
592                      globals=globals)
593
594
595def _is_classvar(a_type, typing):
596    # This test uses a typing internal class, but it's the best way to
597    # test if this is a ClassVar.
598    return (a_type is typing.ClassVar
599            or (type(a_type) is typing._GenericAlias
600                and a_type.__origin__ is typing.ClassVar))
601
602
603def _is_initvar(a_type, dataclasses):
604    # The module we're checking against is the module we're
605    # currently in (dataclasses.py).
606    return (a_type is dataclasses.InitVar
607            or type(a_type) is dataclasses.InitVar)
608
609
610def _is_type(annotation, cls, a_module, a_type, is_type_predicate):
611    # Given a type annotation string, does it refer to a_type in
612    # a_module?  For example, when checking that annotation denotes a
613    # ClassVar, then a_module is typing, and a_type is
614    # typing.ClassVar.
615
616    # It's possible to look up a_module given a_type, but it involves
617    # looking in sys.modules (again!), and seems like a waste since
618    # the caller already knows a_module.
619
620    # - annotation is a string type annotation
621    # - cls is the class that this annotation was found in
622    # - a_module is the module we want to match
623    # - a_type is the type in that module we want to match
624    # - is_type_predicate is a function called with (obj, a_module)
625    #   that determines if obj is of the desired type.
626
627    # Since this test does not do a local namespace lookup (and
628    # instead only a module (global) lookup), there are some things it
629    # gets wrong.
630
631    # With string annotations, cv0 will be detected as a ClassVar:
632    #   CV = ClassVar
633    #   @dataclass
634    #   class C0:
635    #     cv0: CV
636
637    # But in this example cv1 will not be detected as a ClassVar:
638    #   @dataclass
639    #   class C1:
640    #     CV = ClassVar
641    #     cv1: CV
642
643    # In C1, the code in this function (_is_type) will look up "CV" in
644    # the module and not find it, so it will not consider cv1 as a
645    # ClassVar.  This is a fairly obscure corner case, and the best
646    # way to fix it would be to eval() the string "CV" with the
647    # correct global and local namespaces.  However that would involve
648    # a eval() penalty for every single field of every dataclass
649    # that's defined.  It was judged not worth it.
650
651    match = _MODULE_IDENTIFIER_RE.match(annotation)
652    if match:
653        ns = None
654        module_name = match.group(1)
655        if not module_name:
656            # No module name, assume the class's module did
657            # "from dataclasses import InitVar".
658            ns = sys.modules.get(cls.__module__).__dict__
659        else:
660            # Look up module_name in the class's module.
661            module = sys.modules.get(cls.__module__)
662            if module and module.__dict__.get(module_name) is a_module:
663                ns = sys.modules.get(a_type.__module__).__dict__
664        if ns and is_type_predicate(ns.get(match.group(2)), a_module):
665            return True
666    return False
667
668
669def _get_field(cls, a_name, a_type):
670    # Return a Field object for this field name and type.  ClassVars
671    # and InitVars are also returned, but marked as such (see
672    # f._field_type).
673
674    # If the default value isn't derived from Field, then it's only a
675    # normal default value.  Convert it to a Field().
676    default = getattr(cls, a_name, MISSING)
677    if isinstance(default, Field):
678        f = default
679    else:
680        if isinstance(default, types.MemberDescriptorType):
681            # This is a field in __slots__, so it has no default value.
682            default = MISSING
683        f = field(default=default)
684
685    # Only at this point do we know the name and the type.  Set them.
686    f.name = a_name
687    f.type = a_type
688
689    # Assume it's a normal field until proven otherwise.  We're next
690    # going to decide if it's a ClassVar or InitVar, everything else
691    # is just a normal field.
692    f._field_type = _FIELD
693
694    # In addition to checking for actual types here, also check for
695    # string annotations.  get_type_hints() won't always work for us
696    # (see https://github.com/python/typing/issues/508 for example),
697    # plus it's expensive and would require an eval for every stirng
698    # annotation.  So, make a best effort to see if this is a ClassVar
699    # or InitVar using regex's and checking that the thing referenced
700    # is actually of the correct type.
701
702    # For the complete discussion, see https://bugs.python.org/issue33453
703
704    # If typing has not been imported, then it's impossible for any
705    # annotation to be a ClassVar.  So, only look for ClassVar if
706    # typing has been imported by any module (not necessarily cls's
707    # module).
708    typing = sys.modules.get('typing')
709    if typing:
710        if (_is_classvar(a_type, typing)
711            or (isinstance(f.type, str)
712                and _is_type(f.type, cls, typing, typing.ClassVar,
713                             _is_classvar))):
714            f._field_type = _FIELD_CLASSVAR
715
716    # If the type is InitVar, or if it's a matching string annotation,
717    # then it's an InitVar.
718    if f._field_type is _FIELD:
719        # The module we're checking against is the module we're
720        # currently in (dataclasses.py).
721        dataclasses = sys.modules[__name__]
722        if (_is_initvar(a_type, dataclasses)
723            or (isinstance(f.type, str)
724                and _is_type(f.type, cls, dataclasses, dataclasses.InitVar,
725                             _is_initvar))):
726            f._field_type = _FIELD_INITVAR
727
728    # Validations for individual fields.  This is delayed until now,
729    # instead of in the Field() constructor, since only here do we
730    # know the field name, which allows for better error reporting.
731
732    # Special restrictions for ClassVar and InitVar.
733    if f._field_type in (_FIELD_CLASSVAR, _FIELD_INITVAR):
734        if f.default_factory is not MISSING:
735            raise TypeError(f'field {f.name} cannot have a '
736                            'default factory')
737        # Should I check for other field settings? default_factory
738        # seems the most serious to check for.  Maybe add others.  For
739        # example, how about init=False (or really,
740        # init=<not-the-default-init-value>)?  It makes no sense for
741        # ClassVar and InitVar to specify init=<anything>.
742
743    # For real fields, disallow mutable defaults for known types.
744    if f._field_type is _FIELD and isinstance(f.default, (list, dict, set)):
745        raise ValueError(f'mutable default {type(f.default)} for field '
746                         f'{f.name} is not allowed: use default_factory')
747
748    return f
749
750
751def _set_new_attribute(cls, name, value):
752    # Never overwrites an existing attribute.  Returns True if the
753    # attribute already exists.
754    if name in cls.__dict__:
755        return True
756    setattr(cls, name, value)
757    return False
758
759
760# Decide if/how we're going to create a hash function.  Key is
761# (unsafe_hash, eq, frozen, does-hash-exist).  Value is the action to
762# take.  The common case is to do nothing, so instead of providing a
763# function that is a no-op, use None to signify that.
764
765def _hash_set_none(cls, fields, globals):
766    return None
767
768def _hash_add(cls, fields, globals):
769    flds = [f for f in fields if (f.compare if f.hash is None else f.hash)]
770    return _hash_fn(flds, globals)
771
772def _hash_exception(cls, fields, globals):
773    # Raise an exception.
774    raise TypeError(f'Cannot overwrite attribute __hash__ '
775                    f'in class {cls.__name__}')
776
777#
778#                +-------------------------------------- unsafe_hash?
779#                |      +------------------------------- eq?
780#                |      |      +------------------------ frozen?
781#                |      |      |      +----------------  has-explicit-hash?
782#                |      |      |      |
783#                |      |      |      |        +-------  action
784#                |      |      |      |        |
785#                v      v      v      v        v
786_hash_action = {(False, False, False, False): None,
787                (False, False, False, True ): None,
788                (False, False, True,  False): None,
789                (False, False, True,  True ): None,
790                (False, True,  False, False): _hash_set_none,
791                (False, True,  False, True ): None,
792                (False, True,  True,  False): _hash_add,
793                (False, True,  True,  True ): None,
794                (True,  False, False, False): _hash_add,
795                (True,  False, False, True ): _hash_exception,
796                (True,  False, True,  False): _hash_add,
797                (True,  False, True,  True ): _hash_exception,
798                (True,  True,  False, False): _hash_add,
799                (True,  True,  False, True ): _hash_exception,
800                (True,  True,  True,  False): _hash_add,
801                (True,  True,  True,  True ): _hash_exception,
802                }
803# See https://bugs.python.org/issue32929#msg312829 for an if-statement
804# version of this table.
805
806
807def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
808    # Now that dicts retain insertion order, there's no reason to use
809    # an ordered dict.  I am leveraging that ordering here, because
810    # derived class fields overwrite base class fields, but the order
811    # is defined by the base class, which is found first.
812    fields = {}
813
814    if cls.__module__ in sys.modules:
815        globals = sys.modules[cls.__module__].__dict__
816    else:
817        # Theoretically this can happen if someone writes
818        # a custom string to cls.__module__.  In which case
819        # such dataclass won't be fully introspectable
820        # (w.r.t. typing.get_type_hints) but will still function
821        # correctly.
822        globals = {}
823
824    setattr(cls, _PARAMS, _DataclassParams(init, repr, eq, order,
825                                           unsafe_hash, frozen))
826
827    # Find our base classes in reverse MRO order, and exclude
828    # ourselves.  In reversed order so that more derived classes
829    # override earlier field definitions in base classes.  As long as
830    # we're iterating over them, see if any are frozen.
831    any_frozen_base = False
832    has_dataclass_bases = False
833    for b in cls.__mro__[-1:0:-1]:
834        # Only process classes that have been processed by our
835        # decorator.  That is, they have a _FIELDS attribute.
836        base_fields = getattr(b, _FIELDS, None)
837        if base_fields:
838            has_dataclass_bases = True
839            for f in base_fields.values():
840                fields[f.name] = f
841            if getattr(b, _PARAMS).frozen:
842                any_frozen_base = True
843
844    # Annotations that are defined in this class (not in base
845    # classes).  If __annotations__ isn't present, then this class
846    # adds no new annotations.  We use this to compute fields that are
847    # added by this class.
848    #
849    # Fields are found from cls_annotations, which is guaranteed to be
850    # ordered.  Default values are from class attributes, if a field
851    # has a default.  If the default value is a Field(), then it
852    # contains additional info beyond (and possibly including) the
853    # actual default value.  Pseudo-fields ClassVars and InitVars are
854    # included, despite the fact that they're not real fields.  That's
855    # dealt with later.
856    cls_annotations = cls.__dict__.get('__annotations__', {})
857
858    # Now find fields in our class.  While doing so, validate some
859    # things, and set the default values (as class attributes) where
860    # we can.
861    cls_fields = [_get_field(cls, name, type)
862                  for name, type in cls_annotations.items()]
863    for f in cls_fields:
864        fields[f.name] = f
865
866        # If the class attribute (which is the default value for this
867        # field) exists and is of type 'Field', replace it with the
868        # real default.  This is so that normal class introspection
869        # sees a real default value, not a Field.
870        if isinstance(getattr(cls, f.name, None), Field):
871            if f.default is MISSING:
872                # If there's no default, delete the class attribute.
873                # This happens if we specify field(repr=False), for
874                # example (that is, we specified a field object, but
875                # no default value).  Also if we're using a default
876                # factory.  The class attribute should not be set at
877                # all in the post-processed class.
878                delattr(cls, f.name)
879            else:
880                setattr(cls, f.name, f.default)
881
882    # Do we have any Field members that don't also have annotations?
883    for name, value in cls.__dict__.items():
884        if isinstance(value, Field) and not name in cls_annotations:
885            raise TypeError(f'{name!r} is a field but has no type annotation')
886
887    # Check rules that apply if we are derived from any dataclasses.
888    if has_dataclass_bases:
889        # Raise an exception if any of our bases are frozen, but we're not.
890        if any_frozen_base and not frozen:
891            raise TypeError('cannot inherit non-frozen dataclass from a '
892                            'frozen one')
893
894        # Raise an exception if we're frozen, but none of our bases are.
895        if not any_frozen_base and frozen:
896            raise TypeError('cannot inherit frozen dataclass from a '
897                            'non-frozen one')
898
899    # Remember all of the fields on our class (including bases).  This
900    # also marks this class as being a dataclass.
901    setattr(cls, _FIELDS, fields)
902
903    # Was this class defined with an explicit __hash__?  Note that if
904    # __eq__ is defined in this class, then python will automatically
905    # set __hash__ to None.  This is a heuristic, as it's possible
906    # that such a __hash__ == None was not auto-generated, but it
907    # close enough.
908    class_hash = cls.__dict__.get('__hash__', MISSING)
909    has_explicit_hash = not (class_hash is MISSING or
910                             (class_hash is None and '__eq__' in cls.__dict__))
911
912    # If we're generating ordering methods, we must be generating the
913    # eq methods.
914    if order and not eq:
915        raise ValueError('eq must be true if order is true')
916
917    if init:
918        # Does this class have a post-init function?
919        has_post_init = hasattr(cls, _POST_INIT_NAME)
920
921        # Include InitVars and regular fields (so, not ClassVars).
922        flds = [f for f in fields.values()
923                if f._field_type in (_FIELD, _FIELD_INITVAR)]
924        _set_new_attribute(cls, '__init__',
925                           _init_fn(flds,
926                                    frozen,
927                                    has_post_init,
928                                    # The name to use for the "self"
929                                    # param in __init__.  Use "self"
930                                    # if possible.
931                                    '__dataclass_self__' if 'self' in fields
932                                            else 'self',
933                                    globals,
934                          ))
935
936    # Get the fields as a list, and include only real fields.  This is
937    # used in all of the following methods.
938    field_list = [f for f in fields.values() if f._field_type is _FIELD]
939
940    if repr:
941        flds = [f for f in field_list if f.repr]
942        _set_new_attribute(cls, '__repr__', _repr_fn(flds, globals))
943
944    if eq:
945        # Create _eq__ method.  There's no need for a __ne__ method,
946        # since python will call __eq__ and negate it.
947        flds = [f for f in field_list if f.compare]
948        self_tuple = _tuple_str('self', flds)
949        other_tuple = _tuple_str('other', flds)
950        _set_new_attribute(cls, '__eq__',
951                           _cmp_fn('__eq__', '==',
952                                   self_tuple, other_tuple,
953                                   globals=globals))
954
955    if order:
956        # Create and set the ordering methods.
957        flds = [f for f in field_list if f.compare]
958        self_tuple = _tuple_str('self', flds)
959        other_tuple = _tuple_str('other', flds)
960        for name, op in [('__lt__', '<'),
961                         ('__le__', '<='),
962                         ('__gt__', '>'),
963                         ('__ge__', '>='),
964                         ]:
965            if _set_new_attribute(cls, name,
966                                  _cmp_fn(name, op, self_tuple, other_tuple,
967                                          globals=globals)):
968                raise TypeError(f'Cannot overwrite attribute {name} '
969                                f'in class {cls.__name__}. Consider using '
970                                'functools.total_ordering')
971
972    if frozen:
973        for fn in _frozen_get_del_attr(cls, field_list, globals):
974            if _set_new_attribute(cls, fn.__name__, fn):
975                raise TypeError(f'Cannot overwrite attribute {fn.__name__} '
976                                f'in class {cls.__name__}')
977
978    # Decide if/how we're going to create a hash function.
979    hash_action = _hash_action[bool(unsafe_hash),
980                               bool(eq),
981                               bool(frozen),
982                               has_explicit_hash]
983    if hash_action:
984        # No need to call _set_new_attribute here, since by the time
985        # we're here the overwriting is unconditional.
986        cls.__hash__ = hash_action(cls, field_list, globals)
987
988    if not getattr(cls, '__doc__'):
989        # Create a class doc-string.
990        cls.__doc__ = (cls.__name__ +
991                       str(inspect.signature(cls)).replace(' -> None', ''))
992
993    return cls
994
995
996def dataclass(cls=None, /, *, init=True, repr=True, eq=True, order=False,
997              unsafe_hash=False, frozen=False):
998    """Returns the same class as was passed in, with dunder methods
999    added based on the fields defined in the class.
1000
1001    Examines PEP 526 __annotations__ to determine fields.
1002
1003    If init is true, an __init__() method is added to the class. If
1004    repr is true, a __repr__() method is added. If order is true, rich
1005    comparison dunder methods are added. If unsafe_hash is true, a
1006    __hash__() method function is added. If frozen is true, fields may
1007    not be assigned to after instance creation.
1008    """
1009
1010    def wrap(cls):
1011        return _process_class(cls, init, repr, eq, order, unsafe_hash, frozen)
1012
1013    # See if we're being called as @dataclass or @dataclass().
1014    if cls is None:
1015        # We're called with parens.
1016        return wrap
1017
1018    # We're called as @dataclass without parens.
1019    return wrap(cls)
1020
1021
1022def fields(class_or_instance):
1023    """Return a tuple describing the fields of this dataclass.
1024
1025    Accepts a dataclass or an instance of one. Tuple elements are of
1026    type Field.
1027    """
1028
1029    # Might it be worth caching this, per class?
1030    try:
1031        fields = getattr(class_or_instance, _FIELDS)
1032    except AttributeError:
1033        raise TypeError('must be called with a dataclass type or instance')
1034
1035    # Exclude pseudo-fields.  Note that fields is sorted by insertion
1036    # order, so the order of the tuple is as the fields were defined.
1037    return tuple(f for f in fields.values() if f._field_type is _FIELD)
1038
1039
1040def _is_dataclass_instance(obj):
1041    """Returns True if obj is an instance of a dataclass."""
1042    return hasattr(type(obj), _FIELDS)
1043
1044
1045def is_dataclass(obj):
1046    """Returns True if obj is a dataclass or an instance of a
1047    dataclass."""
1048    cls = obj if isinstance(obj, type) else type(obj)
1049    return hasattr(cls, _FIELDS)
1050
1051
1052def asdict(obj, *, dict_factory=dict):
1053    """Return the fields of a dataclass instance as a new dictionary mapping
1054    field names to field values.
1055
1056    Example usage:
1057
1058      @dataclass
1059      class C:
1060          x: int
1061          y: int
1062
1063      c = C(1, 2)
1064      assert asdict(c) == {'x': 1, 'y': 2}
1065
1066    If given, 'dict_factory' will be used instead of built-in dict.
1067    The function applies recursively to field values that are
1068    dataclass instances. This will also look into built-in containers:
1069    tuples, lists, and dicts.
1070    """
1071    if not _is_dataclass_instance(obj):
1072        raise TypeError("asdict() should be called on dataclass instances")
1073    return _asdict_inner(obj, dict_factory)
1074
1075
1076def _asdict_inner(obj, dict_factory):
1077    if _is_dataclass_instance(obj):
1078        result = []
1079        for f in fields(obj):
1080            value = _asdict_inner(getattr(obj, f.name), dict_factory)
1081            result.append((f.name, value))
1082        return dict_factory(result)
1083    elif isinstance(obj, tuple) and hasattr(obj, '_fields'):
1084        # obj is a namedtuple.  Recurse into it, but the returned
1085        # object is another namedtuple of the same type.  This is
1086        # similar to how other list- or tuple-derived classes are
1087        # treated (see below), but we just need to create them
1088        # differently because a namedtuple's __init__ needs to be
1089        # called differently (see bpo-34363).
1090
1091        # I'm not using namedtuple's _asdict()
1092        # method, because:
1093        # - it does not recurse in to the namedtuple fields and
1094        #   convert them to dicts (using dict_factory).
1095        # - I don't actually want to return a dict here.  The the main
1096        #   use case here is json.dumps, and it handles converting
1097        #   namedtuples to lists.  Admittedly we're losing some
1098        #   information here when we produce a json list instead of a
1099        #   dict.  Note that if we returned dicts here instead of
1100        #   namedtuples, we could no longer call asdict() on a data
1101        #   structure where a namedtuple was used as a dict key.
1102
1103        return type(obj)(*[_asdict_inner(v, dict_factory) for v in obj])
1104    elif isinstance(obj, (list, tuple)):
1105        # Assume we can create an object of this type by passing in a
1106        # generator (which is not true for namedtuples, handled
1107        # above).
1108        return type(obj)(_asdict_inner(v, dict_factory) for v in obj)
1109    elif isinstance(obj, dict):
1110        return type(obj)((_asdict_inner(k, dict_factory),
1111                          _asdict_inner(v, dict_factory))
1112                         for k, v in obj.items())
1113    else:
1114        return copy.deepcopy(obj)
1115
1116
1117def astuple(obj, *, tuple_factory=tuple):
1118    """Return the fields of a dataclass instance as a new tuple of field values.
1119
1120    Example usage::
1121
1122      @dataclass
1123      class C:
1124          x: int
1125          y: int
1126
1127    c = C(1, 2)
1128    assert astuple(c) == (1, 2)
1129
1130    If given, 'tuple_factory' will be used instead of built-in tuple.
1131    The function applies recursively to field values that are
1132    dataclass instances. This will also look into built-in containers:
1133    tuples, lists, and dicts.
1134    """
1135
1136    if not _is_dataclass_instance(obj):
1137        raise TypeError("astuple() should be called on dataclass instances")
1138    return _astuple_inner(obj, tuple_factory)
1139
1140
1141def _astuple_inner(obj, tuple_factory):
1142    if _is_dataclass_instance(obj):
1143        result = []
1144        for f in fields(obj):
1145            value = _astuple_inner(getattr(obj, f.name), tuple_factory)
1146            result.append(value)
1147        return tuple_factory(result)
1148    elif isinstance(obj, tuple) and hasattr(obj, '_fields'):
1149        # obj is a namedtuple.  Recurse into it, but the returned
1150        # object is another namedtuple of the same type.  This is
1151        # similar to how other list- or tuple-derived classes are
1152        # treated (see below), but we just need to create them
1153        # differently because a namedtuple's __init__ needs to be
1154        # called differently (see bpo-34363).
1155        return type(obj)(*[_astuple_inner(v, tuple_factory) for v in obj])
1156    elif isinstance(obj, (list, tuple)):
1157        # Assume we can create an object of this type by passing in a
1158        # generator (which is not true for namedtuples, handled
1159        # above).
1160        return type(obj)(_astuple_inner(v, tuple_factory) for v in obj)
1161    elif isinstance(obj, dict):
1162        return type(obj)((_astuple_inner(k, tuple_factory), _astuple_inner(v, tuple_factory))
1163                          for k, v in obj.items())
1164    else:
1165        return copy.deepcopy(obj)
1166
1167
1168def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
1169                   repr=True, eq=True, order=False, unsafe_hash=False,
1170                   frozen=False):
1171    """Return a new dynamically created dataclass.
1172
1173    The dataclass name will be 'cls_name'.  'fields' is an iterable
1174    of either (name), (name, type) or (name, type, Field) objects. If type is
1175    omitted, use the string 'typing.Any'.  Field objects are created by
1176    the equivalent of calling 'field(name, type [, Field-info])'.
1177
1178      C = make_dataclass('C', ['x', ('y', int), ('z', int, field(init=False))], bases=(Base,))
1179
1180    is equivalent to:
1181
1182      @dataclass
1183      class C(Base):
1184          x: 'typing.Any'
1185          y: int
1186          z: int = field(init=False)
1187
1188    For the bases and namespace parameters, see the builtin type() function.
1189
1190    The parameters init, repr, eq, order, unsafe_hash, and frozen are passed to
1191    dataclass().
1192    """
1193
1194    if namespace is None:
1195        namespace = {}
1196    else:
1197        # Copy namespace since we're going to mutate it.
1198        namespace = namespace.copy()
1199
1200    # While we're looking through the field names, validate that they
1201    # are identifiers, are not keywords, and not duplicates.
1202    seen = set()
1203    anns = {}
1204    for item in fields:
1205        if isinstance(item, str):
1206            name = item
1207            tp = 'typing.Any'
1208        elif len(item) == 2:
1209            name, tp, = item
1210        elif len(item) == 3:
1211            name, tp, spec = item
1212            namespace[name] = spec
1213        else:
1214            raise TypeError(f'Invalid field: {item!r}')
1215
1216        if not isinstance(name, str) or not name.isidentifier():
1217            raise TypeError(f'Field names must be valid identifiers: {name!r}')
1218        if keyword.iskeyword(name):
1219            raise TypeError(f'Field names must not be keywords: {name!r}')
1220        if name in seen:
1221            raise TypeError(f'Field name duplicated: {name!r}')
1222
1223        seen.add(name)
1224        anns[name] = tp
1225
1226    namespace['__annotations__'] = anns
1227    # We use `types.new_class()` instead of simply `type()` to allow dynamic creation
1228    # of generic dataclassses.
1229    cls = types.new_class(cls_name, bases, {}, lambda ns: ns.update(namespace))
1230    return dataclass(cls, init=init, repr=repr, eq=eq, order=order,
1231                     unsafe_hash=unsafe_hash, frozen=frozen)
1232
1233
1234def replace(*args, **changes):
1235    """Return a new object replacing specified fields with new values.
1236
1237    This is especially useful for frozen classes.  Example usage:
1238
1239      @dataclass(frozen=True)
1240      class C:
1241          x: int
1242          y: int
1243
1244      c = C(1, 2)
1245      c1 = replace(c, x=3)
1246      assert c1.x == 3 and c1.y == 2
1247      """
1248    if len(args) > 1:
1249        raise TypeError(f'replace() takes 1 positional argument but {len(args)} were given')
1250    if args:
1251        obj, = args
1252    elif 'obj' in changes:
1253        obj = changes.pop('obj')
1254        import warnings
1255        warnings.warn("Passing 'obj' as keyword argument is deprecated",
1256                      DeprecationWarning, stacklevel=2)
1257    else:
1258        raise TypeError("replace() missing 1 required positional argument: 'obj'")
1259
1260    # We're going to mutate 'changes', but that's okay because it's a
1261    # new dict, even if called with 'replace(obj, **my_changes)'.
1262
1263    if not _is_dataclass_instance(obj):
1264        raise TypeError("replace() should be called on dataclass instances")
1265
1266    # It's an error to have init=False fields in 'changes'.
1267    # If a field is not in 'changes', read its value from the provided obj.
1268
1269    for f in getattr(obj, _FIELDS).values():
1270        # Only consider normal fields or InitVars.
1271        if f._field_type is _FIELD_CLASSVAR:
1272            continue
1273
1274        if not f.init:
1275            # Error if this field is specified in changes.
1276            if f.name in changes:
1277                raise ValueError(f'field {f.name} is declared with '
1278                                 'init=False, it cannot be specified with '
1279                                 'replace()')
1280            continue
1281
1282        if f.name not in changes:
1283            if f._field_type is _FIELD_INITVAR:
1284                raise ValueError(f"InitVar {f.name!r} "
1285                                 'must be specified with replace()')
1286            changes[f.name] = getattr(obj, f.name)
1287
1288    # Create the new object, which calls __init__() and
1289    # __post_init__() (if defined), using all of the init fields we've
1290    # added and/or left in 'changes'.  If there are values supplied in
1291    # changes that aren't fields, this will correctly raise a
1292    # TypeError.
1293    return obj.__class__(**changes)
1294replace.__text_signature__ = '(obj, /, **kwargs)'
1295