• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Deliberately use "from dataclasses import *".  Every name in __all__
2# is tested, so they all must be present.  This is a way to catch
3# missing ones.
4
5from dataclasses import *
6
7import abc
8import io
9import pickle
10import inspect
11import builtins
12import types
13import weakref
14import traceback
15import unittest
16from unittest.mock import Mock
17from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional, Protocol, DefaultDict
18from typing import get_type_hints
19from collections import deque, OrderedDict, namedtuple, defaultdict
20from copy import deepcopy
21from functools import total_ordering
22
23import typing       # Needed for the string "typing.ClassVar[int]" to work as an annotation.
24import dataclasses  # Needed for the string "dataclasses.InitVar[int]" to work as an annotation.
25
26from test import support
27
28# Just any custom exception we can catch.
29class CustomError(Exception): pass
30
31class TestCase(unittest.TestCase):
32    def test_no_fields(self):
33        @dataclass
34        class C:
35            pass
36
37        o = C()
38        self.assertEqual(len(fields(C)), 0)
39
40    def test_no_fields_but_member_variable(self):
41        @dataclass
42        class C:
43            i = 0
44
45        o = C()
46        self.assertEqual(len(fields(C)), 0)
47
48    def test_one_field_no_default(self):
49        @dataclass
50        class C:
51            x: int
52
53        o = C(42)
54        self.assertEqual(o.x, 42)
55
56    def test_field_default_default_factory_error(self):
57        msg = "cannot specify both default and default_factory"
58        with self.assertRaisesRegex(ValueError, msg):
59            @dataclass
60            class C:
61                x: int = field(default=1, default_factory=int)
62
63    def test_field_repr(self):
64        int_field = field(default=1, init=True, repr=False)
65        int_field.name = "id"
66        repr_output = repr(int_field)
67        expected_output = "Field(name='id',type=None," \
68                           f"default=1,default_factory={MISSING!r}," \
69                           "init=True,repr=False,hash=None," \
70                           "compare=True,metadata=mappingproxy({})," \
71                           f"kw_only={MISSING!r}," \
72                           "_field_type=None)"
73
74        self.assertEqual(repr_output, expected_output)
75
76    def test_field_recursive_repr(self):
77        rec_field = field()
78        rec_field.type = rec_field
79        rec_field.name = "id"
80        repr_output = repr(rec_field)
81
82        self.assertIn(",type=...,", repr_output)
83
84    def test_recursive_annotation(self):
85        class C:
86            pass
87
88        @dataclass
89        class D:
90            C: C = field()
91
92        self.assertIn(",type=...,", repr(D.__dataclass_fields__["C"]))
93
94    def test_dataclass_params_repr(self):
95        # Even though this is testing an internal implementation detail,
96        # it's testing a feature we want to make sure is correctly implemented
97        # for the sake of dataclasses itself
98        @dataclass(slots=True, frozen=True)
99        class Some: pass
100
101        repr_output = repr(Some.__dataclass_params__)
102        expected_output = "_DataclassParams(init=True,repr=True," \
103                          "eq=True,order=False,unsafe_hash=False,frozen=True," \
104                          "match_args=True,kw_only=False," \
105                          "slots=True,weakref_slot=False)"
106        self.assertEqual(repr_output, expected_output)
107
108    def test_dataclass_params_signature(self):
109        # Even though this is testing an internal implementation detail,
110        # it's testing a feature we want to make sure is correctly implemented
111        # for the sake of dataclasses itself
112        @dataclass
113        class Some: pass
114
115        for param in inspect.signature(dataclass).parameters:
116            if param == 'cls':
117                continue
118            self.assertTrue(hasattr(Some.__dataclass_params__, param), msg=param)
119
120    def test_named_init_params(self):
121        @dataclass
122        class C:
123            x: int
124
125        o = C(x=32)
126        self.assertEqual(o.x, 32)
127
128    def test_two_fields_one_default(self):
129        @dataclass
130        class C:
131            x: int
132            y: int = 0
133
134        o = C(3)
135        self.assertEqual((o.x, o.y), (3, 0))
136
137        # Non-defaults following defaults.
138        with self.assertRaisesRegex(TypeError,
139                                    "non-default argument 'y' follows "
140                                    "default argument 'x'"):
141            @dataclass
142            class C:
143                x: int = 0
144                y: int
145
146        # A derived class adds a non-default field after a default one.
147        with self.assertRaisesRegex(TypeError,
148                                    "non-default argument 'y' follows "
149                                    "default argument 'x'"):
150            @dataclass
151            class B:
152                x: int = 0
153
154            @dataclass
155            class C(B):
156                y: int
157
158        # Override a base class field and add a default to
159        #  a field which didn't use to have a default.
160        with self.assertRaisesRegex(TypeError,
161                                    "non-default argument 'y' follows "
162                                    "default argument 'x'"):
163            @dataclass
164            class B:
165                x: int
166                y: int
167
168            @dataclass
169            class C(B):
170                x: int = 0
171
172    def test_overwrite_hash(self):
173        # Test that declaring this class isn't an error.  It should
174        #  use the user-provided __hash__.
175        @dataclass(frozen=True)
176        class C:
177            x: int
178            def __hash__(self):
179                return 301
180        self.assertEqual(hash(C(100)), 301)
181
182        # Test that declaring this class isn't an error.  It should
183        #  use the generated __hash__.
184        @dataclass(frozen=True)
185        class C:
186            x: int
187            def __eq__(self, other):
188                return False
189        self.assertEqual(hash(C(100)), hash((100,)))
190
191        # But this one should generate an exception, because with
192        #  unsafe_hash=True, it's an error to have a __hash__ defined.
193        with self.assertRaisesRegex(TypeError,
194                                    'Cannot overwrite attribute __hash__'):
195            @dataclass(unsafe_hash=True)
196            class C:
197                def __hash__(self):
198                    pass
199
200        # Creating this class should not generate an exception,
201        #  because even though __hash__ exists before @dataclass is
202        #  called, (due to __eq__ being defined), since it's None
203        #  that's okay.
204        @dataclass(unsafe_hash=True)
205        class C:
206            x: int
207            def __eq__(self):
208                pass
209        # The generated hash function works as we'd expect.
210        self.assertEqual(hash(C(10)), hash((10,)))
211
212        # Creating this class should generate an exception, because
213        #  __hash__ exists and is not None, which it would be if it
214        #  had been auto-generated due to __eq__ being defined.
215        with self.assertRaisesRegex(TypeError,
216                                    'Cannot overwrite attribute __hash__'):
217            @dataclass(unsafe_hash=True)
218            class C:
219                x: int
220                def __eq__(self):
221                    pass
222                def __hash__(self):
223                    pass
224
225    def test_overwrite_fields_in_derived_class(self):
226        # Note that x from C1 replaces x in Base, but the order remains
227        #  the same as defined in Base.
228        @dataclass
229        class Base:
230            x: Any = 15.0
231            y: int = 0
232
233        @dataclass
234        class C1(Base):
235            z: int = 10
236            x: int = 15
237
238        o = Base()
239        self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)')
240
241        o = C1()
242        self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)')
243
244        o = C1(x=5)
245        self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)')
246
247    def test_field_named_self(self):
248        @dataclass
249        class C:
250            self: str
251        c=C('foo')
252        self.assertEqual(c.self, 'foo')
253
254        # Make sure the first parameter is not named 'self'.
255        sig = inspect.signature(C.__init__)
256        first = next(iter(sig.parameters))
257        self.assertNotEqual('self', first)
258
259        # But we do use 'self' if no field named self.
260        @dataclass
261        class C:
262            selfx: str
263
264        # Make sure the first parameter is named 'self'.
265        sig = inspect.signature(C.__init__)
266        first = next(iter(sig.parameters))
267        self.assertEqual('self', first)
268
269    def test_field_named_object(self):
270        @dataclass
271        class C:
272            object: str
273        c = C('foo')
274        self.assertEqual(c.object, 'foo')
275
276    def test_field_named_object_frozen(self):
277        @dataclass(frozen=True)
278        class C:
279            object: str
280        c = C('foo')
281        self.assertEqual(c.object, 'foo')
282
283    def test_field_named_BUILTINS_frozen(self):
284        # gh-96151
285        @dataclass(frozen=True)
286        class C:
287            BUILTINS: int
288        c = C(5)
289        self.assertEqual(c.BUILTINS, 5)
290
291    def test_field_with_special_single_underscore_names(self):
292        # gh-98886
293
294        @dataclass
295        class X:
296            x: int = field(default_factory=lambda: 111)
297            _dflt_x: int = field(default_factory=lambda: 222)
298
299        X()
300
301        @dataclass
302        class Y:
303            y: int = field(default_factory=lambda: 111)
304            _HAS_DEFAULT_FACTORY: int = 222
305
306        assert Y(y=222).y == 222
307
308    def test_field_named_like_builtin(self):
309        # Attribute names can shadow built-in names
310        # since code generation is used.
311        # Ensure that this is not happening.
312        exclusions = {'None', 'True', 'False'}
313        builtins_names = sorted(
314            b for b in builtins.__dict__.keys()
315            if not b.startswith('__') and b not in exclusions
316        )
317        attributes = [(name, str) for name in builtins_names]
318        C = make_dataclass('C', attributes)
319
320        c = C(*[name for name in builtins_names])
321
322        for name in builtins_names:
323            self.assertEqual(getattr(c, name), name)
324
325    def test_field_named_like_builtin_frozen(self):
326        # Attribute names can shadow built-in names
327        # since code generation is used.
328        # Ensure that this is not happening
329        # for frozen data classes.
330        exclusions = {'None', 'True', 'False'}
331        builtins_names = sorted(
332            b for b in builtins.__dict__.keys()
333            if not b.startswith('__') and b not in exclusions
334        )
335        attributes = [(name, str) for name in builtins_names]
336        C = make_dataclass('C', attributes, frozen=True)
337
338        c = C(*[name for name in builtins_names])
339
340        for name in builtins_names:
341            self.assertEqual(getattr(c, name), name)
342
343    def test_0_field_compare(self):
344        # Ensure that order=False is the default.
345        @dataclass
346        class C0:
347            pass
348
349        @dataclass(order=False)
350        class C1:
351            pass
352
353        for cls in [C0, C1]:
354            with self.subTest(cls=cls):
355                self.assertEqual(cls(), cls())
356                for idx, fn in enumerate([lambda a, b: a < b,
357                                          lambda a, b: a <= b,
358                                          lambda a, b: a > b,
359                                          lambda a, b: a >= b]):
360                    with self.subTest(idx=idx):
361                        with self.assertRaisesRegex(TypeError,
362                                                    f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
363                            fn(cls(), cls())
364
365        @dataclass(order=True)
366        class C:
367            pass
368        self.assertLessEqual(C(), C())
369        self.assertGreaterEqual(C(), C())
370
371    def test_1_field_compare(self):
372        # Ensure that order=False is the default.
373        @dataclass
374        class C0:
375            x: int
376
377        @dataclass(order=False)
378        class C1:
379            x: int
380
381        for cls in [C0, C1]:
382            with self.subTest(cls=cls):
383                self.assertEqual(cls(1), cls(1))
384                self.assertNotEqual(cls(0), cls(1))
385                for idx, fn in enumerate([lambda a, b: a < b,
386                                          lambda a, b: a <= b,
387                                          lambda a, b: a > b,
388                                          lambda a, b: a >= b]):
389                    with self.subTest(idx=idx):
390                        with self.assertRaisesRegex(TypeError,
391                                                    f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
392                            fn(cls(0), cls(0))
393
394        @dataclass(order=True)
395        class C:
396            x: int
397        self.assertLess(C(0), C(1))
398        self.assertLessEqual(C(0), C(1))
399        self.assertLessEqual(C(1), C(1))
400        self.assertGreater(C(1), C(0))
401        self.assertGreaterEqual(C(1), C(0))
402        self.assertGreaterEqual(C(1), C(1))
403
404    def test_simple_compare(self):
405        # Ensure that order=False is the default.
406        @dataclass
407        class C0:
408            x: int
409            y: int
410
411        @dataclass(order=False)
412        class C1:
413            x: int
414            y: int
415
416        for cls in [C0, C1]:
417            with self.subTest(cls=cls):
418                self.assertEqual(cls(0, 0), cls(0, 0))
419                self.assertEqual(cls(1, 2), cls(1, 2))
420                self.assertNotEqual(cls(1, 0), cls(0, 0))
421                self.assertNotEqual(cls(1, 0), cls(1, 1))
422                for idx, fn in enumerate([lambda a, b: a < b,
423                                          lambda a, b: a <= b,
424                                          lambda a, b: a > b,
425                                          lambda a, b: a >= b]):
426                    with self.subTest(idx=idx):
427                        with self.assertRaisesRegex(TypeError,
428                                                    f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
429                            fn(cls(0, 0), cls(0, 0))
430
431        @dataclass(order=True)
432        class C:
433            x: int
434            y: int
435
436        for idx, fn in enumerate([lambda a, b: a == b,
437                                  lambda a, b: a <= b,
438                                  lambda a, b: a >= b]):
439            with self.subTest(idx=idx):
440                self.assertTrue(fn(C(0, 0), C(0, 0)))
441
442        for idx, fn in enumerate([lambda a, b: a < b,
443                                  lambda a, b: a <= b,
444                                  lambda a, b: a != b]):
445            with self.subTest(idx=idx):
446                self.assertTrue(fn(C(0, 0), C(0, 1)))
447                self.assertTrue(fn(C(0, 1), C(1, 0)))
448                self.assertTrue(fn(C(1, 0), C(1, 1)))
449
450        for idx, fn in enumerate([lambda a, b: a > b,
451                                  lambda a, b: a >= b,
452                                  lambda a, b: a != b]):
453            with self.subTest(idx=idx):
454                self.assertTrue(fn(C(0, 1), C(0, 0)))
455                self.assertTrue(fn(C(1, 0), C(0, 1)))
456                self.assertTrue(fn(C(1, 1), C(1, 0)))
457
458    def test_compare_subclasses(self):
459        # Comparisons fail for subclasses, even if no fields
460        #  are added.
461        @dataclass
462        class B:
463            i: int
464
465        @dataclass
466        class C(B):
467            pass
468
469        for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False),
470                                              (lambda a, b: a != b, True)]):
471            with self.subTest(idx=idx):
472                self.assertEqual(fn(B(0), C(0)), expected)
473
474        for idx, fn in enumerate([lambda a, b: a < b,
475                                  lambda a, b: a <= b,
476                                  lambda a, b: a > b,
477                                  lambda a, b: a >= b]):
478            with self.subTest(idx=idx):
479                with self.assertRaisesRegex(TypeError,
480                                            "not supported between instances of 'B' and 'C'"):
481                    fn(B(0), C(0))
482
483    def test_eq_order(self):
484        # Test combining eq and order.
485        for (eq,    order, result   ) in [
486            (False, False, 'neither'),
487            (False, True,  'exception'),
488            (True,  False, 'eq_only'),
489            (True,  True,  'both'),
490        ]:
491            with self.subTest(eq=eq, order=order):
492                if result == 'exception':
493                    with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
494                        @dataclass(eq=eq, order=order)
495                        class C:
496                            pass
497                else:
498                    @dataclass(eq=eq, order=order)
499                    class C:
500                        pass
501
502                    if result == 'neither':
503                        self.assertNotIn('__eq__', C.__dict__)
504                        self.assertNotIn('__lt__', C.__dict__)
505                        self.assertNotIn('__le__', C.__dict__)
506                        self.assertNotIn('__gt__', C.__dict__)
507                        self.assertNotIn('__ge__', C.__dict__)
508                    elif result == 'both':
509                        self.assertIn('__eq__', C.__dict__)
510                        self.assertIn('__lt__', C.__dict__)
511                        self.assertIn('__le__', C.__dict__)
512                        self.assertIn('__gt__', C.__dict__)
513                        self.assertIn('__ge__', C.__dict__)
514                    elif result == 'eq_only':
515                        self.assertIn('__eq__', C.__dict__)
516                        self.assertNotIn('__lt__', C.__dict__)
517                        self.assertNotIn('__le__', C.__dict__)
518                        self.assertNotIn('__gt__', C.__dict__)
519                        self.assertNotIn('__ge__', C.__dict__)
520                    else:
521                        assert False, f'unknown result {result!r}'
522
523    def test_field_no_default(self):
524        @dataclass
525        class C:
526            x: int = field()
527
528        self.assertEqual(C(5).x, 5)
529
530        with self.assertRaisesRegex(TypeError,
531                                    r"__init__\(\) missing 1 required "
532                                    "positional argument: 'x'"):
533            C()
534
535    def test_field_default(self):
536        default = object()
537        @dataclass
538        class C:
539            x: object = field(default=default)
540
541        self.assertIs(C.x, default)
542        c = C(10)
543        self.assertEqual(c.x, 10)
544
545        # If we delete the instance attribute, we should then see the
546        #  class attribute.
547        del c.x
548        self.assertIs(c.x, default)
549
550        self.assertIs(C().x, default)
551
552    def test_not_in_repr(self):
553        @dataclass
554        class C:
555            x: int = field(repr=False)
556        with self.assertRaises(TypeError):
557            C()
558        c = C(10)
559        self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()')
560
561        @dataclass
562        class C:
563            x: int = field(repr=False)
564            y: int
565        c = C(10, 20)
566        self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)')
567
568    def test_not_in_compare(self):
569        @dataclass
570        class C:
571            x: int = 0
572            y: int = field(compare=False, default=4)
573
574        self.assertEqual(C(), C(0, 20))
575        self.assertEqual(C(1, 10), C(1, 20))
576        self.assertNotEqual(C(3), C(4, 10))
577        self.assertNotEqual(C(3, 10), C(4, 10))
578
579    def test_no_unhashable_default(self):
580        # See bpo-44674.
581        class Unhashable:
582            __hash__ = None
583
584        unhashable_re = 'mutable default .* for field a is not allowed'
585        with self.assertRaisesRegex(ValueError, unhashable_re):
586            @dataclass
587            class A:
588                a: dict = {}
589
590        with self.assertRaisesRegex(ValueError, unhashable_re):
591            @dataclass
592            class A:
593                a: Any = Unhashable()
594
595        # Make sure that the machinery looking for hashability is using the
596        # class's __hash__, not the instance's __hash__.
597        with self.assertRaisesRegex(ValueError, unhashable_re):
598            unhashable = Unhashable()
599            # This shouldn't make the variable hashable.
600            unhashable.__hash__ = lambda: 0
601            @dataclass
602            class A:
603                a: Any = unhashable
604
605    def test_hash_field_rules(self):
606        # Test all 6 cases of:
607        #  hash=True/False/None
608        #  compare=True/False
609        for (hash_,    compare, result  ) in [
610            (True,     False,   'field' ),
611            (True,     True,    'field' ),
612            (False,    False,   'absent'),
613            (False,    True,    'absent'),
614            (None,     False,   'absent'),
615            (None,     True,    'field' ),
616            ]:
617            with self.subTest(hash=hash_, compare=compare):
618                @dataclass(unsafe_hash=True)
619                class C:
620                    x: int = field(compare=compare, hash=hash_, default=5)
621
622                if result == 'field':
623                    # __hash__ contains the field.
624                    self.assertEqual(hash(C(5)), hash((5,)))
625                elif result == 'absent':
626                    # The field is not present in the hash.
627                    self.assertEqual(hash(C(5)), hash(()))
628                else:
629                    assert False, f'unknown result {result!r}'
630
631    def test_init_false_no_default(self):
632        # If init=False and no default value, then the field won't be
633        #  present in the instance.
634        @dataclass
635        class C:
636            x: int = field(init=False)
637
638        self.assertNotIn('x', C().__dict__)
639
640        @dataclass
641        class C:
642            x: int
643            y: int = 0
644            z: int = field(init=False)
645            t: int = 10
646
647        self.assertNotIn('z', C(0).__dict__)
648        self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0})
649
650    def test_class_marker(self):
651        @dataclass
652        class C:
653            x: int
654            y: str = field(init=False, default=None)
655            z: str = field(repr=False)
656
657        the_fields = fields(C)
658        # the_fields is a tuple of 3 items, each value
659        #  is in __annotations__.
660        self.assertIsInstance(the_fields, tuple)
661        for f in the_fields:
662            self.assertIs(type(f), Field)
663            self.assertIn(f.name, C.__annotations__)
664
665        self.assertEqual(len(the_fields), 3)
666
667        self.assertEqual(the_fields[0].name, 'x')
668        self.assertEqual(the_fields[0].type, int)
669        self.assertFalse(hasattr(C, 'x'))
670        self.assertTrue (the_fields[0].init)
671        self.assertTrue (the_fields[0].repr)
672        self.assertEqual(the_fields[1].name, 'y')
673        self.assertEqual(the_fields[1].type, str)
674        self.assertIsNone(getattr(C, 'y'))
675        self.assertFalse(the_fields[1].init)
676        self.assertTrue (the_fields[1].repr)
677        self.assertEqual(the_fields[2].name, 'z')
678        self.assertEqual(the_fields[2].type, str)
679        self.assertFalse(hasattr(C, 'z'))
680        self.assertTrue (the_fields[2].init)
681        self.assertFalse(the_fields[2].repr)
682
683    def test_field_order(self):
684        @dataclass
685        class B:
686            a: str = 'B:a'
687            b: str = 'B:b'
688            c: str = 'B:c'
689
690        @dataclass
691        class C(B):
692            b: str = 'C:b'
693
694        self.assertEqual([(f.name, f.default) for f in fields(C)],
695                         [('a', 'B:a'),
696                          ('b', 'C:b'),
697                          ('c', 'B:c')])
698
699        @dataclass
700        class D(B):
701            c: str = 'D:c'
702
703        self.assertEqual([(f.name, f.default) for f in fields(D)],
704                         [('a', 'B:a'),
705                          ('b', 'B:b'),
706                          ('c', 'D:c')])
707
708        @dataclass
709        class E(D):
710            a: str = 'E:a'
711            d: str = 'E:d'
712
713        self.assertEqual([(f.name, f.default) for f in fields(E)],
714                         [('a', 'E:a'),
715                          ('b', 'B:b'),
716                          ('c', 'D:c'),
717                          ('d', 'E:d')])
718
719    def test_class_attrs(self):
720        # We only have a class attribute if a default value is
721        #  specified, either directly or via a field with a default.
722        default = object()
723        @dataclass
724        class C:
725            x: int
726            y: int = field(repr=False)
727            z: object = default
728            t: int = field(default=100)
729
730        self.assertFalse(hasattr(C, 'x'))
731        self.assertFalse(hasattr(C, 'y'))
732        self.assertIs   (C.z, default)
733        self.assertEqual(C.t, 100)
734
735    def test_disallowed_mutable_defaults(self):
736        # For the known types, don't allow mutable default values.
737        for typ, empty, non_empty in [(list, [], [1]),
738                                      (dict, {}, {0:1}),
739                                      (set, set(), set([1])),
740                                      ]:
741            with self.subTest(typ=typ):
742                # Can't use a zero-length value.
743                with self.assertRaisesRegex(ValueError,
744                                            f'mutable default {typ} for field '
745                                            'x is not allowed'):
746                    @dataclass
747                    class Point:
748                        x: typ = empty
749
750
751                # Nor a non-zero-length value
752                with self.assertRaisesRegex(ValueError,
753                                            f'mutable default {typ} for field '
754                                            'y is not allowed'):
755                    @dataclass
756                    class Point:
757                        y: typ = non_empty
758
759                # Check subtypes also fail.
760                class Subclass(typ): pass
761
762                with self.assertRaisesRegex(ValueError,
763                                            "mutable default .*Subclass'>"
764                                            " for field z is not allowed"
765                                            ):
766                    @dataclass
767                    class Point:
768                        z: typ = Subclass()
769
770                # Because this is a ClassVar, it can be mutable.
771                @dataclass
772                class C:
773                    z: ClassVar[typ] = typ()
774
775                # Because this is a ClassVar, it can be mutable.
776                @dataclass
777                class C:
778                    x: ClassVar[typ] = Subclass()
779
780    def test_deliberately_mutable_defaults(self):
781        # If a mutable default isn't in the known list of
782        #  (list, dict, set), then it's okay.
783        class Mutable:
784            def __init__(self):
785                self.l = []
786
787        @dataclass
788        class C:
789            x: Mutable
790
791        # These 2 instances will share this value of x.
792        lst = Mutable()
793        o1 = C(lst)
794        o2 = C(lst)
795        self.assertEqual(o1, o2)
796        o1.x.l.extend([1, 2])
797        self.assertEqual(o1, o2)
798        self.assertEqual(o1.x.l, [1, 2])
799        self.assertIs(o1.x, o2.x)
800
801    def test_no_options(self):
802        # Call with dataclass().
803        @dataclass()
804        class C:
805            x: int
806
807        self.assertEqual(C(42).x, 42)
808
809    def test_not_tuple(self):
810        # Make sure we can't be compared to a tuple.
811        @dataclass
812        class Point:
813            x: int
814            y: int
815        self.assertNotEqual(Point(1, 2), (1, 2))
816
817        # And that we can't compare to another unrelated dataclass.
818        @dataclass
819        class C:
820            x: int
821            y: int
822        self.assertNotEqual(Point(1, 3), C(1, 3))
823
824    def test_not_other_dataclass(self):
825        # Test that some of the problems with namedtuple don't happen
826        #  here.
827        @dataclass
828        class Point3D:
829            x: int
830            y: int
831            z: int
832
833        @dataclass
834        class Date:
835            year: int
836            month: int
837            day: int
838
839        self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3))
840        self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3))
841
842        # Make sure we can't unpack.
843        with self.assertRaisesRegex(TypeError, 'unpack'):
844            x, y, z = Point3D(4, 5, 6)
845
846        # Make sure another class with the same field names isn't
847        #  equal.
848        @dataclass
849        class Point3Dv1:
850            x: int = 0
851            y: int = 0
852            z: int = 0
853        self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1())
854
855    def test_function_annotations(self):
856        # Some dummy class and instance to use as a default.
857        class F:
858            pass
859        f = F()
860
861        def validate_class(cls):
862            # First, check __annotations__, even though they're not
863            #  function annotations.
864            self.assertEqual(cls.__annotations__['i'], int)
865            self.assertEqual(cls.__annotations__['j'], str)
866            self.assertEqual(cls.__annotations__['k'], F)
867            self.assertEqual(cls.__annotations__['l'], float)
868            self.assertEqual(cls.__annotations__['z'], complex)
869
870            # Verify __init__.
871
872            signature = inspect.signature(cls.__init__)
873            # Check the return type, should be None.
874            self.assertIs(signature.return_annotation, None)
875
876            # Check each parameter.
877            params = iter(signature.parameters.values())
878            param = next(params)
879            # This is testing an internal name, and probably shouldn't be tested.
880            self.assertEqual(param.name, 'self')
881            param = next(params)
882            self.assertEqual(param.name, 'i')
883            self.assertIs   (param.annotation, int)
884            self.assertEqual(param.default, inspect.Parameter.empty)
885            self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
886            param = next(params)
887            self.assertEqual(param.name, 'j')
888            self.assertIs   (param.annotation, str)
889            self.assertEqual(param.default, inspect.Parameter.empty)
890            self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
891            param = next(params)
892            self.assertEqual(param.name, 'k')
893            self.assertIs   (param.annotation, F)
894            # Don't test for the default, since it's set to MISSING.
895            self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
896            param = next(params)
897            self.assertEqual(param.name, 'l')
898            self.assertIs   (param.annotation, float)
899            # Don't test for the default, since it's set to MISSING.
900            self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
901            self.assertRaises(StopIteration, next, params)
902
903
904        @dataclass
905        class C:
906            i: int
907            j: str
908            k: F = f
909            l: float=field(default=None)
910            z: complex=field(default=3+4j, init=False)
911
912        validate_class(C)
913
914        # Now repeat with __hash__.
915        @dataclass(frozen=True, unsafe_hash=True)
916        class C:
917            i: int
918            j: str
919            k: F = f
920            l: float=field(default=None)
921            z: complex=field(default=3+4j, init=False)
922
923        validate_class(C)
924
925    def test_missing_default(self):
926        # Test that MISSING works the same as a default not being
927        #  specified.
928        @dataclass
929        class C:
930            x: int=field(default=MISSING)
931        with self.assertRaisesRegex(TypeError,
932                                    r'__init__\(\) missing 1 required '
933                                    'positional argument'):
934            C()
935        self.assertNotIn('x', C.__dict__)
936
937        @dataclass
938        class D:
939            x: int
940        with self.assertRaisesRegex(TypeError,
941                                    r'__init__\(\) missing 1 required '
942                                    'positional argument'):
943            D()
944        self.assertNotIn('x', D.__dict__)
945
946    def test_missing_default_factory(self):
947        # Test that MISSING works the same as a default factory not
948        #  being specified (which is really the same as a default not
949        #  being specified, too).
950        @dataclass
951        class C:
952            x: int=field(default_factory=MISSING)
953        with self.assertRaisesRegex(TypeError,
954                                    r'__init__\(\) missing 1 required '
955                                    'positional argument'):
956            C()
957        self.assertNotIn('x', C.__dict__)
958
959        @dataclass
960        class D:
961            x: int=field(default=MISSING, default_factory=MISSING)
962        with self.assertRaisesRegex(TypeError,
963                                    r'__init__\(\) missing 1 required '
964                                    'positional argument'):
965            D()
966        self.assertNotIn('x', D.__dict__)
967
968    def test_missing_repr(self):
969        self.assertIn('MISSING_TYPE object', repr(MISSING))
970
971    def test_dont_include_other_annotations(self):
972        @dataclass
973        class C:
974            i: int
975            def foo(self) -> int:
976                return 4
977            @property
978            def bar(self) -> int:
979                return 5
980        self.assertEqual(list(C.__annotations__), ['i'])
981        self.assertEqual(C(10).foo(), 4)
982        self.assertEqual(C(10).bar, 5)
983        self.assertEqual(C(10).i, 10)
984
985    def test_post_init(self):
986        # Just make sure it gets called
987        @dataclass
988        class C:
989            def __post_init__(self):
990                raise CustomError()
991        with self.assertRaises(CustomError):
992            C()
993
994        @dataclass
995        class C:
996            i: int = 10
997            def __post_init__(self):
998                if self.i == 10:
999                    raise CustomError()
1000        with self.assertRaises(CustomError):
1001            C()
1002        # post-init gets called, but doesn't raise. This is just
1003        #  checking that self is used correctly.
1004        C(5)
1005
1006        # If there's not an __init__, then post-init won't get called.
1007        @dataclass(init=False)
1008        class C:
1009            def __post_init__(self):
1010                raise CustomError()
1011        # Creating the class won't raise
1012        C()
1013
1014        @dataclass
1015        class C:
1016            x: int = 0
1017            def __post_init__(self):
1018                self.x *= 2
1019        self.assertEqual(C().x, 0)
1020        self.assertEqual(C(2).x, 4)
1021
1022        # Make sure that if we're frozen, post-init can't set
1023        #  attributes.
1024        @dataclass(frozen=True)
1025        class C:
1026            x: int = 0
1027            def __post_init__(self):
1028                self.x *= 2
1029        with self.assertRaises(FrozenInstanceError):
1030            C()
1031
1032    def test_post_init_super(self):
1033        # Make sure super() post-init isn't called by default.
1034        class B:
1035            def __post_init__(self):
1036                raise CustomError()
1037
1038        @dataclass
1039        class C(B):
1040            def __post_init__(self):
1041                self.x = 5
1042
1043        self.assertEqual(C().x, 5)
1044
1045        # Now call super(), and it will raise.
1046        @dataclass
1047        class C(B):
1048            def __post_init__(self):
1049                super().__post_init__()
1050
1051        with self.assertRaises(CustomError):
1052            C()
1053
1054        # Make sure post-init is called, even if not defined in our
1055        #  class.
1056        @dataclass
1057        class C(B):
1058            pass
1059
1060        with self.assertRaises(CustomError):
1061            C()
1062
1063    def test_post_init_staticmethod(self):
1064        flag = False
1065        @dataclass
1066        class C:
1067            x: int
1068            y: int
1069            @staticmethod
1070            def __post_init__():
1071                nonlocal flag
1072                flag = True
1073
1074        self.assertFalse(flag)
1075        c = C(3, 4)
1076        self.assertEqual((c.x, c.y), (3, 4))
1077        self.assertTrue(flag)
1078
1079    def test_post_init_classmethod(self):
1080        @dataclass
1081        class C:
1082            flag = False
1083            x: int
1084            y: int
1085            @classmethod
1086            def __post_init__(cls):
1087                cls.flag = True
1088
1089        self.assertFalse(C.flag)
1090        c = C(3, 4)
1091        self.assertEqual((c.x, c.y), (3, 4))
1092        self.assertTrue(C.flag)
1093
1094    def test_post_init_not_auto_added(self):
1095        # See bpo-46757, which had proposed always adding __post_init__.  As
1096        # Raymond Hettinger pointed out, that would be a breaking change.  So,
1097        # add a test to make sure that the current behavior doesn't change.
1098
1099        @dataclass
1100        class A0:
1101            pass
1102
1103        @dataclass
1104        class B0:
1105            b_called: bool = False
1106            def __post_init__(self):
1107                self.b_called = True
1108
1109        @dataclass
1110        class C0(A0, B0):
1111            c_called: bool = False
1112            def __post_init__(self):
1113                super().__post_init__()
1114                self.c_called = True
1115
1116        # Since A0 has no __post_init__, and one wasn't automatically added
1117        # (because that's the rule: it's never added by @dataclass, it's only
1118        # the class author that can add it), then B0.__post_init__ is called.
1119        # Verify that.
1120        c = C0()
1121        self.assertTrue(c.b_called)
1122        self.assertTrue(c.c_called)
1123
1124        ######################################
1125        # Now, the same thing, except A1 defines __post_init__.
1126        @dataclass
1127        class A1:
1128            def __post_init__(self):
1129                pass
1130
1131        @dataclass
1132        class B1:
1133            b_called: bool = False
1134            def __post_init__(self):
1135                self.b_called = True
1136
1137        @dataclass
1138        class C1(A1, B1):
1139            c_called: bool = False
1140            def __post_init__(self):
1141                super().__post_init__()
1142                self.c_called = True
1143
1144        # This time, B1.__post_init__ isn't being called.  This mimics what
1145        # would happen if A1.__post_init__ had been automatically added,
1146        # instead of manually added as we see here.  This test isn't really
1147        # needed, but I'm including it just to demonstrate the changed
1148        # behavior when A1 does define __post_init__.
1149        c = C1()
1150        self.assertFalse(c.b_called)
1151        self.assertTrue(c.c_called)
1152
1153    def test_class_var(self):
1154        # Make sure ClassVars are ignored in __init__, __repr__, etc.
1155        @dataclass
1156        class C:
1157            x: int
1158            y: int = 10
1159            z: ClassVar[int] = 1000
1160            w: ClassVar[int] = 2000
1161            t: ClassVar[int] = 3000
1162            s: ClassVar      = 4000
1163
1164        c = C(5)
1165        self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)')
1166        self.assertEqual(len(fields(C)), 2)                 # We have 2 fields.
1167        self.assertEqual(len(C.__annotations__), 6)         # And 4 ClassVars.
1168        self.assertEqual(c.z, 1000)
1169        self.assertEqual(c.w, 2000)
1170        self.assertEqual(c.t, 3000)
1171        self.assertEqual(c.s, 4000)
1172        C.z += 1
1173        self.assertEqual(c.z, 1001)
1174        c = C(20)
1175        self.assertEqual((c.x, c.y), (20, 10))
1176        self.assertEqual(c.z, 1001)
1177        self.assertEqual(c.w, 2000)
1178        self.assertEqual(c.t, 3000)
1179        self.assertEqual(c.s, 4000)
1180
1181    def test_class_var_no_default(self):
1182        # If a ClassVar has no default value, it should not be set on the class.
1183        @dataclass
1184        class C:
1185            x: ClassVar[int]
1186
1187        self.assertNotIn('x', C.__dict__)
1188
1189    def test_class_var_default_factory(self):
1190        # It makes no sense for a ClassVar to have a default factory. When
1191        #  would it be called? Call it yourself, since it's class-wide.
1192        with self.assertRaisesRegex(TypeError,
1193                                    'cannot have a default factory'):
1194            @dataclass
1195            class C:
1196                x: ClassVar[int] = field(default_factory=int)
1197
1198            self.assertNotIn('x', C.__dict__)
1199
1200    def test_class_var_with_default(self):
1201        # If a ClassVar has a default value, it should be set on the class.
1202        @dataclass
1203        class C:
1204            x: ClassVar[int] = 10
1205        self.assertEqual(C.x, 10)
1206
1207        @dataclass
1208        class C:
1209            x: ClassVar[int] = field(default=10)
1210        self.assertEqual(C.x, 10)
1211
1212    def test_class_var_frozen(self):
1213        # Make sure ClassVars work even if we're frozen.
1214        @dataclass(frozen=True)
1215        class C:
1216            x: int
1217            y: int = 10
1218            z: ClassVar[int] = 1000
1219            w: ClassVar[int] = 2000
1220            t: ClassVar[int] = 3000
1221
1222        c = C(5)
1223        self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)')
1224        self.assertEqual(len(fields(C)), 2)                 # We have 2 fields
1225        self.assertEqual(len(C.__annotations__), 5)         # And 3 ClassVars
1226        self.assertEqual(c.z, 1000)
1227        self.assertEqual(c.w, 2000)
1228        self.assertEqual(c.t, 3000)
1229        # We can still modify the ClassVar, it's only instances that are
1230        #  frozen.
1231        C.z += 1
1232        self.assertEqual(c.z, 1001)
1233        c = C(20)
1234        self.assertEqual((c.x, c.y), (20, 10))
1235        self.assertEqual(c.z, 1001)
1236        self.assertEqual(c.w, 2000)
1237        self.assertEqual(c.t, 3000)
1238
1239    def test_init_var_no_default(self):
1240        # If an InitVar has no default value, it should not be set on the class.
1241        @dataclass
1242        class C:
1243            x: InitVar[int]
1244
1245        self.assertNotIn('x', C.__dict__)
1246
1247    def test_init_var_default_factory(self):
1248        # It makes no sense for an InitVar to have a default factory. When
1249        #  would it be called? Call it yourself, since it's class-wide.
1250        with self.assertRaisesRegex(TypeError,
1251                                    'cannot have a default factory'):
1252            @dataclass
1253            class C:
1254                x: InitVar[int] = field(default_factory=int)
1255
1256            self.assertNotIn('x', C.__dict__)
1257
1258    def test_init_var_with_default(self):
1259        # If an InitVar has a default value, it should be set on the class.
1260        @dataclass
1261        class C:
1262            x: InitVar[int] = 10
1263        self.assertEqual(C.x, 10)
1264
1265        @dataclass
1266        class C:
1267            x: InitVar[int] = field(default=10)
1268        self.assertEqual(C.x, 10)
1269
1270    def test_init_var(self):
1271        @dataclass
1272        class C:
1273            x: int = None
1274            init_param: InitVar[int] = None
1275
1276            def __post_init__(self, init_param):
1277                if self.x is None:
1278                    self.x = init_param*2
1279
1280        c = C(init_param=10)
1281        self.assertEqual(c.x, 20)
1282
1283    def test_init_var_preserve_type(self):
1284        self.assertEqual(InitVar[int].type, int)
1285
1286        # Make sure the repr is correct.
1287        self.assertEqual(repr(InitVar[int]), 'dataclasses.InitVar[int]')
1288        self.assertEqual(repr(InitVar[List[int]]),
1289                         'dataclasses.InitVar[typing.List[int]]')
1290        self.assertEqual(repr(InitVar[list[int]]),
1291                         'dataclasses.InitVar[list[int]]')
1292        self.assertEqual(repr(InitVar[int|str]),
1293                         'dataclasses.InitVar[int | str]')
1294
1295    def test_init_var_inheritance(self):
1296        # Note that this deliberately tests that a dataclass need not
1297        #  have a __post_init__ function if it has an InitVar field.
1298        #  It could just be used in a derived class, as shown here.
1299        @dataclass
1300        class Base:
1301            x: int
1302            init_base: InitVar[int]
1303
1304        # We can instantiate by passing the InitVar, even though
1305        #  it's not used.
1306        b = Base(0, 10)
1307        self.assertEqual(vars(b), {'x': 0})
1308
1309        @dataclass
1310        class C(Base):
1311            y: int
1312            init_derived: InitVar[int]
1313
1314            def __post_init__(self, init_base, init_derived):
1315                self.x = self.x + init_base
1316                self.y = self.y + init_derived
1317
1318        c = C(10, 11, 50, 51)
1319        self.assertEqual(vars(c), {'x': 21, 'y': 101})
1320
1321    def test_init_var_name_shadowing(self):
1322        # Because dataclasses rely exclusively on `__annotations__` for
1323        # handling InitVar and `__annotations__` preserves shadowed definitions,
1324        # you can actually shadow an InitVar with a method or property.
1325        #
1326        # This only works when there is no default value; `dataclasses` uses the
1327        # actual name (which will be bound to the shadowing method) for default
1328        # values.
1329        @dataclass
1330        class C:
1331            shadowed: InitVar[int]
1332            _shadowed: int = field(init=False)
1333
1334            def __post_init__(self, shadowed):
1335                self._shadowed = shadowed * 2
1336
1337            @property
1338            def shadowed(self):
1339                return self._shadowed * 3
1340
1341        c = C(5)
1342        self.assertEqual(c.shadowed, 30)
1343
1344    def test_default_factory(self):
1345        # Test a factory that returns a new list.
1346        @dataclass
1347        class C:
1348            x: int
1349            y: list = field(default_factory=list)
1350
1351        c0 = C(3)
1352        c1 = C(3)
1353        self.assertEqual(c0.x, 3)
1354        self.assertEqual(c0.y, [])
1355        self.assertEqual(c0, c1)
1356        self.assertIsNot(c0.y, c1.y)
1357        self.assertEqual(astuple(C(5, [1])), (5, [1]))
1358
1359        # Test a factory that returns a shared list.
1360        l = []
1361        @dataclass
1362        class C:
1363            x: int
1364            y: list = field(default_factory=lambda: l)
1365
1366        c0 = C(3)
1367        c1 = C(3)
1368        self.assertEqual(c0.x, 3)
1369        self.assertEqual(c0.y, [])
1370        self.assertEqual(c0, c1)
1371        self.assertIs(c0.y, c1.y)
1372        self.assertEqual(astuple(C(5, [1])), (5, [1]))
1373
1374        # Test various other field flags.
1375        # repr
1376        @dataclass
1377        class C:
1378            x: list = field(default_factory=list, repr=False)
1379        self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()')
1380        self.assertEqual(C().x, [])
1381
1382        # hash
1383        @dataclass(unsafe_hash=True)
1384        class C:
1385            x: list = field(default_factory=list, hash=False)
1386        self.assertEqual(astuple(C()), ([],))
1387        self.assertEqual(hash(C()), hash(()))
1388
1389        # init (see also test_default_factory_with_no_init)
1390        @dataclass
1391        class C:
1392            x: list = field(default_factory=list, init=False)
1393        self.assertEqual(astuple(C()), ([],))
1394
1395        # compare
1396        @dataclass
1397        class C:
1398            x: list = field(default_factory=list, compare=False)
1399        self.assertEqual(C(), C([1]))
1400
1401    def test_default_factory_with_no_init(self):
1402        # We need a factory with a side effect.
1403        factory = Mock()
1404
1405        @dataclass
1406        class C:
1407            x: list = field(default_factory=factory, init=False)
1408
1409        # Make sure the default factory is called for each new instance.
1410        C().x
1411        self.assertEqual(factory.call_count, 1)
1412        C().x
1413        self.assertEqual(factory.call_count, 2)
1414
1415    def test_default_factory_not_called_if_value_given(self):
1416        # We need a factory that we can test if it's been called.
1417        factory = Mock()
1418
1419        @dataclass
1420        class C:
1421            x: int = field(default_factory=factory)
1422
1423        # Make sure that if a field has a default factory function,
1424        #  it's not called if a value is specified.
1425        C().x
1426        self.assertEqual(factory.call_count, 1)
1427        self.assertEqual(C(10).x, 10)
1428        self.assertEqual(factory.call_count, 1)
1429        C().x
1430        self.assertEqual(factory.call_count, 2)
1431
1432    def test_default_factory_derived(self):
1433        # See bpo-32896.
1434        @dataclass
1435        class Foo:
1436            x: dict = field(default_factory=dict)
1437
1438        @dataclass
1439        class Bar(Foo):
1440            y: int = 1
1441
1442        self.assertEqual(Foo().x, {})
1443        self.assertEqual(Bar().x, {})
1444        self.assertEqual(Bar().y, 1)
1445
1446        @dataclass
1447        class Baz(Foo):
1448            pass
1449        self.assertEqual(Baz().x, {})
1450
1451    def test_intermediate_non_dataclass(self):
1452        # Test that an intermediate class that defines
1453        #  annotations does not define fields.
1454
1455        @dataclass
1456        class A:
1457            x: int
1458
1459        class B(A):
1460            y: int
1461
1462        @dataclass
1463        class C(B):
1464            z: int
1465
1466        c = C(1, 3)
1467        self.assertEqual((c.x, c.z), (1, 3))
1468
1469        # .y was not initialized.
1470        with self.assertRaisesRegex(AttributeError,
1471                                    'object has no attribute'):
1472            c.y
1473
1474        # And if we again derive a non-dataclass, no fields are added.
1475        class D(C):
1476            t: int
1477        d = D(4, 5)
1478        self.assertEqual((d.x, d.z), (4, 5))
1479
1480    def test_classvar_default_factory(self):
1481        # It's an error for a ClassVar to have a factory function.
1482        with self.assertRaisesRegex(TypeError,
1483                                    'cannot have a default factory'):
1484            @dataclass
1485            class C:
1486                x: ClassVar[int] = field(default_factory=int)
1487
1488    def test_is_dataclass(self):
1489        class NotDataClass:
1490            pass
1491
1492        self.assertFalse(is_dataclass(0))
1493        self.assertFalse(is_dataclass(int))
1494        self.assertFalse(is_dataclass(NotDataClass))
1495        self.assertFalse(is_dataclass(NotDataClass()))
1496
1497        @dataclass
1498        class C:
1499            x: int
1500
1501        @dataclass
1502        class D:
1503            d: C
1504            e: int
1505
1506        c = C(10)
1507        d = D(c, 4)
1508
1509        self.assertTrue(is_dataclass(C))
1510        self.assertTrue(is_dataclass(c))
1511        self.assertFalse(is_dataclass(c.x))
1512        self.assertTrue(is_dataclass(d.d))
1513        self.assertFalse(is_dataclass(d.e))
1514
1515    def test_is_dataclass_when_getattr_always_returns(self):
1516        # See bpo-37868.
1517        class A:
1518            def __getattr__(self, key):
1519                return 0
1520        self.assertFalse(is_dataclass(A))
1521        a = A()
1522
1523        # Also test for an instance attribute.
1524        class B:
1525            pass
1526        b = B()
1527        b.__dataclass_fields__ = []
1528
1529        for obj in a, b:
1530            with self.subTest(obj=obj):
1531                self.assertFalse(is_dataclass(obj))
1532
1533                # Indirect tests for _is_dataclass_instance().
1534                with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1535                    asdict(obj)
1536                with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1537                    astuple(obj)
1538                with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1539                    replace(obj, x=0)
1540
1541    def test_is_dataclass_genericalias(self):
1542        @dataclass
1543        class A(types.GenericAlias):
1544            origin: type
1545            args: type
1546        self.assertTrue(is_dataclass(A))
1547        a = A(list, int)
1548        self.assertTrue(is_dataclass(type(a)))
1549        self.assertTrue(is_dataclass(a))
1550
1551    def test_is_dataclass_inheritance(self):
1552        @dataclass
1553        class X:
1554            y: int
1555
1556        class Z(X):
1557            pass
1558
1559        self.assertTrue(is_dataclass(X), "X should be a dataclass")
1560        self.assertTrue(
1561            is_dataclass(Z),
1562            "Z should be a dataclass because it inherits from X",
1563        )
1564        z_instance = Z(y=5)
1565        self.assertTrue(
1566            is_dataclass(z_instance),
1567            "z_instance should be a dataclass because it is an instance of Z",
1568        )
1569
1570    def test_helper_fields_with_class_instance(self):
1571        # Check that we can call fields() on either a class or instance,
1572        #  and get back the same thing.
1573        @dataclass
1574        class C:
1575            x: int
1576            y: float
1577
1578        self.assertEqual(fields(C), fields(C(0, 0.0)))
1579
1580    def test_helper_fields_exception(self):
1581        # Check that TypeError is raised if not passed a dataclass or
1582        #  instance.
1583        with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1584            fields(0)
1585
1586        class C: pass
1587        with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1588            fields(C)
1589        with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1590            fields(C())
1591
1592    def test_clean_traceback_from_fields_exception(self):
1593        stdout = io.StringIO()
1594        try:
1595            fields(object)
1596        except TypeError as exc:
1597            traceback.print_exception(exc, file=stdout)
1598        printed_traceback = stdout.getvalue()
1599        self.assertNotIn("AttributeError", printed_traceback)
1600        self.assertNotIn("__dataclass_fields__", printed_traceback)
1601
1602    def test_helper_asdict(self):
1603        # Basic tests for asdict(), it should return a new dictionary.
1604        @dataclass
1605        class C:
1606            x: int
1607            y: int
1608        c = C(1, 2)
1609
1610        self.assertEqual(asdict(c), {'x': 1, 'y': 2})
1611        self.assertEqual(asdict(c), asdict(c))
1612        self.assertIsNot(asdict(c), asdict(c))
1613        c.x = 42
1614        self.assertEqual(asdict(c), {'x': 42, 'y': 2})
1615        self.assertIs(type(asdict(c)), dict)
1616
1617    def test_helper_asdict_raises_on_classes(self):
1618        # asdict() should raise on a class object.
1619        @dataclass
1620        class C:
1621            x: int
1622            y: int
1623        with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1624            asdict(C)
1625        with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1626            asdict(int)
1627
1628    def test_helper_asdict_copy_values(self):
1629        @dataclass
1630        class C:
1631            x: int
1632            y: List[int] = field(default_factory=list)
1633        initial = []
1634        c = C(1, initial)
1635        d = asdict(c)
1636        self.assertEqual(d['y'], initial)
1637        self.assertIsNot(d['y'], initial)
1638        c = C(1)
1639        d = asdict(c)
1640        d['y'].append(1)
1641        self.assertEqual(c.y, [])
1642
1643    def test_helper_asdict_nested(self):
1644        @dataclass
1645        class UserId:
1646            token: int
1647            group: int
1648        @dataclass
1649        class User:
1650            name: str
1651            id: UserId
1652        u = User('Joe', UserId(123, 1))
1653        d = asdict(u)
1654        self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}})
1655        self.assertIsNot(asdict(u), asdict(u))
1656        u.id.group = 2
1657        self.assertEqual(asdict(u), {'name': 'Joe',
1658                                     'id': {'token': 123, 'group': 2}})
1659
1660    def test_helper_asdict_builtin_containers(self):
1661        @dataclass
1662        class User:
1663            name: str
1664            id: int
1665        @dataclass
1666        class GroupList:
1667            id: int
1668            users: List[User]
1669        @dataclass
1670        class GroupTuple:
1671            id: int
1672            users: Tuple[User, ...]
1673        @dataclass
1674        class GroupDict:
1675            id: int
1676            users: Dict[str, User]
1677        a = User('Alice', 1)
1678        b = User('Bob', 2)
1679        gl = GroupList(0, [a, b])
1680        gt = GroupTuple(0, (a, b))
1681        gd = GroupDict(0, {'first': a, 'second': b})
1682        self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1},
1683                                                         {'name': 'Bob', 'id': 2}]})
1684        self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1},
1685                                                         {'name': 'Bob', 'id': 2})})
1686        self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1},
1687                                                         'second': {'name': 'Bob', 'id': 2}}})
1688
1689    def test_helper_asdict_builtin_object_containers(self):
1690        @dataclass
1691        class Child:
1692            d: object
1693
1694        @dataclass
1695        class Parent:
1696            child: Child
1697
1698        self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}})
1699        self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}})
1700
1701    def test_helper_asdict_factory(self):
1702        @dataclass
1703        class C:
1704            x: int
1705            y: int
1706        c = C(1, 2)
1707        d = asdict(c, dict_factory=OrderedDict)
1708        self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)]))
1709        self.assertIsNot(d, asdict(c, dict_factory=OrderedDict))
1710        c.x = 42
1711        d = asdict(c, dict_factory=OrderedDict)
1712        self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)]))
1713        self.assertIs(type(d), OrderedDict)
1714
1715    def test_helper_asdict_namedtuple(self):
1716        T = namedtuple('T', 'a b c')
1717        @dataclass
1718        class C:
1719            x: str
1720            y: T
1721        c = C('outer', T(1, C('inner', T(11, 12, 13)), 2))
1722
1723        d = asdict(c)
1724        self.assertEqual(d, {'x': 'outer',
1725                             'y': T(1,
1726                                    {'x': 'inner',
1727                                     'y': T(11, 12, 13)},
1728                                    2),
1729                             }
1730                         )
1731
1732        # Now with a dict_factory.  OrderedDict is convenient, but
1733        # since it compares to dicts, we also need to have separate
1734        # assertIs tests.
1735        d = asdict(c, dict_factory=OrderedDict)
1736        self.assertEqual(d, {'x': 'outer',
1737                             'y': T(1,
1738                                    {'x': 'inner',
1739                                     'y': T(11, 12, 13)},
1740                                    2),
1741                             }
1742                         )
1743
1744        # Make sure that the returned dicts are actually OrderedDicts.
1745        self.assertIs(type(d), OrderedDict)
1746        self.assertIs(type(d['y'][1]), OrderedDict)
1747
1748    def test_helper_asdict_namedtuple_key(self):
1749        # Ensure that a field that contains a dict which has a
1750        # namedtuple as a key works with asdict().
1751
1752        @dataclass
1753        class C:
1754            f: dict
1755        T = namedtuple('T', 'a')
1756
1757        c = C({T('an a'): 0})
1758
1759        self.assertEqual(asdict(c), {'f': {T(a='an a'): 0}})
1760
1761    def test_helper_asdict_namedtuple_derived(self):
1762        class T(namedtuple('Tbase', 'a')):
1763            def my_a(self):
1764                return self.a
1765
1766        @dataclass
1767        class C:
1768            f: T
1769
1770        t = T(6)
1771        c = C(t)
1772
1773        d = asdict(c)
1774        self.assertEqual(d, {'f': T(a=6)})
1775        # Make sure that t has been copied, not used directly.
1776        self.assertIsNot(d['f'], t)
1777        self.assertEqual(d['f'].my_a(), 6)
1778
1779    def test_helper_asdict_defaultdict(self):
1780        # Ensure asdict() does not throw exceptions when a
1781        # defaultdict is a member of a dataclass
1782        @dataclass
1783        class C:
1784            mp: DefaultDict[str, List]
1785
1786        dd = defaultdict(list)
1787        dd["x"].append(12)
1788        c = C(mp=dd)
1789        d = asdict(c)
1790
1791        self.assertEqual(d, {"mp": {"x": [12]}})
1792        self.assertTrue(d["mp"] is not c.mp)  # make sure defaultdict is copied
1793
1794    def test_helper_astuple(self):
1795        # Basic tests for astuple(), it should return a new tuple.
1796        @dataclass
1797        class C:
1798            x: int
1799            y: int = 0
1800        c = C(1)
1801
1802        self.assertEqual(astuple(c), (1, 0))
1803        self.assertEqual(astuple(c), astuple(c))
1804        self.assertIsNot(astuple(c), astuple(c))
1805        c.y = 42
1806        self.assertEqual(astuple(c), (1, 42))
1807        self.assertIs(type(astuple(c)), tuple)
1808
1809    def test_helper_astuple_raises_on_classes(self):
1810        # astuple() should raise on a class object.
1811        @dataclass
1812        class C:
1813            x: int
1814            y: int
1815        with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1816            astuple(C)
1817        with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1818            astuple(int)
1819
1820    def test_helper_astuple_copy_values(self):
1821        @dataclass
1822        class C:
1823            x: int
1824            y: List[int] = field(default_factory=list)
1825        initial = []
1826        c = C(1, initial)
1827        t = astuple(c)
1828        self.assertEqual(t[1], initial)
1829        self.assertIsNot(t[1], initial)
1830        c = C(1)
1831        t = astuple(c)
1832        t[1].append(1)
1833        self.assertEqual(c.y, [])
1834
1835    def test_helper_astuple_nested(self):
1836        @dataclass
1837        class UserId:
1838            token: int
1839            group: int
1840        @dataclass
1841        class User:
1842            name: str
1843            id: UserId
1844        u = User('Joe', UserId(123, 1))
1845        t = astuple(u)
1846        self.assertEqual(t, ('Joe', (123, 1)))
1847        self.assertIsNot(astuple(u), astuple(u))
1848        u.id.group = 2
1849        self.assertEqual(astuple(u), ('Joe', (123, 2)))
1850
1851    def test_helper_astuple_builtin_containers(self):
1852        @dataclass
1853        class User:
1854            name: str
1855            id: int
1856        @dataclass
1857        class GroupList:
1858            id: int
1859            users: List[User]
1860        @dataclass
1861        class GroupTuple:
1862            id: int
1863            users: Tuple[User, ...]
1864        @dataclass
1865        class GroupDict:
1866            id: int
1867            users: Dict[str, User]
1868        a = User('Alice', 1)
1869        b = User('Bob', 2)
1870        gl = GroupList(0, [a, b])
1871        gt = GroupTuple(0, (a, b))
1872        gd = GroupDict(0, {'first': a, 'second': b})
1873        self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)]))
1874        self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2))))
1875        self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)}))
1876
1877    def test_helper_astuple_builtin_object_containers(self):
1878        @dataclass
1879        class Child:
1880            d: object
1881
1882        @dataclass
1883        class Parent:
1884            child: Child
1885
1886        self.assertEqual(astuple(Parent(Child([1]))), (([1],),))
1887        self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),))
1888
1889    def test_helper_astuple_factory(self):
1890        @dataclass
1891        class C:
1892            x: int
1893            y: int
1894        NT = namedtuple('NT', 'x y')
1895        def nt(lst):
1896            return NT(*lst)
1897        c = C(1, 2)
1898        t = astuple(c, tuple_factory=nt)
1899        self.assertEqual(t, NT(1, 2))
1900        self.assertIsNot(t, astuple(c, tuple_factory=nt))
1901        c.x = 42
1902        t = astuple(c, tuple_factory=nt)
1903        self.assertEqual(t, NT(42, 2))
1904        self.assertIs(type(t), NT)
1905
1906    def test_helper_astuple_namedtuple(self):
1907        T = namedtuple('T', 'a b c')
1908        @dataclass
1909        class C:
1910            x: str
1911            y: T
1912        c = C('outer', T(1, C('inner', T(11, 12, 13)), 2))
1913
1914        t = astuple(c)
1915        self.assertEqual(t, ('outer', T(1, ('inner', (11, 12, 13)), 2)))
1916
1917        # Now, using a tuple_factory.  list is convenient here.
1918        t = astuple(c, tuple_factory=list)
1919        self.assertEqual(t, ['outer', T(1, ['inner', T(11, 12, 13)], 2)])
1920
1921    def test_helper_astuple_defaultdict(self):
1922        # Ensure astuple() does not throw exceptions when a
1923        # defaultdict is a member of a dataclass
1924        @dataclass
1925        class C:
1926            mp: DefaultDict[str, List]
1927
1928        dd = defaultdict(list)
1929        dd["x"].append(12)
1930        c = C(mp=dd)
1931        t = astuple(c)
1932
1933        self.assertEqual(t, ({"x": [12]},))
1934        self.assertTrue(t[0] is not dd) # make sure defaultdict is copied
1935
1936    def test_dynamic_class_creation(self):
1937        cls_dict = {'__annotations__': {'x': int, 'y': int},
1938                    }
1939
1940        # Create the class.
1941        cls = type('C', (), cls_dict)
1942
1943        # Make it a dataclass.
1944        cls1 = dataclass(cls)
1945
1946        self.assertEqual(cls1, cls)
1947        self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2})
1948
1949    def test_dynamic_class_creation_using_field(self):
1950        cls_dict = {'__annotations__': {'x': int, 'y': int},
1951                    'y': field(default=5),
1952                    }
1953
1954        # Create the class.
1955        cls = type('C', (), cls_dict)
1956
1957        # Make it a dataclass.
1958        cls1 = dataclass(cls)
1959
1960        self.assertEqual(cls1, cls)
1961        self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5})
1962
1963    def test_init_in_order(self):
1964        @dataclass
1965        class C:
1966            a: int
1967            b: int = field()
1968            c: list = field(default_factory=list, init=False)
1969            d: list = field(default_factory=list)
1970            e: int = field(default=4, init=False)
1971            f: int = 4
1972
1973        calls = []
1974        def setattr(self, name, value):
1975            calls.append((name, value))
1976
1977        C.__setattr__ = setattr
1978        c = C(0, 1)
1979        self.assertEqual(('a', 0), calls[0])
1980        self.assertEqual(('b', 1), calls[1])
1981        self.assertEqual(('c', []), calls[2])
1982        self.assertEqual(('d', []), calls[3])
1983        self.assertNotIn(('e', 4), calls)
1984        self.assertEqual(('f', 4), calls[4])
1985
1986    def test_items_in_dicts(self):
1987        @dataclass
1988        class C:
1989            a: int
1990            b: list = field(default_factory=list, init=False)
1991            c: list = field(default_factory=list)
1992            d: int = field(default=4, init=False)
1993            e: int = 0
1994
1995        c = C(0)
1996        # Class dict
1997        self.assertNotIn('a', C.__dict__)
1998        self.assertNotIn('b', C.__dict__)
1999        self.assertNotIn('c', C.__dict__)
2000        self.assertIn('d', C.__dict__)
2001        self.assertEqual(C.d, 4)
2002        self.assertIn('e', C.__dict__)
2003        self.assertEqual(C.e, 0)
2004        # Instance dict
2005        self.assertIn('a', c.__dict__)
2006        self.assertEqual(c.a, 0)
2007        self.assertIn('b', c.__dict__)
2008        self.assertEqual(c.b, [])
2009        self.assertIn('c', c.__dict__)
2010        self.assertEqual(c.c, [])
2011        self.assertNotIn('d', c.__dict__)
2012        self.assertIn('e', c.__dict__)
2013        self.assertEqual(c.e, 0)
2014
2015    def test_alternate_classmethod_constructor(self):
2016        # Since __post_init__ can't take params, use a classmethod
2017        #  alternate constructor.  This is mostly an example to show
2018        #  how to use this technique.
2019        @dataclass
2020        class C:
2021            x: int
2022            @classmethod
2023            def from_file(cls, filename):
2024                # In a real example, create a new instance
2025                #  and populate 'x' from contents of a file.
2026                value_in_file = 20
2027                return cls(value_in_file)
2028
2029        self.assertEqual(C.from_file('filename').x, 20)
2030
2031    def test_field_metadata_default(self):
2032        # Make sure the default metadata is read-only and of
2033        #  zero length.
2034        @dataclass
2035        class C:
2036            i: int
2037
2038        self.assertFalse(fields(C)[0].metadata)
2039        self.assertEqual(len(fields(C)[0].metadata), 0)
2040        with self.assertRaisesRegex(TypeError,
2041                                    'does not support item assignment'):
2042            fields(C)[0].metadata['test'] = 3
2043
2044    def test_field_metadata_mapping(self):
2045        # Make sure only a mapping can be passed as metadata
2046        #  zero length.
2047        with self.assertRaises(TypeError):
2048            @dataclass
2049            class C:
2050                i: int = field(metadata=0)
2051
2052        # Make sure an empty dict works.
2053        d = {}
2054        @dataclass
2055        class C:
2056            i: int = field(metadata=d)
2057        self.assertFalse(fields(C)[0].metadata)
2058        self.assertEqual(len(fields(C)[0].metadata), 0)
2059        # Update should work (see bpo-35960).
2060        d['foo'] = 1
2061        self.assertEqual(len(fields(C)[0].metadata), 1)
2062        self.assertEqual(fields(C)[0].metadata['foo'], 1)
2063        with self.assertRaisesRegex(TypeError,
2064                                    'does not support item assignment'):
2065            fields(C)[0].metadata['test'] = 3
2066
2067        # Make sure a non-empty dict works.
2068        d = {'test': 10, 'bar': '42', 3: 'three'}
2069        @dataclass
2070        class C:
2071            i: int = field(metadata=d)
2072        self.assertEqual(len(fields(C)[0].metadata), 3)
2073        self.assertEqual(fields(C)[0].metadata['test'], 10)
2074        self.assertEqual(fields(C)[0].metadata['bar'], '42')
2075        self.assertEqual(fields(C)[0].metadata[3], 'three')
2076        # Update should work.
2077        d['foo'] = 1
2078        self.assertEqual(len(fields(C)[0].metadata), 4)
2079        self.assertEqual(fields(C)[0].metadata['foo'], 1)
2080        with self.assertRaises(KeyError):
2081            # Non-existent key.
2082            fields(C)[0].metadata['baz']
2083        with self.assertRaisesRegex(TypeError,
2084                                    'does not support item assignment'):
2085            fields(C)[0].metadata['test'] = 3
2086
2087    def test_field_metadata_custom_mapping(self):
2088        # Try a custom mapping.
2089        class SimpleNameSpace:
2090            def __init__(self, **kw):
2091                self.__dict__.update(kw)
2092
2093            def __getitem__(self, item):
2094                if item == 'xyzzy':
2095                    return 'plugh'
2096                return getattr(self, item)
2097
2098            def __len__(self):
2099                return self.__dict__.__len__()
2100
2101        @dataclass
2102        class C:
2103            i: int = field(metadata=SimpleNameSpace(a=10))
2104
2105        self.assertEqual(len(fields(C)[0].metadata), 1)
2106        self.assertEqual(fields(C)[0].metadata['a'], 10)
2107        with self.assertRaises(AttributeError):
2108            fields(C)[0].metadata['b']
2109        # Make sure we're still talking to our custom mapping.
2110        self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh')
2111
2112    def test_generic_dataclasses(self):
2113        T = TypeVar('T')
2114
2115        @dataclass
2116        class LabeledBox(Generic[T]):
2117            content: T
2118            label: str = '<unknown>'
2119
2120        box = LabeledBox(42)
2121        self.assertEqual(box.content, 42)
2122        self.assertEqual(box.label, '<unknown>')
2123
2124        # Subscripting the resulting class should work, etc.
2125        Alias = List[LabeledBox[int]]
2126
2127    def test_generic_extending(self):
2128        S = TypeVar('S')
2129        T = TypeVar('T')
2130
2131        @dataclass
2132        class Base(Generic[T, S]):
2133            x: T
2134            y: S
2135
2136        @dataclass
2137        class DataDerived(Base[int, T]):
2138            new_field: str
2139        Alias = DataDerived[str]
2140        c = Alias(0, 'test1', 'test2')
2141        self.assertEqual(astuple(c), (0, 'test1', 'test2'))
2142
2143        class NonDataDerived(Base[int, T]):
2144            def new_method(self):
2145                return self.y
2146        Alias = NonDataDerived[float]
2147        c = Alias(10, 1.0)
2148        self.assertEqual(c.new_method(), 1.0)
2149
2150    def test_generic_dynamic(self):
2151        T = TypeVar('T')
2152
2153        @dataclass
2154        class Parent(Generic[T]):
2155            x: T
2156        Child = make_dataclass('Child', [('y', T), ('z', Optional[T], None)],
2157                               bases=(Parent[int], Generic[T]), namespace={'other': 42})
2158        self.assertIs(Child[int](1, 2).z, None)
2159        self.assertEqual(Child[int](1, 2, 3).z, 3)
2160        self.assertEqual(Child[int](1, 2, 3).other, 42)
2161        # Check that type aliases work correctly.
2162        Alias = Child[T]
2163        self.assertEqual(Alias[int](1, 2).x, 1)
2164        # Check MRO resolution.
2165        self.assertEqual(Child.__mro__, (Child, Parent, Generic, object))
2166
2167    def test_dataclasses_pickleable(self):
2168        global P, Q, R
2169        @dataclass
2170        class P:
2171            x: int
2172            y: int = 0
2173        @dataclass
2174        class Q:
2175            x: int
2176            y: int = field(default=0, init=False)
2177        @dataclass
2178        class R:
2179            x: int
2180            y: List[int] = field(default_factory=list)
2181        q = Q(1)
2182        q.y = 2
2183        samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])]
2184        for sample in samples:
2185            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
2186                with self.subTest(sample=sample, proto=proto):
2187                    new_sample = pickle.loads(pickle.dumps(sample, proto))
2188                    self.assertEqual(sample.x, new_sample.x)
2189                    self.assertEqual(sample.y, new_sample.y)
2190                    self.assertIsNot(sample, new_sample)
2191                    new_sample.x = 42
2192                    another_new_sample = pickle.loads(pickle.dumps(new_sample, proto))
2193                    self.assertEqual(new_sample.x, another_new_sample.x)
2194                    self.assertEqual(sample.y, another_new_sample.y)
2195
2196    def test_dataclasses_qualnames(self):
2197        @dataclass(order=True, unsafe_hash=True, frozen=True)
2198        class A:
2199            x: int
2200            y: int
2201
2202        self.assertEqual(A.__init__.__name__, "__init__")
2203        for function in (
2204            '__eq__',
2205            '__lt__',
2206            '__le__',
2207            '__gt__',
2208            '__ge__',
2209            '__hash__',
2210            '__init__',
2211            '__repr__',
2212            '__setattr__',
2213            '__delattr__',
2214        ):
2215            self.assertEqual(getattr(A, function).__qualname__, f"TestCase.test_dataclasses_qualnames.<locals>.A.{function}")
2216
2217        with self.assertRaisesRegex(TypeError, r"A\.__init__\(\) missing"):
2218            A()
2219
2220
2221class TestFieldNoAnnotation(unittest.TestCase):
2222    def test_field_without_annotation(self):
2223        with self.assertRaisesRegex(TypeError,
2224                                    "'f' is a field but has no type annotation"):
2225            @dataclass
2226            class C:
2227                f = field()
2228
2229    def test_field_without_annotation_but_annotation_in_base(self):
2230        @dataclass
2231        class B:
2232            f: int
2233
2234        with self.assertRaisesRegex(TypeError,
2235                                    "'f' is a field but has no type annotation"):
2236            # This is still an error: make sure we don't pick up the
2237            #  type annotation in the base class.
2238            @dataclass
2239            class C(B):
2240                f = field()
2241
2242    def test_field_without_annotation_but_annotation_in_base_not_dataclass(self):
2243        # Same test, but with the base class not a dataclass.
2244        class B:
2245            f: int
2246
2247        with self.assertRaisesRegex(TypeError,
2248                                    "'f' is a field but has no type annotation"):
2249            # This is still an error: make sure we don't pick up the
2250            #  type annotation in the base class.
2251            @dataclass
2252            class C(B):
2253                f = field()
2254
2255
2256class TestDocString(unittest.TestCase):
2257    def assertDocStrEqual(self, a, b):
2258        # Because 3.6 and 3.7 differ in how inspect.signature work
2259        #  (see bpo #32108), for the time being just compare them with
2260        #  whitespace stripped.
2261        self.assertEqual(a.replace(' ', ''), b.replace(' ', ''))
2262
2263    @support.requires_docstrings
2264    def test_existing_docstring_not_overridden(self):
2265        @dataclass
2266        class C:
2267            """Lorem ipsum"""
2268            x: int
2269
2270        self.assertEqual(C.__doc__, "Lorem ipsum")
2271
2272    def test_docstring_no_fields(self):
2273        @dataclass
2274        class C:
2275            pass
2276
2277        self.assertDocStrEqual(C.__doc__, "C()")
2278
2279    def test_docstring_one_field(self):
2280        @dataclass
2281        class C:
2282            x: int
2283
2284        self.assertDocStrEqual(C.__doc__, "C(x:int)")
2285
2286    def test_docstring_two_fields(self):
2287        @dataclass
2288        class C:
2289            x: int
2290            y: int
2291
2292        self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)")
2293
2294    def test_docstring_three_fields(self):
2295        @dataclass
2296        class C:
2297            x: int
2298            y: int
2299            z: str
2300
2301        self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)")
2302
2303    def test_docstring_one_field_with_default(self):
2304        @dataclass
2305        class C:
2306            x: int = 3
2307
2308        self.assertDocStrEqual(C.__doc__, "C(x:int=3)")
2309
2310    def test_docstring_one_field_with_default_none(self):
2311        @dataclass
2312        class C:
2313            x: Union[int, type(None)] = None
2314
2315        self.assertDocStrEqual(C.__doc__, "C(x:Optional[int]=None)")
2316
2317    def test_docstring_list_field(self):
2318        @dataclass
2319        class C:
2320            x: List[int]
2321
2322        self.assertDocStrEqual(C.__doc__, "C(x:List[int])")
2323
2324    def test_docstring_list_field_with_default_factory(self):
2325        @dataclass
2326        class C:
2327            x: List[int] = field(default_factory=list)
2328
2329        self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)")
2330
2331    def test_docstring_deque_field(self):
2332        @dataclass
2333        class C:
2334            x: deque
2335
2336        self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)")
2337
2338    def test_docstring_deque_field_with_default_factory(self):
2339        @dataclass
2340        class C:
2341            x: deque = field(default_factory=deque)
2342
2343        self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
2344
2345    def test_docstring_with_no_signature(self):
2346        # See https://github.com/python/cpython/issues/103449
2347        class Meta(type):
2348            __call__ = dict
2349        class Base(metaclass=Meta):
2350            pass
2351
2352        @dataclass
2353        class C(Base):
2354            pass
2355
2356        self.assertDocStrEqual(C.__doc__, "C")
2357
2358
2359class TestInit(unittest.TestCase):
2360    def test_base_has_init(self):
2361        class B:
2362            def __init__(self):
2363                self.z = 100
2364
2365        # Make sure that declaring this class doesn't raise an error.
2366        #  The issue is that we can't override __init__ in our class,
2367        #  but it should be okay to add __init__ to us if our base has
2368        #  an __init__.
2369        @dataclass
2370        class C(B):
2371            x: int = 0
2372        c = C(10)
2373        self.assertEqual(c.x, 10)
2374        self.assertNotIn('z', vars(c))
2375
2376        # Make sure that if we don't add an init, the base __init__
2377        #  gets called.
2378        @dataclass(init=False)
2379        class C(B):
2380            x: int = 10
2381        c = C()
2382        self.assertEqual(c.x, 10)
2383        self.assertEqual(c.z, 100)
2384
2385    def test_no_init(self):
2386        @dataclass(init=False)
2387        class C:
2388            i: int = 0
2389        self.assertEqual(C().i, 0)
2390
2391        @dataclass(init=False)
2392        class C:
2393            i: int = 2
2394            def __init__(self):
2395                self.i = 3
2396        self.assertEqual(C().i, 3)
2397
2398    def test_overwriting_init(self):
2399        # If the class has __init__, use it no matter the value of
2400        #  init=.
2401
2402        @dataclass
2403        class C:
2404            x: int
2405            def __init__(self, x):
2406                self.x = 2 * x
2407        self.assertEqual(C(3).x, 6)
2408
2409        @dataclass(init=True)
2410        class C:
2411            x: int
2412            def __init__(self, x):
2413                self.x = 2 * x
2414        self.assertEqual(C(4).x, 8)
2415
2416        @dataclass(init=False)
2417        class C:
2418            x: int
2419            def __init__(self, x):
2420                self.x = 2 * x
2421        self.assertEqual(C(5).x, 10)
2422
2423    def test_inherit_from_protocol(self):
2424        # Dataclasses inheriting from protocol should preserve their own `__init__`.
2425        # See bpo-45081.
2426
2427        class P(Protocol):
2428            a: int
2429
2430        @dataclass
2431        class C(P):
2432            a: int
2433
2434        self.assertEqual(C(5).a, 5)
2435
2436        @dataclass
2437        class D(P):
2438            def __init__(self, a):
2439                self.a = a * 2
2440
2441        self.assertEqual(D(5).a, 10)
2442
2443
2444class TestRepr(unittest.TestCase):
2445    def test_repr(self):
2446        @dataclass
2447        class B:
2448            x: int
2449
2450        @dataclass
2451        class C(B):
2452            y: int = 10
2453
2454        o = C(4)
2455        self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)')
2456
2457        @dataclass
2458        class D(C):
2459            x: int = 20
2460        self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)')
2461
2462        @dataclass
2463        class C:
2464            @dataclass
2465            class D:
2466                i: int
2467            @dataclass
2468            class E:
2469                pass
2470        self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)')
2471        self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()')
2472
2473    def test_no_repr(self):
2474        # Test a class with no __repr__ and repr=False.
2475        @dataclass(repr=False)
2476        class C:
2477            x: int
2478        self.assertIn(f'{__name__}.TestRepr.test_no_repr.<locals>.C object at',
2479                      repr(C(3)))
2480
2481        # Test a class with a __repr__ and repr=False.
2482        @dataclass(repr=False)
2483        class C:
2484            x: int
2485            def __repr__(self):
2486                return 'C-class'
2487        self.assertEqual(repr(C(3)), 'C-class')
2488
2489    def test_overwriting_repr(self):
2490        # If the class has __repr__, use it no matter the value of
2491        #  repr=.
2492
2493        @dataclass
2494        class C:
2495            x: int
2496            def __repr__(self):
2497                return 'x'
2498        self.assertEqual(repr(C(0)), 'x')
2499
2500        @dataclass(repr=True)
2501        class C:
2502            x: int
2503            def __repr__(self):
2504                return 'x'
2505        self.assertEqual(repr(C(0)), 'x')
2506
2507        @dataclass(repr=False)
2508        class C:
2509            x: int
2510            def __repr__(self):
2511                return 'x'
2512        self.assertEqual(repr(C(0)), 'x')
2513
2514
2515class TestEq(unittest.TestCase):
2516    def test_recursive_eq(self):
2517        # Test a class with recursive child
2518        @dataclass
2519        class C:
2520            recursive: object = ...
2521        c = C()
2522        c.recursive = c
2523        self.assertEqual(c, c)
2524
2525    def test_no_eq(self):
2526        # Test a class with no __eq__ and eq=False.
2527        @dataclass(eq=False)
2528        class C:
2529            x: int
2530        self.assertNotEqual(C(0), C(0))
2531        c = C(3)
2532        self.assertEqual(c, c)
2533
2534        # Test a class with an __eq__ and eq=False.
2535        @dataclass(eq=False)
2536        class C:
2537            x: int
2538            def __eq__(self, other):
2539                return other == 10
2540        self.assertEqual(C(3), 10)
2541
2542    def test_overwriting_eq(self):
2543        # If the class has __eq__, use it no matter the value of
2544        #  eq=.
2545
2546        @dataclass
2547        class C:
2548            x: int
2549            def __eq__(self, other):
2550                return other == 3
2551        self.assertEqual(C(1), 3)
2552        self.assertNotEqual(C(1), 1)
2553
2554        @dataclass(eq=True)
2555        class C:
2556            x: int
2557            def __eq__(self, other):
2558                return other == 4
2559        self.assertEqual(C(1), 4)
2560        self.assertNotEqual(C(1), 1)
2561
2562        @dataclass(eq=False)
2563        class C:
2564            x: int
2565            def __eq__(self, other):
2566                return other == 5
2567        self.assertEqual(C(1), 5)
2568        self.assertNotEqual(C(1), 1)
2569
2570
2571class TestOrdering(unittest.TestCase):
2572    def test_functools_total_ordering(self):
2573        # Test that functools.total_ordering works with this class.
2574        @total_ordering
2575        @dataclass
2576        class C:
2577            x: int
2578            def __lt__(self, other):
2579                # Perform the test "backward", just to make
2580                #  sure this is being called.
2581                return self.x >= other
2582
2583        self.assertLess(C(0), -1)
2584        self.assertLessEqual(C(0), -1)
2585        self.assertGreater(C(0), 1)
2586        self.assertGreaterEqual(C(0), 1)
2587
2588    def test_no_order(self):
2589        # Test that no ordering functions are added by default.
2590        @dataclass(order=False)
2591        class C:
2592            x: int
2593        # Make sure no order methods are added.
2594        self.assertNotIn('__le__', C.__dict__)
2595        self.assertNotIn('__lt__', C.__dict__)
2596        self.assertNotIn('__ge__', C.__dict__)
2597        self.assertNotIn('__gt__', C.__dict__)
2598
2599        # Test that __lt__ is still called
2600        @dataclass(order=False)
2601        class C:
2602            x: int
2603            def __lt__(self, other):
2604                return False
2605        # Make sure other methods aren't added.
2606        self.assertNotIn('__le__', C.__dict__)
2607        self.assertNotIn('__ge__', C.__dict__)
2608        self.assertNotIn('__gt__', C.__dict__)
2609
2610    def test_overwriting_order(self):
2611        with self.assertRaisesRegex(TypeError,
2612                                    'Cannot overwrite attribute __lt__'
2613                                    '.*using functools.total_ordering'):
2614            @dataclass(order=True)
2615            class C:
2616                x: int
2617                def __lt__(self):
2618                    pass
2619
2620        with self.assertRaisesRegex(TypeError,
2621                                    'Cannot overwrite attribute __le__'
2622                                    '.*using functools.total_ordering'):
2623            @dataclass(order=True)
2624            class C:
2625                x: int
2626                def __le__(self):
2627                    pass
2628
2629        with self.assertRaisesRegex(TypeError,
2630                                    'Cannot overwrite attribute __gt__'
2631                                    '.*using functools.total_ordering'):
2632            @dataclass(order=True)
2633            class C:
2634                x: int
2635                def __gt__(self):
2636                    pass
2637
2638        with self.assertRaisesRegex(TypeError,
2639                                    'Cannot overwrite attribute __ge__'
2640                                    '.*using functools.total_ordering'):
2641            @dataclass(order=True)
2642            class C:
2643                x: int
2644                def __ge__(self):
2645                    pass
2646
2647class TestHash(unittest.TestCase):
2648    def test_unsafe_hash(self):
2649        @dataclass(unsafe_hash=True)
2650        class C:
2651            x: int
2652            y: str
2653        self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
2654
2655    def test_hash_rules(self):
2656        def non_bool(value):
2657            # Map to something else that's True, but not a bool.
2658            if value is None:
2659                return None
2660            if value:
2661                return (3,)
2662            return 0
2663
2664        def test(case, unsafe_hash, eq, frozen, with_hash, result):
2665            with self.subTest(case=case, unsafe_hash=unsafe_hash, eq=eq,
2666                              frozen=frozen):
2667                if result != 'exception':
2668                    if with_hash:
2669                        @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2670                        class C:
2671                            def __hash__(self):
2672                                return 0
2673                    else:
2674                        @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2675                        class C:
2676                            pass
2677
2678                # See if the result matches what's expected.
2679                if result == 'fn':
2680                    # __hash__ contains the function we generated.
2681                    self.assertIn('__hash__', C.__dict__)
2682                    self.assertIsNotNone(C.__dict__['__hash__'])
2683
2684                elif result == '':
2685                    # __hash__ is not present in our class.
2686                    if not with_hash:
2687                        self.assertNotIn('__hash__', C.__dict__)
2688
2689                elif result == 'none':
2690                    # __hash__ is set to None.
2691                    self.assertIn('__hash__', C.__dict__)
2692                    self.assertIsNone(C.__dict__['__hash__'])
2693
2694                elif result == 'exception':
2695                    # Creating the class should cause an exception.
2696                    #  This only happens with with_hash==True.
2697                    assert(with_hash)
2698                    with self.assertRaisesRegex(TypeError, 'Cannot overwrite attribute __hash__'):
2699                        @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2700                        class C:
2701                            def __hash__(self):
2702                                return 0
2703
2704                else:
2705                    assert False, f'unknown result {result!r}'
2706
2707        # There are 8 cases of:
2708        #  unsafe_hash=True/False
2709        #  eq=True/False
2710        #  frozen=True/False
2711        # And for each of these, a different result if
2712        #  __hash__ is defined or not.
2713        for case, (unsafe_hash,  eq,    frozen, res_no_defined_hash, res_defined_hash) in enumerate([
2714                  (False,        False, False,  '',                  ''),
2715                  (False,        False, True,   '',                  ''),
2716                  (False,        True,  False,  'none',              ''),
2717                  (False,        True,  True,   'fn',                ''),
2718                  (True,         False, False,  'fn',                'exception'),
2719                  (True,         False, True,   'fn',                'exception'),
2720                  (True,         True,  False,  'fn',                'exception'),
2721                  (True,         True,  True,   'fn',                'exception'),
2722                  ], 1):
2723            test(case, unsafe_hash, eq, frozen, False, res_no_defined_hash)
2724            test(case, unsafe_hash, eq, frozen, True,  res_defined_hash)
2725
2726            # Test non-bool truth values, too.  This is just to
2727            #  make sure the data-driven table in the decorator
2728            #  handles non-bool values.
2729            test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), False, res_no_defined_hash)
2730            test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), True,  res_defined_hash)
2731
2732
2733    def test_eq_only(self):
2734        # If a class defines __eq__, __hash__ is automatically added
2735        #  and set to None.  This is normal Python behavior, not
2736        #  related to dataclasses.  Make sure we don't interfere with
2737        #  that (see bpo=32546).
2738
2739        @dataclass
2740        class C:
2741            i: int
2742            def __eq__(self, other):
2743                return self.i == other.i
2744        self.assertEqual(C(1), C(1))
2745        self.assertNotEqual(C(1), C(4))
2746
2747        # And make sure things work in this case if we specify
2748        #  unsafe_hash=True.
2749        @dataclass(unsafe_hash=True)
2750        class C:
2751            i: int
2752            def __eq__(self, other):
2753                return self.i == other.i
2754        self.assertEqual(C(1), C(1.0))
2755        self.assertEqual(hash(C(1)), hash(C(1.0)))
2756
2757        # And check that the classes __eq__ is being used, despite
2758        #  specifying eq=True.
2759        @dataclass(unsafe_hash=True, eq=True)
2760        class C:
2761            i: int
2762            def __eq__(self, other):
2763                return self.i == 3 and self.i == other.i
2764        self.assertEqual(C(3), C(3))
2765        self.assertNotEqual(C(1), C(1))
2766        self.assertEqual(hash(C(1)), hash(C(1.0)))
2767
2768    def test_0_field_hash(self):
2769        @dataclass(frozen=True)
2770        class C:
2771            pass
2772        self.assertEqual(hash(C()), hash(()))
2773
2774        @dataclass(unsafe_hash=True)
2775        class C:
2776            pass
2777        self.assertEqual(hash(C()), hash(()))
2778
2779    def test_1_field_hash(self):
2780        @dataclass(frozen=True)
2781        class C:
2782            x: int
2783        self.assertEqual(hash(C(4)), hash((4,)))
2784        self.assertEqual(hash(C(42)), hash((42,)))
2785
2786        @dataclass(unsafe_hash=True)
2787        class C:
2788            x: int
2789        self.assertEqual(hash(C(4)), hash((4,)))
2790        self.assertEqual(hash(C(42)), hash((42,)))
2791
2792    def test_hash_no_args(self):
2793        # Test dataclasses with no hash= argument.  This exists to
2794        #  make sure that if the @dataclass parameter name is changed
2795        #  or the non-default hashing behavior changes, the default
2796        #  hashability keeps working the same way.
2797
2798        class Base:
2799            def __hash__(self):
2800                return 301
2801
2802        # If frozen or eq is None, then use the default value (do not
2803        #  specify any value in the decorator).
2804        for frozen, eq,    base,   expected       in [
2805            (None,  None,  object, 'unhashable'),
2806            (None,  None,  Base,   'unhashable'),
2807            (None,  False, object, 'object'),
2808            (None,  False, Base,   'base'),
2809            (None,  True,  object, 'unhashable'),
2810            (None,  True,  Base,   'unhashable'),
2811            (False, None,  object, 'unhashable'),
2812            (False, None,  Base,   'unhashable'),
2813            (False, False, object, 'object'),
2814            (False, False, Base,   'base'),
2815            (False, True,  object, 'unhashable'),
2816            (False, True,  Base,   'unhashable'),
2817            (True,  None,  object, 'tuple'),
2818            (True,  None,  Base,   'tuple'),
2819            (True,  False, object, 'object'),
2820            (True,  False, Base,   'base'),
2821            (True,  True,  object, 'tuple'),
2822            (True,  True,  Base,   'tuple'),
2823            ]:
2824
2825            with self.subTest(frozen=frozen, eq=eq, base=base, expected=expected):
2826                # First, create the class.
2827                if frozen is None and eq is None:
2828                    @dataclass
2829                    class C(base):
2830                        i: int
2831                elif frozen is None:
2832                    @dataclass(eq=eq)
2833                    class C(base):
2834                        i: int
2835                elif eq is None:
2836                    @dataclass(frozen=frozen)
2837                    class C(base):
2838                        i: int
2839                else:
2840                    @dataclass(frozen=frozen, eq=eq)
2841                    class C(base):
2842                        i: int
2843
2844                # Now, make sure it hashes as expected.
2845                if expected == 'unhashable':
2846                    c = C(10)
2847                    with self.assertRaisesRegex(TypeError, 'unhashable type'):
2848                        hash(c)
2849
2850                elif expected == 'base':
2851                    self.assertEqual(hash(C(10)), 301)
2852
2853                elif expected == 'object':
2854                    # I'm not sure what test to use here.  object's
2855                    #  hash isn't based on id(), so calling hash()
2856                    #  won't tell us much.  So, just check the
2857                    #  function used is object's.
2858                    self.assertIs(C.__hash__, object.__hash__)
2859
2860                elif expected == 'tuple':
2861                    self.assertEqual(hash(C(42)), hash((42,)))
2862
2863                else:
2864                    assert False, f'unknown value for expected={expected!r}'
2865
2866
2867class TestFrozen(unittest.TestCase):
2868    def test_frozen(self):
2869        @dataclass(frozen=True)
2870        class C:
2871            i: int
2872
2873        c = C(10)
2874        self.assertEqual(c.i, 10)
2875        with self.assertRaises(FrozenInstanceError):
2876            c.i = 5
2877        self.assertEqual(c.i, 10)
2878
2879    def test_frozen_empty(self):
2880        @dataclass(frozen=True)
2881        class C:
2882            pass
2883
2884        c = C()
2885        self.assertFalse(hasattr(c, 'i'))
2886        with self.assertRaises(FrozenInstanceError):
2887            c.i = 5
2888        self.assertFalse(hasattr(c, 'i'))
2889        with self.assertRaises(FrozenInstanceError):
2890            del c.i
2891
2892    def test_inherit(self):
2893        @dataclass(frozen=True)
2894        class C:
2895            i: int
2896
2897        @dataclass(frozen=True)
2898        class D(C):
2899            j: int
2900
2901        d = D(0, 10)
2902        with self.assertRaises(FrozenInstanceError):
2903            d.i = 5
2904        with self.assertRaises(FrozenInstanceError):
2905            d.j = 6
2906        self.assertEqual(d.i, 0)
2907        self.assertEqual(d.j, 10)
2908
2909    def test_inherit_nonfrozen_from_empty_frozen(self):
2910        @dataclass(frozen=True)
2911        class C:
2912            pass
2913
2914        with self.assertRaisesRegex(TypeError,
2915                                    'cannot inherit non-frozen dataclass from a frozen one'):
2916            @dataclass
2917            class D(C):
2918                j: int
2919
2920    def test_inherit_frozen_mutliple_inheritance(self):
2921        @dataclass
2922        class NotFrozen:
2923            pass
2924
2925        @dataclass(frozen=True)
2926        class Frozen:
2927            pass
2928
2929        class NotDataclass:
2930            pass
2931
2932        for bases in (
2933            (NotFrozen, Frozen),
2934            (Frozen, NotFrozen),
2935            (Frozen, NotDataclass),
2936            (NotDataclass, Frozen),
2937        ):
2938            with self.subTest(bases=bases):
2939                with self.assertRaisesRegex(
2940                    TypeError,
2941                    'cannot inherit non-frozen dataclass from a frozen one',
2942                ):
2943                    @dataclass
2944                    class NotFrozenChild(*bases):
2945                        pass
2946
2947        for bases in (
2948            (NotFrozen, Frozen),
2949            (Frozen, NotFrozen),
2950            (NotFrozen, NotDataclass),
2951            (NotDataclass, NotFrozen),
2952        ):
2953            with self.subTest(bases=bases):
2954                with self.assertRaisesRegex(
2955                    TypeError,
2956                    'cannot inherit frozen dataclass from a non-frozen one',
2957                ):
2958                    @dataclass(frozen=True)
2959                    class FrozenChild(*bases):
2960                        pass
2961
2962    def test_inherit_frozen_mutliple_inheritance_regular_mixins(self):
2963        @dataclass(frozen=True)
2964        class Frozen:
2965            pass
2966
2967        class NotDataclass:
2968            pass
2969
2970        class C1(Frozen, NotDataclass):
2971            pass
2972        self.assertEqual(C1.__mro__, (C1, Frozen, NotDataclass, object))
2973
2974        class C2(NotDataclass, Frozen):
2975            pass
2976        self.assertEqual(C2.__mro__, (C2, NotDataclass, Frozen, object))
2977
2978        @dataclass(frozen=True)
2979        class C3(Frozen, NotDataclass):
2980            pass
2981        self.assertEqual(C3.__mro__, (C3, Frozen, NotDataclass, object))
2982
2983        @dataclass(frozen=True)
2984        class C4(NotDataclass, Frozen):
2985            pass
2986        self.assertEqual(C4.__mro__, (C4, NotDataclass, Frozen, object))
2987
2988    def test_multiple_frozen_dataclasses_inheritance(self):
2989        @dataclass(frozen=True)
2990        class FrozenA:
2991            pass
2992
2993        @dataclass(frozen=True)
2994        class FrozenB:
2995            pass
2996
2997        class C1(FrozenA, FrozenB):
2998            pass
2999        self.assertEqual(C1.__mro__, (C1, FrozenA, FrozenB, object))
3000
3001        class C2(FrozenB, FrozenA):
3002            pass
3003        self.assertEqual(C2.__mro__, (C2, FrozenB, FrozenA, object))
3004
3005        @dataclass(frozen=True)
3006        class C3(FrozenA, FrozenB):
3007            pass
3008        self.assertEqual(C3.__mro__, (C3, FrozenA, FrozenB, object))
3009
3010        @dataclass(frozen=True)
3011        class C4(FrozenB, FrozenA):
3012            pass
3013        self.assertEqual(C4.__mro__, (C4, FrozenB, FrozenA, object))
3014
3015    def test_inherit_nonfrozen_from_empty(self):
3016        @dataclass
3017        class C:
3018            pass
3019
3020        @dataclass
3021        class D(C):
3022            j: int
3023
3024        d = D(3)
3025        self.assertEqual(d.j, 3)
3026        self.assertIsInstance(d, C)
3027
3028    # Test both ways: with an intermediate normal (non-dataclass)
3029    #  class and without an intermediate class.
3030    def test_inherit_nonfrozen_from_frozen(self):
3031        for intermediate_class in [True, False]:
3032            with self.subTest(intermediate_class=intermediate_class):
3033                @dataclass(frozen=True)
3034                class C:
3035                    i: int
3036
3037                if intermediate_class:
3038                    class I(C): pass
3039                else:
3040                    I = C
3041
3042                with self.assertRaisesRegex(TypeError,
3043                                            'cannot inherit non-frozen dataclass from a frozen one'):
3044                    @dataclass
3045                    class D(I):
3046                        pass
3047
3048    def test_inherit_frozen_from_nonfrozen(self):
3049        for intermediate_class in [True, False]:
3050            with self.subTest(intermediate_class=intermediate_class):
3051                @dataclass
3052                class C:
3053                    i: int
3054
3055                if intermediate_class:
3056                    class I(C): pass
3057                else:
3058                    I = C
3059
3060                with self.assertRaisesRegex(TypeError,
3061                                            'cannot inherit frozen dataclass from a non-frozen one'):
3062                    @dataclass(frozen=True)
3063                    class D(I):
3064                        pass
3065
3066    def test_inherit_from_normal_class(self):
3067        for intermediate_class in [True, False]:
3068            with self.subTest(intermediate_class=intermediate_class):
3069                class C:
3070                    pass
3071
3072                if intermediate_class:
3073                    class I(C): pass
3074                else:
3075                    I = C
3076
3077                @dataclass(frozen=True)
3078                class D(I):
3079                    i: int
3080
3081            d = D(10)
3082            with self.assertRaises(FrozenInstanceError):
3083                d.i = 5
3084
3085    def test_non_frozen_normal_derived(self):
3086        # See bpo-32953.
3087
3088        @dataclass(frozen=True)
3089        class D:
3090            x: int
3091            y: int = 10
3092
3093        class S(D):
3094            pass
3095
3096        s = S(3)
3097        self.assertEqual(s.x, 3)
3098        self.assertEqual(s.y, 10)
3099        s.cached = True
3100
3101        # But can't change the frozen attributes.
3102        with self.assertRaises(FrozenInstanceError):
3103            s.x = 5
3104        with self.assertRaises(FrozenInstanceError):
3105            s.y = 5
3106        self.assertEqual(s.x, 3)
3107        self.assertEqual(s.y, 10)
3108        self.assertEqual(s.cached, True)
3109
3110        with self.assertRaises(FrozenInstanceError):
3111            del s.x
3112        self.assertEqual(s.x, 3)
3113        with self.assertRaises(FrozenInstanceError):
3114            del s.y
3115        self.assertEqual(s.y, 10)
3116        del s.cached
3117        self.assertFalse(hasattr(s, 'cached'))
3118        with self.assertRaises(AttributeError) as cm:
3119            del s.cached
3120        self.assertNotIsInstance(cm.exception, FrozenInstanceError)
3121
3122    def test_non_frozen_normal_derived_from_empty_frozen(self):
3123        @dataclass(frozen=True)
3124        class D:
3125            pass
3126
3127        class S(D):
3128            pass
3129
3130        s = S()
3131        self.assertFalse(hasattr(s, 'x'))
3132        s.x = 5
3133        self.assertEqual(s.x, 5)
3134
3135        del s.x
3136        self.assertFalse(hasattr(s, 'x'))
3137        with self.assertRaises(AttributeError) as cm:
3138            del s.x
3139        self.assertNotIsInstance(cm.exception, FrozenInstanceError)
3140
3141    def test_overwriting_frozen(self):
3142        # frozen uses __setattr__ and __delattr__.
3143        with self.assertRaisesRegex(TypeError,
3144                                    'Cannot overwrite attribute __setattr__'):
3145            @dataclass(frozen=True)
3146            class C:
3147                x: int
3148                def __setattr__(self):
3149                    pass
3150
3151        with self.assertRaisesRegex(TypeError,
3152                                    'Cannot overwrite attribute __delattr__'):
3153            @dataclass(frozen=True)
3154            class C:
3155                x: int
3156                def __delattr__(self):
3157                    pass
3158
3159        @dataclass(frozen=False)
3160        class C:
3161            x: int
3162            def __setattr__(self, name, value):
3163                self.__dict__['x'] = value * 2
3164        self.assertEqual(C(10).x, 20)
3165
3166    def test_frozen_hash(self):
3167        @dataclass(frozen=True)
3168        class C:
3169            x: Any
3170
3171        # If x is immutable, we can compute the hash.  No exception is
3172        # raised.
3173        hash(C(3))
3174
3175        # If x is mutable, computing the hash is an error.
3176        with self.assertRaisesRegex(TypeError, 'unhashable type'):
3177            hash(C({}))
3178
3179    def test_frozen_deepcopy_without_slots(self):
3180        # see: https://github.com/python/cpython/issues/89683
3181        @dataclass(frozen=True, slots=False)
3182        class C:
3183            s: str
3184
3185        c = C('hello')
3186        self.assertEqual(deepcopy(c), c)
3187
3188    def test_frozen_deepcopy_with_slots(self):
3189        # see: https://github.com/python/cpython/issues/89683
3190        with self.subTest('generated __slots__'):
3191            @dataclass(frozen=True, slots=True)
3192            class C:
3193                s: str
3194
3195            c = C('hello')
3196            self.assertEqual(deepcopy(c), c)
3197
3198        with self.subTest('user-defined __slots__ and no __{get,set}state__'):
3199            @dataclass(frozen=True, slots=False)
3200            class C:
3201                __slots__ = ('s',)
3202                s: str
3203
3204            # with user-defined slots, __getstate__ and __setstate__ are not
3205            # automatically added, hence the error
3206            err = r"^cannot\ assign\ to\ field\ 's'$"
3207            self.assertRaisesRegex(FrozenInstanceError, err, deepcopy, C(''))
3208
3209        with self.subTest('user-defined __slots__ and __{get,set}state__'):
3210            @dataclass(frozen=True, slots=False)
3211            class C:
3212                __slots__ = ('s',)
3213                __getstate__ = dataclasses._dataclass_getstate
3214                __setstate__ = dataclasses._dataclass_setstate
3215
3216                s: str
3217
3218            c = C('hello')
3219            self.assertEqual(deepcopy(c), c)
3220
3221
3222class TestSlots(unittest.TestCase):
3223    def test_simple(self):
3224        @dataclass
3225        class C:
3226            __slots__ = ('x',)
3227            x: Any
3228
3229        # There was a bug where a variable in a slot was assumed to
3230        #  also have a default value (of type
3231        #  types.MemberDescriptorType).
3232        with self.assertRaisesRegex(TypeError,
3233                                    r"__init__\(\) missing 1 required positional argument: 'x'"):
3234            C()
3235
3236        # We can create an instance, and assign to x.
3237        c = C(10)
3238        self.assertEqual(c.x, 10)
3239        c.x = 5
3240        self.assertEqual(c.x, 5)
3241
3242        # We can't assign to anything else.
3243        with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'y'"):
3244            c.y = 5
3245
3246    def test_derived_added_field(self):
3247        # See bpo-33100.
3248        @dataclass
3249        class Base:
3250            __slots__ = ('x',)
3251            x: Any
3252
3253        @dataclass
3254        class Derived(Base):
3255            x: int
3256            y: int
3257
3258        d = Derived(1, 2)
3259        self.assertEqual((d.x, d.y), (1, 2))
3260
3261        # We can add a new field to the derived instance.
3262        d.z = 10
3263
3264    def test_generated_slots(self):
3265        @dataclass(slots=True)
3266        class C:
3267            x: int
3268            y: int
3269
3270        c = C(1, 2)
3271        self.assertEqual((c.x, c.y), (1, 2))
3272
3273        c.x = 3
3274        c.y = 4
3275        self.assertEqual((c.x, c.y), (3, 4))
3276
3277        with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'z'"):
3278            c.z = 5
3279
3280    def test_add_slots_when_slots_exists(self):
3281        with self.assertRaisesRegex(TypeError, '^C already specifies __slots__$'):
3282            @dataclass(slots=True)
3283            class C:
3284                __slots__ = ('x',)
3285                x: int
3286
3287    def test_generated_slots_value(self):
3288
3289        class Root:
3290            __slots__ = {'x'}
3291
3292        class Root2(Root):
3293            __slots__ = {'k': '...', 'j': ''}
3294
3295        class Root3(Root2):
3296            __slots__ = ['h']
3297
3298        class Root4(Root3):
3299            __slots__ = 'aa'
3300
3301        @dataclass(slots=True)
3302        class Base(Root4):
3303            y: int
3304            j: str
3305            h: str
3306
3307        self.assertEqual(Base.__slots__, ('y', ))
3308
3309        @dataclass(slots=True)
3310        class Derived(Base):
3311            aa: float
3312            x: str
3313            z: int
3314            k: str
3315            h: str
3316
3317        self.assertEqual(Derived.__slots__, ('z', ))
3318
3319        @dataclass
3320        class AnotherDerived(Base):
3321            z: int
3322
3323        self.assertNotIn('__slots__', AnotherDerived.__dict__)
3324
3325    def test_cant_inherit_from_iterator_slots(self):
3326
3327        class Root:
3328            __slots__ = iter(['a'])
3329
3330        class Root2(Root):
3331            __slots__ = ('b', )
3332
3333        with self.assertRaisesRegex(
3334           TypeError,
3335            "^Slots of 'Root' cannot be determined"
3336        ):
3337            @dataclass(slots=True)
3338            class C(Root2):
3339                x: int
3340
3341    def test_returns_new_class(self):
3342        class A:
3343            x: int
3344
3345        B = dataclass(A, slots=True)
3346        self.assertIsNot(A, B)
3347
3348        self.assertFalse(hasattr(A, "__slots__"))
3349        self.assertTrue(hasattr(B, "__slots__"))
3350
3351    # Can't be local to test_frozen_pickle.
3352    @dataclass(frozen=True, slots=True)
3353    class FrozenSlotsClass:
3354        foo: str
3355        bar: int
3356
3357    @dataclass(frozen=True)
3358    class FrozenWithoutSlotsClass:
3359        foo: str
3360        bar: int
3361
3362    def test_frozen_pickle(self):
3363        # bpo-43999
3364
3365        self.assertEqual(self.FrozenSlotsClass.__slots__, ("foo", "bar"))
3366        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
3367            with self.subTest(proto=proto):
3368                obj = self.FrozenSlotsClass("a", 1)
3369                p = pickle.loads(pickle.dumps(obj, protocol=proto))
3370                self.assertIsNot(obj, p)
3371                self.assertEqual(obj, p)
3372
3373                obj = self.FrozenWithoutSlotsClass("a", 1)
3374                p = pickle.loads(pickle.dumps(obj, protocol=proto))
3375                self.assertIsNot(obj, p)
3376                self.assertEqual(obj, p)
3377
3378    @dataclass(frozen=True, slots=True)
3379    class FrozenSlotsGetStateClass:
3380        foo: str
3381        bar: int
3382
3383        getstate_called: bool = field(default=False, compare=False)
3384
3385        def __getstate__(self):
3386            object.__setattr__(self, 'getstate_called', True)
3387            return [self.foo, self.bar]
3388
3389    @dataclass(frozen=True, slots=True)
3390    class FrozenSlotsSetStateClass:
3391        foo: str
3392        bar: int
3393
3394        setstate_called: bool = field(default=False, compare=False)
3395
3396        def __setstate__(self, state):
3397            object.__setattr__(self, 'setstate_called', True)
3398            object.__setattr__(self, 'foo', state[0])
3399            object.__setattr__(self, 'bar', state[1])
3400
3401    @dataclass(frozen=True, slots=True)
3402    class FrozenSlotsAllStateClass:
3403        foo: str
3404        bar: int
3405
3406        getstate_called: bool = field(default=False, compare=False)
3407        setstate_called: bool = field(default=False, compare=False)
3408
3409        def __getstate__(self):
3410            object.__setattr__(self, 'getstate_called', True)
3411            return [self.foo, self.bar]
3412
3413        def __setstate__(self, state):
3414            object.__setattr__(self, 'setstate_called', True)
3415            object.__setattr__(self, 'foo', state[0])
3416            object.__setattr__(self, 'bar', state[1])
3417
3418    def test_frozen_slots_pickle_custom_state(self):
3419        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
3420            with self.subTest(proto=proto):
3421                obj = self.FrozenSlotsGetStateClass('a', 1)
3422                dumped = pickle.dumps(obj, protocol=proto)
3423
3424                self.assertTrue(obj.getstate_called)
3425                self.assertEqual(obj, pickle.loads(dumped))
3426
3427        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
3428            with self.subTest(proto=proto):
3429                obj = self.FrozenSlotsSetStateClass('a', 1)
3430                obj2 = pickle.loads(pickle.dumps(obj, protocol=proto))
3431
3432                self.assertTrue(obj2.setstate_called)
3433                self.assertEqual(obj, obj2)
3434
3435        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
3436            with self.subTest(proto=proto):
3437                obj = self.FrozenSlotsAllStateClass('a', 1)
3438                dumped = pickle.dumps(obj, protocol=proto)
3439
3440                self.assertTrue(obj.getstate_called)
3441
3442                obj2 = pickle.loads(dumped)
3443                self.assertTrue(obj2.setstate_called)
3444                self.assertEqual(obj, obj2)
3445
3446    def test_slots_with_default_no_init(self):
3447        # Originally reported in bpo-44649.
3448        @dataclass(slots=True)
3449        class A:
3450            a: str
3451            b: str = field(default='b', init=False)
3452
3453        obj = A("a")
3454        self.assertEqual(obj.a, 'a')
3455        self.assertEqual(obj.b, 'b')
3456
3457    def test_slots_with_default_factory_no_init(self):
3458        # Originally reported in bpo-44649.
3459        @dataclass(slots=True)
3460        class A:
3461            a: str
3462            b: str = field(default_factory=lambda:'b', init=False)
3463
3464        obj = A("a")
3465        self.assertEqual(obj.a, 'a')
3466        self.assertEqual(obj.b, 'b')
3467
3468    def test_slots_no_weakref(self):
3469        @dataclass(slots=True)
3470        class A:
3471            # No weakref.
3472            pass
3473
3474        self.assertNotIn("__weakref__", A.__slots__)
3475        a = A()
3476        with self.assertRaisesRegex(TypeError,
3477                                    "cannot create weak reference"):
3478            weakref.ref(a)
3479        with self.assertRaises(AttributeError):
3480            a.__weakref__
3481
3482    def test_slots_weakref(self):
3483        @dataclass(slots=True, weakref_slot=True)
3484        class A:
3485            a: int
3486
3487        self.assertIn("__weakref__", A.__slots__)
3488        a = A(1)
3489        a_ref = weakref.ref(a)
3490
3491        self.assertIs(a.__weakref__, a_ref)
3492
3493    def test_slots_weakref_base_str(self):
3494        class Base:
3495            __slots__ = '__weakref__'
3496
3497        @dataclass(slots=True)
3498        class A(Base):
3499            a: int
3500
3501        # __weakref__ is in the base class, not A.  But an A is still weakref-able.
3502        self.assertIn("__weakref__", Base.__slots__)
3503        self.assertNotIn("__weakref__", A.__slots__)
3504        a = A(1)
3505        weakref.ref(a)
3506
3507    def test_slots_weakref_base_tuple(self):
3508        # Same as test_slots_weakref_base, but use a tuple instead of a string
3509        # in the base class.
3510        class Base:
3511            __slots__ = ('__weakref__',)
3512
3513        @dataclass(slots=True)
3514        class A(Base):
3515            a: int
3516
3517        # __weakref__ is in the base class, not A.  But an A is still
3518        # weakref-able.
3519        self.assertIn("__weakref__", Base.__slots__)
3520        self.assertNotIn("__weakref__", A.__slots__)
3521        a = A(1)
3522        weakref.ref(a)
3523
3524    def test_weakref_slot_without_slot(self):
3525        with self.assertRaisesRegex(TypeError,
3526                                    "weakref_slot is True but slots is False"):
3527            @dataclass(weakref_slot=True)
3528            class A:
3529                a: int
3530
3531    def test_weakref_slot_make_dataclass(self):
3532        A = make_dataclass('A', [('a', int),], slots=True, weakref_slot=True)
3533        self.assertIn("__weakref__", A.__slots__)
3534        a = A(1)
3535        weakref.ref(a)
3536
3537        # And make sure if raises if slots=True is not given.
3538        with self.assertRaisesRegex(TypeError,
3539                                    "weakref_slot is True but slots is False"):
3540            B = make_dataclass('B', [('a', int),], weakref_slot=True)
3541
3542    def test_weakref_slot_subclass_weakref_slot(self):
3543        @dataclass(slots=True, weakref_slot=True)
3544        class Base:
3545            field: int
3546
3547        # A *can* also specify weakref_slot=True if it wants to (gh-93521)
3548        @dataclass(slots=True, weakref_slot=True)
3549        class A(Base):
3550            ...
3551
3552        # __weakref__ is in the base class, not A.  But an instance of A
3553        # is still weakref-able.
3554        self.assertIn("__weakref__", Base.__slots__)
3555        self.assertNotIn("__weakref__", A.__slots__)
3556        a = A(1)
3557        a_ref = weakref.ref(a)
3558        self.assertIs(a.__weakref__, a_ref)
3559
3560    def test_weakref_slot_subclass_no_weakref_slot(self):
3561        @dataclass(slots=True, weakref_slot=True)
3562        class Base:
3563            field: int
3564
3565        @dataclass(slots=True)
3566        class A(Base):
3567            ...
3568
3569        # __weakref__ is in the base class, not A.  Even though A doesn't
3570        # specify weakref_slot, it should still be weakref-able.
3571        self.assertIn("__weakref__", Base.__slots__)
3572        self.assertNotIn("__weakref__", A.__slots__)
3573        a = A(1)
3574        a_ref = weakref.ref(a)
3575        self.assertIs(a.__weakref__, a_ref)
3576
3577    def test_weakref_slot_normal_base_weakref_slot(self):
3578        class Base:
3579            __slots__ = ('__weakref__',)
3580
3581        @dataclass(slots=True, weakref_slot=True)
3582        class A(Base):
3583            field: int
3584
3585        # __weakref__ is in the base class, not A.  But an instance of
3586        # A is still weakref-able.
3587        self.assertIn("__weakref__", Base.__slots__)
3588        self.assertNotIn("__weakref__", A.__slots__)
3589        a = A(1)
3590        a_ref = weakref.ref(a)
3591        self.assertIs(a.__weakref__, a_ref)
3592
3593
3594    def test_dataclass_derived_weakref_slot(self):
3595        class A:
3596            pass
3597
3598        @dataclass(slots=True, weakref_slot=True)
3599        class B(A):
3600            pass
3601
3602        self.assertEqual(B.__slots__, ())
3603        B()
3604
3605    def test_dataclass_derived_generic(self):
3606        T = typing.TypeVar('T')
3607
3608        @dataclass(slots=True, weakref_slot=True)
3609        class A(typing.Generic[T]):
3610            pass
3611        self.assertEqual(A.__slots__, ('__weakref__',))
3612        self.assertTrue(A.__weakref__)
3613        A()
3614
3615        @dataclass(slots=True, weakref_slot=True)
3616        class B[T2]:
3617            pass
3618        self.assertEqual(B.__slots__, ('__weakref__',))
3619        self.assertTrue(B.__weakref__)
3620        B()
3621
3622    def test_dataclass_derived_generic_from_base(self):
3623        T = typing.TypeVar('T')
3624
3625        class RawBase: ...
3626
3627        @dataclass(slots=True, weakref_slot=True)
3628        class C1(typing.Generic[T], RawBase):
3629            pass
3630        self.assertEqual(C1.__slots__, ())
3631        self.assertTrue(C1.__weakref__)
3632        C1()
3633        @dataclass(slots=True, weakref_slot=True)
3634        class C2(RawBase, typing.Generic[T]):
3635            pass
3636        self.assertEqual(C2.__slots__, ())
3637        self.assertTrue(C2.__weakref__)
3638        C2()
3639
3640        @dataclass(slots=True, weakref_slot=True)
3641        class D[T2](RawBase):
3642            pass
3643        self.assertEqual(D.__slots__, ())
3644        self.assertTrue(D.__weakref__)
3645        D()
3646
3647    def test_dataclass_derived_generic_from_slotted_base(self):
3648        T = typing.TypeVar('T')
3649
3650        class WithSlots:
3651            __slots__ = ('a', 'b')
3652
3653        @dataclass(slots=True, weakref_slot=True)
3654        class E1(WithSlots, Generic[T]):
3655            pass
3656        self.assertEqual(E1.__slots__, ('__weakref__',))
3657        self.assertTrue(E1.__weakref__)
3658        E1()
3659        @dataclass(slots=True, weakref_slot=True)
3660        class E2(Generic[T], WithSlots):
3661            pass
3662        self.assertEqual(E2.__slots__, ('__weakref__',))
3663        self.assertTrue(E2.__weakref__)
3664        E2()
3665
3666        @dataclass(slots=True, weakref_slot=True)
3667        class F[T2](WithSlots):
3668            pass
3669        self.assertEqual(F.__slots__, ('__weakref__',))
3670        self.assertTrue(F.__weakref__)
3671        F()
3672
3673    def test_dataclass_derived_generic_from_slotted_base(self):
3674        T = typing.TypeVar('T')
3675
3676        class WithWeakrefSlot:
3677            __slots__ = ('__weakref__',)
3678
3679        @dataclass(slots=True, weakref_slot=True)
3680        class G1(WithWeakrefSlot, Generic[T]):
3681            pass
3682        self.assertEqual(G1.__slots__, ())
3683        self.assertTrue(G1.__weakref__)
3684        G1()
3685        @dataclass(slots=True, weakref_slot=True)
3686        class G2(Generic[T], WithWeakrefSlot):
3687            pass
3688        self.assertEqual(G2.__slots__, ())
3689        self.assertTrue(G2.__weakref__)
3690        G2()
3691
3692        @dataclass(slots=True, weakref_slot=True)
3693        class H[T2](WithWeakrefSlot):
3694            pass
3695        self.assertEqual(H.__slots__, ())
3696        self.assertTrue(H.__weakref__)
3697        H()
3698
3699    def test_dataclass_slot_dict(self):
3700        class WithDictSlot:
3701            __slots__ = ('__dict__',)
3702
3703        @dataclass(slots=True)
3704        class A(WithDictSlot): ...
3705
3706        self.assertEqual(A.__slots__, ())
3707        self.assertEqual(A().__dict__, {})
3708        A()
3709
3710    @support.cpython_only
3711    def test_dataclass_slot_dict_ctype(self):
3712        # https://github.com/python/cpython/issues/123935
3713        from test.support import import_helper
3714        # Skips test if `_testcapi` is not present:
3715        _testcapi = import_helper.import_module('_testcapi')
3716
3717        @dataclass(slots=True)
3718        class HasDictOffset(_testcapi.HeapCTypeWithDict):
3719            __dict__: dict = {}
3720        self.assertNotEqual(_testcapi.HeapCTypeWithDict.__dictoffset__, 0)
3721        self.assertEqual(HasDictOffset.__slots__, ())
3722
3723        @dataclass(slots=True)
3724        class DoesNotHaveDictOffset(_testcapi.HeapCTypeWithWeakref):
3725            __dict__: dict = {}
3726        self.assertEqual(_testcapi.HeapCTypeWithWeakref.__dictoffset__, 0)
3727        self.assertEqual(DoesNotHaveDictOffset.__slots__, ('__dict__',))
3728
3729    @support.cpython_only
3730    def test_slots_with_wrong_init_subclass(self):
3731        # TODO: This test is for a kinda-buggy behavior.
3732        # Ideally, it should be fixed and `__init_subclass__`
3733        # should be fully supported in the future versions.
3734        # See https://github.com/python/cpython/issues/91126
3735        class WrongSuper:
3736            def __init_subclass__(cls, arg):
3737                pass
3738
3739        with self.assertRaisesRegex(
3740            TypeError,
3741            "missing 1 required positional argument: 'arg'",
3742        ):
3743            @dataclass(slots=True)
3744            class WithWrongSuper(WrongSuper, arg=1):
3745                pass
3746
3747        class CorrectSuper:
3748            args = []
3749            def __init_subclass__(cls, arg="default"):
3750                cls.args.append(arg)
3751
3752        @dataclass(slots=True)
3753        class WithCorrectSuper(CorrectSuper):
3754            pass
3755
3756        # __init_subclass__ is called twice: once for `WithCorrectSuper`
3757        # and once for `WithCorrectSuper__slots__` new class
3758        # that we create internally.
3759        self.assertEqual(CorrectSuper.args, ["default", "default"])
3760
3761
3762class TestDescriptors(unittest.TestCase):
3763    def test_set_name(self):
3764        # See bpo-33141.
3765
3766        # Create a descriptor.
3767        class D:
3768            def __set_name__(self, owner, name):
3769                self.name = name + 'x'
3770            def __get__(self, instance, owner):
3771                if instance is not None:
3772                    return 1
3773                return self
3774
3775        # This is the case of just normal descriptor behavior, no
3776        #  dataclass code is involved in initializing the descriptor.
3777        @dataclass
3778        class C:
3779            c: int=D()
3780        self.assertEqual(C.c.name, 'cx')
3781
3782        # Now test with a default value and init=False, which is the
3783        #  only time this is really meaningful.  If not using
3784        #  init=False, then the descriptor will be overwritten, anyway.
3785        @dataclass
3786        class C:
3787            c: int=field(default=D(), init=False)
3788        self.assertEqual(C.c.name, 'cx')
3789        self.assertEqual(C().c, 1)
3790
3791    def test_non_descriptor(self):
3792        # PEP 487 says __set_name__ should work on non-descriptors.
3793        # Create a descriptor.
3794
3795        class D:
3796            def __set_name__(self, owner, name):
3797                self.name = name + 'x'
3798
3799        @dataclass
3800        class C:
3801            c: int=field(default=D(), init=False)
3802        self.assertEqual(C.c.name, 'cx')
3803
3804    def test_lookup_on_instance(self):
3805        # See bpo-33175.
3806        class D:
3807            pass
3808
3809        d = D()
3810        # Create an attribute on the instance, not type.
3811        d.__set_name__ = Mock()
3812
3813        # Make sure d.__set_name__ is not called.
3814        @dataclass
3815        class C:
3816            i: int=field(default=d, init=False)
3817
3818        self.assertEqual(d.__set_name__.call_count, 0)
3819
3820    def test_lookup_on_class(self):
3821        # See bpo-33175.
3822        class D:
3823            pass
3824        D.__set_name__ = Mock()
3825
3826        # Make sure D.__set_name__ is called.
3827        @dataclass
3828        class C:
3829            i: int=field(default=D(), init=False)
3830
3831        self.assertEqual(D.__set_name__.call_count, 1)
3832
3833    def test_init_calls_set(self):
3834        class D:
3835            pass
3836
3837        D.__set__ = Mock()
3838
3839        @dataclass
3840        class C:
3841            i: D = D()
3842
3843        # Make sure D.__set__ is called.
3844        D.__set__.reset_mock()
3845        c = C(5)
3846        self.assertEqual(D.__set__.call_count, 1)
3847
3848    def test_getting_field_calls_get(self):
3849        class D:
3850            pass
3851
3852        D.__set__ = Mock()
3853        D.__get__ = Mock()
3854
3855        @dataclass
3856        class C:
3857            i: D = D()
3858
3859        c = C(5)
3860
3861        # Make sure D.__get__ is called.
3862        D.__get__.reset_mock()
3863        value = c.i
3864        self.assertEqual(D.__get__.call_count, 1)
3865
3866    def test_setting_field_calls_set(self):
3867        class D:
3868            pass
3869
3870        D.__set__ = Mock()
3871
3872        @dataclass
3873        class C:
3874            i: D = D()
3875
3876        c = C(5)
3877
3878        # Make sure D.__set__ is called.
3879        D.__set__.reset_mock()
3880        c.i = 10
3881        self.assertEqual(D.__set__.call_count, 1)
3882
3883    def test_setting_uninitialized_descriptor_field(self):
3884        class D:
3885            pass
3886
3887        D.__set__ = Mock()
3888
3889        @dataclass
3890        class C:
3891            i: D
3892
3893        # D.__set__ is not called because there's no D instance to call it on
3894        D.__set__.reset_mock()
3895        c = C(5)
3896        self.assertEqual(D.__set__.call_count, 0)
3897
3898        # D.__set__ still isn't called after setting i to an instance of D
3899        # because descriptors don't behave like that when stored as instance vars
3900        c.i = D()
3901        c.i = 5
3902        self.assertEqual(D.__set__.call_count, 0)
3903
3904    def test_default_value(self):
3905        class D:
3906            def __get__(self, instance: Any, owner: object) -> int:
3907                if instance is None:
3908                    return 100
3909
3910                return instance._x
3911
3912            def __set__(self, instance: Any, value: int) -> None:
3913                instance._x = value
3914
3915        @dataclass
3916        class C:
3917            i: D = D()
3918
3919        c = C()
3920        self.assertEqual(c.i, 100)
3921
3922        c = C(5)
3923        self.assertEqual(c.i, 5)
3924
3925    def test_no_default_value(self):
3926        class D:
3927            def __get__(self, instance: Any, owner: object) -> int:
3928                if instance is None:
3929                    raise AttributeError()
3930
3931                return instance._x
3932
3933            def __set__(self, instance: Any, value: int) -> None:
3934                instance._x = value
3935
3936        @dataclass
3937        class C:
3938            i: D = D()
3939
3940        with self.assertRaisesRegex(TypeError, 'missing 1 required positional argument'):
3941            c = C()
3942
3943class TestStringAnnotations(unittest.TestCase):
3944    def test_classvar(self):
3945        # Some expressions recognized as ClassVar really aren't.  But
3946        #  if you're using string annotations, it's not an exact
3947        #  science.
3948        # These tests assume that both "import typing" and "from
3949        # typing import *" have been run in this file.
3950        for typestr in ('ClassVar[int]',
3951                        'ClassVar [int]',
3952                        ' ClassVar [int]',
3953                        'ClassVar',
3954                        ' ClassVar ',
3955                        'typing.ClassVar[int]',
3956                        'typing.ClassVar[str]',
3957                        ' typing.ClassVar[str]',
3958                        'typing .ClassVar[str]',
3959                        'typing. ClassVar[str]',
3960                        'typing.ClassVar [str]',
3961                        'typing.ClassVar [ str]',
3962
3963                        # Not syntactically valid, but these will
3964                        #  be treated as ClassVars.
3965                        'typing.ClassVar.[int]',
3966                        'typing.ClassVar+',
3967                        ):
3968            with self.subTest(typestr=typestr):
3969                @dataclass
3970                class C:
3971                    x: typestr
3972
3973                # x is a ClassVar, so C() takes no args.
3974                C()
3975
3976                # And it won't appear in the class's dict because it doesn't
3977                # have a default.
3978                self.assertNotIn('x', C.__dict__)
3979
3980    def test_isnt_classvar(self):
3981        for typestr in ('CV',
3982                        't.ClassVar',
3983                        't.ClassVar[int]',
3984                        'typing..ClassVar[int]',
3985                        'Classvar',
3986                        'Classvar[int]',
3987                        'typing.ClassVarx[int]',
3988                        'typong.ClassVar[int]',
3989                        'dataclasses.ClassVar[int]',
3990                        'typingxClassVar[str]',
3991                        ):
3992            with self.subTest(typestr=typestr):
3993                @dataclass
3994                class C:
3995                    x: typestr
3996
3997                # x is not a ClassVar, so C() takes one arg.
3998                self.assertEqual(C(10).x, 10)
3999
4000    def test_initvar(self):
4001        # These tests assume that both "import dataclasses" and "from
4002        #  dataclasses import *" have been run in this file.
4003        for typestr in ('InitVar[int]',
4004                        'InitVar [int]'
4005                        ' InitVar [int]',
4006                        'InitVar',
4007                        ' InitVar ',
4008                        'dataclasses.InitVar[int]',
4009                        'dataclasses.InitVar[str]',
4010                        ' dataclasses.InitVar[str]',
4011                        'dataclasses .InitVar[str]',
4012                        'dataclasses. InitVar[str]',
4013                        'dataclasses.InitVar [str]',
4014                        'dataclasses.InitVar [ str]',
4015
4016                        # Not syntactically valid, but these will
4017                        #  be treated as InitVars.
4018                        'dataclasses.InitVar.[int]',
4019                        'dataclasses.InitVar+',
4020                        ):
4021            with self.subTest(typestr=typestr):
4022                @dataclass
4023                class C:
4024                    x: typestr
4025
4026                # x is an InitVar, so doesn't create a member.
4027                with self.assertRaisesRegex(AttributeError,
4028                                            "object has no attribute 'x'"):
4029                    C(1).x
4030
4031    def test_isnt_initvar(self):
4032        for typestr in ('IV',
4033                        'dc.InitVar',
4034                        'xdataclasses.xInitVar',
4035                        'typing.xInitVar[int]',
4036                        ):
4037            with self.subTest(typestr=typestr):
4038                @dataclass
4039                class C:
4040                    x: typestr
4041
4042                # x is not an InitVar, so there will be a member x.
4043                self.assertEqual(C(10).x, 10)
4044
4045    def test_classvar_module_level_import(self):
4046        from test.test_dataclasses import dataclass_module_1
4047        from test.test_dataclasses import dataclass_module_1_str
4048        from test.test_dataclasses import dataclass_module_2
4049        from test.test_dataclasses import dataclass_module_2_str
4050
4051        for m in (dataclass_module_1, dataclass_module_1_str,
4052                  dataclass_module_2, dataclass_module_2_str,
4053                  ):
4054            with self.subTest(m=m):
4055                # There's a difference in how the ClassVars are
4056                # interpreted when using string annotations or
4057                # not. See the imported modules for details.
4058                if m.USING_STRINGS:
4059                    c = m.CV(10)
4060                else:
4061                    c = m.CV()
4062                self.assertEqual(c.cv0, 20)
4063
4064
4065                # There's a difference in how the InitVars are
4066                # interpreted when using string annotations or
4067                # not. See the imported modules for details.
4068                c = m.IV(0, 1, 2, 3, 4)
4069
4070                for field_name in ('iv0', 'iv1', 'iv2', 'iv3'):
4071                    with self.subTest(field_name=field_name):
4072                        with self.assertRaisesRegex(AttributeError, f"object has no attribute '{field_name}'"):
4073                            # Since field_name is an InitVar, it's
4074                            # not an instance field.
4075                            getattr(c, field_name)
4076
4077                if m.USING_STRINGS:
4078                    # iv4 is interpreted as a normal field.
4079                    self.assertIn('not_iv4', c.__dict__)
4080                    self.assertEqual(c.not_iv4, 4)
4081                else:
4082                    # iv4 is interpreted as an InitVar, so it
4083                    # won't exist on the instance.
4084                    self.assertNotIn('not_iv4', c.__dict__)
4085
4086    def test_text_annotations(self):
4087        from test.test_dataclasses import dataclass_textanno
4088
4089        self.assertEqual(
4090            get_type_hints(dataclass_textanno.Bar),
4091            {'foo': dataclass_textanno.Foo})
4092        self.assertEqual(
4093            get_type_hints(dataclass_textanno.Bar.__init__),
4094            {'foo': dataclass_textanno.Foo,
4095             'return': type(None)})
4096
4097
4098ByMakeDataClass = make_dataclass('ByMakeDataClass', [('x', int)])
4099ManualModuleMakeDataClass = make_dataclass('ManualModuleMakeDataClass',
4100                                           [('x', int)],
4101                                           module=__name__)
4102WrongNameMakeDataclass = make_dataclass('Wrong', [('x', int)])
4103WrongModuleMakeDataclass = make_dataclass('WrongModuleMakeDataclass',
4104                                          [('x', int)],
4105                                          module='custom')
4106
4107class TestMakeDataclass(unittest.TestCase):
4108    def test_simple(self):
4109        C = make_dataclass('C',
4110                           [('x', int),
4111                            ('y', int, field(default=5))],
4112                           namespace={'add_one': lambda self: self.x + 1})
4113        c = C(10)
4114        self.assertEqual((c.x, c.y), (10, 5))
4115        self.assertEqual(c.add_one(), 11)
4116
4117
4118    def test_no_mutate_namespace(self):
4119        # Make sure a provided namespace isn't mutated.
4120        ns = {}
4121        C = make_dataclass('C',
4122                           [('x', int),
4123                            ('y', int, field(default=5))],
4124                           namespace=ns)
4125        self.assertEqual(ns, {})
4126
4127    def test_base(self):
4128        class Base1:
4129            pass
4130        class Base2:
4131            pass
4132        C = make_dataclass('C',
4133                           [('x', int)],
4134                           bases=(Base1, Base2))
4135        c = C(2)
4136        self.assertIsInstance(c, C)
4137        self.assertIsInstance(c, Base1)
4138        self.assertIsInstance(c, Base2)
4139
4140    def test_base_dataclass(self):
4141        @dataclass
4142        class Base1:
4143            x: int
4144        class Base2:
4145            pass
4146        C = make_dataclass('C',
4147                           [('y', int)],
4148                           bases=(Base1, Base2))
4149        with self.assertRaisesRegex(TypeError, 'required positional'):
4150            c = C(2)
4151        c = C(1, 2)
4152        self.assertIsInstance(c, C)
4153        self.assertIsInstance(c, Base1)
4154        self.assertIsInstance(c, Base2)
4155
4156        self.assertEqual((c.x, c.y), (1, 2))
4157
4158    def test_init_var(self):
4159        def post_init(self, y):
4160            self.x *= y
4161
4162        C = make_dataclass('C',
4163                           [('x', int),
4164                            ('y', InitVar[int]),
4165                            ],
4166                           namespace={'__post_init__': post_init},
4167                           )
4168        c = C(2, 3)
4169        self.assertEqual(vars(c), {'x': 6})
4170        self.assertEqual(len(fields(c)), 1)
4171
4172    def test_class_var(self):
4173        C = make_dataclass('C',
4174                           [('x', int),
4175                            ('y', ClassVar[int], 10),
4176                            ('z', ClassVar[int], field(default=20)),
4177                            ])
4178        c = C(1)
4179        self.assertEqual(vars(c), {'x': 1})
4180        self.assertEqual(len(fields(c)), 1)
4181        self.assertEqual(C.y, 10)
4182        self.assertEqual(C.z, 20)
4183
4184    def test_other_params(self):
4185        C = make_dataclass('C',
4186                           [('x', int),
4187                            ('y', ClassVar[int], 10),
4188                            ('z', ClassVar[int], field(default=20)),
4189                            ],
4190                           init=False)
4191        # Make sure we have a repr, but no init.
4192        self.assertNotIn('__init__', vars(C))
4193        self.assertIn('__repr__', vars(C))
4194
4195        # Make sure random other params don't work.
4196        with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
4197            C = make_dataclass('C',
4198                               [],
4199                               xxinit=False)
4200
4201    def test_no_types(self):
4202        C = make_dataclass('Point', ['x', 'y', 'z'])
4203        c = C(1, 2, 3)
4204        self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
4205        self.assertEqual(C.__annotations__, {'x': 'typing.Any',
4206                                             'y': 'typing.Any',
4207                                             'z': 'typing.Any'})
4208
4209        C = make_dataclass('Point', ['x', ('y', int), 'z'])
4210        c = C(1, 2, 3)
4211        self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
4212        self.assertEqual(C.__annotations__, {'x': 'typing.Any',
4213                                             'y': int,
4214                                             'z': 'typing.Any'})
4215
4216    def test_module_attr(self):
4217        self.assertEqual(ByMakeDataClass.__module__, __name__)
4218        self.assertEqual(ByMakeDataClass(1).__module__, __name__)
4219        self.assertEqual(WrongModuleMakeDataclass.__module__, "custom")
4220        Nested = make_dataclass('Nested', [])
4221        self.assertEqual(Nested.__module__, __name__)
4222        self.assertEqual(Nested().__module__, __name__)
4223
4224    def test_pickle_support(self):
4225        for klass in [ByMakeDataClass, ManualModuleMakeDataClass]:
4226            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
4227                with self.subTest(proto=proto):
4228                    self.assertEqual(
4229                        pickle.loads(pickle.dumps(klass, proto)),
4230                        klass,
4231                    )
4232                    self.assertEqual(
4233                        pickle.loads(pickle.dumps(klass(1), proto)),
4234                        klass(1),
4235                    )
4236
4237    def test_cannot_be_pickled(self):
4238        for klass in [WrongNameMakeDataclass, WrongModuleMakeDataclass]:
4239            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
4240                with self.subTest(proto=proto):
4241                    with self.assertRaises(pickle.PickleError):
4242                        pickle.dumps(klass, proto)
4243                    with self.assertRaises(pickle.PickleError):
4244                        pickle.dumps(klass(1), proto)
4245
4246    def test_invalid_type_specification(self):
4247        for bad_field in [(),
4248                          (1, 2, 3, 4),
4249                          ]:
4250            with self.subTest(bad_field=bad_field):
4251                with self.assertRaisesRegex(TypeError, r'Invalid field: '):
4252                    make_dataclass('C', ['a', bad_field])
4253
4254        # And test for things with no len().
4255        for bad_field in [float,
4256                          lambda x:x,
4257                          ]:
4258            with self.subTest(bad_field=bad_field):
4259                with self.assertRaisesRegex(TypeError, r'has no len\(\)'):
4260                    make_dataclass('C', ['a', bad_field])
4261
4262    def test_duplicate_field_names(self):
4263        for field in ['a', 'ab']:
4264            with self.subTest(field=field):
4265                with self.assertRaisesRegex(TypeError, 'Field name duplicated'):
4266                    make_dataclass('C', [field, 'a', field])
4267
4268    def test_keyword_field_names(self):
4269        for field in ['for', 'async', 'await', 'as']:
4270            with self.subTest(field=field):
4271                with self.assertRaisesRegex(TypeError, 'must not be keywords'):
4272                    make_dataclass('C', ['a', field])
4273                with self.assertRaisesRegex(TypeError, 'must not be keywords'):
4274                    make_dataclass('C', [field])
4275                with self.assertRaisesRegex(TypeError, 'must not be keywords'):
4276                    make_dataclass('C', [field, 'a'])
4277
4278    def test_non_identifier_field_names(self):
4279        for field in ['()', 'x,y', '*', '2@3', '', 'little johnny tables']:
4280            with self.subTest(field=field):
4281                with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
4282                    make_dataclass('C', ['a', field])
4283                with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
4284                    make_dataclass('C', [field])
4285                with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
4286                    make_dataclass('C', [field, 'a'])
4287
4288    def test_underscore_field_names(self):
4289        # Unlike namedtuple, it's okay if dataclass field names have
4290        # an underscore.
4291        make_dataclass('C', ['_', '_a', 'a_a', 'a_'])
4292
4293    def test_funny_class_names_names(self):
4294        # No reason to prevent weird class names, since
4295        # types.new_class allows them.
4296        for classname in ['()', 'x,y', '*', '2@3', '']:
4297            with self.subTest(classname=classname):
4298                C = make_dataclass(classname, ['a', 'b'])
4299                self.assertEqual(C.__name__, classname)
4300
4301class TestReplace(unittest.TestCase):
4302    def test(self):
4303        @dataclass(frozen=True)
4304        class C:
4305            x: int
4306            y: int
4307
4308        c = C(1, 2)
4309        c1 = replace(c, x=3)
4310        self.assertEqual(c1.x, 3)
4311        self.assertEqual(c1.y, 2)
4312
4313    def test_frozen(self):
4314        @dataclass(frozen=True)
4315        class C:
4316            x: int
4317            y: int
4318            z: int = field(init=False, default=10)
4319            t: int = field(init=False, default=100)
4320
4321        c = C(1, 2)
4322        c1 = replace(c, x=3)
4323        self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100))
4324        self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100))
4325
4326
4327        with self.assertRaisesRegex(TypeError, 'init=False'):
4328            replace(c, x=3, z=20, t=50)
4329        with self.assertRaisesRegex(TypeError, 'init=False'):
4330            replace(c, z=20)
4331            replace(c, x=3, z=20, t=50)
4332
4333        # Make sure the result is still frozen.
4334        with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"):
4335            c1.x = 3
4336
4337        # Make sure we can't replace an attribute that doesn't exist,
4338        #  if we're also replacing one that does exist.  Test this
4339        #  here, because setting attributes on frozen instances is
4340        #  handled slightly differently from non-frozen ones.
4341        with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
4342                                             "keyword argument 'a'"):
4343            c1 = replace(c, x=20, a=5)
4344
4345    def test_invalid_field_name(self):
4346        @dataclass(frozen=True)
4347        class C:
4348            x: int
4349            y: int
4350
4351        c = C(1, 2)
4352        with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
4353                                    "keyword argument 'z'"):
4354            c1 = replace(c, z=3)
4355
4356    def test_invalid_object(self):
4357        @dataclass(frozen=True)
4358        class C:
4359            x: int
4360            y: int
4361
4362        with self.assertRaisesRegex(TypeError, 'dataclass instance'):
4363            replace(C, x=3)
4364
4365        with self.assertRaisesRegex(TypeError, 'dataclass instance'):
4366            replace(0, x=3)
4367
4368    def test_no_init(self):
4369        @dataclass
4370        class C:
4371            x: int
4372            y: int = field(init=False, default=10)
4373
4374        c = C(1)
4375        c.y = 20
4376
4377        # Make sure y gets the default value.
4378        c1 = replace(c, x=5)
4379        self.assertEqual((c1.x, c1.y), (5, 10))
4380
4381        # Trying to replace y is an error.
4382        with self.assertRaisesRegex(TypeError, 'init=False'):
4383            replace(c, x=2, y=30)
4384
4385        with self.assertRaisesRegex(TypeError, 'init=False'):
4386            replace(c, y=30)
4387
4388    def test_classvar(self):
4389        @dataclass
4390        class C:
4391            x: int
4392            y: ClassVar[int] = 1000
4393
4394        c = C(1)
4395        d = C(2)
4396
4397        self.assertIs(c.y, d.y)
4398        self.assertEqual(c.y, 1000)
4399
4400        # Trying to replace y is an error: can't replace ClassVars.
4401        with self.assertRaisesRegex(TypeError, r"__init__\(\) got an "
4402                                    "unexpected keyword argument 'y'"):
4403            replace(c, y=30)
4404
4405        replace(c, x=5)
4406
4407    def test_initvar_is_specified(self):
4408        @dataclass
4409        class C:
4410            x: int
4411            y: InitVar[int]
4412
4413            def __post_init__(self, y):
4414                self.x *= y
4415
4416        c = C(1, 10)
4417        self.assertEqual(c.x, 10)
4418        with self.assertRaisesRegex(TypeError, r"InitVar 'y' must be "
4419                                    r"specified with replace\(\)"):
4420            replace(c, x=3)
4421        c = replace(c, x=3, y=5)
4422        self.assertEqual(c.x, 15)
4423
4424    def test_initvar_with_default_value(self):
4425        @dataclass
4426        class C:
4427            x: int
4428            y: InitVar[int] = None
4429            z: InitVar[int] = 42
4430
4431            def __post_init__(self, y, z):
4432                if y is not None:
4433                    self.x += y
4434                if z is not None:
4435                    self.x += z
4436
4437        c = C(x=1, y=10, z=1)
4438        self.assertEqual(replace(c), C(x=12))
4439        self.assertEqual(replace(c, y=4), C(x=12, y=4, z=42))
4440        self.assertEqual(replace(c, y=4, z=1), C(x=12, y=4, z=1))
4441
4442    def test_recursive_repr(self):
4443        @dataclass
4444        class C:
4445            f: "C"
4446
4447        c = C(None)
4448        c.f = c
4449        self.assertEqual(repr(c), "TestReplace.test_recursive_repr.<locals>.C(f=...)")
4450
4451    def test_recursive_repr_two_attrs(self):
4452        @dataclass
4453        class C:
4454            f: "C"
4455            g: "C"
4456
4457        c = C(None, None)
4458        c.f = c
4459        c.g = c
4460        self.assertEqual(repr(c), "TestReplace.test_recursive_repr_two_attrs"
4461                                  ".<locals>.C(f=..., g=...)")
4462
4463    def test_recursive_repr_indirection(self):
4464        @dataclass
4465        class C:
4466            f: "D"
4467
4468        @dataclass
4469        class D:
4470            f: "C"
4471
4472        c = C(None)
4473        d = D(None)
4474        c.f = d
4475        d.f = c
4476        self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection"
4477                                  ".<locals>.C(f=TestReplace.test_recursive_repr_indirection"
4478                                  ".<locals>.D(f=...))")
4479
4480    def test_recursive_repr_indirection_two(self):
4481        @dataclass
4482        class C:
4483            f: "D"
4484
4485        @dataclass
4486        class D:
4487            f: "E"
4488
4489        @dataclass
4490        class E:
4491            f: "C"
4492
4493        c = C(None)
4494        d = D(None)
4495        e = E(None)
4496        c.f = d
4497        d.f = e
4498        e.f = c
4499        self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection_two"
4500                                  ".<locals>.C(f=TestReplace.test_recursive_repr_indirection_two"
4501                                  ".<locals>.D(f=TestReplace.test_recursive_repr_indirection_two"
4502                                  ".<locals>.E(f=...)))")
4503
4504    def test_recursive_repr_misc_attrs(self):
4505        @dataclass
4506        class C:
4507            f: "C"
4508            g: int
4509
4510        c = C(None, 1)
4511        c.f = c
4512        self.assertEqual(repr(c), "TestReplace.test_recursive_repr_misc_attrs"
4513                                  ".<locals>.C(f=..., g=1)")
4514
4515    ## def test_initvar(self):
4516    ##     @dataclass
4517    ##     class C:
4518    ##         x: int
4519    ##         y: InitVar[int]
4520
4521    ##     c = C(1, 10)
4522    ##     d = C(2, 20)
4523
4524    ##     # In our case, replacing an InitVar is a no-op
4525    ##     self.assertEqual(c, replace(c, y=5))
4526
4527    ##     replace(c, x=5)
4528
4529class TestAbstract(unittest.TestCase):
4530    def test_abc_implementation(self):
4531        class Ordered(abc.ABC):
4532            @abc.abstractmethod
4533            def __lt__(self, other):
4534                pass
4535
4536            @abc.abstractmethod
4537            def __le__(self, other):
4538                pass
4539
4540        @dataclass(order=True)
4541        class Date(Ordered):
4542            year: int
4543            month: 'Month'
4544            day: 'int'
4545
4546        self.assertFalse(inspect.isabstract(Date))
4547        self.assertGreater(Date(2020,12,25), Date(2020,8,31))
4548
4549    def test_maintain_abc(self):
4550        class A(abc.ABC):
4551            @abc.abstractmethod
4552            def foo(self):
4553                pass
4554
4555        @dataclass
4556        class Date(A):
4557            year: int
4558            month: 'Month'
4559            day: 'int'
4560
4561        self.assertTrue(inspect.isabstract(Date))
4562        msg = "class Date without an implementation for abstract method 'foo'"
4563        self.assertRaisesRegex(TypeError, msg, Date)
4564
4565
4566class TestMatchArgs(unittest.TestCase):
4567    def test_match_args(self):
4568        @dataclass
4569        class C:
4570            a: int
4571        self.assertEqual(C(42).__match_args__, ('a',))
4572
4573    def test_explicit_match_args(self):
4574        ma = ()
4575        @dataclass
4576        class C:
4577            a: int
4578            __match_args__ = ma
4579        self.assertIs(C(42).__match_args__, ma)
4580
4581    def test_bpo_43764(self):
4582        @dataclass(repr=False, eq=False, init=False)
4583        class X:
4584            a: int
4585            b: int
4586            c: int
4587        self.assertEqual(X.__match_args__, ("a", "b", "c"))
4588
4589    def test_match_args_argument(self):
4590        @dataclass(match_args=False)
4591        class X:
4592            a: int
4593        self.assertNotIn('__match_args__', X.__dict__)
4594
4595        @dataclass(match_args=False)
4596        class Y:
4597            a: int
4598            __match_args__ = ('b',)
4599        self.assertEqual(Y.__match_args__, ('b',))
4600
4601        @dataclass(match_args=False)
4602        class Z(Y):
4603            z: int
4604        self.assertEqual(Z.__match_args__, ('b',))
4605
4606        # Ensure parent dataclass __match_args__ is seen, if child class
4607        # specifies match_args=False.
4608        @dataclass
4609        class A:
4610            a: int
4611            z: int
4612        @dataclass(match_args=False)
4613        class B(A):
4614            b: int
4615        self.assertEqual(B.__match_args__, ('a', 'z'))
4616
4617    def test_make_dataclasses(self):
4618        C = make_dataclass('C', [('x', int), ('y', int)])
4619        self.assertEqual(C.__match_args__, ('x', 'y'))
4620
4621        C = make_dataclass('C', [('x', int), ('y', int)], match_args=True)
4622        self.assertEqual(C.__match_args__, ('x', 'y'))
4623
4624        C = make_dataclass('C', [('x', int), ('y', int)], match_args=False)
4625        self.assertNotIn('__match__args__', C.__dict__)
4626
4627        C = make_dataclass('C', [('x', int), ('y', int)], namespace={'__match_args__': ('z',)})
4628        self.assertEqual(C.__match_args__, ('z',))
4629
4630
4631class TestKeywordArgs(unittest.TestCase):
4632    def test_no_classvar_kwarg(self):
4633        msg = 'field a is a ClassVar but specifies kw_only'
4634        with self.assertRaisesRegex(TypeError, msg):
4635            @dataclass
4636            class A:
4637                a: ClassVar[int] = field(kw_only=True)
4638
4639        with self.assertRaisesRegex(TypeError, msg):
4640            @dataclass
4641            class A:
4642                a: ClassVar[int] = field(kw_only=False)
4643
4644        with self.assertRaisesRegex(TypeError, msg):
4645            @dataclass(kw_only=True)
4646            class A:
4647                a: ClassVar[int] = field(kw_only=False)
4648
4649    def test_field_marked_as_kwonly(self):
4650        #######################
4651        # Using dataclass(kw_only=True)
4652        @dataclass(kw_only=True)
4653        class A:
4654            a: int
4655        self.assertTrue(fields(A)[0].kw_only)
4656
4657        @dataclass(kw_only=True)
4658        class A:
4659            a: int = field(kw_only=True)
4660        self.assertTrue(fields(A)[0].kw_only)
4661
4662        @dataclass(kw_only=True)
4663        class A:
4664            a: int = field(kw_only=False)
4665        self.assertFalse(fields(A)[0].kw_only)
4666
4667        #######################
4668        # Using dataclass(kw_only=False)
4669        @dataclass(kw_only=False)
4670        class A:
4671            a: int
4672        self.assertFalse(fields(A)[0].kw_only)
4673
4674        @dataclass(kw_only=False)
4675        class A:
4676            a: int = field(kw_only=True)
4677        self.assertTrue(fields(A)[0].kw_only)
4678
4679        @dataclass(kw_only=False)
4680        class A:
4681            a: int = field(kw_only=False)
4682        self.assertFalse(fields(A)[0].kw_only)
4683
4684        #######################
4685        # Not specifying dataclass(kw_only)
4686        @dataclass
4687        class A:
4688            a: int
4689        self.assertFalse(fields(A)[0].kw_only)
4690
4691        @dataclass
4692        class A:
4693            a: int = field(kw_only=True)
4694        self.assertTrue(fields(A)[0].kw_only)
4695
4696        @dataclass
4697        class A:
4698            a: int = field(kw_only=False)
4699        self.assertFalse(fields(A)[0].kw_only)
4700
4701    def test_match_args(self):
4702        # kw fields don't show up in __match_args__.
4703        @dataclass(kw_only=True)
4704        class C:
4705            a: int
4706        self.assertEqual(C(a=42).__match_args__, ())
4707
4708        @dataclass
4709        class C:
4710            a: int
4711            b: int = field(kw_only=True)
4712        self.assertEqual(C(42, b=10).__match_args__, ('a',))
4713
4714    def test_KW_ONLY(self):
4715        @dataclass
4716        class A:
4717            a: int
4718            _: KW_ONLY
4719            b: int
4720            c: int
4721        A(3, c=5, b=4)
4722        msg = "takes 2 positional arguments but 4 were given"
4723        with self.assertRaisesRegex(TypeError, msg):
4724            A(3, 4, 5)
4725
4726
4727        @dataclass(kw_only=True)
4728        class B:
4729            a: int
4730            _: KW_ONLY
4731            b: int
4732            c: int
4733        B(a=3, b=4, c=5)
4734        msg = "takes 1 positional argument but 4 were given"
4735        with self.assertRaisesRegex(TypeError, msg):
4736            B(3, 4, 5)
4737
4738        # Explicitly make a field that follows KW_ONLY be non-keyword-only.
4739        @dataclass
4740        class C:
4741            a: int
4742            _: KW_ONLY
4743            b: int
4744            c: int = field(kw_only=False)
4745        c = C(1, 2, b=3)
4746        self.assertEqual(c.a, 1)
4747        self.assertEqual(c.b, 3)
4748        self.assertEqual(c.c, 2)
4749        c = C(1, b=3, c=2)
4750        self.assertEqual(c.a, 1)
4751        self.assertEqual(c.b, 3)
4752        self.assertEqual(c.c, 2)
4753        c = C(1, b=3, c=2)
4754        self.assertEqual(c.a, 1)
4755        self.assertEqual(c.b, 3)
4756        self.assertEqual(c.c, 2)
4757        c = C(c=2, b=3, a=1)
4758        self.assertEqual(c.a, 1)
4759        self.assertEqual(c.b, 3)
4760        self.assertEqual(c.c, 2)
4761
4762    def test_KW_ONLY_as_string(self):
4763        @dataclass
4764        class A:
4765            a: int
4766            _: 'dataclasses.KW_ONLY'
4767            b: int
4768            c: int
4769        A(3, c=5, b=4)
4770        msg = "takes 2 positional arguments but 4 were given"
4771        with self.assertRaisesRegex(TypeError, msg):
4772            A(3, 4, 5)
4773
4774    def test_KW_ONLY_twice(self):
4775        msg = "'Y' is KW_ONLY, but KW_ONLY has already been specified"
4776
4777        with self.assertRaisesRegex(TypeError, msg):
4778            @dataclass
4779            class A:
4780                a: int
4781                X: KW_ONLY
4782                Y: KW_ONLY
4783                b: int
4784                c: int
4785
4786        with self.assertRaisesRegex(TypeError, msg):
4787            @dataclass
4788            class A:
4789                a: int
4790                X: KW_ONLY
4791                b: int
4792                Y: KW_ONLY
4793                c: int
4794
4795        with self.assertRaisesRegex(TypeError, msg):
4796            @dataclass
4797            class A:
4798                a: int
4799                X: KW_ONLY
4800                b: int
4801                c: int
4802                Y: KW_ONLY
4803
4804        # But this usage is okay, since it's not using KW_ONLY.
4805        @dataclass
4806        class A:
4807            a: int
4808            _: KW_ONLY
4809            b: int
4810            c: int = field(kw_only=True)
4811
4812        # And if inheriting, it's okay.
4813        @dataclass
4814        class A:
4815            a: int
4816            _: KW_ONLY
4817            b: int
4818            c: int
4819        @dataclass
4820        class B(A):
4821            _: KW_ONLY
4822            d: int
4823
4824        # Make sure the error is raised in a derived class.
4825        with self.assertRaisesRegex(TypeError, msg):
4826            @dataclass
4827            class A:
4828                a: int
4829                _: KW_ONLY
4830                b: int
4831                c: int
4832            @dataclass
4833            class B(A):
4834                X: KW_ONLY
4835                d: int
4836                Y: KW_ONLY
4837
4838
4839    def test_post_init(self):
4840        @dataclass
4841        class A:
4842            a: int
4843            _: KW_ONLY
4844            b: InitVar[int]
4845            c: int
4846            d: InitVar[int]
4847            def __post_init__(self, b, d):
4848                raise CustomError(f'{b=} {d=}')
4849        with self.assertRaisesRegex(CustomError, 'b=3 d=4'):
4850            A(1, c=2, b=3, d=4)
4851
4852        @dataclass
4853        class B:
4854            a: int
4855            _: KW_ONLY
4856            b: InitVar[int]
4857            c: int
4858            d: InitVar[int]
4859            def __post_init__(self, b, d):
4860                self.a = b
4861                self.c = d
4862        b = B(1, c=2, b=3, d=4)
4863        self.assertEqual(asdict(b), {'a': 3, 'c': 4})
4864
4865    def test_defaults(self):
4866        # For kwargs, make sure we can have defaults after non-defaults.
4867        @dataclass
4868        class A:
4869            a: int = 0
4870            _: KW_ONLY
4871            b: int
4872            c: int = 1
4873            d: int
4874
4875        a = A(d=4, b=3)
4876        self.assertEqual(a.a, 0)
4877        self.assertEqual(a.b, 3)
4878        self.assertEqual(a.c, 1)
4879        self.assertEqual(a.d, 4)
4880
4881        # Make sure we still check for non-kwarg non-defaults not following
4882        # defaults.
4883        err_regex = "non-default argument 'z' follows default argument 'a'"
4884        with self.assertRaisesRegex(TypeError, err_regex):
4885            @dataclass
4886            class A:
4887                a: int = 0
4888                z: int
4889                _: KW_ONLY
4890                b: int
4891                c: int = 1
4892                d: int
4893
4894    def test_make_dataclass(self):
4895        A = make_dataclass("A", ['a'], kw_only=True)
4896        self.assertTrue(fields(A)[0].kw_only)
4897
4898        B = make_dataclass("B",
4899                           ['a', ('b', int, field(kw_only=False))],
4900                           kw_only=True)
4901        self.assertTrue(fields(B)[0].kw_only)
4902        self.assertFalse(fields(B)[1].kw_only)
4903
4904
4905if __name__ == '__main__':
4906    unittest.main()
4907