• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import re
2import sys
3import copy
4import types
5import inspect
6import keyword
7import itertools
8import abc
9from reprlib import recursive_repr
10
11
12__all__ = ['dataclass',
13           'field',
14           'Field',
15           'FrozenInstanceError',
16           'InitVar',
17           'KW_ONLY',
18           'MISSING',
19
20           # Helper functions.
21           'fields',
22           'asdict',
23           'astuple',
24           'make_dataclass',
25           'replace',
26           'is_dataclass',
27           ]
28
29# Conditions for adding methods.  The boxes indicate what action the
30# dataclass decorator takes.  For all of these tables, when I talk
31# about init=, repr=, eq=, order=, unsafe_hash=, or frozen=, I'm
32# referring to the arguments to the @dataclass decorator.  When
33# checking if a dunder method already exists, I mean check for an
34# entry in the class's __dict__.  I never check to see if an attribute
35# is defined in a base class.
36
37# Key:
38# +=========+=========================================+
39# + Value   | Meaning                                 |
40# +=========+=========================================+
41# | <blank> | No action: no method is added.          |
42# +---------+-----------------------------------------+
43# | add     | Generated method is added.              |
44# +---------+-----------------------------------------+
45# | raise   | TypeError is raised.                    |
46# +---------+-----------------------------------------+
47# | None    | Attribute is set to None.               |
48# +=========+=========================================+
49
50# __init__
51#
52#   +--- init= parameter
53#   |
54#   v     |       |       |
55#         |  no   |  yes  |  <--- class has __init__ in __dict__?
56# +=======+=======+=======+
57# | False |       |       |
58# +-------+-------+-------+
59# | True  | add   |       |  <- the default
60# +=======+=======+=======+
61
62# __repr__
63#
64#    +--- repr= parameter
65#    |
66#    v    |       |       |
67#         |  no   |  yes  |  <--- class has __repr__ in __dict__?
68# +=======+=======+=======+
69# | False |       |       |
70# +-------+-------+-------+
71# | True  | add   |       |  <- the default
72# +=======+=======+=======+
73
74
75# __setattr__
76# __delattr__
77#
78#    +--- frozen= parameter
79#    |
80#    v    |       |       |
81#         |  no   |  yes  |  <--- class has __setattr__ or __delattr__ in __dict__?
82# +=======+=======+=======+
83# | False |       |       |  <- the default
84# +-------+-------+-------+
85# | True  | add   | raise |
86# +=======+=======+=======+
87# Raise because not adding these methods would break the "frozen-ness"
88# of the class.
89
90# __eq__
91#
92#    +--- eq= parameter
93#    |
94#    v    |       |       |
95#         |  no   |  yes  |  <--- class has __eq__ in __dict__?
96# +=======+=======+=======+
97# | False |       |       |
98# +-------+-------+-------+
99# | True  | add   |       |  <- the default
100# +=======+=======+=======+
101
102# __lt__
103# __le__
104# __gt__
105# __ge__
106#
107#    +--- order= parameter
108#    |
109#    v    |       |       |
110#         |  no   |  yes  |  <--- class has any comparison method in __dict__?
111# +=======+=======+=======+
112# | False |       |       |  <- the default
113# +-------+-------+-------+
114# | True  | add   | raise |
115# +=======+=======+=======+
116# Raise because to allow this case would interfere with using
117# functools.total_ordering.
118
119# __hash__
120
121#    +------------------- unsafe_hash= parameter
122#    |       +----------- eq= parameter
123#    |       |       +--- frozen= parameter
124#    |       |       |
125#    v       v       v    |        |        |
126#                         |   no   |  yes   |  <--- class has explicitly defined __hash__
127# +=======+=======+=======+========+========+
128# | False | False | False |        |        | No __eq__, use the base class __hash__
129# +-------+-------+-------+--------+--------+
130# | False | False | True  |        |        | No __eq__, use the base class __hash__
131# +-------+-------+-------+--------+--------+
132# | False | True  | False | None   |        | <-- the default, not hashable
133# +-------+-------+-------+--------+--------+
134# | False | True  | True  | add    |        | Frozen, so hashable, allows override
135# +-------+-------+-------+--------+--------+
136# | True  | False | False | add    | raise  | Has no __eq__, but hashable
137# +-------+-------+-------+--------+--------+
138# | True  | False | True  | add    | raise  | Has no __eq__, but hashable
139# +-------+-------+-------+--------+--------+
140# | True  | True  | False | add    | raise  | Not frozen, but hashable
141# +-------+-------+-------+--------+--------+
142# | True  | True  | True  | add    | raise  | Frozen, so hashable
143# +=======+=======+=======+========+========+
144# For boxes that are blank, __hash__ is untouched and therefore
145# inherited from the base class.  If the base is object, then
146# id-based hashing is used.
147#
148# Note that a class may already have __hash__=None if it specified an
149# __eq__ method in the class body (not one that was created by
150# @dataclass).
151#
152# See _hash_action (below) for a coded version of this table.
153
154# __match_args__
155#
156#    +--- match_args= parameter
157#    |
158#    v    |       |       |
159#         |  no   |  yes  |  <--- class has __match_args__ in __dict__?
160# +=======+=======+=======+
161# | False |       |       |
162# +-------+-------+-------+
163# | True  | add   |       |  <- the default
164# +=======+=======+=======+
165# __match_args__ is always added unless the class already defines it. It is a
166# tuple of __init__ parameter names; non-init fields must be matched by keyword.
167
168
169# Raised when an attempt is made to modify a frozen class.
170class FrozenInstanceError(AttributeError): pass
171
172# A sentinel object for default values to signal that a default
173# factory will be used.  This is given a nice repr() which will appear
174# in the function signature of dataclasses' constructors.
175class _HAS_DEFAULT_FACTORY_CLASS:
176    def __repr__(self):
177        return '<factory>'
178_HAS_DEFAULT_FACTORY = _HAS_DEFAULT_FACTORY_CLASS()
179
180# A sentinel object to detect if a parameter is supplied or not.  Use
181# a class to give it a better repr.
182class _MISSING_TYPE:
183    pass
184MISSING = _MISSING_TYPE()
185
186# A sentinel object to indicate that following fields are keyword-only by
187# default.  Use a class to give it a better repr.
188class _KW_ONLY_TYPE:
189    pass
190KW_ONLY = _KW_ONLY_TYPE()
191
192# Since most per-field metadata will be unused, create an empty
193# read-only proxy that can be shared among all fields.
194_EMPTY_METADATA = types.MappingProxyType({})
195
196# Markers for the various kinds of fields and pseudo-fields.
197class _FIELD_BASE:
198    def __init__(self, name):
199        self.name = name
200    def __repr__(self):
201        return self.name
202_FIELD = _FIELD_BASE('_FIELD')
203_FIELD_CLASSVAR = _FIELD_BASE('_FIELD_CLASSVAR')
204_FIELD_INITVAR = _FIELD_BASE('_FIELD_INITVAR')
205
206# The name of an attribute on the class where we store the Field
207# objects.  Also used to check if a class is a Data Class.
208_FIELDS = '__dataclass_fields__'
209
210# The name of an attribute on the class that stores the parameters to
211# @dataclass.
212_PARAMS = '__dataclass_params__'
213
214# The name of the function, that if it exists, is called at the end of
215# __init__.
216_POST_INIT_NAME = '__post_init__'
217
218# String regex that string annotations for ClassVar or InitVar must match.
219# Allows "identifier.identifier[" or "identifier[".
220# https://bugs.python.org/issue33453 for details.
221_MODULE_IDENTIFIER_RE = re.compile(r'^(?:\s*(\w+)\s*\.)?\s*(\w+)')
222
223# Atomic immutable types which don't require any recursive handling and for which deepcopy
224# returns the same object. We can provide a fast-path for these types in asdict and astuple.
225_ATOMIC_TYPES = frozenset({
226    # Common JSON Serializable types
227    types.NoneType,
228    bool,
229    int,
230    float,
231    str,
232    # Other common types
233    complex,
234    bytes,
235    # Other types that are also unaffected by deepcopy
236    types.EllipsisType,
237    types.NotImplementedType,
238    types.CodeType,
239    types.BuiltinFunctionType,
240    types.FunctionType,
241    type,
242    range,
243    property,
244})
245
246
247class InitVar:
248    __slots__ = ('type', )
249
250    def __init__(self, type):
251        self.type = type
252
253    def __repr__(self):
254        if isinstance(self.type, type):
255            type_name = self.type.__name__
256        else:
257            # typing objects, e.g. List[int]
258            type_name = repr(self.type)
259        return f'dataclasses.InitVar[{type_name}]'
260
261    def __class_getitem__(cls, type):
262        return InitVar(type)
263
264# Instances of Field are only ever created from within this module,
265# and only from the field() function, although Field instances are
266# exposed externally as (conceptually) read-only objects.
267#
268# name and type are filled in after the fact, not in __init__.
269# They're not known at the time this class is instantiated, but it's
270# convenient if they're available later.
271#
272# When cls._FIELDS is filled in with a list of Field objects, the name
273# and type fields will have been populated.
274class Field:
275    __slots__ = ('name',
276                 'type',
277                 'default',
278                 'default_factory',
279                 'repr',
280                 'hash',
281                 'init',
282                 'compare',
283                 'metadata',
284                 'kw_only',
285                 '_field_type',  # Private: not to be used by user code.
286                 )
287
288    def __init__(self, default, default_factory, init, repr, hash, compare,
289                 metadata, kw_only):
290        self.name = None
291        self.type = None
292        self.default = default
293        self.default_factory = default_factory
294        self.init = init
295        self.repr = repr
296        self.hash = hash
297        self.compare = compare
298        self.metadata = (_EMPTY_METADATA
299                         if metadata is None else
300                         types.MappingProxyType(metadata))
301        self.kw_only = kw_only
302        self._field_type = None
303
304    @recursive_repr()
305    def __repr__(self):
306        return ('Field('
307                f'name={self.name!r},'
308                f'type={self.type!r},'
309                f'default={self.default!r},'
310                f'default_factory={self.default_factory!r},'
311                f'init={self.init!r},'
312                f'repr={self.repr!r},'
313                f'hash={self.hash!r},'
314                f'compare={self.compare!r},'
315                f'metadata={self.metadata!r},'
316                f'kw_only={self.kw_only!r},'
317                f'_field_type={self._field_type}'
318                ')')
319
320    # This is used to support the PEP 487 __set_name__ protocol in the
321    # case where we're using a field that contains a descriptor as a
322    # default value.  For details on __set_name__, see
323    # https://peps.python.org/pep-0487/#implementation-details.
324    #
325    # Note that in _process_class, this Field object is overwritten
326    # with the default value, so the end result is a descriptor that
327    # had __set_name__ called on it at the right time.
328    def __set_name__(self, owner, name):
329        func = getattr(type(self.default), '__set_name__', None)
330        if func:
331            # There is a __set_name__ method on the descriptor, call
332            # it.
333            func(self.default, owner, name)
334
335    __class_getitem__ = classmethod(types.GenericAlias)
336
337
338class _DataclassParams:
339    __slots__ = ('init',
340                 'repr',
341                 'eq',
342                 'order',
343                 'unsafe_hash',
344                 'frozen',
345                 'match_args',
346                 'kw_only',
347                 'slots',
348                 'weakref_slot',
349                 )
350
351    def __init__(self,
352                 init, repr, eq, order, unsafe_hash, frozen,
353                 match_args, kw_only, slots, weakref_slot):
354        self.init = init
355        self.repr = repr
356        self.eq = eq
357        self.order = order
358        self.unsafe_hash = unsafe_hash
359        self.frozen = frozen
360        self.match_args = match_args
361        self.kw_only = kw_only
362        self.slots = slots
363        self.weakref_slot = weakref_slot
364
365    def __repr__(self):
366        return ('_DataclassParams('
367                f'init={self.init!r},'
368                f'repr={self.repr!r},'
369                f'eq={self.eq!r},'
370                f'order={self.order!r},'
371                f'unsafe_hash={self.unsafe_hash!r},'
372                f'frozen={self.frozen!r},'
373                f'match_args={self.match_args!r},'
374                f'kw_only={self.kw_only!r},'
375                f'slots={self.slots!r},'
376                f'weakref_slot={self.weakref_slot!r}'
377                ')')
378
379
380# This function is used instead of exposing Field creation directly,
381# so that a type checker can be told (via overloads) that this is a
382# function whose type depends on its parameters.
383def field(*, default=MISSING, default_factory=MISSING, init=True, repr=True,
384          hash=None, compare=True, metadata=None, kw_only=MISSING):
385    """Return an object to identify dataclass fields.
386
387    default is the default value of the field.  default_factory is a
388    0-argument function called to initialize a field's value.  If init
389    is true, the field will be a parameter to the class's __init__()
390    function.  If repr is true, the field will be included in the
391    object's repr().  If hash is true, the field will be included in the
392    object's hash().  If compare is true, the field will be used in
393    comparison functions.  metadata, if specified, must be a mapping
394    which is stored but not otherwise examined by dataclass.  If kw_only
395    is true, the field will become a keyword-only parameter to
396    __init__().
397
398    It is an error to specify both default and default_factory.
399    """
400
401    if default is not MISSING and default_factory is not MISSING:
402        raise ValueError('cannot specify both default and default_factory')
403    return Field(default, default_factory, init, repr, hash, compare,
404                 metadata, kw_only)
405
406
407def _fields_in_init_order(fields):
408    # Returns the fields as __init__ will output them.  It returns 2 tuples:
409    # the first for normal args, and the second for keyword args.
410
411    return (tuple(f for f in fields if f.init and not f.kw_only),
412            tuple(f for f in fields if f.init and f.kw_only)
413            )
414
415
416def _tuple_str(obj_name, fields):
417    # Return a string representing each field of obj_name as a tuple
418    # member.  So, if fields is ['x', 'y'] and obj_name is "self",
419    # return "(self.x,self.y)".
420
421    # Special case for the 0-tuple.
422    if not fields:
423        return '()'
424    # Note the trailing comma, needed if this turns out to be a 1-tuple.
425    return f'({",".join([f"{obj_name}.{f.name}" for f in fields])},)'
426
427
428class _FuncBuilder:
429    def __init__(self, globals):
430        self.names = []
431        self.src = []
432        self.globals = globals
433        self.locals = {}
434        self.overwrite_errors = {}
435        self.unconditional_adds = {}
436
437    def add_fn(self, name, args, body, *, locals=None, return_type=MISSING,
438               overwrite_error=False, unconditional_add=False, decorator=None):
439        if locals is not None:
440            self.locals.update(locals)
441
442        # Keep track if this method is allowed to be overwritten if it already
443        # exists in the class.  The error is method-specific, so keep it with
444        # the name.  We'll use this when we generate all of the functions in
445        # the add_fns_to_class call.  overwrite_error is either True, in which
446        # case we'll raise an error, or it's a string, in which case we'll
447        # raise an error and append this string.
448        if overwrite_error:
449            self.overwrite_errors[name] = overwrite_error
450
451        # Should this function always overwrite anything that's already in the
452        # class?  The default is to not overwrite a function that already
453        # exists.
454        if unconditional_add:
455            self.unconditional_adds[name] = True
456
457        self.names.append(name)
458
459        if return_type is not MISSING:
460            self.locals[f'__dataclass_{name}_return_type__'] = return_type
461            return_annotation = f'->__dataclass_{name}_return_type__'
462        else:
463            return_annotation = ''
464        args = ','.join(args)
465        body = '\n'.join(body)
466
467        # Compute the text of the entire function, add it to the text we're generating.
468        self.src.append(f'{f' {decorator}\n' if decorator else ''} def {name}({args}){return_annotation}:\n{body}')
469
470    def add_fns_to_class(self, cls):
471        # The source to all of the functions we're generating.
472        fns_src = '\n'.join(self.src)
473
474        # The locals they use.
475        local_vars = ','.join(self.locals.keys())
476
477        # The names of all of the functions, used for the return value of the
478        # outer function.  Need to handle the 0-tuple specially.
479        if len(self.names) == 0:
480            return_names = '()'
481        else:
482            return_names  =f'({",".join(self.names)},)'
483
484        # txt is the entire function we're going to execute, including the
485        # bodies of the functions we're defining.  Here's a greatly simplified
486        # version:
487        # def __create_fn__():
488        #  def __init__(self, x, y):
489        #   self.x = x
490        #   self.y = y
491        #  @recursive_repr
492        #  def __repr__(self):
493        #   return f"cls(x={self.x!r},y={self.y!r})"
494        # return __init__,__repr__
495
496        txt = f"def __create_fn__({local_vars}):\n{fns_src}\n return {return_names}"
497        ns = {}
498        exec(txt, self.globals, ns)
499        fns = ns['__create_fn__'](**self.locals)
500
501        # Now that we've generated the functions, assign them into cls.
502        for name, fn in zip(self.names, fns):
503            fn.__qualname__ = f"{cls.__qualname__}.{fn.__name__}"
504            if self.unconditional_adds.get(name, False):
505                setattr(cls, name, fn)
506            else:
507                already_exists = _set_new_attribute(cls, name, fn)
508
509                # See if it's an error to overwrite this particular function.
510                if already_exists and (msg_extra := self.overwrite_errors.get(name)):
511                    error_msg = (f'Cannot overwrite attribute {fn.__name__} '
512                                 f'in class {cls.__name__}')
513                    if not msg_extra is True:
514                        error_msg = f'{error_msg} {msg_extra}'
515
516                    raise TypeError(error_msg)
517
518
519def _field_assign(frozen, name, value, self_name):
520    # If we're a frozen class, then assign to our fields in __init__
521    # via object.__setattr__.  Otherwise, just use a simple
522    # assignment.
523    #
524    # self_name is what "self" is called in this function: don't
525    # hard-code "self", since that might be a field name.
526    if frozen:
527        return f'  __dataclass_builtins_object__.__setattr__({self_name},{name!r},{value})'
528    return f'  {self_name}.{name}={value}'
529
530
531def _field_init(f, frozen, globals, self_name, slots):
532    # Return the text of the line in the body of __init__ that will
533    # initialize this field.
534
535    default_name = f'__dataclass_dflt_{f.name}__'
536    if f.default_factory is not MISSING:
537        if f.init:
538            # This field has a default factory.  If a parameter is
539            # given, use it.  If not, call the factory.
540            globals[default_name] = f.default_factory
541            value = (f'{default_name}() '
542                     f'if {f.name} is __dataclass_HAS_DEFAULT_FACTORY__ '
543                     f'else {f.name}')
544        else:
545            # This is a field that's not in the __init__ params, but
546            # has a default factory function.  It needs to be
547            # initialized here by calling the factory function,
548            # because there's no other way to initialize it.
549
550            # For a field initialized with a default=defaultvalue, the
551            # class dict just has the default value
552            # (cls.fieldname=defaultvalue).  But that won't work for a
553            # default factory, the factory must be called in __init__
554            # and we must assign that to self.fieldname.  We can't
555            # fall back to the class dict's value, both because it's
556            # not set, and because it might be different per-class
557            # (which, after all, is why we have a factory function!).
558
559            globals[default_name] = f.default_factory
560            value = f'{default_name}()'
561    else:
562        # No default factory.
563        if f.init:
564            if f.default is MISSING:
565                # There's no default, just do an assignment.
566                value = f.name
567            elif f.default is not MISSING:
568                globals[default_name] = f.default
569                value = f.name
570        else:
571            # If the class has slots, then initialize this field.
572            if slots and f.default is not MISSING:
573                globals[default_name] = f.default
574                value = default_name
575            else:
576                # This field does not need initialization: reading from it will
577                # just use the class attribute that contains the default.
578                # Signify that to the caller by returning None.
579                return None
580
581    # Only test this now, so that we can create variables for the
582    # default.  However, return None to signify that we're not going
583    # to actually do the assignment statement for InitVars.
584    if f._field_type is _FIELD_INITVAR:
585        return None
586
587    # Now, actually generate the field assignment.
588    return _field_assign(frozen, f.name, value, self_name)
589
590
591def _init_param(f):
592    # Return the __init__ parameter string for this field.  For
593    # example, the equivalent of 'x:int=3' (except instead of 'int',
594    # reference a variable set to int, and instead of '3', reference a
595    # variable set to 3).
596    if f.default is MISSING and f.default_factory is MISSING:
597        # There's no default, and no default_factory, just output the
598        # variable name and type.
599        default = ''
600    elif f.default is not MISSING:
601        # There's a default, this will be the name that's used to look
602        # it up.
603        default = f'=__dataclass_dflt_{f.name}__'
604    elif f.default_factory is not MISSING:
605        # There's a factory function.  Set a marker.
606        default = '=__dataclass_HAS_DEFAULT_FACTORY__'
607    return f'{f.name}:__dataclass_type_{f.name}__{default}'
608
609
610def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
611             self_name, func_builder, slots):
612    # fields contains both real fields and InitVar pseudo-fields.
613
614    # Make sure we don't have fields without defaults following fields
615    # with defaults.  This actually would be caught when exec-ing the
616    # function source code, but catching it here gives a better error
617    # message, and future-proofs us in case we build up the function
618    # using ast.
619
620    seen_default = None
621    for f in std_fields:
622        # Only consider the non-kw-only fields in the __init__ call.
623        if f.init:
624            if not (f.default is MISSING and f.default_factory is MISSING):
625                seen_default = f
626            elif seen_default:
627                raise TypeError(f'non-default argument {f.name!r} '
628                                f'follows default argument {seen_default.name!r}')
629
630    locals = {**{f'__dataclass_type_{f.name}__': f.type for f in fields},
631              **{'__dataclass_HAS_DEFAULT_FACTORY__': _HAS_DEFAULT_FACTORY,
632                 '__dataclass_builtins_object__': object,
633                 }
634              }
635
636    body_lines = []
637    for f in fields:
638        line = _field_init(f, frozen, locals, self_name, slots)
639        # line is None means that this field doesn't require
640        # initialization (it's a pseudo-field).  Just skip it.
641        if line:
642            body_lines.append(line)
643
644    # Does this class have a post-init function?
645    if has_post_init:
646        params_str = ','.join(f.name for f in fields
647                              if f._field_type is _FIELD_INITVAR)
648        body_lines.append(f'  {self_name}.{_POST_INIT_NAME}({params_str})')
649
650    # If no body lines, use 'pass'.
651    if not body_lines:
652        body_lines = ['  pass']
653
654    _init_params = [_init_param(f) for f in std_fields]
655    if kw_only_fields:
656        # Add the keyword-only args.  Because the * can only be added if
657        # there's at least one keyword-only arg, there needs to be a test here
658        # (instead of just concatenting the lists together).
659        _init_params += ['*']
660        _init_params += [_init_param(f) for f in kw_only_fields]
661    func_builder.add_fn('__init__',
662                        [self_name] + _init_params,
663                        body_lines,
664                        locals=locals,
665                        return_type=None)
666
667
668def _frozen_get_del_attr(cls, fields, func_builder):
669    locals = {'cls': cls,
670              'FrozenInstanceError': FrozenInstanceError}
671    condition = 'type(self) is cls'
672    if fields:
673        condition += ' or name in {' + ', '.join(repr(f.name) for f in fields) + '}'
674
675    func_builder.add_fn('__setattr__',
676                        ('self', 'name', 'value'),
677                        (f'  if {condition}:',
678                          '   raise FrozenInstanceError(f"cannot assign to field {name!r}")',
679                         f'  super(cls, self).__setattr__(name, value)'),
680                        locals=locals,
681                        overwrite_error=True)
682    func_builder.add_fn('__delattr__',
683                        ('self', 'name'),
684                        (f'  if {condition}:',
685                          '   raise FrozenInstanceError(f"cannot delete field {name!r}")',
686                         f'  super(cls, self).__delattr__(name)'),
687                        locals=locals,
688                        overwrite_error=True)
689
690
691def _is_classvar(a_type, typing):
692    # This test uses a typing internal class, but it's the best way to
693    # test if this is a ClassVar.
694    return (a_type is typing.ClassVar
695            or (type(a_type) is typing._GenericAlias
696                and a_type.__origin__ is typing.ClassVar))
697
698
699def _is_initvar(a_type, dataclasses):
700    # The module we're checking against is the module we're
701    # currently in (dataclasses.py).
702    return (a_type is dataclasses.InitVar
703            or type(a_type) is dataclasses.InitVar)
704
705def _is_kw_only(a_type, dataclasses):
706    return a_type is dataclasses.KW_ONLY
707
708
709def _is_type(annotation, cls, a_module, a_type, is_type_predicate):
710    # Given a type annotation string, does it refer to a_type in
711    # a_module?  For example, when checking that annotation denotes a
712    # ClassVar, then a_module is typing, and a_type is
713    # typing.ClassVar.
714
715    # It's possible to look up a_module given a_type, but it involves
716    # looking in sys.modules (again!), and seems like a waste since
717    # the caller already knows a_module.
718
719    # - annotation is a string type annotation
720    # - cls is the class that this annotation was found in
721    # - a_module is the module we want to match
722    # - a_type is the type in that module we want to match
723    # - is_type_predicate is a function called with (obj, a_module)
724    #   that determines if obj is of the desired type.
725
726    # Since this test does not do a local namespace lookup (and
727    # instead only a module (global) lookup), there are some things it
728    # gets wrong.
729
730    # With string annotations, cv0 will be detected as a ClassVar:
731    #   CV = ClassVar
732    #   @dataclass
733    #   class C0:
734    #     cv0: CV
735
736    # But in this example cv1 will not be detected as a ClassVar:
737    #   @dataclass
738    #   class C1:
739    #     CV = ClassVar
740    #     cv1: CV
741
742    # In C1, the code in this function (_is_type) will look up "CV" in
743    # the module and not find it, so it will not consider cv1 as a
744    # ClassVar.  This is a fairly obscure corner case, and the best
745    # way to fix it would be to eval() the string "CV" with the
746    # correct global and local namespaces.  However that would involve
747    # a eval() penalty for every single field of every dataclass
748    # that's defined.  It was judged not worth it.
749
750    match = _MODULE_IDENTIFIER_RE.match(annotation)
751    if match:
752        ns = None
753        module_name = match.group(1)
754        if not module_name:
755            # No module name, assume the class's module did
756            # "from dataclasses import InitVar".
757            ns = sys.modules.get(cls.__module__).__dict__
758        else:
759            # Look up module_name in the class's module.
760            module = sys.modules.get(cls.__module__)
761            if module and module.__dict__.get(module_name) is a_module:
762                ns = sys.modules.get(a_type.__module__).__dict__
763        if ns and is_type_predicate(ns.get(match.group(2)), a_module):
764            return True
765    return False
766
767
768def _get_field(cls, a_name, a_type, default_kw_only):
769    # Return a Field object for this field name and type.  ClassVars and
770    # InitVars are also returned, but marked as such (see f._field_type).
771    # default_kw_only is the value of kw_only to use if there isn't a field()
772    # that defines it.
773
774    # If the default value isn't derived from Field, then it's only a
775    # normal default value.  Convert it to a Field().
776    default = getattr(cls, a_name, MISSING)
777    if isinstance(default, Field):
778        f = default
779    else:
780        if isinstance(default, types.MemberDescriptorType):
781            # This is a field in __slots__, so it has no default value.
782            default = MISSING
783        f = field(default=default)
784
785    # Only at this point do we know the name and the type.  Set them.
786    f.name = a_name
787    f.type = a_type
788
789    # Assume it's a normal field until proven otherwise.  We're next
790    # going to decide if it's a ClassVar or InitVar, everything else
791    # is just a normal field.
792    f._field_type = _FIELD
793
794    # In addition to checking for actual types here, also check for
795    # string annotations.  get_type_hints() won't always work for us
796    # (see https://github.com/python/typing/issues/508 for example),
797    # plus it's expensive and would require an eval for every string
798    # annotation.  So, make a best effort to see if this is a ClassVar
799    # or InitVar using regex's and checking that the thing referenced
800    # is actually of the correct type.
801
802    # For the complete discussion, see https://bugs.python.org/issue33453
803
804    # If typing has not been imported, then it's impossible for any
805    # annotation to be a ClassVar.  So, only look for ClassVar if
806    # typing has been imported by any module (not necessarily cls's
807    # module).
808    typing = sys.modules.get('typing')
809    if typing:
810        if (_is_classvar(a_type, typing)
811            or (isinstance(f.type, str)
812                and _is_type(f.type, cls, typing, typing.ClassVar,
813                             _is_classvar))):
814            f._field_type = _FIELD_CLASSVAR
815
816    # If the type is InitVar, or if it's a matching string annotation,
817    # then it's an InitVar.
818    if f._field_type is _FIELD:
819        # The module we're checking against is the module we're
820        # currently in (dataclasses.py).
821        dataclasses = sys.modules[__name__]
822        if (_is_initvar(a_type, dataclasses)
823            or (isinstance(f.type, str)
824                and _is_type(f.type, cls, dataclasses, dataclasses.InitVar,
825                             _is_initvar))):
826            f._field_type = _FIELD_INITVAR
827
828    # Validations for individual fields.  This is delayed until now,
829    # instead of in the Field() constructor, since only here do we
830    # know the field name, which allows for better error reporting.
831
832    # Special restrictions for ClassVar and InitVar.
833    if f._field_type in (_FIELD_CLASSVAR, _FIELD_INITVAR):
834        if f.default_factory is not MISSING:
835            raise TypeError(f'field {f.name} cannot have a '
836                            'default factory')
837        # Should I check for other field settings? default_factory
838        # seems the most serious to check for.  Maybe add others.  For
839        # example, how about init=False (or really,
840        # init=<not-the-default-init-value>)?  It makes no sense for
841        # ClassVar and InitVar to specify init=<anything>.
842
843    # kw_only validation and assignment.
844    if f._field_type in (_FIELD, _FIELD_INITVAR):
845        # For real and InitVar fields, if kw_only wasn't specified use the
846        # default value.
847        if f.kw_only is MISSING:
848            f.kw_only = default_kw_only
849    else:
850        # Make sure kw_only isn't set for ClassVars
851        assert f._field_type is _FIELD_CLASSVAR
852        if f.kw_only is not MISSING:
853            raise TypeError(f'field {f.name} is a ClassVar but specifies '
854                            'kw_only')
855
856    # For real fields, disallow mutable defaults.  Use unhashable as a proxy
857    # indicator for mutability.  Read the __hash__ attribute from the class,
858    # not the instance.
859    if f._field_type is _FIELD and f.default.__class__.__hash__ is None:
860        raise ValueError(f'mutable default {type(f.default)} for field '
861                         f'{f.name} is not allowed: use default_factory')
862
863    return f
864
865def _set_new_attribute(cls, name, value):
866    # Never overwrites an existing attribute.  Returns True if the
867    # attribute already exists.
868    if name in cls.__dict__:
869        return True
870    setattr(cls, name, value)
871    return False
872
873
874# Decide if/how we're going to create a hash function.  Key is
875# (unsafe_hash, eq, frozen, does-hash-exist).  Value is the action to
876# take.  The common case is to do nothing, so instead of providing a
877# function that is a no-op, use None to signify that.
878
879def _hash_set_none(cls, fields, func_builder):
880    # It's sort of a hack that I'm setting this here, instead of at
881    # func_builder.add_fns_to_class time, but since this is an exceptional case
882    # (it's not setting an attribute to a function, but to a scalar value),
883    # just do it directly here.  I might come to regret this.
884    cls.__hash__ = None
885
886def _hash_add(cls, fields, func_builder):
887    flds = [f for f in fields if (f.compare if f.hash is None else f.hash)]
888    self_tuple = _tuple_str('self', flds)
889    func_builder.add_fn('__hash__',
890                        ('self',),
891                        [f'  return hash({self_tuple})'],
892                        unconditional_add=True)
893
894def _hash_exception(cls, fields, func_builder):
895    # Raise an exception.
896    raise TypeError(f'Cannot overwrite attribute __hash__ '
897                    f'in class {cls.__name__}')
898
899#
900#                +-------------------------------------- unsafe_hash?
901#                |      +------------------------------- eq?
902#                |      |      +------------------------ frozen?
903#                |      |      |      +----------------  has-explicit-hash?
904#                |      |      |      |
905#                |      |      |      |        +-------  action
906#                |      |      |      |        |
907#                v      v      v      v        v
908_hash_action = {(False, False, False, False): None,
909                (False, False, False, True ): None,
910                (False, False, True,  False): None,
911                (False, False, True,  True ): None,
912                (False, True,  False, False): _hash_set_none,
913                (False, True,  False, True ): None,
914                (False, True,  True,  False): _hash_add,
915                (False, True,  True,  True ): None,
916                (True,  False, False, False): _hash_add,
917                (True,  False, False, True ): _hash_exception,
918                (True,  False, True,  False): _hash_add,
919                (True,  False, True,  True ): _hash_exception,
920                (True,  True,  False, False): _hash_add,
921                (True,  True,  False, True ): _hash_exception,
922                (True,  True,  True,  False): _hash_add,
923                (True,  True,  True,  True ): _hash_exception,
924                }
925# See https://bugs.python.org/issue32929#msg312829 for an if-statement
926# version of this table.
927
928
929def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
930                   match_args, kw_only, slots, weakref_slot):
931    # Now that dicts retain insertion order, there's no reason to use
932    # an ordered dict.  I am leveraging that ordering here, because
933    # derived class fields overwrite base class fields, but the order
934    # is defined by the base class, which is found first.
935    fields = {}
936
937    if cls.__module__ in sys.modules:
938        globals = sys.modules[cls.__module__].__dict__
939    else:
940        # Theoretically this can happen if someone writes
941        # a custom string to cls.__module__.  In which case
942        # such dataclass won't be fully introspectable
943        # (w.r.t. typing.get_type_hints) but will still function
944        # correctly.
945        globals = {}
946
947    setattr(cls, _PARAMS, _DataclassParams(init, repr, eq, order,
948                                           unsafe_hash, frozen,
949                                           match_args, kw_only,
950                                           slots, weakref_slot))
951
952    # Find our base classes in reverse MRO order, and exclude
953    # ourselves.  In reversed order so that more derived classes
954    # override earlier field definitions in base classes.  As long as
955    # we're iterating over them, see if all or any of them are frozen.
956    any_frozen_base = False
957    # By default `all_frozen_bases` is `None` to represent a case,
958    # where some dataclasses does not have any bases with `_FIELDS`
959    all_frozen_bases = None
960    has_dataclass_bases = False
961    for b in cls.__mro__[-1:0:-1]:
962        # Only process classes that have been processed by our
963        # decorator.  That is, they have a _FIELDS attribute.
964        base_fields = getattr(b, _FIELDS, None)
965        if base_fields is not None:
966            has_dataclass_bases = True
967            for f in base_fields.values():
968                fields[f.name] = f
969            if all_frozen_bases is None:
970                all_frozen_bases = True
971            current_frozen = getattr(b, _PARAMS).frozen
972            all_frozen_bases = all_frozen_bases and current_frozen
973            any_frozen_base = any_frozen_base or current_frozen
974
975    # Annotations defined specifically in this class (not in base classes).
976    #
977    # Fields are found from cls_annotations, which is guaranteed to be
978    # ordered.  Default values are from class attributes, if a field
979    # has a default.  If the default value is a Field(), then it
980    # contains additional info beyond (and possibly including) the
981    # actual default value.  Pseudo-fields ClassVars and InitVars are
982    # included, despite the fact that they're not real fields.  That's
983    # dealt with later.
984    cls_annotations = inspect.get_annotations(cls)
985
986    # Now find fields in our class.  While doing so, validate some
987    # things, and set the default values (as class attributes) where
988    # we can.
989    cls_fields = []
990    # Get a reference to this module for the _is_kw_only() test.
991    KW_ONLY_seen = False
992    dataclasses = sys.modules[__name__]
993    for name, type in cls_annotations.items():
994        # See if this is a marker to change the value of kw_only.
995        if (_is_kw_only(type, dataclasses)
996            or (isinstance(type, str)
997                and _is_type(type, cls, dataclasses, dataclasses.KW_ONLY,
998                             _is_kw_only))):
999            # Switch the default to kw_only=True, and ignore this
1000            # annotation: it's not a real field.
1001            if KW_ONLY_seen:
1002                raise TypeError(f'{name!r} is KW_ONLY, but KW_ONLY '
1003                                'has already been specified')
1004            KW_ONLY_seen = True
1005            kw_only = True
1006        else:
1007            # Otherwise it's a field of some type.
1008            cls_fields.append(_get_field(cls, name, type, kw_only))
1009
1010    for f in cls_fields:
1011        fields[f.name] = f
1012
1013        # If the class attribute (which is the default value for this
1014        # field) exists and is of type 'Field', replace it with the
1015        # real default.  This is so that normal class introspection
1016        # sees a real default value, not a Field.
1017        if isinstance(getattr(cls, f.name, None), Field):
1018            if f.default is MISSING:
1019                # If there's no default, delete the class attribute.
1020                # This happens if we specify field(repr=False), for
1021                # example (that is, we specified a field object, but
1022                # no default value).  Also if we're using a default
1023                # factory.  The class attribute should not be set at
1024                # all in the post-processed class.
1025                delattr(cls, f.name)
1026            else:
1027                setattr(cls, f.name, f.default)
1028
1029    # Do we have any Field members that don't also have annotations?
1030    for name, value in cls.__dict__.items():
1031        if isinstance(value, Field) and not name in cls_annotations:
1032            raise TypeError(f'{name!r} is a field but has no type annotation')
1033
1034    # Check rules that apply if we are derived from any dataclasses.
1035    if has_dataclass_bases:
1036        # Raise an exception if any of our bases are frozen, but we're not.
1037        if any_frozen_base and not frozen:
1038            raise TypeError('cannot inherit non-frozen dataclass from a '
1039                            'frozen one')
1040
1041        # Raise an exception if we're frozen, but none of our bases are.
1042        if all_frozen_bases is False and frozen:
1043            raise TypeError('cannot inherit frozen dataclass from a '
1044                            'non-frozen one')
1045
1046    # Remember all of the fields on our class (including bases).  This
1047    # also marks this class as being a dataclass.
1048    setattr(cls, _FIELDS, fields)
1049
1050    # Was this class defined with an explicit __hash__?  Note that if
1051    # __eq__ is defined in this class, then python will automatically
1052    # set __hash__ to None.  This is a heuristic, as it's possible
1053    # that such a __hash__ == None was not auto-generated, but it's
1054    # close enough.
1055    class_hash = cls.__dict__.get('__hash__', MISSING)
1056    has_explicit_hash = not (class_hash is MISSING or
1057                             (class_hash is None and '__eq__' in cls.__dict__))
1058
1059    # If we're generating ordering methods, we must be generating the
1060    # eq methods.
1061    if order and not eq:
1062        raise ValueError('eq must be true if order is true')
1063
1064    # Include InitVars and regular fields (so, not ClassVars).  This is
1065    # initialized here, outside of the "if init:" test, because std_init_fields
1066    # is used with match_args, below.
1067    all_init_fields = [f for f in fields.values()
1068                       if f._field_type in (_FIELD, _FIELD_INITVAR)]
1069    (std_init_fields,
1070     kw_only_init_fields) = _fields_in_init_order(all_init_fields)
1071
1072    func_builder = _FuncBuilder(globals)
1073
1074    if init:
1075        # Does this class have a post-init function?
1076        has_post_init = hasattr(cls, _POST_INIT_NAME)
1077
1078        _init_fn(all_init_fields,
1079                 std_init_fields,
1080                 kw_only_init_fields,
1081                 frozen,
1082                 has_post_init,
1083                 # The name to use for the "self"
1084                 # param in __init__.  Use "self"
1085                 # if possible.
1086                 '__dataclass_self__' if 'self' in fields
1087                 else 'self',
1088                 func_builder,
1089                 slots,
1090                 )
1091
1092    _set_new_attribute(cls, '__replace__', _replace)
1093
1094    # Get the fields as a list, and include only real fields.  This is
1095    # used in all of the following methods.
1096    field_list = [f for f in fields.values() if f._field_type is _FIELD]
1097
1098    if repr:
1099        flds = [f for f in field_list if f.repr]
1100        func_builder.add_fn('__repr__',
1101                            ('self',),
1102                            ['  return f"{self.__class__.__qualname__}(' +
1103                             ', '.join([f"{f.name}={{self.{f.name}!r}}"
1104                                        for f in flds]) + ')"'],
1105                            locals={'__dataclasses_recursive_repr': recursive_repr},
1106                            decorator="@__dataclasses_recursive_repr()")
1107
1108    if eq:
1109        # Create __eq__ method.  There's no need for a __ne__ method,
1110        # since python will call __eq__ and negate it.
1111        cmp_fields = (field for field in field_list if field.compare)
1112        terms = [f'self.{field.name}==other.{field.name}' for field in cmp_fields]
1113        field_comparisons = ' and '.join(terms) or 'True'
1114        func_builder.add_fn('__eq__',
1115                            ('self', 'other'),
1116                            [ '  if self is other:',
1117                              '   return True',
1118                              '  if other.__class__ is self.__class__:',
1119                             f'   return {field_comparisons}',
1120                              '  return NotImplemented'])
1121
1122    if order:
1123        # Create and set the ordering methods.
1124        flds = [f for f in field_list if f.compare]
1125        self_tuple = _tuple_str('self', flds)
1126        other_tuple = _tuple_str('other', flds)
1127        for name, op in [('__lt__', '<'),
1128                         ('__le__', '<='),
1129                         ('__gt__', '>'),
1130                         ('__ge__', '>='),
1131                         ]:
1132            # Create a comparison function.  If the fields in the object are
1133            # named 'x' and 'y', then self_tuple is the string
1134            # '(self.x,self.y)' and other_tuple is the string
1135            # '(other.x,other.y)'.
1136            func_builder.add_fn(name,
1137                            ('self', 'other'),
1138                            [ '  if other.__class__ is self.__class__:',
1139                             f'   return {self_tuple}{op}{other_tuple}',
1140                              '  return NotImplemented'],
1141                            overwrite_error='Consider using functools.total_ordering')
1142
1143    if frozen:
1144        _frozen_get_del_attr(cls, field_list, func_builder)
1145
1146    # Decide if/how we're going to create a hash function.
1147    hash_action = _hash_action[bool(unsafe_hash),
1148                               bool(eq),
1149                               bool(frozen),
1150                               has_explicit_hash]
1151    if hash_action:
1152        cls.__hash__ = hash_action(cls, field_list, func_builder)
1153
1154    # Generate the methods and add them to the class.  This needs to be done
1155    # before the __doc__ logic below, since inspect will look at the __init__
1156    # signature.
1157    func_builder.add_fns_to_class(cls)
1158
1159    if not getattr(cls, '__doc__'):
1160        # Create a class doc-string.
1161        try:
1162            # In some cases fetching a signature is not possible.
1163            # But, we surely should not fail in this case.
1164            text_sig = str(inspect.signature(cls)).replace(' -> None', '')
1165        except (TypeError, ValueError):
1166            text_sig = ''
1167        cls.__doc__ = (cls.__name__ + text_sig)
1168
1169    if match_args:
1170        # I could probably compute this once.
1171        _set_new_attribute(cls, '__match_args__',
1172                           tuple(f.name for f in std_init_fields))
1173
1174    # It's an error to specify weakref_slot if slots is False.
1175    if weakref_slot and not slots:
1176        raise TypeError('weakref_slot is True but slots is False')
1177    if slots:
1178        cls = _add_slots(cls, frozen, weakref_slot)
1179
1180    abc.update_abstractmethods(cls)
1181
1182    return cls
1183
1184
1185# _dataclass_getstate and _dataclass_setstate are needed for pickling frozen
1186# classes with slots.  These could be slightly more performant if we generated
1187# the code instead of iterating over fields.  But that can be a project for
1188# another day, if performance becomes an issue.
1189def _dataclass_getstate(self):
1190    return [getattr(self, f.name) for f in fields(self)]
1191
1192
1193def _dataclass_setstate(self, state):
1194    for field, value in zip(fields(self), state):
1195        # use setattr because dataclass may be frozen
1196        object.__setattr__(self, field.name, value)
1197
1198
1199def _get_slots(cls):
1200    match cls.__dict__.get('__slots__'):
1201        # `__dictoffset__` and `__weakrefoffset__` can tell us whether
1202        # the base type has dict/weakref slots, in a way that works correctly
1203        # for both Python classes and C extension types. Extension types
1204        # don't use `__slots__` for slot creation
1205        case None:
1206            slots = []
1207            if getattr(cls, '__weakrefoffset__', -1) != 0:
1208                slots.append('__weakref__')
1209            if getattr(cls, '__dictoffset__', -1) != 0:
1210                slots.append('__dict__')
1211            yield from slots
1212        case str(slot):
1213            yield slot
1214        # Slots may be any iterable, but we cannot handle an iterator
1215        # because it will already be (partially) consumed.
1216        case iterable if not hasattr(iterable, '__next__'):
1217            yield from iterable
1218        case _:
1219            raise TypeError(f"Slots of '{cls.__name__}' cannot be determined")
1220
1221
1222def _add_slots(cls, is_frozen, weakref_slot):
1223    # Need to create a new class, since we can't set __slots__
1224    #  after a class has been created.
1225
1226    # Make sure __slots__ isn't already set.
1227    if '__slots__' in cls.__dict__:
1228        raise TypeError(f'{cls.__name__} already specifies __slots__')
1229
1230    # Create a new dict for our new class.
1231    cls_dict = dict(cls.__dict__)
1232    field_names = tuple(f.name for f in fields(cls))
1233    # Make sure slots don't overlap with those in base classes.
1234    inherited_slots = set(
1235        itertools.chain.from_iterable(map(_get_slots, cls.__mro__[1:-1]))
1236    )
1237    # The slots for our class.  Remove slots from our base classes.  Add
1238    # '__weakref__' if weakref_slot was given, unless it is already present.
1239    cls_dict["__slots__"] = tuple(
1240        itertools.filterfalse(
1241            inherited_slots.__contains__,
1242            itertools.chain(
1243                # gh-93521: '__weakref__' also needs to be filtered out if
1244                # already present in inherited_slots
1245                field_names, ('__weakref__',) if weakref_slot else ()
1246            )
1247        ),
1248    )
1249
1250    for field_name in field_names:
1251        # Remove our attributes, if present. They'll still be
1252        #  available in _MARKER.
1253        cls_dict.pop(field_name, None)
1254
1255    # Remove __dict__ itself.
1256    cls_dict.pop('__dict__', None)
1257
1258    # Clear existing `__weakref__` descriptor, it belongs to a previous type:
1259    cls_dict.pop('__weakref__', None)  # gh-102069
1260
1261    # And finally create the class.
1262    qualname = getattr(cls, '__qualname__', None)
1263    cls = type(cls)(cls.__name__, cls.__bases__, cls_dict)
1264    if qualname is not None:
1265        cls.__qualname__ = qualname
1266
1267    if is_frozen:
1268        # Need this for pickling frozen classes with slots.
1269        if '__getstate__' not in cls_dict:
1270            cls.__getstate__ = _dataclass_getstate
1271        if '__setstate__' not in cls_dict:
1272            cls.__setstate__ = _dataclass_setstate
1273
1274    return cls
1275
1276
1277def dataclass(cls=None, /, *, init=True, repr=True, eq=True, order=False,
1278              unsafe_hash=False, frozen=False, match_args=True,
1279              kw_only=False, slots=False, weakref_slot=False):
1280    """Add dunder methods based on the fields defined in the class.
1281
1282    Examines PEP 526 __annotations__ to determine fields.
1283
1284    If init is true, an __init__() method is added to the class. If repr
1285    is true, a __repr__() method is added. If order is true, rich
1286    comparison dunder methods are added. If unsafe_hash is true, a
1287    __hash__() method is added. If frozen is true, fields may not be
1288    assigned to after instance creation. If match_args is true, the
1289    __match_args__ tuple is added. If kw_only is true, then by default
1290    all fields are keyword-only. If slots is true, a new class with a
1291    __slots__ attribute is returned.
1292    """
1293
1294    def wrap(cls):
1295        return _process_class(cls, init, repr, eq, order, unsafe_hash,
1296                              frozen, match_args, kw_only, slots,
1297                              weakref_slot)
1298
1299    # See if we're being called as @dataclass or @dataclass().
1300    if cls is None:
1301        # We're called with parens.
1302        return wrap
1303
1304    # We're called as @dataclass without parens.
1305    return wrap(cls)
1306
1307
1308def fields(class_or_instance):
1309    """Return a tuple describing the fields of this dataclass.
1310
1311    Accepts a dataclass or an instance of one. Tuple elements are of
1312    type Field.
1313    """
1314
1315    # Might it be worth caching this, per class?
1316    try:
1317        fields = getattr(class_or_instance, _FIELDS)
1318    except AttributeError:
1319        raise TypeError('must be called with a dataclass type or instance') from None
1320
1321    # Exclude pseudo-fields.  Note that fields is sorted by insertion
1322    # order, so the order of the tuple is as the fields were defined.
1323    return tuple(f for f in fields.values() if f._field_type is _FIELD)
1324
1325
1326def _is_dataclass_instance(obj):
1327    """Returns True if obj is an instance of a dataclass."""
1328    return hasattr(type(obj), _FIELDS)
1329
1330
1331def is_dataclass(obj):
1332    """Returns True if obj is a dataclass or an instance of a
1333    dataclass."""
1334    cls = obj if isinstance(obj, type) else type(obj)
1335    return hasattr(cls, _FIELDS)
1336
1337
1338def asdict(obj, *, dict_factory=dict):
1339    """Return the fields of a dataclass instance as a new dictionary mapping
1340    field names to field values.
1341
1342    Example usage::
1343
1344      @dataclass
1345      class C:
1346          x: int
1347          y: int
1348
1349      c = C(1, 2)
1350      assert asdict(c) == {'x': 1, 'y': 2}
1351
1352    If given, 'dict_factory' will be used instead of built-in dict.
1353    The function applies recursively to field values that are
1354    dataclass instances. This will also look into built-in containers:
1355    tuples, lists, and dicts. Other objects are copied with 'copy.deepcopy()'.
1356    """
1357    if not _is_dataclass_instance(obj):
1358        raise TypeError("asdict() should be called on dataclass instances")
1359    return _asdict_inner(obj, dict_factory)
1360
1361
1362def _asdict_inner(obj, dict_factory):
1363    obj_type = type(obj)
1364    if obj_type in _ATOMIC_TYPES:
1365        return obj
1366    elif hasattr(obj_type, _FIELDS):
1367        # dataclass instance: fast path for the common case
1368        if dict_factory is dict:
1369            return {
1370                f.name: _asdict_inner(getattr(obj, f.name), dict)
1371                for f in fields(obj)
1372            }
1373        else:
1374            return dict_factory([
1375                (f.name, _asdict_inner(getattr(obj, f.name), dict_factory))
1376                for f in fields(obj)
1377            ])
1378    # handle the builtin types first for speed; subclasses handled below
1379    elif obj_type is list:
1380        return [_asdict_inner(v, dict_factory) for v in obj]
1381    elif obj_type is dict:
1382        return {
1383            _asdict_inner(k, dict_factory): _asdict_inner(v, dict_factory)
1384            for k, v in obj.items()
1385        }
1386    elif obj_type is tuple:
1387        return tuple([_asdict_inner(v, dict_factory) for v in obj])
1388    elif issubclass(obj_type, tuple):
1389        if hasattr(obj, '_fields'):
1390            # obj is a namedtuple.  Recurse into it, but the returned
1391            # object is another namedtuple of the same type.  This is
1392            # similar to how other list- or tuple-derived classes are
1393            # treated (see below), but we just need to create them
1394            # differently because a namedtuple's __init__ needs to be
1395            # called differently (see bpo-34363).
1396
1397            # I'm not using namedtuple's _asdict()
1398            # method, because:
1399            # - it does not recurse in to the namedtuple fields and
1400            #   convert them to dicts (using dict_factory).
1401            # - I don't actually want to return a dict here.  The main
1402            #   use case here is json.dumps, and it handles converting
1403            #   namedtuples to lists.  Admittedly we're losing some
1404            #   information here when we produce a json list instead of a
1405            #   dict.  Note that if we returned dicts here instead of
1406            #   namedtuples, we could no longer call asdict() on a data
1407            #   structure where a namedtuple was used as a dict key.
1408            return obj_type(*[_asdict_inner(v, dict_factory) for v in obj])
1409        else:
1410            return obj_type(_asdict_inner(v, dict_factory) for v in obj)
1411    elif issubclass(obj_type, dict):
1412        if hasattr(obj_type, 'default_factory'):
1413            # obj is a defaultdict, which has a different constructor from
1414            # dict as it requires the default_factory as its first arg.
1415            result = obj_type(obj.default_factory)
1416            for k, v in obj.items():
1417                result[_asdict_inner(k, dict_factory)] = _asdict_inner(v, dict_factory)
1418            return result
1419        return obj_type((_asdict_inner(k, dict_factory),
1420                         _asdict_inner(v, dict_factory))
1421                        for k, v in obj.items())
1422    elif issubclass(obj_type, list):
1423        # Assume we can create an object of this type by passing in a
1424        # generator
1425        return obj_type(_asdict_inner(v, dict_factory) for v in obj)
1426    else:
1427        return copy.deepcopy(obj)
1428
1429
1430def astuple(obj, *, tuple_factory=tuple):
1431    """Return the fields of a dataclass instance as a new tuple of field values.
1432
1433    Example usage::
1434
1435      @dataclass
1436      class C:
1437          x: int
1438          y: int
1439
1440      c = C(1, 2)
1441      assert astuple(c) == (1, 2)
1442
1443    If given, 'tuple_factory' will be used instead of built-in tuple.
1444    The function applies recursively to field values that are
1445    dataclass instances. This will also look into built-in containers:
1446    tuples, lists, and dicts. Other objects are copied with 'copy.deepcopy()'.
1447    """
1448
1449    if not _is_dataclass_instance(obj):
1450        raise TypeError("astuple() should be called on dataclass instances")
1451    return _astuple_inner(obj, tuple_factory)
1452
1453
1454def _astuple_inner(obj, tuple_factory):
1455    if type(obj) in _ATOMIC_TYPES:
1456        return obj
1457    elif _is_dataclass_instance(obj):
1458        return tuple_factory([
1459            _astuple_inner(getattr(obj, f.name), tuple_factory)
1460            for f in fields(obj)
1461        ])
1462    elif isinstance(obj, tuple) and hasattr(obj, '_fields'):
1463        # obj is a namedtuple.  Recurse into it, but the returned
1464        # object is another namedtuple of the same type.  This is
1465        # similar to how other list- or tuple-derived classes are
1466        # treated (see below), but we just need to create them
1467        # differently because a namedtuple's __init__ needs to be
1468        # called differently (see bpo-34363).
1469        return type(obj)(*[_astuple_inner(v, tuple_factory) for v in obj])
1470    elif isinstance(obj, (list, tuple)):
1471        # Assume we can create an object of this type by passing in a
1472        # generator (which is not true for namedtuples, handled
1473        # above).
1474        return type(obj)(_astuple_inner(v, tuple_factory) for v in obj)
1475    elif isinstance(obj, dict):
1476        obj_type = type(obj)
1477        if hasattr(obj_type, 'default_factory'):
1478            # obj is a defaultdict, which has a different constructor from
1479            # dict as it requires the default_factory as its first arg.
1480            result = obj_type(getattr(obj, 'default_factory'))
1481            for k, v in obj.items():
1482                result[_astuple_inner(k, tuple_factory)] = _astuple_inner(v, tuple_factory)
1483            return result
1484        return obj_type((_astuple_inner(k, tuple_factory), _astuple_inner(v, tuple_factory))
1485                          for k, v in obj.items())
1486    else:
1487        return copy.deepcopy(obj)
1488
1489
1490def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
1491                   repr=True, eq=True, order=False, unsafe_hash=False,
1492                   frozen=False, match_args=True, kw_only=False, slots=False,
1493                   weakref_slot=False, module=None):
1494    """Return a new dynamically created dataclass.
1495
1496    The dataclass name will be 'cls_name'.  'fields' is an iterable
1497    of either (name), (name, type) or (name, type, Field) objects. If type is
1498    omitted, use the string 'typing.Any'.  Field objects are created by
1499    the equivalent of calling 'field(name, type [, Field-info])'.::
1500
1501      C = make_dataclass('C', ['x', ('y', int), ('z', int, field(init=False))], bases=(Base,))
1502
1503    is equivalent to::
1504
1505      @dataclass
1506      class C(Base):
1507          x: 'typing.Any'
1508          y: int
1509          z: int = field(init=False)
1510
1511    For the bases and namespace parameters, see the builtin type() function.
1512
1513    The parameters init, repr, eq, order, unsafe_hash, frozen, match_args, kw_only,
1514    slots, and weakref_slot are passed to dataclass().
1515
1516    If module parameter is defined, the '__module__' attribute of the dataclass is
1517    set to that value.
1518    """
1519
1520    if namespace is None:
1521        namespace = {}
1522
1523    # While we're looking through the field names, validate that they
1524    # are identifiers, are not keywords, and not duplicates.
1525    seen = set()
1526    annotations = {}
1527    defaults = {}
1528    for item in fields:
1529        if isinstance(item, str):
1530            name = item
1531            tp = 'typing.Any'
1532        elif len(item) == 2:
1533            name, tp, = item
1534        elif len(item) == 3:
1535            name, tp, spec = item
1536            defaults[name] = spec
1537        else:
1538            raise TypeError(f'Invalid field: {item!r}')
1539
1540        if not isinstance(name, str) or not name.isidentifier():
1541            raise TypeError(f'Field names must be valid identifiers: {name!r}')
1542        if keyword.iskeyword(name):
1543            raise TypeError(f'Field names must not be keywords: {name!r}')
1544        if name in seen:
1545            raise TypeError(f'Field name duplicated: {name!r}')
1546
1547        seen.add(name)
1548        annotations[name] = tp
1549
1550    # Update 'ns' with the user-supplied namespace plus our calculated values.
1551    def exec_body_callback(ns):
1552        ns.update(namespace)
1553        ns.update(defaults)
1554        ns['__annotations__'] = annotations
1555
1556    # We use `types.new_class()` instead of simply `type()` to allow dynamic creation
1557    # of generic dataclasses.
1558    cls = types.new_class(cls_name, bases, {}, exec_body_callback)
1559
1560    # For pickling to work, the __module__ variable needs to be set to the frame
1561    # where the dataclass is created.
1562    if module is None:
1563        try:
1564            module = sys._getframemodulename(1) or '__main__'
1565        except AttributeError:
1566            try:
1567                module = sys._getframe(1).f_globals.get('__name__', '__main__')
1568            except (AttributeError, ValueError):
1569                pass
1570    if module is not None:
1571        cls.__module__ = module
1572
1573    # Apply the normal decorator.
1574    return dataclass(cls, init=init, repr=repr, eq=eq, order=order,
1575                     unsafe_hash=unsafe_hash, frozen=frozen,
1576                     match_args=match_args, kw_only=kw_only, slots=slots,
1577                     weakref_slot=weakref_slot)
1578
1579
1580def replace(obj, /, **changes):
1581    """Return a new object replacing specified fields with new values.
1582
1583    This is especially useful for frozen classes.  Example usage::
1584
1585      @dataclass(frozen=True)
1586      class C:
1587          x: int
1588          y: int
1589
1590      c = C(1, 2)
1591      c1 = replace(c, x=3)
1592      assert c1.x == 3 and c1.y == 2
1593    """
1594    if not _is_dataclass_instance(obj):
1595        raise TypeError("replace() should be called on dataclass instances")
1596    return _replace(obj, **changes)
1597
1598
1599def _replace(self, /, **changes):
1600    # We're going to mutate 'changes', but that's okay because it's a
1601    # new dict, even if called with 'replace(self, **my_changes)'.
1602
1603    # It's an error to have init=False fields in 'changes'.
1604    # If a field is not in 'changes', read its value from the provided 'self'.
1605
1606    for f in getattr(self, _FIELDS).values():
1607        # Only consider normal fields or InitVars.
1608        if f._field_type is _FIELD_CLASSVAR:
1609            continue
1610
1611        if not f.init:
1612            # Error if this field is specified in changes.
1613            if f.name in changes:
1614                raise TypeError(f'field {f.name} is declared with '
1615                                f'init=False, it cannot be specified with '
1616                                f'replace()')
1617            continue
1618
1619        if f.name not in changes:
1620            if f._field_type is _FIELD_INITVAR and f.default is MISSING:
1621                raise TypeError(f"InitVar {f.name!r} "
1622                                f'must be specified with replace()')
1623            changes[f.name] = getattr(self, f.name)
1624
1625    # Create the new object, which calls __init__() and
1626    # __post_init__() (if defined), using all of the init fields we've
1627    # added and/or left in 'changes'.  If there are values supplied in
1628    # changes that aren't fields, this will correctly raise a
1629    # TypeError.
1630    return self.__class__(**changes)
1631