• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Deliberately use "from dataclasses import *".  Every name in __all__
2# is tested, so they all must be present.  This is a way to catch
3# missing ones.
4
5from dataclasses import *
6
7import pickle
8import inspect
9import builtins
10import unittest
11from unittest.mock import Mock
12from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional
13from typing import get_type_hints
14from collections import deque, OrderedDict, namedtuple
15from functools import total_ordering
16
17import typing       # Needed for the string "typing.ClassVar[int]" to work as an annotation.
18import dataclasses  # Needed for the string "dataclasses.InitVar[int]" to work as an annotation.
19
20# Just any custom exception we can catch.
21class CustomError(Exception): pass
22
23class TestCase(unittest.TestCase):
24    def test_no_fields(self):
25        @dataclass
26        class C:
27            pass
28
29        o = C()
30        self.assertEqual(len(fields(C)), 0)
31
32    def test_no_fields_but_member_variable(self):
33        @dataclass
34        class C:
35            i = 0
36
37        o = C()
38        self.assertEqual(len(fields(C)), 0)
39
40    def test_one_field_no_default(self):
41        @dataclass
42        class C:
43            x: int
44
45        o = C(42)
46        self.assertEqual(o.x, 42)
47
48    def test_field_default_default_factory_error(self):
49        msg = "cannot specify both default and default_factory"
50        with self.assertRaisesRegex(ValueError, msg):
51            @dataclass
52            class C:
53                x: int = field(default=1, default_factory=int)
54
55    def test_field_repr(self):
56        int_field = field(default=1, init=True, repr=False)
57        int_field.name = "id"
58        repr_output = repr(int_field)
59        expected_output = "Field(name='id',type=None," \
60                           f"default=1,default_factory={MISSING!r}," \
61                           "init=True,repr=False,hash=None," \
62                           "compare=True,metadata=mappingproxy({})," \
63                           "_field_type=None)"
64
65        self.assertEqual(repr_output, expected_output)
66
67    def test_named_init_params(self):
68        @dataclass
69        class C:
70            x: int
71
72        o = C(x=32)
73        self.assertEqual(o.x, 32)
74
75    def test_two_fields_one_default(self):
76        @dataclass
77        class C:
78            x: int
79            y: int = 0
80
81        o = C(3)
82        self.assertEqual((o.x, o.y), (3, 0))
83
84        # Non-defaults following defaults.
85        with self.assertRaisesRegex(TypeError,
86                                    "non-default argument 'y' follows "
87                                    "default argument"):
88            @dataclass
89            class C:
90                x: int = 0
91                y: int
92
93        # A derived class adds a non-default field after a default one.
94        with self.assertRaisesRegex(TypeError,
95                                    "non-default argument 'y' follows "
96                                    "default argument"):
97            @dataclass
98            class B:
99                x: int = 0
100
101            @dataclass
102            class C(B):
103                y: int
104
105        # Override a base class field and add a default to
106        #  a field which didn't use to have a default.
107        with self.assertRaisesRegex(TypeError,
108                                    "non-default argument 'y' follows "
109                                    "default argument"):
110            @dataclass
111            class B:
112                x: int
113                y: int
114
115            @dataclass
116            class C(B):
117                x: int = 0
118
119    def test_overwrite_hash(self):
120        # Test that declaring this class isn't an error.  It should
121        #  use the user-provided __hash__.
122        @dataclass(frozen=True)
123        class C:
124            x: int
125            def __hash__(self):
126                return 301
127        self.assertEqual(hash(C(100)), 301)
128
129        # Test that declaring this class isn't an error.  It should
130        #  use the generated __hash__.
131        @dataclass(frozen=True)
132        class C:
133            x: int
134            def __eq__(self, other):
135                return False
136        self.assertEqual(hash(C(100)), hash((100,)))
137
138        # But this one should generate an exception, because with
139        #  unsafe_hash=True, it's an error to have a __hash__ defined.
140        with self.assertRaisesRegex(TypeError,
141                                    'Cannot overwrite attribute __hash__'):
142            @dataclass(unsafe_hash=True)
143            class C:
144                def __hash__(self):
145                    pass
146
147        # Creating this class should not generate an exception,
148        #  because even though __hash__ exists before @dataclass is
149        #  called, (due to __eq__ being defined), since it's None
150        #  that's okay.
151        @dataclass(unsafe_hash=True)
152        class C:
153            x: int
154            def __eq__(self):
155                pass
156        # The generated hash function works as we'd expect.
157        self.assertEqual(hash(C(10)), hash((10,)))
158
159        # Creating this class should generate an exception, because
160        #  __hash__ exists and is not None, which it would be if it
161        #  had been auto-generated due to __eq__ being defined.
162        with self.assertRaisesRegex(TypeError,
163                                    'Cannot overwrite attribute __hash__'):
164            @dataclass(unsafe_hash=True)
165            class C:
166                x: int
167                def __eq__(self):
168                    pass
169                def __hash__(self):
170                    pass
171
172    def test_overwrite_fields_in_derived_class(self):
173        # Note that x from C1 replaces x in Base, but the order remains
174        #  the same as defined in Base.
175        @dataclass
176        class Base:
177            x: Any = 15.0
178            y: int = 0
179
180        @dataclass
181        class C1(Base):
182            z: int = 10
183            x: int = 15
184
185        o = Base()
186        self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)')
187
188        o = C1()
189        self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)')
190
191        o = C1(x=5)
192        self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)')
193
194    def test_field_named_self(self):
195        @dataclass
196        class C:
197            self: str
198        c=C('foo')
199        self.assertEqual(c.self, 'foo')
200
201        # Make sure the first parameter is not named 'self'.
202        sig = inspect.signature(C.__init__)
203        first = next(iter(sig.parameters))
204        self.assertNotEqual('self', first)
205
206        # But we do use 'self' if no field named self.
207        @dataclass
208        class C:
209            selfx: str
210
211        # Make sure the first parameter is named 'self'.
212        sig = inspect.signature(C.__init__)
213        first = next(iter(sig.parameters))
214        self.assertEqual('self', first)
215
216    def test_field_named_object(self):
217        @dataclass
218        class C:
219            object: str
220        c = C('foo')
221        self.assertEqual(c.object, 'foo')
222
223    def test_field_named_object_frozen(self):
224        @dataclass(frozen=True)
225        class C:
226            object: str
227        c = C('foo')
228        self.assertEqual(c.object, 'foo')
229
230    def test_field_named_like_builtin(self):
231        # Attribute names can shadow built-in names
232        # since code generation is used.
233        # Ensure that this is not happening.
234        exclusions = {'None', 'True', 'False'}
235        builtins_names = sorted(
236            b for b in builtins.__dict__.keys()
237            if not b.startswith('__') and b not in exclusions
238        )
239        attributes = [(name, str) for name in builtins_names]
240        C = make_dataclass('C', attributes)
241
242        c = C(*[name for name in builtins_names])
243
244        for name in builtins_names:
245            self.assertEqual(getattr(c, name), name)
246
247    def test_field_named_like_builtin_frozen(self):
248        # Attribute names can shadow built-in names
249        # since code generation is used.
250        # Ensure that this is not happening
251        # for frozen data classes.
252        exclusions = {'None', 'True', 'False'}
253        builtins_names = sorted(
254            b for b in builtins.__dict__.keys()
255            if not b.startswith('__') and b not in exclusions
256        )
257        attributes = [(name, str) for name in builtins_names]
258        C = make_dataclass('C', attributes, frozen=True)
259
260        c = C(*[name for name in builtins_names])
261
262        for name in builtins_names:
263            self.assertEqual(getattr(c, name), name)
264
265    def test_0_field_compare(self):
266        # Ensure that order=False is the default.
267        @dataclass
268        class C0:
269            pass
270
271        @dataclass(order=False)
272        class C1:
273            pass
274
275        for cls in [C0, C1]:
276            with self.subTest(cls=cls):
277                self.assertEqual(cls(), cls())
278                for idx, fn in enumerate([lambda a, b: a < b,
279                                          lambda a, b: a <= b,
280                                          lambda a, b: a > b,
281                                          lambda a, b: a >= b]):
282                    with self.subTest(idx=idx):
283                        with self.assertRaisesRegex(TypeError,
284                                                    f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
285                            fn(cls(), cls())
286
287        @dataclass(order=True)
288        class C:
289            pass
290        self.assertLessEqual(C(), C())
291        self.assertGreaterEqual(C(), C())
292
293    def test_1_field_compare(self):
294        # Ensure that order=False is the default.
295        @dataclass
296        class C0:
297            x: int
298
299        @dataclass(order=False)
300        class C1:
301            x: int
302
303        for cls in [C0, C1]:
304            with self.subTest(cls=cls):
305                self.assertEqual(cls(1), cls(1))
306                self.assertNotEqual(cls(0), cls(1))
307                for idx, fn in enumerate([lambda a, b: a < b,
308                                          lambda a, b: a <= b,
309                                          lambda a, b: a > b,
310                                          lambda a, b: a >= b]):
311                    with self.subTest(idx=idx):
312                        with self.assertRaisesRegex(TypeError,
313                                                    f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
314                            fn(cls(0), cls(0))
315
316        @dataclass(order=True)
317        class C:
318            x: int
319        self.assertLess(C(0), C(1))
320        self.assertLessEqual(C(0), C(1))
321        self.assertLessEqual(C(1), C(1))
322        self.assertGreater(C(1), C(0))
323        self.assertGreaterEqual(C(1), C(0))
324        self.assertGreaterEqual(C(1), C(1))
325
326    def test_simple_compare(self):
327        # Ensure that order=False is the default.
328        @dataclass
329        class C0:
330            x: int
331            y: int
332
333        @dataclass(order=False)
334        class C1:
335            x: int
336            y: int
337
338        for cls in [C0, C1]:
339            with self.subTest(cls=cls):
340                self.assertEqual(cls(0, 0), cls(0, 0))
341                self.assertEqual(cls(1, 2), cls(1, 2))
342                self.assertNotEqual(cls(1, 0), cls(0, 0))
343                self.assertNotEqual(cls(1, 0), cls(1, 1))
344                for idx, fn in enumerate([lambda a, b: a < b,
345                                          lambda a, b: a <= b,
346                                          lambda a, b: a > b,
347                                          lambda a, b: a >= b]):
348                    with self.subTest(idx=idx):
349                        with self.assertRaisesRegex(TypeError,
350                                                    f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
351                            fn(cls(0, 0), cls(0, 0))
352
353        @dataclass(order=True)
354        class C:
355            x: int
356            y: int
357
358        for idx, fn in enumerate([lambda a, b: a == b,
359                                  lambda a, b: a <= b,
360                                  lambda a, b: a >= b]):
361            with self.subTest(idx=idx):
362                self.assertTrue(fn(C(0, 0), C(0, 0)))
363
364        for idx, fn in enumerate([lambda a, b: a < b,
365                                  lambda a, b: a <= b,
366                                  lambda a, b: a != b]):
367            with self.subTest(idx=idx):
368                self.assertTrue(fn(C(0, 0), C(0, 1)))
369                self.assertTrue(fn(C(0, 1), C(1, 0)))
370                self.assertTrue(fn(C(1, 0), C(1, 1)))
371
372        for idx, fn in enumerate([lambda a, b: a > b,
373                                  lambda a, b: a >= b,
374                                  lambda a, b: a != b]):
375            with self.subTest(idx=idx):
376                self.assertTrue(fn(C(0, 1), C(0, 0)))
377                self.assertTrue(fn(C(1, 0), C(0, 1)))
378                self.assertTrue(fn(C(1, 1), C(1, 0)))
379
380    def test_compare_subclasses(self):
381        # Comparisons fail for subclasses, even if no fields
382        #  are added.
383        @dataclass
384        class B:
385            i: int
386
387        @dataclass
388        class C(B):
389            pass
390
391        for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False),
392                                              (lambda a, b: a != b, True)]):
393            with self.subTest(idx=idx):
394                self.assertEqual(fn(B(0), C(0)), expected)
395
396        for idx, fn in enumerate([lambda a, b: a < b,
397                                  lambda a, b: a <= b,
398                                  lambda a, b: a > b,
399                                  lambda a, b: a >= b]):
400            with self.subTest(idx=idx):
401                with self.assertRaisesRegex(TypeError,
402                                            "not supported between instances of 'B' and 'C'"):
403                    fn(B(0), C(0))
404
405    def test_eq_order(self):
406        # Test combining eq and order.
407        for (eq,    order, result   ) in [
408            (False, False, 'neither'),
409            (False, True,  'exception'),
410            (True,  False, 'eq_only'),
411            (True,  True,  'both'),
412        ]:
413            with self.subTest(eq=eq, order=order):
414                if result == 'exception':
415                    with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
416                        @dataclass(eq=eq, order=order)
417                        class C:
418                            pass
419                else:
420                    @dataclass(eq=eq, order=order)
421                    class C:
422                        pass
423
424                    if result == 'neither':
425                        self.assertNotIn('__eq__', C.__dict__)
426                        self.assertNotIn('__lt__', C.__dict__)
427                        self.assertNotIn('__le__', C.__dict__)
428                        self.assertNotIn('__gt__', C.__dict__)
429                        self.assertNotIn('__ge__', C.__dict__)
430                    elif result == 'both':
431                        self.assertIn('__eq__', C.__dict__)
432                        self.assertIn('__lt__', C.__dict__)
433                        self.assertIn('__le__', C.__dict__)
434                        self.assertIn('__gt__', C.__dict__)
435                        self.assertIn('__ge__', C.__dict__)
436                    elif result == 'eq_only':
437                        self.assertIn('__eq__', C.__dict__)
438                        self.assertNotIn('__lt__', C.__dict__)
439                        self.assertNotIn('__le__', C.__dict__)
440                        self.assertNotIn('__gt__', C.__dict__)
441                        self.assertNotIn('__ge__', C.__dict__)
442                    else:
443                        assert False, f'unknown result {result!r}'
444
445    def test_field_no_default(self):
446        @dataclass
447        class C:
448            x: int = field()
449
450        self.assertEqual(C(5).x, 5)
451
452        with self.assertRaisesRegex(TypeError,
453                                    r"__init__\(\) missing 1 required "
454                                    "positional argument: 'x'"):
455            C()
456
457    def test_field_default(self):
458        default = object()
459        @dataclass
460        class C:
461            x: object = field(default=default)
462
463        self.assertIs(C.x, default)
464        c = C(10)
465        self.assertEqual(c.x, 10)
466
467        # If we delete the instance attribute, we should then see the
468        #  class attribute.
469        del c.x
470        self.assertIs(c.x, default)
471
472        self.assertIs(C().x, default)
473
474    def test_not_in_repr(self):
475        @dataclass
476        class C:
477            x: int = field(repr=False)
478        with self.assertRaises(TypeError):
479            C()
480        c = C(10)
481        self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()')
482
483        @dataclass
484        class C:
485            x: int = field(repr=False)
486            y: int
487        c = C(10, 20)
488        self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)')
489
490    def test_not_in_compare(self):
491        @dataclass
492        class C:
493            x: int = 0
494            y: int = field(compare=False, default=4)
495
496        self.assertEqual(C(), C(0, 20))
497        self.assertEqual(C(1, 10), C(1, 20))
498        self.assertNotEqual(C(3), C(4, 10))
499        self.assertNotEqual(C(3, 10), C(4, 10))
500
501    def test_hash_field_rules(self):
502        # Test all 6 cases of:
503        #  hash=True/False/None
504        #  compare=True/False
505        for (hash_,    compare, result  ) in [
506            (True,     False,   'field' ),
507            (True,     True,    'field' ),
508            (False,    False,   'absent'),
509            (False,    True,    'absent'),
510            (None,     False,   'absent'),
511            (None,     True,    'field' ),
512            ]:
513            with self.subTest(hash=hash_, compare=compare):
514                @dataclass(unsafe_hash=True)
515                class C:
516                    x: int = field(compare=compare, hash=hash_, default=5)
517
518                if result == 'field':
519                    # __hash__ contains the field.
520                    self.assertEqual(hash(C(5)), hash((5,)))
521                elif result == 'absent':
522                    # The field is not present in the hash.
523                    self.assertEqual(hash(C(5)), hash(()))
524                else:
525                    assert False, f'unknown result {result!r}'
526
527    def test_init_false_no_default(self):
528        # If init=False and no default value, then the field won't be
529        #  present in the instance.
530        @dataclass
531        class C:
532            x: int = field(init=False)
533
534        self.assertNotIn('x', C().__dict__)
535
536        @dataclass
537        class C:
538            x: int
539            y: int = 0
540            z: int = field(init=False)
541            t: int = 10
542
543        self.assertNotIn('z', C(0).__dict__)
544        self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0})
545
546    def test_class_marker(self):
547        @dataclass
548        class C:
549            x: int
550            y: str = field(init=False, default=None)
551            z: str = field(repr=False)
552
553        the_fields = fields(C)
554        # the_fields is a tuple of 3 items, each value
555        #  is in __annotations__.
556        self.assertIsInstance(the_fields, tuple)
557        for f in the_fields:
558            self.assertIs(type(f), Field)
559            self.assertIn(f.name, C.__annotations__)
560
561        self.assertEqual(len(the_fields), 3)
562
563        self.assertEqual(the_fields[0].name, 'x')
564        self.assertEqual(the_fields[0].type, int)
565        self.assertFalse(hasattr(C, 'x'))
566        self.assertTrue (the_fields[0].init)
567        self.assertTrue (the_fields[0].repr)
568        self.assertEqual(the_fields[1].name, 'y')
569        self.assertEqual(the_fields[1].type, str)
570        self.assertIsNone(getattr(C, 'y'))
571        self.assertFalse(the_fields[1].init)
572        self.assertTrue (the_fields[1].repr)
573        self.assertEqual(the_fields[2].name, 'z')
574        self.assertEqual(the_fields[2].type, str)
575        self.assertFalse(hasattr(C, 'z'))
576        self.assertTrue (the_fields[2].init)
577        self.assertFalse(the_fields[2].repr)
578
579    def test_field_order(self):
580        @dataclass
581        class B:
582            a: str = 'B:a'
583            b: str = 'B:b'
584            c: str = 'B:c'
585
586        @dataclass
587        class C(B):
588            b: str = 'C:b'
589
590        self.assertEqual([(f.name, f.default) for f in fields(C)],
591                         [('a', 'B:a'),
592                          ('b', 'C:b'),
593                          ('c', 'B:c')])
594
595        @dataclass
596        class D(B):
597            c: str = 'D:c'
598
599        self.assertEqual([(f.name, f.default) for f in fields(D)],
600                         [('a', 'B:a'),
601                          ('b', 'B:b'),
602                          ('c', 'D:c')])
603
604        @dataclass
605        class E(D):
606            a: str = 'E:a'
607            d: str = 'E:d'
608
609        self.assertEqual([(f.name, f.default) for f in fields(E)],
610                         [('a', 'E:a'),
611                          ('b', 'B:b'),
612                          ('c', 'D:c'),
613                          ('d', 'E:d')])
614
615    def test_class_attrs(self):
616        # We only have a class attribute if a default value is
617        #  specified, either directly or via a field with a default.
618        default = object()
619        @dataclass
620        class C:
621            x: int
622            y: int = field(repr=False)
623            z: object = default
624            t: int = field(default=100)
625
626        self.assertFalse(hasattr(C, 'x'))
627        self.assertFalse(hasattr(C, 'y'))
628        self.assertIs   (C.z, default)
629        self.assertEqual(C.t, 100)
630
631    def test_disallowed_mutable_defaults(self):
632        # For the known types, don't allow mutable default values.
633        for typ, empty, non_empty in [(list, [], [1]),
634                                      (dict, {}, {0:1}),
635                                      (set, set(), set([1])),
636                                      ]:
637            with self.subTest(typ=typ):
638                # Can't use a zero-length value.
639                with self.assertRaisesRegex(ValueError,
640                                            f'mutable default {typ} for field '
641                                            'x is not allowed'):
642                    @dataclass
643                    class Point:
644                        x: typ = empty
645
646
647                # Nor a non-zero-length value
648                with self.assertRaisesRegex(ValueError,
649                                            f'mutable default {typ} for field '
650                                            'y is not allowed'):
651                    @dataclass
652                    class Point:
653                        y: typ = non_empty
654
655                # Check subtypes also fail.
656                class Subclass(typ): pass
657
658                with self.assertRaisesRegex(ValueError,
659                                            f"mutable default .*Subclass'>"
660                                            ' for field z is not allowed'
661                                            ):
662                    @dataclass
663                    class Point:
664                        z: typ = Subclass()
665
666                # Because this is a ClassVar, it can be mutable.
667                @dataclass
668                class C:
669                    z: ClassVar[typ] = typ()
670
671                # Because this is a ClassVar, it can be mutable.
672                @dataclass
673                class C:
674                    x: ClassVar[typ] = Subclass()
675
676    def test_deliberately_mutable_defaults(self):
677        # If a mutable default isn't in the known list of
678        #  (list, dict, set), then it's okay.
679        class Mutable:
680            def __init__(self):
681                self.l = []
682
683        @dataclass
684        class C:
685            x: Mutable
686
687        # These 2 instances will share this value of x.
688        lst = Mutable()
689        o1 = C(lst)
690        o2 = C(lst)
691        self.assertEqual(o1, o2)
692        o1.x.l.extend([1, 2])
693        self.assertEqual(o1, o2)
694        self.assertEqual(o1.x.l, [1, 2])
695        self.assertIs(o1.x, o2.x)
696
697    def test_no_options(self):
698        # Call with dataclass().
699        @dataclass()
700        class C:
701            x: int
702
703        self.assertEqual(C(42).x, 42)
704
705    def test_not_tuple(self):
706        # Make sure we can't be compared to a tuple.
707        @dataclass
708        class Point:
709            x: int
710            y: int
711        self.assertNotEqual(Point(1, 2), (1, 2))
712
713        # And that we can't compare to another unrelated dataclass.
714        @dataclass
715        class C:
716            x: int
717            y: int
718        self.assertNotEqual(Point(1, 3), C(1, 3))
719
720    def test_not_other_dataclass(self):
721        # Test that some of the problems with namedtuple don't happen
722        #  here.
723        @dataclass
724        class Point3D:
725            x: int
726            y: int
727            z: int
728
729        @dataclass
730        class Date:
731            year: int
732            month: int
733            day: int
734
735        self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3))
736        self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3))
737
738        # Make sure we can't unpack.
739        with self.assertRaisesRegex(TypeError, 'unpack'):
740            x, y, z = Point3D(4, 5, 6)
741
742        # Make sure another class with the same field names isn't
743        #  equal.
744        @dataclass
745        class Point3Dv1:
746            x: int = 0
747            y: int = 0
748            z: int = 0
749        self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1())
750
751    def test_function_annotations(self):
752        # Some dummy class and instance to use as a default.
753        class F:
754            pass
755        f = F()
756
757        def validate_class(cls):
758            # First, check __annotations__, even though they're not
759            #  function annotations.
760            self.assertEqual(cls.__annotations__['i'], int)
761            self.assertEqual(cls.__annotations__['j'], str)
762            self.assertEqual(cls.__annotations__['k'], F)
763            self.assertEqual(cls.__annotations__['l'], float)
764            self.assertEqual(cls.__annotations__['z'], complex)
765
766            # Verify __init__.
767
768            signature = inspect.signature(cls.__init__)
769            # Check the return type, should be None.
770            self.assertIs(signature.return_annotation, None)
771
772            # Check each parameter.
773            params = iter(signature.parameters.values())
774            param = next(params)
775            # This is testing an internal name, and probably shouldn't be tested.
776            self.assertEqual(param.name, 'self')
777            param = next(params)
778            self.assertEqual(param.name, 'i')
779            self.assertIs   (param.annotation, int)
780            self.assertEqual(param.default, inspect.Parameter.empty)
781            self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
782            param = next(params)
783            self.assertEqual(param.name, 'j')
784            self.assertIs   (param.annotation, str)
785            self.assertEqual(param.default, inspect.Parameter.empty)
786            self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
787            param = next(params)
788            self.assertEqual(param.name, 'k')
789            self.assertIs   (param.annotation, F)
790            # Don't test for the default, since it's set to MISSING.
791            self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
792            param = next(params)
793            self.assertEqual(param.name, 'l')
794            self.assertIs   (param.annotation, float)
795            # Don't test for the default, since it's set to MISSING.
796            self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
797            self.assertRaises(StopIteration, next, params)
798
799
800        @dataclass
801        class C:
802            i: int
803            j: str
804            k: F = f
805            l: float=field(default=None)
806            z: complex=field(default=3+4j, init=False)
807
808        validate_class(C)
809
810        # Now repeat with __hash__.
811        @dataclass(frozen=True, unsafe_hash=True)
812        class C:
813            i: int
814            j: str
815            k: F = f
816            l: float=field(default=None)
817            z: complex=field(default=3+4j, init=False)
818
819        validate_class(C)
820
821    def test_missing_default(self):
822        # Test that MISSING works the same as a default not being
823        #  specified.
824        @dataclass
825        class C:
826            x: int=field(default=MISSING)
827        with self.assertRaisesRegex(TypeError,
828                                    r'__init__\(\) missing 1 required '
829                                    'positional argument'):
830            C()
831        self.assertNotIn('x', C.__dict__)
832
833        @dataclass
834        class D:
835            x: int
836        with self.assertRaisesRegex(TypeError,
837                                    r'__init__\(\) missing 1 required '
838                                    'positional argument'):
839            D()
840        self.assertNotIn('x', D.__dict__)
841
842    def test_missing_default_factory(self):
843        # Test that MISSING works the same as a default factory not
844        #  being specified (which is really the same as a default not
845        #  being specified, too).
846        @dataclass
847        class C:
848            x: int=field(default_factory=MISSING)
849        with self.assertRaisesRegex(TypeError,
850                                    r'__init__\(\) missing 1 required '
851                                    'positional argument'):
852            C()
853        self.assertNotIn('x', C.__dict__)
854
855        @dataclass
856        class D:
857            x: int=field(default=MISSING, default_factory=MISSING)
858        with self.assertRaisesRegex(TypeError,
859                                    r'__init__\(\) missing 1 required '
860                                    'positional argument'):
861            D()
862        self.assertNotIn('x', D.__dict__)
863
864    def test_missing_repr(self):
865        self.assertIn('MISSING_TYPE object', repr(MISSING))
866
867    def test_dont_include_other_annotations(self):
868        @dataclass
869        class C:
870            i: int
871            def foo(self) -> int:
872                return 4
873            @property
874            def bar(self) -> int:
875                return 5
876        self.assertEqual(list(C.__annotations__), ['i'])
877        self.assertEqual(C(10).foo(), 4)
878        self.assertEqual(C(10).bar, 5)
879        self.assertEqual(C(10).i, 10)
880
881    def test_post_init(self):
882        # Just make sure it gets called
883        @dataclass
884        class C:
885            def __post_init__(self):
886                raise CustomError()
887        with self.assertRaises(CustomError):
888            C()
889
890        @dataclass
891        class C:
892            i: int = 10
893            def __post_init__(self):
894                if self.i == 10:
895                    raise CustomError()
896        with self.assertRaises(CustomError):
897            C()
898        # post-init gets called, but doesn't raise. This is just
899        #  checking that self is used correctly.
900        C(5)
901
902        # If there's not an __init__, then post-init won't get called.
903        @dataclass(init=False)
904        class C:
905            def __post_init__(self):
906                raise CustomError()
907        # Creating the class won't raise
908        C()
909
910        @dataclass
911        class C:
912            x: int = 0
913            def __post_init__(self):
914                self.x *= 2
915        self.assertEqual(C().x, 0)
916        self.assertEqual(C(2).x, 4)
917
918        # Make sure that if we're frozen, post-init can't set
919        #  attributes.
920        @dataclass(frozen=True)
921        class C:
922            x: int = 0
923            def __post_init__(self):
924                self.x *= 2
925        with self.assertRaises(FrozenInstanceError):
926            C()
927
928    def test_post_init_super(self):
929        # Make sure super() post-init isn't called by default.
930        class B:
931            def __post_init__(self):
932                raise CustomError()
933
934        @dataclass
935        class C(B):
936            def __post_init__(self):
937                self.x = 5
938
939        self.assertEqual(C().x, 5)
940
941        # Now call super(), and it will raise.
942        @dataclass
943        class C(B):
944            def __post_init__(self):
945                super().__post_init__()
946
947        with self.assertRaises(CustomError):
948            C()
949
950        # Make sure post-init is called, even if not defined in our
951        #  class.
952        @dataclass
953        class C(B):
954            pass
955
956        with self.assertRaises(CustomError):
957            C()
958
959    def test_post_init_staticmethod(self):
960        flag = False
961        @dataclass
962        class C:
963            x: int
964            y: int
965            @staticmethod
966            def __post_init__():
967                nonlocal flag
968                flag = True
969
970        self.assertFalse(flag)
971        c = C(3, 4)
972        self.assertEqual((c.x, c.y), (3, 4))
973        self.assertTrue(flag)
974
975    def test_post_init_classmethod(self):
976        @dataclass
977        class C:
978            flag = False
979            x: int
980            y: int
981            @classmethod
982            def __post_init__(cls):
983                cls.flag = True
984
985        self.assertFalse(C.flag)
986        c = C(3, 4)
987        self.assertEqual((c.x, c.y), (3, 4))
988        self.assertTrue(C.flag)
989
990    def test_class_var(self):
991        # Make sure ClassVars are ignored in __init__, __repr__, etc.
992        @dataclass
993        class C:
994            x: int
995            y: int = 10
996            z: ClassVar[int] = 1000
997            w: ClassVar[int] = 2000
998            t: ClassVar[int] = 3000
999            s: ClassVar      = 4000
1000
1001        c = C(5)
1002        self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)')
1003        self.assertEqual(len(fields(C)), 2)                 # We have 2 fields.
1004        self.assertEqual(len(C.__annotations__), 6)         # And 4 ClassVars.
1005        self.assertEqual(c.z, 1000)
1006        self.assertEqual(c.w, 2000)
1007        self.assertEqual(c.t, 3000)
1008        self.assertEqual(c.s, 4000)
1009        C.z += 1
1010        self.assertEqual(c.z, 1001)
1011        c = C(20)
1012        self.assertEqual((c.x, c.y), (20, 10))
1013        self.assertEqual(c.z, 1001)
1014        self.assertEqual(c.w, 2000)
1015        self.assertEqual(c.t, 3000)
1016        self.assertEqual(c.s, 4000)
1017
1018    def test_class_var_no_default(self):
1019        # If a ClassVar has no default value, it should not be set on the class.
1020        @dataclass
1021        class C:
1022            x: ClassVar[int]
1023
1024        self.assertNotIn('x', C.__dict__)
1025
1026    def test_class_var_default_factory(self):
1027        # It makes no sense for a ClassVar to have a default factory. When
1028        #  would it be called? Call it yourself, since it's class-wide.
1029        with self.assertRaisesRegex(TypeError,
1030                                    'cannot have a default factory'):
1031            @dataclass
1032            class C:
1033                x: ClassVar[int] = field(default_factory=int)
1034
1035            self.assertNotIn('x', C.__dict__)
1036
1037    def test_class_var_with_default(self):
1038        # If a ClassVar has a default value, it should be set on the class.
1039        @dataclass
1040        class C:
1041            x: ClassVar[int] = 10
1042        self.assertEqual(C.x, 10)
1043
1044        @dataclass
1045        class C:
1046            x: ClassVar[int] = field(default=10)
1047        self.assertEqual(C.x, 10)
1048
1049    def test_class_var_frozen(self):
1050        # Make sure ClassVars work even if we're frozen.
1051        @dataclass(frozen=True)
1052        class C:
1053            x: int
1054            y: int = 10
1055            z: ClassVar[int] = 1000
1056            w: ClassVar[int] = 2000
1057            t: ClassVar[int] = 3000
1058
1059        c = C(5)
1060        self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)')
1061        self.assertEqual(len(fields(C)), 2)                 # We have 2 fields
1062        self.assertEqual(len(C.__annotations__), 5)         # And 3 ClassVars
1063        self.assertEqual(c.z, 1000)
1064        self.assertEqual(c.w, 2000)
1065        self.assertEqual(c.t, 3000)
1066        # We can still modify the ClassVar, it's only instances that are
1067        #  frozen.
1068        C.z += 1
1069        self.assertEqual(c.z, 1001)
1070        c = C(20)
1071        self.assertEqual((c.x, c.y), (20, 10))
1072        self.assertEqual(c.z, 1001)
1073        self.assertEqual(c.w, 2000)
1074        self.assertEqual(c.t, 3000)
1075
1076    def test_init_var_no_default(self):
1077        # If an InitVar has no default value, it should not be set on the class.
1078        @dataclass
1079        class C:
1080            x: InitVar[int]
1081
1082        self.assertNotIn('x', C.__dict__)
1083
1084    def test_init_var_default_factory(self):
1085        # It makes no sense for an InitVar to have a default factory. When
1086        #  would it be called? Call it yourself, since it's class-wide.
1087        with self.assertRaisesRegex(TypeError,
1088                                    'cannot have a default factory'):
1089            @dataclass
1090            class C:
1091                x: InitVar[int] = field(default_factory=int)
1092
1093            self.assertNotIn('x', C.__dict__)
1094
1095    def test_init_var_with_default(self):
1096        # If an InitVar has a default value, it should be set on the class.
1097        @dataclass
1098        class C:
1099            x: InitVar[int] = 10
1100        self.assertEqual(C.x, 10)
1101
1102        @dataclass
1103        class C:
1104            x: InitVar[int] = field(default=10)
1105        self.assertEqual(C.x, 10)
1106
1107    def test_init_var(self):
1108        @dataclass
1109        class C:
1110            x: int = None
1111            init_param: InitVar[int] = None
1112
1113            def __post_init__(self, init_param):
1114                if self.x is None:
1115                    self.x = init_param*2
1116
1117        c = C(init_param=10)
1118        self.assertEqual(c.x, 20)
1119
1120    def test_init_var_preserve_type(self):
1121        self.assertEqual(InitVar[int].type, int)
1122
1123        # Make sure the repr is correct.
1124        self.assertEqual(repr(InitVar[int]), 'dataclasses.InitVar[int]')
1125        self.assertEqual(repr(InitVar[List[int]]),
1126                         'dataclasses.InitVar[typing.List[int]]')
1127
1128    def test_init_var_inheritance(self):
1129        # Note that this deliberately tests that a dataclass need not
1130        #  have a __post_init__ function if it has an InitVar field.
1131        #  It could just be used in a derived class, as shown here.
1132        @dataclass
1133        class Base:
1134            x: int
1135            init_base: InitVar[int]
1136
1137        # We can instantiate by passing the InitVar, even though
1138        #  it's not used.
1139        b = Base(0, 10)
1140        self.assertEqual(vars(b), {'x': 0})
1141
1142        @dataclass
1143        class C(Base):
1144            y: int
1145            init_derived: InitVar[int]
1146
1147            def __post_init__(self, init_base, init_derived):
1148                self.x = self.x + init_base
1149                self.y = self.y + init_derived
1150
1151        c = C(10, 11, 50, 51)
1152        self.assertEqual(vars(c), {'x': 21, 'y': 101})
1153
1154    def test_default_factory(self):
1155        # Test a factory that returns a new list.
1156        @dataclass
1157        class C:
1158            x: int
1159            y: list = field(default_factory=list)
1160
1161        c0 = C(3)
1162        c1 = C(3)
1163        self.assertEqual(c0.x, 3)
1164        self.assertEqual(c0.y, [])
1165        self.assertEqual(c0, c1)
1166        self.assertIsNot(c0.y, c1.y)
1167        self.assertEqual(astuple(C(5, [1])), (5, [1]))
1168
1169        # Test a factory that returns a shared list.
1170        l = []
1171        @dataclass
1172        class C:
1173            x: int
1174            y: list = field(default_factory=lambda: l)
1175
1176        c0 = C(3)
1177        c1 = C(3)
1178        self.assertEqual(c0.x, 3)
1179        self.assertEqual(c0.y, [])
1180        self.assertEqual(c0, c1)
1181        self.assertIs(c0.y, c1.y)
1182        self.assertEqual(astuple(C(5, [1])), (5, [1]))
1183
1184        # Test various other field flags.
1185        # repr
1186        @dataclass
1187        class C:
1188            x: list = field(default_factory=list, repr=False)
1189        self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()')
1190        self.assertEqual(C().x, [])
1191
1192        # hash
1193        @dataclass(unsafe_hash=True)
1194        class C:
1195            x: list = field(default_factory=list, hash=False)
1196        self.assertEqual(astuple(C()), ([],))
1197        self.assertEqual(hash(C()), hash(()))
1198
1199        # init (see also test_default_factory_with_no_init)
1200        @dataclass
1201        class C:
1202            x: list = field(default_factory=list, init=False)
1203        self.assertEqual(astuple(C()), ([],))
1204
1205        # compare
1206        @dataclass
1207        class C:
1208            x: list = field(default_factory=list, compare=False)
1209        self.assertEqual(C(), C([1]))
1210
1211    def test_default_factory_with_no_init(self):
1212        # We need a factory with a side effect.
1213        factory = Mock()
1214
1215        @dataclass
1216        class C:
1217            x: list = field(default_factory=factory, init=False)
1218
1219        # Make sure the default factory is called for each new instance.
1220        C().x
1221        self.assertEqual(factory.call_count, 1)
1222        C().x
1223        self.assertEqual(factory.call_count, 2)
1224
1225    def test_default_factory_not_called_if_value_given(self):
1226        # We need a factory that we can test if it's been called.
1227        factory = Mock()
1228
1229        @dataclass
1230        class C:
1231            x: int = field(default_factory=factory)
1232
1233        # Make sure that if a field has a default factory function,
1234        #  it's not called if a value is specified.
1235        C().x
1236        self.assertEqual(factory.call_count, 1)
1237        self.assertEqual(C(10).x, 10)
1238        self.assertEqual(factory.call_count, 1)
1239        C().x
1240        self.assertEqual(factory.call_count, 2)
1241
1242    def test_default_factory_derived(self):
1243        # See bpo-32896.
1244        @dataclass
1245        class Foo:
1246            x: dict = field(default_factory=dict)
1247
1248        @dataclass
1249        class Bar(Foo):
1250            y: int = 1
1251
1252        self.assertEqual(Foo().x, {})
1253        self.assertEqual(Bar().x, {})
1254        self.assertEqual(Bar().y, 1)
1255
1256        @dataclass
1257        class Baz(Foo):
1258            pass
1259        self.assertEqual(Baz().x, {})
1260
1261    def test_intermediate_non_dataclass(self):
1262        # Test that an intermediate class that defines
1263        #  annotations does not define fields.
1264
1265        @dataclass
1266        class A:
1267            x: int
1268
1269        class B(A):
1270            y: int
1271
1272        @dataclass
1273        class C(B):
1274            z: int
1275
1276        c = C(1, 3)
1277        self.assertEqual((c.x, c.z), (1, 3))
1278
1279        # .y was not initialized.
1280        with self.assertRaisesRegex(AttributeError,
1281                                    'object has no attribute'):
1282            c.y
1283
1284        # And if we again derive a non-dataclass, no fields are added.
1285        class D(C):
1286            t: int
1287        d = D(4, 5)
1288        self.assertEqual((d.x, d.z), (4, 5))
1289
1290    def test_classvar_default_factory(self):
1291        # It's an error for a ClassVar to have a factory function.
1292        with self.assertRaisesRegex(TypeError,
1293                                    'cannot have a default factory'):
1294            @dataclass
1295            class C:
1296                x: ClassVar[int] = field(default_factory=int)
1297
1298    def test_is_dataclass(self):
1299        class NotDataClass:
1300            pass
1301
1302        self.assertFalse(is_dataclass(0))
1303        self.assertFalse(is_dataclass(int))
1304        self.assertFalse(is_dataclass(NotDataClass))
1305        self.assertFalse(is_dataclass(NotDataClass()))
1306
1307        @dataclass
1308        class C:
1309            x: int
1310
1311        @dataclass
1312        class D:
1313            d: C
1314            e: int
1315
1316        c = C(10)
1317        d = D(c, 4)
1318
1319        self.assertTrue(is_dataclass(C))
1320        self.assertTrue(is_dataclass(c))
1321        self.assertFalse(is_dataclass(c.x))
1322        self.assertTrue(is_dataclass(d.d))
1323        self.assertFalse(is_dataclass(d.e))
1324
1325    def test_is_dataclass_when_getattr_always_returns(self):
1326        # See bpo-37868.
1327        class A:
1328            def __getattr__(self, key):
1329                return 0
1330        self.assertFalse(is_dataclass(A))
1331        a = A()
1332
1333        # Also test for an instance attribute.
1334        class B:
1335            pass
1336        b = B()
1337        b.__dataclass_fields__ = []
1338
1339        for obj in a, b:
1340            with self.subTest(obj=obj):
1341                self.assertFalse(is_dataclass(obj))
1342
1343                # Indirect tests for _is_dataclass_instance().
1344                with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1345                    asdict(obj)
1346                with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1347                    astuple(obj)
1348                with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1349                    replace(obj, x=0)
1350
1351    def test_helper_fields_with_class_instance(self):
1352        # Check that we can call fields() on either a class or instance,
1353        #  and get back the same thing.
1354        @dataclass
1355        class C:
1356            x: int
1357            y: float
1358
1359        self.assertEqual(fields(C), fields(C(0, 0.0)))
1360
1361    def test_helper_fields_exception(self):
1362        # Check that TypeError is raised if not passed a dataclass or
1363        #  instance.
1364        with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1365            fields(0)
1366
1367        class C: pass
1368        with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1369            fields(C)
1370        with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1371            fields(C())
1372
1373    def test_helper_asdict(self):
1374        # Basic tests for asdict(), it should return a new dictionary.
1375        @dataclass
1376        class C:
1377            x: int
1378            y: int
1379        c = C(1, 2)
1380
1381        self.assertEqual(asdict(c), {'x': 1, 'y': 2})
1382        self.assertEqual(asdict(c), asdict(c))
1383        self.assertIsNot(asdict(c), asdict(c))
1384        c.x = 42
1385        self.assertEqual(asdict(c), {'x': 42, 'y': 2})
1386        self.assertIs(type(asdict(c)), dict)
1387
1388    def test_helper_asdict_raises_on_classes(self):
1389        # asdict() should raise on a class object.
1390        @dataclass
1391        class C:
1392            x: int
1393            y: int
1394        with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1395            asdict(C)
1396        with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1397            asdict(int)
1398
1399    def test_helper_asdict_copy_values(self):
1400        @dataclass
1401        class C:
1402            x: int
1403            y: List[int] = field(default_factory=list)
1404        initial = []
1405        c = C(1, initial)
1406        d = asdict(c)
1407        self.assertEqual(d['y'], initial)
1408        self.assertIsNot(d['y'], initial)
1409        c = C(1)
1410        d = asdict(c)
1411        d['y'].append(1)
1412        self.assertEqual(c.y, [])
1413
1414    def test_helper_asdict_nested(self):
1415        @dataclass
1416        class UserId:
1417            token: int
1418            group: int
1419        @dataclass
1420        class User:
1421            name: str
1422            id: UserId
1423        u = User('Joe', UserId(123, 1))
1424        d = asdict(u)
1425        self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}})
1426        self.assertIsNot(asdict(u), asdict(u))
1427        u.id.group = 2
1428        self.assertEqual(asdict(u), {'name': 'Joe',
1429                                     'id': {'token': 123, 'group': 2}})
1430
1431    def test_helper_asdict_builtin_containers(self):
1432        @dataclass
1433        class User:
1434            name: str
1435            id: int
1436        @dataclass
1437        class GroupList:
1438            id: int
1439            users: List[User]
1440        @dataclass
1441        class GroupTuple:
1442            id: int
1443            users: Tuple[User, ...]
1444        @dataclass
1445        class GroupDict:
1446            id: int
1447            users: Dict[str, User]
1448        a = User('Alice', 1)
1449        b = User('Bob', 2)
1450        gl = GroupList(0, [a, b])
1451        gt = GroupTuple(0, (a, b))
1452        gd = GroupDict(0, {'first': a, 'second': b})
1453        self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1},
1454                                                         {'name': 'Bob', 'id': 2}]})
1455        self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1},
1456                                                         {'name': 'Bob', 'id': 2})})
1457        self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1},
1458                                                         'second': {'name': 'Bob', 'id': 2}}})
1459
1460    def test_helper_asdict_builtin_object_containers(self):
1461        @dataclass
1462        class Child:
1463            d: object
1464
1465        @dataclass
1466        class Parent:
1467            child: Child
1468
1469        self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}})
1470        self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}})
1471
1472    def test_helper_asdict_factory(self):
1473        @dataclass
1474        class C:
1475            x: int
1476            y: int
1477        c = C(1, 2)
1478        d = asdict(c, dict_factory=OrderedDict)
1479        self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)]))
1480        self.assertIsNot(d, asdict(c, dict_factory=OrderedDict))
1481        c.x = 42
1482        d = asdict(c, dict_factory=OrderedDict)
1483        self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)]))
1484        self.assertIs(type(d), OrderedDict)
1485
1486    def test_helper_asdict_namedtuple(self):
1487        T = namedtuple('T', 'a b c')
1488        @dataclass
1489        class C:
1490            x: str
1491            y: T
1492        c = C('outer', T(1, C('inner', T(11, 12, 13)), 2))
1493
1494        d = asdict(c)
1495        self.assertEqual(d, {'x': 'outer',
1496                             'y': T(1,
1497                                    {'x': 'inner',
1498                                     'y': T(11, 12, 13)},
1499                                    2),
1500                             }
1501                         )
1502
1503        # Now with a dict_factory.  OrderedDict is convenient, but
1504        # since it compares to dicts, we also need to have separate
1505        # assertIs tests.
1506        d = asdict(c, dict_factory=OrderedDict)
1507        self.assertEqual(d, {'x': 'outer',
1508                             'y': T(1,
1509                                    {'x': 'inner',
1510                                     'y': T(11, 12, 13)},
1511                                    2),
1512                             }
1513                         )
1514
1515        # Make sure that the returned dicts are actually OrderedDicts.
1516        self.assertIs(type(d), OrderedDict)
1517        self.assertIs(type(d['y'][1]), OrderedDict)
1518
1519    def test_helper_asdict_namedtuple_key(self):
1520        # Ensure that a field that contains a dict which has a
1521        # namedtuple as a key works with asdict().
1522
1523        @dataclass
1524        class C:
1525            f: dict
1526        T = namedtuple('T', 'a')
1527
1528        c = C({T('an a'): 0})
1529
1530        self.assertEqual(asdict(c), {'f': {T(a='an a'): 0}})
1531
1532    def test_helper_asdict_namedtuple_derived(self):
1533        class T(namedtuple('Tbase', 'a')):
1534            def my_a(self):
1535                return self.a
1536
1537        @dataclass
1538        class C:
1539            f: T
1540
1541        t = T(6)
1542        c = C(t)
1543
1544        d = asdict(c)
1545        self.assertEqual(d, {'f': T(a=6)})
1546        # Make sure that t has been copied, not used directly.
1547        self.assertIsNot(d['f'], t)
1548        self.assertEqual(d['f'].my_a(), 6)
1549
1550    def test_helper_astuple(self):
1551        # Basic tests for astuple(), it should return a new tuple.
1552        @dataclass
1553        class C:
1554            x: int
1555            y: int = 0
1556        c = C(1)
1557
1558        self.assertEqual(astuple(c), (1, 0))
1559        self.assertEqual(astuple(c), astuple(c))
1560        self.assertIsNot(astuple(c), astuple(c))
1561        c.y = 42
1562        self.assertEqual(astuple(c), (1, 42))
1563        self.assertIs(type(astuple(c)), tuple)
1564
1565    def test_helper_astuple_raises_on_classes(self):
1566        # astuple() should raise on a class object.
1567        @dataclass
1568        class C:
1569            x: int
1570            y: int
1571        with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1572            astuple(C)
1573        with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1574            astuple(int)
1575
1576    def test_helper_astuple_copy_values(self):
1577        @dataclass
1578        class C:
1579            x: int
1580            y: List[int] = field(default_factory=list)
1581        initial = []
1582        c = C(1, initial)
1583        t = astuple(c)
1584        self.assertEqual(t[1], initial)
1585        self.assertIsNot(t[1], initial)
1586        c = C(1)
1587        t = astuple(c)
1588        t[1].append(1)
1589        self.assertEqual(c.y, [])
1590
1591    def test_helper_astuple_nested(self):
1592        @dataclass
1593        class UserId:
1594            token: int
1595            group: int
1596        @dataclass
1597        class User:
1598            name: str
1599            id: UserId
1600        u = User('Joe', UserId(123, 1))
1601        t = astuple(u)
1602        self.assertEqual(t, ('Joe', (123, 1)))
1603        self.assertIsNot(astuple(u), astuple(u))
1604        u.id.group = 2
1605        self.assertEqual(astuple(u), ('Joe', (123, 2)))
1606
1607    def test_helper_astuple_builtin_containers(self):
1608        @dataclass
1609        class User:
1610            name: str
1611            id: int
1612        @dataclass
1613        class GroupList:
1614            id: int
1615            users: List[User]
1616        @dataclass
1617        class GroupTuple:
1618            id: int
1619            users: Tuple[User, ...]
1620        @dataclass
1621        class GroupDict:
1622            id: int
1623            users: Dict[str, User]
1624        a = User('Alice', 1)
1625        b = User('Bob', 2)
1626        gl = GroupList(0, [a, b])
1627        gt = GroupTuple(0, (a, b))
1628        gd = GroupDict(0, {'first': a, 'second': b})
1629        self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)]))
1630        self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2))))
1631        self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)}))
1632
1633    def test_helper_astuple_builtin_object_containers(self):
1634        @dataclass
1635        class Child:
1636            d: object
1637
1638        @dataclass
1639        class Parent:
1640            child: Child
1641
1642        self.assertEqual(astuple(Parent(Child([1]))), (([1],),))
1643        self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),))
1644
1645    def test_helper_astuple_factory(self):
1646        @dataclass
1647        class C:
1648            x: int
1649            y: int
1650        NT = namedtuple('NT', 'x y')
1651        def nt(lst):
1652            return NT(*lst)
1653        c = C(1, 2)
1654        t = astuple(c, tuple_factory=nt)
1655        self.assertEqual(t, NT(1, 2))
1656        self.assertIsNot(t, astuple(c, tuple_factory=nt))
1657        c.x = 42
1658        t = astuple(c, tuple_factory=nt)
1659        self.assertEqual(t, NT(42, 2))
1660        self.assertIs(type(t), NT)
1661
1662    def test_helper_astuple_namedtuple(self):
1663        T = namedtuple('T', 'a b c')
1664        @dataclass
1665        class C:
1666            x: str
1667            y: T
1668        c = C('outer', T(1, C('inner', T(11, 12, 13)), 2))
1669
1670        t = astuple(c)
1671        self.assertEqual(t, ('outer', T(1, ('inner', (11, 12, 13)), 2)))
1672
1673        # Now, using a tuple_factory.  list is convenient here.
1674        t = astuple(c, tuple_factory=list)
1675        self.assertEqual(t, ['outer', T(1, ['inner', T(11, 12, 13)], 2)])
1676
1677    def test_dynamic_class_creation(self):
1678        cls_dict = {'__annotations__': {'x': int, 'y': int},
1679                    }
1680
1681        # Create the class.
1682        cls = type('C', (), cls_dict)
1683
1684        # Make it a dataclass.
1685        cls1 = dataclass(cls)
1686
1687        self.assertEqual(cls1, cls)
1688        self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2})
1689
1690    def test_dynamic_class_creation_using_field(self):
1691        cls_dict = {'__annotations__': {'x': int, 'y': int},
1692                    'y': field(default=5),
1693                    }
1694
1695        # Create the class.
1696        cls = type('C', (), cls_dict)
1697
1698        # Make it a dataclass.
1699        cls1 = dataclass(cls)
1700
1701        self.assertEqual(cls1, cls)
1702        self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5})
1703
1704    def test_init_in_order(self):
1705        @dataclass
1706        class C:
1707            a: int
1708            b: int = field()
1709            c: list = field(default_factory=list, init=False)
1710            d: list = field(default_factory=list)
1711            e: int = field(default=4, init=False)
1712            f: int = 4
1713
1714        calls = []
1715        def setattr(self, name, value):
1716            calls.append((name, value))
1717
1718        C.__setattr__ = setattr
1719        c = C(0, 1)
1720        self.assertEqual(('a', 0), calls[0])
1721        self.assertEqual(('b', 1), calls[1])
1722        self.assertEqual(('c', []), calls[2])
1723        self.assertEqual(('d', []), calls[3])
1724        self.assertNotIn(('e', 4), calls)
1725        self.assertEqual(('f', 4), calls[4])
1726
1727    def test_items_in_dicts(self):
1728        @dataclass
1729        class C:
1730            a: int
1731            b: list = field(default_factory=list, init=False)
1732            c: list = field(default_factory=list)
1733            d: int = field(default=4, init=False)
1734            e: int = 0
1735
1736        c = C(0)
1737        # Class dict
1738        self.assertNotIn('a', C.__dict__)
1739        self.assertNotIn('b', C.__dict__)
1740        self.assertNotIn('c', C.__dict__)
1741        self.assertIn('d', C.__dict__)
1742        self.assertEqual(C.d, 4)
1743        self.assertIn('e', C.__dict__)
1744        self.assertEqual(C.e, 0)
1745        # Instance dict
1746        self.assertIn('a', c.__dict__)
1747        self.assertEqual(c.a, 0)
1748        self.assertIn('b', c.__dict__)
1749        self.assertEqual(c.b, [])
1750        self.assertIn('c', c.__dict__)
1751        self.assertEqual(c.c, [])
1752        self.assertNotIn('d', c.__dict__)
1753        self.assertIn('e', c.__dict__)
1754        self.assertEqual(c.e, 0)
1755
1756    def test_alternate_classmethod_constructor(self):
1757        # Since __post_init__ can't take params, use a classmethod
1758        #  alternate constructor.  This is mostly an example to show
1759        #  how to use this technique.
1760        @dataclass
1761        class C:
1762            x: int
1763            @classmethod
1764            def from_file(cls, filename):
1765                # In a real example, create a new instance
1766                #  and populate 'x' from contents of a file.
1767                value_in_file = 20
1768                return cls(value_in_file)
1769
1770        self.assertEqual(C.from_file('filename').x, 20)
1771
1772    def test_field_metadata_default(self):
1773        # Make sure the default metadata is read-only and of
1774        #  zero length.
1775        @dataclass
1776        class C:
1777            i: int
1778
1779        self.assertFalse(fields(C)[0].metadata)
1780        self.assertEqual(len(fields(C)[0].metadata), 0)
1781        with self.assertRaisesRegex(TypeError,
1782                                    'does not support item assignment'):
1783            fields(C)[0].metadata['test'] = 3
1784
1785    def test_field_metadata_mapping(self):
1786        # Make sure only a mapping can be passed as metadata
1787        #  zero length.
1788        with self.assertRaises(TypeError):
1789            @dataclass
1790            class C:
1791                i: int = field(metadata=0)
1792
1793        # Make sure an empty dict works.
1794        d = {}
1795        @dataclass
1796        class C:
1797            i: int = field(metadata=d)
1798        self.assertFalse(fields(C)[0].metadata)
1799        self.assertEqual(len(fields(C)[0].metadata), 0)
1800        # Update should work (see bpo-35960).
1801        d['foo'] = 1
1802        self.assertEqual(len(fields(C)[0].metadata), 1)
1803        self.assertEqual(fields(C)[0].metadata['foo'], 1)
1804        with self.assertRaisesRegex(TypeError,
1805                                    'does not support item assignment'):
1806            fields(C)[0].metadata['test'] = 3
1807
1808        # Make sure a non-empty dict works.
1809        d = {'test': 10, 'bar': '42', 3: 'three'}
1810        @dataclass
1811        class C:
1812            i: int = field(metadata=d)
1813        self.assertEqual(len(fields(C)[0].metadata), 3)
1814        self.assertEqual(fields(C)[0].metadata['test'], 10)
1815        self.assertEqual(fields(C)[0].metadata['bar'], '42')
1816        self.assertEqual(fields(C)[0].metadata[3], 'three')
1817        # Update should work.
1818        d['foo'] = 1
1819        self.assertEqual(len(fields(C)[0].metadata), 4)
1820        self.assertEqual(fields(C)[0].metadata['foo'], 1)
1821        with self.assertRaises(KeyError):
1822            # Non-existent key.
1823            fields(C)[0].metadata['baz']
1824        with self.assertRaisesRegex(TypeError,
1825                                    'does not support item assignment'):
1826            fields(C)[0].metadata['test'] = 3
1827
1828    def test_field_metadata_custom_mapping(self):
1829        # Try a custom mapping.
1830        class SimpleNameSpace:
1831            def __init__(self, **kw):
1832                self.__dict__.update(kw)
1833
1834            def __getitem__(self, item):
1835                if item == 'xyzzy':
1836                    return 'plugh'
1837                return getattr(self, item)
1838
1839            def __len__(self):
1840                return self.__dict__.__len__()
1841
1842        @dataclass
1843        class C:
1844            i: int = field(metadata=SimpleNameSpace(a=10))
1845
1846        self.assertEqual(len(fields(C)[0].metadata), 1)
1847        self.assertEqual(fields(C)[0].metadata['a'], 10)
1848        with self.assertRaises(AttributeError):
1849            fields(C)[0].metadata['b']
1850        # Make sure we're still talking to our custom mapping.
1851        self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh')
1852
1853    def test_generic_dataclasses(self):
1854        T = TypeVar('T')
1855
1856        @dataclass
1857        class LabeledBox(Generic[T]):
1858            content: T
1859            label: str = '<unknown>'
1860
1861        box = LabeledBox(42)
1862        self.assertEqual(box.content, 42)
1863        self.assertEqual(box.label, '<unknown>')
1864
1865        # Subscripting the resulting class should work, etc.
1866        Alias = List[LabeledBox[int]]
1867
1868    def test_generic_extending(self):
1869        S = TypeVar('S')
1870        T = TypeVar('T')
1871
1872        @dataclass
1873        class Base(Generic[T, S]):
1874            x: T
1875            y: S
1876
1877        @dataclass
1878        class DataDerived(Base[int, T]):
1879            new_field: str
1880        Alias = DataDerived[str]
1881        c = Alias(0, 'test1', 'test2')
1882        self.assertEqual(astuple(c), (0, 'test1', 'test2'))
1883
1884        class NonDataDerived(Base[int, T]):
1885            def new_method(self):
1886                return self.y
1887        Alias = NonDataDerived[float]
1888        c = Alias(10, 1.0)
1889        self.assertEqual(c.new_method(), 1.0)
1890
1891    def test_generic_dynamic(self):
1892        T = TypeVar('T')
1893
1894        @dataclass
1895        class Parent(Generic[T]):
1896            x: T
1897        Child = make_dataclass('Child', [('y', T), ('z', Optional[T], None)],
1898                               bases=(Parent[int], Generic[T]), namespace={'other': 42})
1899        self.assertIs(Child[int](1, 2).z, None)
1900        self.assertEqual(Child[int](1, 2, 3).z, 3)
1901        self.assertEqual(Child[int](1, 2, 3).other, 42)
1902        # Check that type aliases work correctly.
1903        Alias = Child[T]
1904        self.assertEqual(Alias[int](1, 2).x, 1)
1905        # Check MRO resolution.
1906        self.assertEqual(Child.__mro__, (Child, Parent, Generic, object))
1907
1908    def test_dataclassses_pickleable(self):
1909        global P, Q, R
1910        @dataclass
1911        class P:
1912            x: int
1913            y: int = 0
1914        @dataclass
1915        class Q:
1916            x: int
1917            y: int = field(default=0, init=False)
1918        @dataclass
1919        class R:
1920            x: int
1921            y: List[int] = field(default_factory=list)
1922        q = Q(1)
1923        q.y = 2
1924        samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])]
1925        for sample in samples:
1926            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1927                with self.subTest(sample=sample, proto=proto):
1928                    new_sample = pickle.loads(pickle.dumps(sample, proto))
1929                    self.assertEqual(sample.x, new_sample.x)
1930                    self.assertEqual(sample.y, new_sample.y)
1931                    self.assertIsNot(sample, new_sample)
1932                    new_sample.x = 42
1933                    another_new_sample = pickle.loads(pickle.dumps(new_sample, proto))
1934                    self.assertEqual(new_sample.x, another_new_sample.x)
1935                    self.assertEqual(sample.y, another_new_sample.y)
1936
1937
1938class TestFieldNoAnnotation(unittest.TestCase):
1939    def test_field_without_annotation(self):
1940        with self.assertRaisesRegex(TypeError,
1941                                    "'f' is a field but has no type annotation"):
1942            @dataclass
1943            class C:
1944                f = field()
1945
1946    def test_field_without_annotation_but_annotation_in_base(self):
1947        @dataclass
1948        class B:
1949            f: int
1950
1951        with self.assertRaisesRegex(TypeError,
1952                                    "'f' is a field but has no type annotation"):
1953            # This is still an error: make sure we don't pick up the
1954            #  type annotation in the base class.
1955            @dataclass
1956            class C(B):
1957                f = field()
1958
1959    def test_field_without_annotation_but_annotation_in_base_not_dataclass(self):
1960        # Same test, but with the base class not a dataclass.
1961        class B:
1962            f: int
1963
1964        with self.assertRaisesRegex(TypeError,
1965                                    "'f' is a field but has no type annotation"):
1966            # This is still an error: make sure we don't pick up the
1967            #  type annotation in the base class.
1968            @dataclass
1969            class C(B):
1970                f = field()
1971
1972
1973class TestDocString(unittest.TestCase):
1974    def assertDocStrEqual(self, a, b):
1975        # Because 3.6 and 3.7 differ in how inspect.signature work
1976        #  (see bpo #32108), for the time being just compare them with
1977        #  whitespace stripped.
1978        self.assertEqual(a.replace(' ', ''), b.replace(' ', ''))
1979
1980    def test_existing_docstring_not_overridden(self):
1981        @dataclass
1982        class C:
1983            """Lorem ipsum"""
1984            x: int
1985
1986        self.assertEqual(C.__doc__, "Lorem ipsum")
1987
1988    def test_docstring_no_fields(self):
1989        @dataclass
1990        class C:
1991            pass
1992
1993        self.assertDocStrEqual(C.__doc__, "C()")
1994
1995    def test_docstring_one_field(self):
1996        @dataclass
1997        class C:
1998            x: int
1999
2000        self.assertDocStrEqual(C.__doc__, "C(x:int)")
2001
2002    def test_docstring_two_fields(self):
2003        @dataclass
2004        class C:
2005            x: int
2006            y: int
2007
2008        self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)")
2009
2010    def test_docstring_three_fields(self):
2011        @dataclass
2012        class C:
2013            x: int
2014            y: int
2015            z: str
2016
2017        self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)")
2018
2019    def test_docstring_one_field_with_default(self):
2020        @dataclass
2021        class C:
2022            x: int = 3
2023
2024        self.assertDocStrEqual(C.__doc__, "C(x:int=3)")
2025
2026    def test_docstring_one_field_with_default_none(self):
2027        @dataclass
2028        class C:
2029            x: Union[int, type(None)] = None
2030
2031        self.assertDocStrEqual(C.__doc__, "C(x:Optional[int]=None)")
2032
2033    def test_docstring_list_field(self):
2034        @dataclass
2035        class C:
2036            x: List[int]
2037
2038        self.assertDocStrEqual(C.__doc__, "C(x:List[int])")
2039
2040    def test_docstring_list_field_with_default_factory(self):
2041        @dataclass
2042        class C:
2043            x: List[int] = field(default_factory=list)
2044
2045        self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)")
2046
2047    def test_docstring_deque_field(self):
2048        @dataclass
2049        class C:
2050            x: deque
2051
2052        self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)")
2053
2054    def test_docstring_deque_field_with_default_factory(self):
2055        @dataclass
2056        class C:
2057            x: deque = field(default_factory=deque)
2058
2059        self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
2060
2061
2062class TestInit(unittest.TestCase):
2063    def test_base_has_init(self):
2064        class B:
2065            def __init__(self):
2066                self.z = 100
2067                pass
2068
2069        # Make sure that declaring this class doesn't raise an error.
2070        #  The issue is that we can't override __init__ in our class,
2071        #  but it should be okay to add __init__ to us if our base has
2072        #  an __init__.
2073        @dataclass
2074        class C(B):
2075            x: int = 0
2076        c = C(10)
2077        self.assertEqual(c.x, 10)
2078        self.assertNotIn('z', vars(c))
2079
2080        # Make sure that if we don't add an init, the base __init__
2081        #  gets called.
2082        @dataclass(init=False)
2083        class C(B):
2084            x: int = 10
2085        c = C()
2086        self.assertEqual(c.x, 10)
2087        self.assertEqual(c.z, 100)
2088
2089    def test_no_init(self):
2090        dataclass(init=False)
2091        class C:
2092            i: int = 0
2093        self.assertEqual(C().i, 0)
2094
2095        dataclass(init=False)
2096        class C:
2097            i: int = 2
2098            def __init__(self):
2099                self.i = 3
2100        self.assertEqual(C().i, 3)
2101
2102    def test_overwriting_init(self):
2103        # If the class has __init__, use it no matter the value of
2104        #  init=.
2105
2106        @dataclass
2107        class C:
2108            x: int
2109            def __init__(self, x):
2110                self.x = 2 * x
2111        self.assertEqual(C(3).x, 6)
2112
2113        @dataclass(init=True)
2114        class C:
2115            x: int
2116            def __init__(self, x):
2117                self.x = 2 * x
2118        self.assertEqual(C(4).x, 8)
2119
2120        @dataclass(init=False)
2121        class C:
2122            x: int
2123            def __init__(self, x):
2124                self.x = 2 * x
2125        self.assertEqual(C(5).x, 10)
2126
2127
2128class TestRepr(unittest.TestCase):
2129    def test_repr(self):
2130        @dataclass
2131        class B:
2132            x: int
2133
2134        @dataclass
2135        class C(B):
2136            y: int = 10
2137
2138        o = C(4)
2139        self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)')
2140
2141        @dataclass
2142        class D(C):
2143            x: int = 20
2144        self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)')
2145
2146        @dataclass
2147        class C:
2148            @dataclass
2149            class D:
2150                i: int
2151            @dataclass
2152            class E:
2153                pass
2154        self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)')
2155        self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()')
2156
2157    def test_no_repr(self):
2158        # Test a class with no __repr__ and repr=False.
2159        @dataclass(repr=False)
2160        class C:
2161            x: int
2162        self.assertIn(f'{__name__}.TestRepr.test_no_repr.<locals>.C object at',
2163                      repr(C(3)))
2164
2165        # Test a class with a __repr__ and repr=False.
2166        @dataclass(repr=False)
2167        class C:
2168            x: int
2169            def __repr__(self):
2170                return 'C-class'
2171        self.assertEqual(repr(C(3)), 'C-class')
2172
2173    def test_overwriting_repr(self):
2174        # If the class has __repr__, use it no matter the value of
2175        #  repr=.
2176
2177        @dataclass
2178        class C:
2179            x: int
2180            def __repr__(self):
2181                return 'x'
2182        self.assertEqual(repr(C(0)), 'x')
2183
2184        @dataclass(repr=True)
2185        class C:
2186            x: int
2187            def __repr__(self):
2188                return 'x'
2189        self.assertEqual(repr(C(0)), 'x')
2190
2191        @dataclass(repr=False)
2192        class C:
2193            x: int
2194            def __repr__(self):
2195                return 'x'
2196        self.assertEqual(repr(C(0)), 'x')
2197
2198
2199class TestEq(unittest.TestCase):
2200    def test_no_eq(self):
2201        # Test a class with no __eq__ and eq=False.
2202        @dataclass(eq=False)
2203        class C:
2204            x: int
2205        self.assertNotEqual(C(0), C(0))
2206        c = C(3)
2207        self.assertEqual(c, c)
2208
2209        # Test a class with an __eq__ and eq=False.
2210        @dataclass(eq=False)
2211        class C:
2212            x: int
2213            def __eq__(self, other):
2214                return other == 10
2215        self.assertEqual(C(3), 10)
2216
2217    def test_overwriting_eq(self):
2218        # If the class has __eq__, use it no matter the value of
2219        #  eq=.
2220
2221        @dataclass
2222        class C:
2223            x: int
2224            def __eq__(self, other):
2225                return other == 3
2226        self.assertEqual(C(1), 3)
2227        self.assertNotEqual(C(1), 1)
2228
2229        @dataclass(eq=True)
2230        class C:
2231            x: int
2232            def __eq__(self, other):
2233                return other == 4
2234        self.assertEqual(C(1), 4)
2235        self.assertNotEqual(C(1), 1)
2236
2237        @dataclass(eq=False)
2238        class C:
2239            x: int
2240            def __eq__(self, other):
2241                return other == 5
2242        self.assertEqual(C(1), 5)
2243        self.assertNotEqual(C(1), 1)
2244
2245
2246class TestOrdering(unittest.TestCase):
2247    def test_functools_total_ordering(self):
2248        # Test that functools.total_ordering works with this class.
2249        @total_ordering
2250        @dataclass
2251        class C:
2252            x: int
2253            def __lt__(self, other):
2254                # Perform the test "backward", just to make
2255                #  sure this is being called.
2256                return self.x >= other
2257
2258        self.assertLess(C(0), -1)
2259        self.assertLessEqual(C(0), -1)
2260        self.assertGreater(C(0), 1)
2261        self.assertGreaterEqual(C(0), 1)
2262
2263    def test_no_order(self):
2264        # Test that no ordering functions are added by default.
2265        @dataclass(order=False)
2266        class C:
2267            x: int
2268        # Make sure no order methods are added.
2269        self.assertNotIn('__le__', C.__dict__)
2270        self.assertNotIn('__lt__', C.__dict__)
2271        self.assertNotIn('__ge__', C.__dict__)
2272        self.assertNotIn('__gt__', C.__dict__)
2273
2274        # Test that __lt__ is still called
2275        @dataclass(order=False)
2276        class C:
2277            x: int
2278            def __lt__(self, other):
2279                return False
2280        # Make sure other methods aren't added.
2281        self.assertNotIn('__le__', C.__dict__)
2282        self.assertNotIn('__ge__', C.__dict__)
2283        self.assertNotIn('__gt__', C.__dict__)
2284
2285    def test_overwriting_order(self):
2286        with self.assertRaisesRegex(TypeError,
2287                                    'Cannot overwrite attribute __lt__'
2288                                    '.*using functools.total_ordering'):
2289            @dataclass(order=True)
2290            class C:
2291                x: int
2292                def __lt__(self):
2293                    pass
2294
2295        with self.assertRaisesRegex(TypeError,
2296                                    'Cannot overwrite attribute __le__'
2297                                    '.*using functools.total_ordering'):
2298            @dataclass(order=True)
2299            class C:
2300                x: int
2301                def __le__(self):
2302                    pass
2303
2304        with self.assertRaisesRegex(TypeError,
2305                                    'Cannot overwrite attribute __gt__'
2306                                    '.*using functools.total_ordering'):
2307            @dataclass(order=True)
2308            class C:
2309                x: int
2310                def __gt__(self):
2311                    pass
2312
2313        with self.assertRaisesRegex(TypeError,
2314                                    'Cannot overwrite attribute __ge__'
2315                                    '.*using functools.total_ordering'):
2316            @dataclass(order=True)
2317            class C:
2318                x: int
2319                def __ge__(self):
2320                    pass
2321
2322class TestHash(unittest.TestCase):
2323    def test_unsafe_hash(self):
2324        @dataclass(unsafe_hash=True)
2325        class C:
2326            x: int
2327            y: str
2328        self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
2329
2330    def test_hash_rules(self):
2331        def non_bool(value):
2332            # Map to something else that's True, but not a bool.
2333            if value is None:
2334                return None
2335            if value:
2336                return (3,)
2337            return 0
2338
2339        def test(case, unsafe_hash, eq, frozen, with_hash, result):
2340            with self.subTest(case=case, unsafe_hash=unsafe_hash, eq=eq,
2341                              frozen=frozen):
2342                if result != 'exception':
2343                    if with_hash:
2344                        @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2345                        class C:
2346                            def __hash__(self):
2347                                return 0
2348                    else:
2349                        @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2350                        class C:
2351                            pass
2352
2353                # See if the result matches what's expected.
2354                if result == 'fn':
2355                    # __hash__ contains the function we generated.
2356                    self.assertIn('__hash__', C.__dict__)
2357                    self.assertIsNotNone(C.__dict__['__hash__'])
2358
2359                elif result == '':
2360                    # __hash__ is not present in our class.
2361                    if not with_hash:
2362                        self.assertNotIn('__hash__', C.__dict__)
2363
2364                elif result == 'none':
2365                    # __hash__ is set to None.
2366                    self.assertIn('__hash__', C.__dict__)
2367                    self.assertIsNone(C.__dict__['__hash__'])
2368
2369                elif result == 'exception':
2370                    # Creating the class should cause an exception.
2371                    #  This only happens with with_hash==True.
2372                    assert(with_hash)
2373                    with self.assertRaisesRegex(TypeError, 'Cannot overwrite attribute __hash__'):
2374                        @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2375                        class C:
2376                            def __hash__(self):
2377                                return 0
2378
2379                else:
2380                    assert False, f'unknown result {result!r}'
2381
2382        # There are 8 cases of:
2383        #  unsafe_hash=True/False
2384        #  eq=True/False
2385        #  frozen=True/False
2386        # And for each of these, a different result if
2387        #  __hash__ is defined or not.
2388        for case, (unsafe_hash,  eq,    frozen, res_no_defined_hash, res_defined_hash) in enumerate([
2389                  (False,        False, False,  '',                  ''),
2390                  (False,        False, True,   '',                  ''),
2391                  (False,        True,  False,  'none',              ''),
2392                  (False,        True,  True,   'fn',                ''),
2393                  (True,         False, False,  'fn',                'exception'),
2394                  (True,         False, True,   'fn',                'exception'),
2395                  (True,         True,  False,  'fn',                'exception'),
2396                  (True,         True,  True,   'fn',                'exception'),
2397                  ], 1):
2398            test(case, unsafe_hash, eq, frozen, False, res_no_defined_hash)
2399            test(case, unsafe_hash, eq, frozen, True,  res_defined_hash)
2400
2401            # Test non-bool truth values, too.  This is just to
2402            #  make sure the data-driven table in the decorator
2403            #  handles non-bool values.
2404            test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), False, res_no_defined_hash)
2405            test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), True,  res_defined_hash)
2406
2407
2408    def test_eq_only(self):
2409        # If a class defines __eq__, __hash__ is automatically added
2410        #  and set to None.  This is normal Python behavior, not
2411        #  related to dataclasses.  Make sure we don't interfere with
2412        #  that (see bpo=32546).
2413
2414        @dataclass
2415        class C:
2416            i: int
2417            def __eq__(self, other):
2418                return self.i == other.i
2419        self.assertEqual(C(1), C(1))
2420        self.assertNotEqual(C(1), C(4))
2421
2422        # And make sure things work in this case if we specify
2423        #  unsafe_hash=True.
2424        @dataclass(unsafe_hash=True)
2425        class C:
2426            i: int
2427            def __eq__(self, other):
2428                return self.i == other.i
2429        self.assertEqual(C(1), C(1.0))
2430        self.assertEqual(hash(C(1)), hash(C(1.0)))
2431
2432        # And check that the classes __eq__ is being used, despite
2433        #  specifying eq=True.
2434        @dataclass(unsafe_hash=True, eq=True)
2435        class C:
2436            i: int
2437            def __eq__(self, other):
2438                return self.i == 3 and self.i == other.i
2439        self.assertEqual(C(3), C(3))
2440        self.assertNotEqual(C(1), C(1))
2441        self.assertEqual(hash(C(1)), hash(C(1.0)))
2442
2443    def test_0_field_hash(self):
2444        @dataclass(frozen=True)
2445        class C:
2446            pass
2447        self.assertEqual(hash(C()), hash(()))
2448
2449        @dataclass(unsafe_hash=True)
2450        class C:
2451            pass
2452        self.assertEqual(hash(C()), hash(()))
2453
2454    def test_1_field_hash(self):
2455        @dataclass(frozen=True)
2456        class C:
2457            x: int
2458        self.assertEqual(hash(C(4)), hash((4,)))
2459        self.assertEqual(hash(C(42)), hash((42,)))
2460
2461        @dataclass(unsafe_hash=True)
2462        class C:
2463            x: int
2464        self.assertEqual(hash(C(4)), hash((4,)))
2465        self.assertEqual(hash(C(42)), hash((42,)))
2466
2467    def test_hash_no_args(self):
2468        # Test dataclasses with no hash= argument.  This exists to
2469        #  make sure that if the @dataclass parameter name is changed
2470        #  or the non-default hashing behavior changes, the default
2471        #  hashability keeps working the same way.
2472
2473        class Base:
2474            def __hash__(self):
2475                return 301
2476
2477        # If frozen or eq is None, then use the default value (do not
2478        #  specify any value in the decorator).
2479        for frozen, eq,    base,   expected       in [
2480            (None,  None,  object, 'unhashable'),
2481            (None,  None,  Base,   'unhashable'),
2482            (None,  False, object, 'object'),
2483            (None,  False, Base,   'base'),
2484            (None,  True,  object, 'unhashable'),
2485            (None,  True,  Base,   'unhashable'),
2486            (False, None,  object, 'unhashable'),
2487            (False, None,  Base,   'unhashable'),
2488            (False, False, object, 'object'),
2489            (False, False, Base,   'base'),
2490            (False, True,  object, 'unhashable'),
2491            (False, True,  Base,   'unhashable'),
2492            (True,  None,  object, 'tuple'),
2493            (True,  None,  Base,   'tuple'),
2494            (True,  False, object, 'object'),
2495            (True,  False, Base,   'base'),
2496            (True,  True,  object, 'tuple'),
2497            (True,  True,  Base,   'tuple'),
2498            ]:
2499
2500            with self.subTest(frozen=frozen, eq=eq, base=base, expected=expected):
2501                # First, create the class.
2502                if frozen is None and eq is None:
2503                    @dataclass
2504                    class C(base):
2505                        i: int
2506                elif frozen is None:
2507                    @dataclass(eq=eq)
2508                    class C(base):
2509                        i: int
2510                elif eq is None:
2511                    @dataclass(frozen=frozen)
2512                    class C(base):
2513                        i: int
2514                else:
2515                    @dataclass(frozen=frozen, eq=eq)
2516                    class C(base):
2517                        i: int
2518
2519                # Now, make sure it hashes as expected.
2520                if expected == 'unhashable':
2521                    c = C(10)
2522                    with self.assertRaisesRegex(TypeError, 'unhashable type'):
2523                        hash(c)
2524
2525                elif expected == 'base':
2526                    self.assertEqual(hash(C(10)), 301)
2527
2528                elif expected == 'object':
2529                    # I'm not sure what test to use here.  object's
2530                    #  hash isn't based on id(), so calling hash()
2531                    #  won't tell us much.  So, just check the
2532                    #  function used is object's.
2533                    self.assertIs(C.__hash__, object.__hash__)
2534
2535                elif expected == 'tuple':
2536                    self.assertEqual(hash(C(42)), hash((42,)))
2537
2538                else:
2539                    assert False, f'unknown value for expected={expected!r}'
2540
2541
2542class TestFrozen(unittest.TestCase):
2543    def test_frozen(self):
2544        @dataclass(frozen=True)
2545        class C:
2546            i: int
2547
2548        c = C(10)
2549        self.assertEqual(c.i, 10)
2550        with self.assertRaises(FrozenInstanceError):
2551            c.i = 5
2552        self.assertEqual(c.i, 10)
2553
2554    def test_inherit(self):
2555        @dataclass(frozen=True)
2556        class C:
2557            i: int
2558
2559        @dataclass(frozen=True)
2560        class D(C):
2561            j: int
2562
2563        d = D(0, 10)
2564        with self.assertRaises(FrozenInstanceError):
2565            d.i = 5
2566        with self.assertRaises(FrozenInstanceError):
2567            d.j = 6
2568        self.assertEqual(d.i, 0)
2569        self.assertEqual(d.j, 10)
2570
2571    # Test both ways: with an intermediate normal (non-dataclass)
2572    #  class and without an intermediate class.
2573    def test_inherit_nonfrozen_from_frozen(self):
2574        for intermediate_class in [True, False]:
2575            with self.subTest(intermediate_class=intermediate_class):
2576                @dataclass(frozen=True)
2577                class C:
2578                    i: int
2579
2580                if intermediate_class:
2581                    class I(C): pass
2582                else:
2583                    I = C
2584
2585                with self.assertRaisesRegex(TypeError,
2586                                            'cannot inherit non-frozen dataclass from a frozen one'):
2587                    @dataclass
2588                    class D(I):
2589                        pass
2590
2591    def test_inherit_frozen_from_nonfrozen(self):
2592        for intermediate_class in [True, False]:
2593            with self.subTest(intermediate_class=intermediate_class):
2594                @dataclass
2595                class C:
2596                    i: int
2597
2598                if intermediate_class:
2599                    class I(C): pass
2600                else:
2601                    I = C
2602
2603                with self.assertRaisesRegex(TypeError,
2604                                            'cannot inherit frozen dataclass from a non-frozen one'):
2605                    @dataclass(frozen=True)
2606                    class D(I):
2607                        pass
2608
2609    def test_inherit_from_normal_class(self):
2610        for intermediate_class in [True, False]:
2611            with self.subTest(intermediate_class=intermediate_class):
2612                class C:
2613                    pass
2614
2615                if intermediate_class:
2616                    class I(C): pass
2617                else:
2618                    I = C
2619
2620                @dataclass(frozen=True)
2621                class D(I):
2622                    i: int
2623
2624            d = D(10)
2625            with self.assertRaises(FrozenInstanceError):
2626                d.i = 5
2627
2628    def test_non_frozen_normal_derived(self):
2629        # See bpo-32953.
2630
2631        @dataclass(frozen=True)
2632        class D:
2633            x: int
2634            y: int = 10
2635
2636        class S(D):
2637            pass
2638
2639        s = S(3)
2640        self.assertEqual(s.x, 3)
2641        self.assertEqual(s.y, 10)
2642        s.cached = True
2643
2644        # But can't change the frozen attributes.
2645        with self.assertRaises(FrozenInstanceError):
2646            s.x = 5
2647        with self.assertRaises(FrozenInstanceError):
2648            s.y = 5
2649        self.assertEqual(s.x, 3)
2650        self.assertEqual(s.y, 10)
2651        self.assertEqual(s.cached, True)
2652
2653    def test_overwriting_frozen(self):
2654        # frozen uses __setattr__ and __delattr__.
2655        with self.assertRaisesRegex(TypeError,
2656                                    'Cannot overwrite attribute __setattr__'):
2657            @dataclass(frozen=True)
2658            class C:
2659                x: int
2660                def __setattr__(self):
2661                    pass
2662
2663        with self.assertRaisesRegex(TypeError,
2664                                    'Cannot overwrite attribute __delattr__'):
2665            @dataclass(frozen=True)
2666            class C:
2667                x: int
2668                def __delattr__(self):
2669                    pass
2670
2671        @dataclass(frozen=False)
2672        class C:
2673            x: int
2674            def __setattr__(self, name, value):
2675                self.__dict__['x'] = value * 2
2676        self.assertEqual(C(10).x, 20)
2677
2678    def test_frozen_hash(self):
2679        @dataclass(frozen=True)
2680        class C:
2681            x: Any
2682
2683        # If x is immutable, we can compute the hash.  No exception is
2684        # raised.
2685        hash(C(3))
2686
2687        # If x is mutable, computing the hash is an error.
2688        with self.assertRaisesRegex(TypeError, 'unhashable type'):
2689            hash(C({}))
2690
2691
2692class TestSlots(unittest.TestCase):
2693    def test_simple(self):
2694        @dataclass
2695        class C:
2696            __slots__ = ('x',)
2697            x: Any
2698
2699        # There was a bug where a variable in a slot was assumed to
2700        #  also have a default value (of type
2701        #  types.MemberDescriptorType).
2702        with self.assertRaisesRegex(TypeError,
2703                                    r"__init__\(\) missing 1 required positional argument: 'x'"):
2704            C()
2705
2706        # We can create an instance, and assign to x.
2707        c = C(10)
2708        self.assertEqual(c.x, 10)
2709        c.x = 5
2710        self.assertEqual(c.x, 5)
2711
2712        # We can't assign to anything else.
2713        with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'y'"):
2714            c.y = 5
2715
2716    def test_derived_added_field(self):
2717        # See bpo-33100.
2718        @dataclass
2719        class Base:
2720            __slots__ = ('x',)
2721            x: Any
2722
2723        @dataclass
2724        class Derived(Base):
2725            x: int
2726            y: int
2727
2728        d = Derived(1, 2)
2729        self.assertEqual((d.x, d.y), (1, 2))
2730
2731        # We can add a new field to the derived instance.
2732        d.z = 10
2733
2734class TestDescriptors(unittest.TestCase):
2735    def test_set_name(self):
2736        # See bpo-33141.
2737
2738        # Create a descriptor.
2739        class D:
2740            def __set_name__(self, owner, name):
2741                self.name = name + 'x'
2742            def __get__(self, instance, owner):
2743                if instance is not None:
2744                    return 1
2745                return self
2746
2747        # This is the case of just normal descriptor behavior, no
2748        #  dataclass code is involved in initializing the descriptor.
2749        @dataclass
2750        class C:
2751            c: int=D()
2752        self.assertEqual(C.c.name, 'cx')
2753
2754        # Now test with a default value and init=False, which is the
2755        #  only time this is really meaningful.  If not using
2756        #  init=False, then the descriptor will be overwritten, anyway.
2757        @dataclass
2758        class C:
2759            c: int=field(default=D(), init=False)
2760        self.assertEqual(C.c.name, 'cx')
2761        self.assertEqual(C().c, 1)
2762
2763    def test_non_descriptor(self):
2764        # PEP 487 says __set_name__ should work on non-descriptors.
2765        # Create a descriptor.
2766
2767        class D:
2768            def __set_name__(self, owner, name):
2769                self.name = name + 'x'
2770
2771        @dataclass
2772        class C:
2773            c: int=field(default=D(), init=False)
2774        self.assertEqual(C.c.name, 'cx')
2775
2776    def test_lookup_on_instance(self):
2777        # See bpo-33175.
2778        class D:
2779            pass
2780
2781        d = D()
2782        # Create an attribute on the instance, not type.
2783        d.__set_name__ = Mock()
2784
2785        # Make sure d.__set_name__ is not called.
2786        @dataclass
2787        class C:
2788            i: int=field(default=d, init=False)
2789
2790        self.assertEqual(d.__set_name__.call_count, 0)
2791
2792    def test_lookup_on_class(self):
2793        # See bpo-33175.
2794        class D:
2795            pass
2796        D.__set_name__ = Mock()
2797
2798        # Make sure D.__set_name__ is called.
2799        @dataclass
2800        class C:
2801            i: int=field(default=D(), init=False)
2802
2803        self.assertEqual(D.__set_name__.call_count, 1)
2804
2805
2806class TestStringAnnotations(unittest.TestCase):
2807    def test_classvar(self):
2808        # Some expressions recognized as ClassVar really aren't.  But
2809        #  if you're using string annotations, it's not an exact
2810        #  science.
2811        # These tests assume that both "import typing" and "from
2812        # typing import *" have been run in this file.
2813        for typestr in ('ClassVar[int]',
2814                        'ClassVar [int]'
2815                        ' ClassVar [int]',
2816                        'ClassVar',
2817                        ' ClassVar ',
2818                        'typing.ClassVar[int]',
2819                        'typing.ClassVar[str]',
2820                        ' typing.ClassVar[str]',
2821                        'typing .ClassVar[str]',
2822                        'typing. ClassVar[str]',
2823                        'typing.ClassVar [str]',
2824                        'typing.ClassVar [ str]',
2825
2826                        # Not syntactically valid, but these will
2827                        #  be treated as ClassVars.
2828                        'typing.ClassVar.[int]',
2829                        'typing.ClassVar+',
2830                        ):
2831            with self.subTest(typestr=typestr):
2832                @dataclass
2833                class C:
2834                    x: typestr
2835
2836                # x is a ClassVar, so C() takes no args.
2837                C()
2838
2839                # And it won't appear in the class's dict because it doesn't
2840                # have a default.
2841                self.assertNotIn('x', C.__dict__)
2842
2843    def test_isnt_classvar(self):
2844        for typestr in ('CV',
2845                        't.ClassVar',
2846                        't.ClassVar[int]',
2847                        'typing..ClassVar[int]',
2848                        'Classvar',
2849                        'Classvar[int]',
2850                        'typing.ClassVarx[int]',
2851                        'typong.ClassVar[int]',
2852                        'dataclasses.ClassVar[int]',
2853                        'typingxClassVar[str]',
2854                        ):
2855            with self.subTest(typestr=typestr):
2856                @dataclass
2857                class C:
2858                    x: typestr
2859
2860                # x is not a ClassVar, so C() takes one arg.
2861                self.assertEqual(C(10).x, 10)
2862
2863    def test_initvar(self):
2864        # These tests assume that both "import dataclasses" and "from
2865        #  dataclasses import *" have been run in this file.
2866        for typestr in ('InitVar[int]',
2867                        'InitVar [int]'
2868                        ' InitVar [int]',
2869                        'InitVar',
2870                        ' InitVar ',
2871                        'dataclasses.InitVar[int]',
2872                        'dataclasses.InitVar[str]',
2873                        ' dataclasses.InitVar[str]',
2874                        'dataclasses .InitVar[str]',
2875                        'dataclasses. InitVar[str]',
2876                        'dataclasses.InitVar [str]',
2877                        'dataclasses.InitVar [ str]',
2878
2879                        # Not syntactically valid, but these will
2880                        #  be treated as InitVars.
2881                        'dataclasses.InitVar.[int]',
2882                        'dataclasses.InitVar+',
2883                        ):
2884            with self.subTest(typestr=typestr):
2885                @dataclass
2886                class C:
2887                    x: typestr
2888
2889                # x is an InitVar, so doesn't create a member.
2890                with self.assertRaisesRegex(AttributeError,
2891                                            "object has no attribute 'x'"):
2892                    C(1).x
2893
2894    def test_isnt_initvar(self):
2895        for typestr in ('IV',
2896                        'dc.InitVar',
2897                        'xdataclasses.xInitVar',
2898                        'typing.xInitVar[int]',
2899                        ):
2900            with self.subTest(typestr=typestr):
2901                @dataclass
2902                class C:
2903                    x: typestr
2904
2905                # x is not an InitVar, so there will be a member x.
2906                self.assertEqual(C(10).x, 10)
2907
2908    def test_classvar_module_level_import(self):
2909        from test import dataclass_module_1
2910        from test import dataclass_module_1_str
2911        from test import dataclass_module_2
2912        from test import dataclass_module_2_str
2913
2914        for m in (dataclass_module_1, dataclass_module_1_str,
2915                  dataclass_module_2, dataclass_module_2_str,
2916                  ):
2917            with self.subTest(m=m):
2918                # There's a difference in how the ClassVars are
2919                # interpreted when using string annotations or
2920                # not. See the imported modules for details.
2921                if m.USING_STRINGS:
2922                    c = m.CV(10)
2923                else:
2924                    c = m.CV()
2925                self.assertEqual(c.cv0, 20)
2926
2927
2928                # There's a difference in how the InitVars are
2929                # interpreted when using string annotations or
2930                # not. See the imported modules for details.
2931                c = m.IV(0, 1, 2, 3, 4)
2932
2933                for field_name in ('iv0', 'iv1', 'iv2', 'iv3'):
2934                    with self.subTest(field_name=field_name):
2935                        with self.assertRaisesRegex(AttributeError, f"object has no attribute '{field_name}'"):
2936                            # Since field_name is an InitVar, it's
2937                            # not an instance field.
2938                            getattr(c, field_name)
2939
2940                if m.USING_STRINGS:
2941                    # iv4 is interpreted as a normal field.
2942                    self.assertIn('not_iv4', c.__dict__)
2943                    self.assertEqual(c.not_iv4, 4)
2944                else:
2945                    # iv4 is interpreted as an InitVar, so it
2946                    # won't exist on the instance.
2947                    self.assertNotIn('not_iv4', c.__dict__)
2948
2949    def test_text_annotations(self):
2950        from test import dataclass_textanno
2951
2952        self.assertEqual(
2953            get_type_hints(dataclass_textanno.Bar),
2954            {'foo': dataclass_textanno.Foo})
2955        self.assertEqual(
2956            get_type_hints(dataclass_textanno.Bar.__init__),
2957            {'foo': dataclass_textanno.Foo,
2958             'return': type(None)})
2959
2960
2961class TestMakeDataclass(unittest.TestCase):
2962    def test_simple(self):
2963        C = make_dataclass('C',
2964                           [('x', int),
2965                            ('y', int, field(default=5))],
2966                           namespace={'add_one': lambda self: self.x + 1})
2967        c = C(10)
2968        self.assertEqual((c.x, c.y), (10, 5))
2969        self.assertEqual(c.add_one(), 11)
2970
2971
2972    def test_no_mutate_namespace(self):
2973        # Make sure a provided namespace isn't mutated.
2974        ns = {}
2975        C = make_dataclass('C',
2976                           [('x', int),
2977                            ('y', int, field(default=5))],
2978                           namespace=ns)
2979        self.assertEqual(ns, {})
2980
2981    def test_base(self):
2982        class Base1:
2983            pass
2984        class Base2:
2985            pass
2986        C = make_dataclass('C',
2987                           [('x', int)],
2988                           bases=(Base1, Base2))
2989        c = C(2)
2990        self.assertIsInstance(c, C)
2991        self.assertIsInstance(c, Base1)
2992        self.assertIsInstance(c, Base2)
2993
2994    def test_base_dataclass(self):
2995        @dataclass
2996        class Base1:
2997            x: int
2998        class Base2:
2999            pass
3000        C = make_dataclass('C',
3001                           [('y', int)],
3002                           bases=(Base1, Base2))
3003        with self.assertRaisesRegex(TypeError, 'required positional'):
3004            c = C(2)
3005        c = C(1, 2)
3006        self.assertIsInstance(c, C)
3007        self.assertIsInstance(c, Base1)
3008        self.assertIsInstance(c, Base2)
3009
3010        self.assertEqual((c.x, c.y), (1, 2))
3011
3012    def test_init_var(self):
3013        def post_init(self, y):
3014            self.x *= y
3015
3016        C = make_dataclass('C',
3017                           [('x', int),
3018                            ('y', InitVar[int]),
3019                            ],
3020                           namespace={'__post_init__': post_init},
3021                           )
3022        c = C(2, 3)
3023        self.assertEqual(vars(c), {'x': 6})
3024        self.assertEqual(len(fields(c)), 1)
3025
3026    def test_class_var(self):
3027        C = make_dataclass('C',
3028                           [('x', int),
3029                            ('y', ClassVar[int], 10),
3030                            ('z', ClassVar[int], field(default=20)),
3031                            ])
3032        c = C(1)
3033        self.assertEqual(vars(c), {'x': 1})
3034        self.assertEqual(len(fields(c)), 1)
3035        self.assertEqual(C.y, 10)
3036        self.assertEqual(C.z, 20)
3037
3038    def test_other_params(self):
3039        C = make_dataclass('C',
3040                           [('x', int),
3041                            ('y', ClassVar[int], 10),
3042                            ('z', ClassVar[int], field(default=20)),
3043                            ],
3044                           init=False)
3045        # Make sure we have a repr, but no init.
3046        self.assertNotIn('__init__', vars(C))
3047        self.assertIn('__repr__', vars(C))
3048
3049        # Make sure random other params don't work.
3050        with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
3051            C = make_dataclass('C',
3052                               [],
3053                               xxinit=False)
3054
3055    def test_no_types(self):
3056        C = make_dataclass('Point', ['x', 'y', 'z'])
3057        c = C(1, 2, 3)
3058        self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
3059        self.assertEqual(C.__annotations__, {'x': 'typing.Any',
3060                                             'y': 'typing.Any',
3061                                             'z': 'typing.Any'})
3062
3063        C = make_dataclass('Point', ['x', ('y', int), 'z'])
3064        c = C(1, 2, 3)
3065        self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
3066        self.assertEqual(C.__annotations__, {'x': 'typing.Any',
3067                                             'y': int,
3068                                             'z': 'typing.Any'})
3069
3070    def test_invalid_type_specification(self):
3071        for bad_field in [(),
3072                          (1, 2, 3, 4),
3073                          ]:
3074            with self.subTest(bad_field=bad_field):
3075                with self.assertRaisesRegex(TypeError, r'Invalid field: '):
3076                    make_dataclass('C', ['a', bad_field])
3077
3078        # And test for things with no len().
3079        for bad_field in [float,
3080                          lambda x:x,
3081                          ]:
3082            with self.subTest(bad_field=bad_field):
3083                with self.assertRaisesRegex(TypeError, r'has no len\(\)'):
3084                    make_dataclass('C', ['a', bad_field])
3085
3086    def test_duplicate_field_names(self):
3087        for field in ['a', 'ab']:
3088            with self.subTest(field=field):
3089                with self.assertRaisesRegex(TypeError, 'Field name duplicated'):
3090                    make_dataclass('C', [field, 'a', field])
3091
3092    def test_keyword_field_names(self):
3093        for field in ['for', 'async', 'await', 'as']:
3094            with self.subTest(field=field):
3095                with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3096                    make_dataclass('C', ['a', field])
3097                with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3098                    make_dataclass('C', [field])
3099                with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3100                    make_dataclass('C', [field, 'a'])
3101
3102    def test_non_identifier_field_names(self):
3103        for field in ['()', 'x,y', '*', '2@3', '', 'little johnny tables']:
3104            with self.subTest(field=field):
3105                with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
3106                    make_dataclass('C', ['a', field])
3107                with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
3108                    make_dataclass('C', [field])
3109                with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
3110                    make_dataclass('C', [field, 'a'])
3111
3112    def test_underscore_field_names(self):
3113        # Unlike namedtuple, it's okay if dataclass field names have
3114        # an underscore.
3115        make_dataclass('C', ['_', '_a', 'a_a', 'a_'])
3116
3117    def test_funny_class_names_names(self):
3118        # No reason to prevent weird class names, since
3119        # types.new_class allows them.
3120        for classname in ['()', 'x,y', '*', '2@3', '']:
3121            with self.subTest(classname=classname):
3122                C = make_dataclass(classname, ['a', 'b'])
3123                self.assertEqual(C.__name__, classname)
3124
3125class TestReplace(unittest.TestCase):
3126    def test(self):
3127        @dataclass(frozen=True)
3128        class C:
3129            x: int
3130            y: int
3131
3132        c = C(1, 2)
3133        c1 = replace(c, x=3)
3134        self.assertEqual(c1.x, 3)
3135        self.assertEqual(c1.y, 2)
3136
3137    def test_frozen(self):
3138        @dataclass(frozen=True)
3139        class C:
3140            x: int
3141            y: int
3142            z: int = field(init=False, default=10)
3143            t: int = field(init=False, default=100)
3144
3145        c = C(1, 2)
3146        c1 = replace(c, x=3)
3147        self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100))
3148        self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100))
3149
3150
3151        with self.assertRaisesRegex(ValueError, 'init=False'):
3152            replace(c, x=3, z=20, t=50)
3153        with self.assertRaisesRegex(ValueError, 'init=False'):
3154            replace(c, z=20)
3155            replace(c, x=3, z=20, t=50)
3156
3157        # Make sure the result is still frozen.
3158        with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"):
3159            c1.x = 3
3160
3161        # Make sure we can't replace an attribute that doesn't exist,
3162        #  if we're also replacing one that does exist.  Test this
3163        #  here, because setting attributes on frozen instances is
3164        #  handled slightly differently from non-frozen ones.
3165        with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
3166                                             "keyword argument 'a'"):
3167            c1 = replace(c, x=20, a=5)
3168
3169    def test_invalid_field_name(self):
3170        @dataclass(frozen=True)
3171        class C:
3172            x: int
3173            y: int
3174
3175        c = C(1, 2)
3176        with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
3177                                    "keyword argument 'z'"):
3178            c1 = replace(c, z=3)
3179
3180    def test_invalid_object(self):
3181        @dataclass(frozen=True)
3182        class C:
3183            x: int
3184            y: int
3185
3186        with self.assertRaisesRegex(TypeError, 'dataclass instance'):
3187            replace(C, x=3)
3188
3189        with self.assertRaisesRegex(TypeError, 'dataclass instance'):
3190            replace(0, x=3)
3191
3192    def test_no_init(self):
3193        @dataclass
3194        class C:
3195            x: int
3196            y: int = field(init=False, default=10)
3197
3198        c = C(1)
3199        c.y = 20
3200
3201        # Make sure y gets the default value.
3202        c1 = replace(c, x=5)
3203        self.assertEqual((c1.x, c1.y), (5, 10))
3204
3205        # Trying to replace y is an error.
3206        with self.assertRaisesRegex(ValueError, 'init=False'):
3207            replace(c, x=2, y=30)
3208
3209        with self.assertRaisesRegex(ValueError, 'init=False'):
3210            replace(c, y=30)
3211
3212    def test_classvar(self):
3213        @dataclass
3214        class C:
3215            x: int
3216            y: ClassVar[int] = 1000
3217
3218        c = C(1)
3219        d = C(2)
3220
3221        self.assertIs(c.y, d.y)
3222        self.assertEqual(c.y, 1000)
3223
3224        # Trying to replace y is an error: can't replace ClassVars.
3225        with self.assertRaisesRegex(TypeError, r"__init__\(\) got an "
3226                                    "unexpected keyword argument 'y'"):
3227            replace(c, y=30)
3228
3229        replace(c, x=5)
3230
3231    def test_initvar_is_specified(self):
3232        @dataclass
3233        class C:
3234            x: int
3235            y: InitVar[int]
3236
3237            def __post_init__(self, y):
3238                self.x *= y
3239
3240        c = C(1, 10)
3241        self.assertEqual(c.x, 10)
3242        with self.assertRaisesRegex(ValueError, r"InitVar 'y' must be "
3243                                    "specified with replace()"):
3244            replace(c, x=3)
3245        c = replace(c, x=3, y=5)
3246        self.assertEqual(c.x, 15)
3247
3248    def test_recursive_repr(self):
3249        @dataclass
3250        class C:
3251            f: "C"
3252
3253        c = C(None)
3254        c.f = c
3255        self.assertEqual(repr(c), "TestReplace.test_recursive_repr.<locals>.C(f=...)")
3256
3257    def test_recursive_repr_two_attrs(self):
3258        @dataclass
3259        class C:
3260            f: "C"
3261            g: "C"
3262
3263        c = C(None, None)
3264        c.f = c
3265        c.g = c
3266        self.assertEqual(repr(c), "TestReplace.test_recursive_repr_two_attrs"
3267                                  ".<locals>.C(f=..., g=...)")
3268
3269    def test_recursive_repr_indirection(self):
3270        @dataclass
3271        class C:
3272            f: "D"
3273
3274        @dataclass
3275        class D:
3276            f: "C"
3277
3278        c = C(None)
3279        d = D(None)
3280        c.f = d
3281        d.f = c
3282        self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection"
3283                                  ".<locals>.C(f=TestReplace.test_recursive_repr_indirection"
3284                                  ".<locals>.D(f=...))")
3285
3286    def test_recursive_repr_indirection_two(self):
3287        @dataclass
3288        class C:
3289            f: "D"
3290
3291        @dataclass
3292        class D:
3293            f: "E"
3294
3295        @dataclass
3296        class E:
3297            f: "C"
3298
3299        c = C(None)
3300        d = D(None)
3301        e = E(None)
3302        c.f = d
3303        d.f = e
3304        e.f = c
3305        self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection_two"
3306                                  ".<locals>.C(f=TestReplace.test_recursive_repr_indirection_two"
3307                                  ".<locals>.D(f=TestReplace.test_recursive_repr_indirection_two"
3308                                  ".<locals>.E(f=...)))")
3309
3310    def test_recursive_repr_misc_attrs(self):
3311        @dataclass
3312        class C:
3313            f: "C"
3314            g: int
3315
3316        c = C(None, 1)
3317        c.f = c
3318        self.assertEqual(repr(c), "TestReplace.test_recursive_repr_misc_attrs"
3319                                  ".<locals>.C(f=..., g=1)")
3320
3321    ## def test_initvar(self):
3322    ##     @dataclass
3323    ##     class C:
3324    ##         x: int
3325    ##         y: InitVar[int]
3326
3327    ##     c = C(1, 10)
3328    ##     d = C(2, 20)
3329
3330    ##     # In our case, replacing an InitVar is a no-op
3331    ##     self.assertEqual(c, replace(c, y=5))
3332
3333    ##     replace(c, x=5)
3334
3335
3336if __name__ == '__main__':
3337    unittest.main()
3338