1# Deliberately use "from dataclasses import *". Every name in __all__ 2# is tested, so they all must be present. This is a way to catch 3# missing ones. 4 5from dataclasses import * 6 7import abc 8import io 9import pickle 10import inspect 11import builtins 12import types 13import weakref 14import traceback 15import unittest 16from unittest.mock import Mock 17from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional, Protocol, DefaultDict 18from typing import get_type_hints 19from collections import deque, OrderedDict, namedtuple, defaultdict 20from copy import deepcopy 21from functools import total_ordering 22 23import typing # Needed for the string "typing.ClassVar[int]" to work as an annotation. 24import dataclasses # Needed for the string "dataclasses.InitVar[int]" to work as an annotation. 25 26from test import support 27 28# Just any custom exception we can catch. 29class CustomError(Exception): pass 30 31class TestCase(unittest.TestCase): 32 def test_no_fields(self): 33 @dataclass 34 class C: 35 pass 36 37 o = C() 38 self.assertEqual(len(fields(C)), 0) 39 40 def test_no_fields_but_member_variable(self): 41 @dataclass 42 class C: 43 i = 0 44 45 o = C() 46 self.assertEqual(len(fields(C)), 0) 47 48 def test_one_field_no_default(self): 49 @dataclass 50 class C: 51 x: int 52 53 o = C(42) 54 self.assertEqual(o.x, 42) 55 56 def test_field_default_default_factory_error(self): 57 msg = "cannot specify both default and default_factory" 58 with self.assertRaisesRegex(ValueError, msg): 59 @dataclass 60 class C: 61 x: int = field(default=1, default_factory=int) 62 63 def test_field_repr(self): 64 int_field = field(default=1, init=True, repr=False) 65 int_field.name = "id" 66 repr_output = repr(int_field) 67 expected_output = "Field(name='id',type=None," \ 68 f"default=1,default_factory={MISSING!r}," \ 69 "init=True,repr=False,hash=None," \ 70 "compare=True,metadata=mappingproxy({})," \ 71 f"kw_only={MISSING!r}," \ 72 "_field_type=None)" 73 74 self.assertEqual(repr_output, expected_output) 75 76 def test_field_recursive_repr(self): 77 rec_field = field() 78 rec_field.type = rec_field 79 rec_field.name = "id" 80 repr_output = repr(rec_field) 81 82 self.assertIn(",type=...,", repr_output) 83 84 def test_recursive_annotation(self): 85 class C: 86 pass 87 88 @dataclass 89 class D: 90 C: C = field() 91 92 self.assertIn(",type=...,", repr(D.__dataclass_fields__["C"])) 93 94 def test_dataclass_params_repr(self): 95 # Even though this is testing an internal implementation detail, 96 # it's testing a feature we want to make sure is correctly implemented 97 # for the sake of dataclasses itself 98 @dataclass(slots=True, frozen=True) 99 class Some: pass 100 101 repr_output = repr(Some.__dataclass_params__) 102 expected_output = "_DataclassParams(init=True,repr=True," \ 103 "eq=True,order=False,unsafe_hash=False,frozen=True," \ 104 "match_args=True,kw_only=False," \ 105 "slots=True,weakref_slot=False)" 106 self.assertEqual(repr_output, expected_output) 107 108 def test_dataclass_params_signature(self): 109 # Even though this is testing an internal implementation detail, 110 # it's testing a feature we want to make sure is correctly implemented 111 # for the sake of dataclasses itself 112 @dataclass 113 class Some: pass 114 115 for param in inspect.signature(dataclass).parameters: 116 if param == 'cls': 117 continue 118 self.assertTrue(hasattr(Some.__dataclass_params__, param), msg=param) 119 120 def test_named_init_params(self): 121 @dataclass 122 class C: 123 x: int 124 125 o = C(x=32) 126 self.assertEqual(o.x, 32) 127 128 def test_two_fields_one_default(self): 129 @dataclass 130 class C: 131 x: int 132 y: int = 0 133 134 o = C(3) 135 self.assertEqual((o.x, o.y), (3, 0)) 136 137 # Non-defaults following defaults. 138 with self.assertRaisesRegex(TypeError, 139 "non-default argument 'y' follows " 140 "default argument 'x'"): 141 @dataclass 142 class C: 143 x: int = 0 144 y: int 145 146 # A derived class adds a non-default field after a default one. 147 with self.assertRaisesRegex(TypeError, 148 "non-default argument 'y' follows " 149 "default argument 'x'"): 150 @dataclass 151 class B: 152 x: int = 0 153 154 @dataclass 155 class C(B): 156 y: int 157 158 # Override a base class field and add a default to 159 # a field which didn't use to have a default. 160 with self.assertRaisesRegex(TypeError, 161 "non-default argument 'y' follows " 162 "default argument 'x'"): 163 @dataclass 164 class B: 165 x: int 166 y: int 167 168 @dataclass 169 class C(B): 170 x: int = 0 171 172 def test_overwrite_hash(self): 173 # Test that declaring this class isn't an error. It should 174 # use the user-provided __hash__. 175 @dataclass(frozen=True) 176 class C: 177 x: int 178 def __hash__(self): 179 return 301 180 self.assertEqual(hash(C(100)), 301) 181 182 # Test that declaring this class isn't an error. It should 183 # use the generated __hash__. 184 @dataclass(frozen=True) 185 class C: 186 x: int 187 def __eq__(self, other): 188 return False 189 self.assertEqual(hash(C(100)), hash((100,))) 190 191 # But this one should generate an exception, because with 192 # unsafe_hash=True, it's an error to have a __hash__ defined. 193 with self.assertRaisesRegex(TypeError, 194 'Cannot overwrite attribute __hash__'): 195 @dataclass(unsafe_hash=True) 196 class C: 197 def __hash__(self): 198 pass 199 200 # Creating this class should not generate an exception, 201 # because even though __hash__ exists before @dataclass is 202 # called, (due to __eq__ being defined), since it's None 203 # that's okay. 204 @dataclass(unsafe_hash=True) 205 class C: 206 x: int 207 def __eq__(self): 208 pass 209 # The generated hash function works as we'd expect. 210 self.assertEqual(hash(C(10)), hash((10,))) 211 212 # Creating this class should generate an exception, because 213 # __hash__ exists and is not None, which it would be if it 214 # had been auto-generated due to __eq__ being defined. 215 with self.assertRaisesRegex(TypeError, 216 'Cannot overwrite attribute __hash__'): 217 @dataclass(unsafe_hash=True) 218 class C: 219 x: int 220 def __eq__(self): 221 pass 222 def __hash__(self): 223 pass 224 225 def test_overwrite_fields_in_derived_class(self): 226 # Note that x from C1 replaces x in Base, but the order remains 227 # the same as defined in Base. 228 @dataclass 229 class Base: 230 x: Any = 15.0 231 y: int = 0 232 233 @dataclass 234 class C1(Base): 235 z: int = 10 236 x: int = 15 237 238 o = Base() 239 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)') 240 241 o = C1() 242 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)') 243 244 o = C1(x=5) 245 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)') 246 247 def test_field_named_self(self): 248 @dataclass 249 class C: 250 self: str 251 c=C('foo') 252 self.assertEqual(c.self, 'foo') 253 254 # Make sure the first parameter is not named 'self'. 255 sig = inspect.signature(C.__init__) 256 first = next(iter(sig.parameters)) 257 self.assertNotEqual('self', first) 258 259 # But we do use 'self' if no field named self. 260 @dataclass 261 class C: 262 selfx: str 263 264 # Make sure the first parameter is named 'self'. 265 sig = inspect.signature(C.__init__) 266 first = next(iter(sig.parameters)) 267 self.assertEqual('self', first) 268 269 def test_field_named_object(self): 270 @dataclass 271 class C: 272 object: str 273 c = C('foo') 274 self.assertEqual(c.object, 'foo') 275 276 def test_field_named_object_frozen(self): 277 @dataclass(frozen=True) 278 class C: 279 object: str 280 c = C('foo') 281 self.assertEqual(c.object, 'foo') 282 283 def test_field_named_BUILTINS_frozen(self): 284 # gh-96151 285 @dataclass(frozen=True) 286 class C: 287 BUILTINS: int 288 c = C(5) 289 self.assertEqual(c.BUILTINS, 5) 290 291 def test_field_with_special_single_underscore_names(self): 292 # gh-98886 293 294 @dataclass 295 class X: 296 x: int = field(default_factory=lambda: 111) 297 _dflt_x: int = field(default_factory=lambda: 222) 298 299 X() 300 301 @dataclass 302 class Y: 303 y: int = field(default_factory=lambda: 111) 304 _HAS_DEFAULT_FACTORY: int = 222 305 306 assert Y(y=222).y == 222 307 308 def test_field_named_like_builtin(self): 309 # Attribute names can shadow built-in names 310 # since code generation is used. 311 # Ensure that this is not happening. 312 exclusions = {'None', 'True', 'False'} 313 builtins_names = sorted( 314 b for b in builtins.__dict__.keys() 315 if not b.startswith('__') and b not in exclusions 316 ) 317 attributes = [(name, str) for name in builtins_names] 318 C = make_dataclass('C', attributes) 319 320 c = C(*[name for name in builtins_names]) 321 322 for name in builtins_names: 323 self.assertEqual(getattr(c, name), name) 324 325 def test_field_named_like_builtin_frozen(self): 326 # Attribute names can shadow built-in names 327 # since code generation is used. 328 # Ensure that this is not happening 329 # for frozen data classes. 330 exclusions = {'None', 'True', 'False'} 331 builtins_names = sorted( 332 b for b in builtins.__dict__.keys() 333 if not b.startswith('__') and b not in exclusions 334 ) 335 attributes = [(name, str) for name in builtins_names] 336 C = make_dataclass('C', attributes, frozen=True) 337 338 c = C(*[name for name in builtins_names]) 339 340 for name in builtins_names: 341 self.assertEqual(getattr(c, name), name) 342 343 def test_0_field_compare(self): 344 # Ensure that order=False is the default. 345 @dataclass 346 class C0: 347 pass 348 349 @dataclass(order=False) 350 class C1: 351 pass 352 353 for cls in [C0, C1]: 354 with self.subTest(cls=cls): 355 self.assertEqual(cls(), cls()) 356 for idx, fn in enumerate([lambda a, b: a < b, 357 lambda a, b: a <= b, 358 lambda a, b: a > b, 359 lambda a, b: a >= b]): 360 with self.subTest(idx=idx): 361 with self.assertRaisesRegex(TypeError, 362 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"): 363 fn(cls(), cls()) 364 365 @dataclass(order=True) 366 class C: 367 pass 368 self.assertLessEqual(C(), C()) 369 self.assertGreaterEqual(C(), C()) 370 371 def test_1_field_compare(self): 372 # Ensure that order=False is the default. 373 @dataclass 374 class C0: 375 x: int 376 377 @dataclass(order=False) 378 class C1: 379 x: int 380 381 for cls in [C0, C1]: 382 with self.subTest(cls=cls): 383 self.assertEqual(cls(1), cls(1)) 384 self.assertNotEqual(cls(0), cls(1)) 385 for idx, fn in enumerate([lambda a, b: a < b, 386 lambda a, b: a <= b, 387 lambda a, b: a > b, 388 lambda a, b: a >= b]): 389 with self.subTest(idx=idx): 390 with self.assertRaisesRegex(TypeError, 391 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"): 392 fn(cls(0), cls(0)) 393 394 @dataclass(order=True) 395 class C: 396 x: int 397 self.assertLess(C(0), C(1)) 398 self.assertLessEqual(C(0), C(1)) 399 self.assertLessEqual(C(1), C(1)) 400 self.assertGreater(C(1), C(0)) 401 self.assertGreaterEqual(C(1), C(0)) 402 self.assertGreaterEqual(C(1), C(1)) 403 404 def test_simple_compare(self): 405 # Ensure that order=False is the default. 406 @dataclass 407 class C0: 408 x: int 409 y: int 410 411 @dataclass(order=False) 412 class C1: 413 x: int 414 y: int 415 416 for cls in [C0, C1]: 417 with self.subTest(cls=cls): 418 self.assertEqual(cls(0, 0), cls(0, 0)) 419 self.assertEqual(cls(1, 2), cls(1, 2)) 420 self.assertNotEqual(cls(1, 0), cls(0, 0)) 421 self.assertNotEqual(cls(1, 0), cls(1, 1)) 422 for idx, fn in enumerate([lambda a, b: a < b, 423 lambda a, b: a <= b, 424 lambda a, b: a > b, 425 lambda a, b: a >= b]): 426 with self.subTest(idx=idx): 427 with self.assertRaisesRegex(TypeError, 428 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"): 429 fn(cls(0, 0), cls(0, 0)) 430 431 @dataclass(order=True) 432 class C: 433 x: int 434 y: int 435 436 for idx, fn in enumerate([lambda a, b: a == b, 437 lambda a, b: a <= b, 438 lambda a, b: a >= b]): 439 with self.subTest(idx=idx): 440 self.assertTrue(fn(C(0, 0), C(0, 0))) 441 442 for idx, fn in enumerate([lambda a, b: a < b, 443 lambda a, b: a <= b, 444 lambda a, b: a != b]): 445 with self.subTest(idx=idx): 446 self.assertTrue(fn(C(0, 0), C(0, 1))) 447 self.assertTrue(fn(C(0, 1), C(1, 0))) 448 self.assertTrue(fn(C(1, 0), C(1, 1))) 449 450 for idx, fn in enumerate([lambda a, b: a > b, 451 lambda a, b: a >= b, 452 lambda a, b: a != b]): 453 with self.subTest(idx=idx): 454 self.assertTrue(fn(C(0, 1), C(0, 0))) 455 self.assertTrue(fn(C(1, 0), C(0, 1))) 456 self.assertTrue(fn(C(1, 1), C(1, 0))) 457 458 def test_compare_subclasses(self): 459 # Comparisons fail for subclasses, even if no fields 460 # are added. 461 @dataclass 462 class B: 463 i: int 464 465 @dataclass 466 class C(B): 467 pass 468 469 for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False), 470 (lambda a, b: a != b, True)]): 471 with self.subTest(idx=idx): 472 self.assertEqual(fn(B(0), C(0)), expected) 473 474 for idx, fn in enumerate([lambda a, b: a < b, 475 lambda a, b: a <= b, 476 lambda a, b: a > b, 477 lambda a, b: a >= b]): 478 with self.subTest(idx=idx): 479 with self.assertRaisesRegex(TypeError, 480 "not supported between instances of 'B' and 'C'"): 481 fn(B(0), C(0)) 482 483 def test_eq_order(self): 484 # Test combining eq and order. 485 for (eq, order, result ) in [ 486 (False, False, 'neither'), 487 (False, True, 'exception'), 488 (True, False, 'eq_only'), 489 (True, True, 'both'), 490 ]: 491 with self.subTest(eq=eq, order=order): 492 if result == 'exception': 493 with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'): 494 @dataclass(eq=eq, order=order) 495 class C: 496 pass 497 else: 498 @dataclass(eq=eq, order=order) 499 class C: 500 pass 501 502 if result == 'neither': 503 self.assertNotIn('__eq__', C.__dict__) 504 self.assertNotIn('__lt__', C.__dict__) 505 self.assertNotIn('__le__', C.__dict__) 506 self.assertNotIn('__gt__', C.__dict__) 507 self.assertNotIn('__ge__', C.__dict__) 508 elif result == 'both': 509 self.assertIn('__eq__', C.__dict__) 510 self.assertIn('__lt__', C.__dict__) 511 self.assertIn('__le__', C.__dict__) 512 self.assertIn('__gt__', C.__dict__) 513 self.assertIn('__ge__', C.__dict__) 514 elif result == 'eq_only': 515 self.assertIn('__eq__', C.__dict__) 516 self.assertNotIn('__lt__', C.__dict__) 517 self.assertNotIn('__le__', C.__dict__) 518 self.assertNotIn('__gt__', C.__dict__) 519 self.assertNotIn('__ge__', C.__dict__) 520 else: 521 assert False, f'unknown result {result!r}' 522 523 def test_field_no_default(self): 524 @dataclass 525 class C: 526 x: int = field() 527 528 self.assertEqual(C(5).x, 5) 529 530 with self.assertRaisesRegex(TypeError, 531 r"__init__\(\) missing 1 required " 532 "positional argument: 'x'"): 533 C() 534 535 def test_field_default(self): 536 default = object() 537 @dataclass 538 class C: 539 x: object = field(default=default) 540 541 self.assertIs(C.x, default) 542 c = C(10) 543 self.assertEqual(c.x, 10) 544 545 # If we delete the instance attribute, we should then see the 546 # class attribute. 547 del c.x 548 self.assertIs(c.x, default) 549 550 self.assertIs(C().x, default) 551 552 def test_not_in_repr(self): 553 @dataclass 554 class C: 555 x: int = field(repr=False) 556 with self.assertRaises(TypeError): 557 C() 558 c = C(10) 559 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()') 560 561 @dataclass 562 class C: 563 x: int = field(repr=False) 564 y: int 565 c = C(10, 20) 566 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)') 567 568 def test_not_in_compare(self): 569 @dataclass 570 class C: 571 x: int = 0 572 y: int = field(compare=False, default=4) 573 574 self.assertEqual(C(), C(0, 20)) 575 self.assertEqual(C(1, 10), C(1, 20)) 576 self.assertNotEqual(C(3), C(4, 10)) 577 self.assertNotEqual(C(3, 10), C(4, 10)) 578 579 def test_no_unhashable_default(self): 580 # See bpo-44674. 581 class Unhashable: 582 __hash__ = None 583 584 unhashable_re = 'mutable default .* for field a is not allowed' 585 with self.assertRaisesRegex(ValueError, unhashable_re): 586 @dataclass 587 class A: 588 a: dict = {} 589 590 with self.assertRaisesRegex(ValueError, unhashable_re): 591 @dataclass 592 class A: 593 a: Any = Unhashable() 594 595 # Make sure that the machinery looking for hashability is using the 596 # class's __hash__, not the instance's __hash__. 597 with self.assertRaisesRegex(ValueError, unhashable_re): 598 unhashable = Unhashable() 599 # This shouldn't make the variable hashable. 600 unhashable.__hash__ = lambda: 0 601 @dataclass 602 class A: 603 a: Any = unhashable 604 605 def test_hash_field_rules(self): 606 # Test all 6 cases of: 607 # hash=True/False/None 608 # compare=True/False 609 for (hash_, compare, result ) in [ 610 (True, False, 'field' ), 611 (True, True, 'field' ), 612 (False, False, 'absent'), 613 (False, True, 'absent'), 614 (None, False, 'absent'), 615 (None, True, 'field' ), 616 ]: 617 with self.subTest(hash=hash_, compare=compare): 618 @dataclass(unsafe_hash=True) 619 class C: 620 x: int = field(compare=compare, hash=hash_, default=5) 621 622 if result == 'field': 623 # __hash__ contains the field. 624 self.assertEqual(hash(C(5)), hash((5,))) 625 elif result == 'absent': 626 # The field is not present in the hash. 627 self.assertEqual(hash(C(5)), hash(())) 628 else: 629 assert False, f'unknown result {result!r}' 630 631 def test_init_false_no_default(self): 632 # If init=False and no default value, then the field won't be 633 # present in the instance. 634 @dataclass 635 class C: 636 x: int = field(init=False) 637 638 self.assertNotIn('x', C().__dict__) 639 640 @dataclass 641 class C: 642 x: int 643 y: int = 0 644 z: int = field(init=False) 645 t: int = 10 646 647 self.assertNotIn('z', C(0).__dict__) 648 self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0}) 649 650 def test_class_marker(self): 651 @dataclass 652 class C: 653 x: int 654 y: str = field(init=False, default=None) 655 z: str = field(repr=False) 656 657 the_fields = fields(C) 658 # the_fields is a tuple of 3 items, each value 659 # is in __annotations__. 660 self.assertIsInstance(the_fields, tuple) 661 for f in the_fields: 662 self.assertIs(type(f), Field) 663 self.assertIn(f.name, C.__annotations__) 664 665 self.assertEqual(len(the_fields), 3) 666 667 self.assertEqual(the_fields[0].name, 'x') 668 self.assertEqual(the_fields[0].type, int) 669 self.assertFalse(hasattr(C, 'x')) 670 self.assertTrue (the_fields[0].init) 671 self.assertTrue (the_fields[0].repr) 672 self.assertEqual(the_fields[1].name, 'y') 673 self.assertEqual(the_fields[1].type, str) 674 self.assertIsNone(getattr(C, 'y')) 675 self.assertFalse(the_fields[1].init) 676 self.assertTrue (the_fields[1].repr) 677 self.assertEqual(the_fields[2].name, 'z') 678 self.assertEqual(the_fields[2].type, str) 679 self.assertFalse(hasattr(C, 'z')) 680 self.assertTrue (the_fields[2].init) 681 self.assertFalse(the_fields[2].repr) 682 683 def test_field_order(self): 684 @dataclass 685 class B: 686 a: str = 'B:a' 687 b: str = 'B:b' 688 c: str = 'B:c' 689 690 @dataclass 691 class C(B): 692 b: str = 'C:b' 693 694 self.assertEqual([(f.name, f.default) for f in fields(C)], 695 [('a', 'B:a'), 696 ('b', 'C:b'), 697 ('c', 'B:c')]) 698 699 @dataclass 700 class D(B): 701 c: str = 'D:c' 702 703 self.assertEqual([(f.name, f.default) for f in fields(D)], 704 [('a', 'B:a'), 705 ('b', 'B:b'), 706 ('c', 'D:c')]) 707 708 @dataclass 709 class E(D): 710 a: str = 'E:a' 711 d: str = 'E:d' 712 713 self.assertEqual([(f.name, f.default) for f in fields(E)], 714 [('a', 'E:a'), 715 ('b', 'B:b'), 716 ('c', 'D:c'), 717 ('d', 'E:d')]) 718 719 def test_class_attrs(self): 720 # We only have a class attribute if a default value is 721 # specified, either directly or via a field with a default. 722 default = object() 723 @dataclass 724 class C: 725 x: int 726 y: int = field(repr=False) 727 z: object = default 728 t: int = field(default=100) 729 730 self.assertFalse(hasattr(C, 'x')) 731 self.assertFalse(hasattr(C, 'y')) 732 self.assertIs (C.z, default) 733 self.assertEqual(C.t, 100) 734 735 def test_disallowed_mutable_defaults(self): 736 # For the known types, don't allow mutable default values. 737 for typ, empty, non_empty in [(list, [], [1]), 738 (dict, {}, {0:1}), 739 (set, set(), set([1])), 740 ]: 741 with self.subTest(typ=typ): 742 # Can't use a zero-length value. 743 with self.assertRaisesRegex(ValueError, 744 f'mutable default {typ} for field ' 745 'x is not allowed'): 746 @dataclass 747 class Point: 748 x: typ = empty 749 750 751 # Nor a non-zero-length value 752 with self.assertRaisesRegex(ValueError, 753 f'mutable default {typ} for field ' 754 'y is not allowed'): 755 @dataclass 756 class Point: 757 y: typ = non_empty 758 759 # Check subtypes also fail. 760 class Subclass(typ): pass 761 762 with self.assertRaisesRegex(ValueError, 763 "mutable default .*Subclass'>" 764 " for field z is not allowed" 765 ): 766 @dataclass 767 class Point: 768 z: typ = Subclass() 769 770 # Because this is a ClassVar, it can be mutable. 771 @dataclass 772 class C: 773 z: ClassVar[typ] = typ() 774 775 # Because this is a ClassVar, it can be mutable. 776 @dataclass 777 class C: 778 x: ClassVar[typ] = Subclass() 779 780 def test_deliberately_mutable_defaults(self): 781 # If a mutable default isn't in the known list of 782 # (list, dict, set), then it's okay. 783 class Mutable: 784 def __init__(self): 785 self.l = [] 786 787 @dataclass 788 class C: 789 x: Mutable 790 791 # These 2 instances will share this value of x. 792 lst = Mutable() 793 o1 = C(lst) 794 o2 = C(lst) 795 self.assertEqual(o1, o2) 796 o1.x.l.extend([1, 2]) 797 self.assertEqual(o1, o2) 798 self.assertEqual(o1.x.l, [1, 2]) 799 self.assertIs(o1.x, o2.x) 800 801 def test_no_options(self): 802 # Call with dataclass(). 803 @dataclass() 804 class C: 805 x: int 806 807 self.assertEqual(C(42).x, 42) 808 809 def test_not_tuple(self): 810 # Make sure we can't be compared to a tuple. 811 @dataclass 812 class Point: 813 x: int 814 y: int 815 self.assertNotEqual(Point(1, 2), (1, 2)) 816 817 # And that we can't compare to another unrelated dataclass. 818 @dataclass 819 class C: 820 x: int 821 y: int 822 self.assertNotEqual(Point(1, 3), C(1, 3)) 823 824 def test_not_other_dataclass(self): 825 # Test that some of the problems with namedtuple don't happen 826 # here. 827 @dataclass 828 class Point3D: 829 x: int 830 y: int 831 z: int 832 833 @dataclass 834 class Date: 835 year: int 836 month: int 837 day: int 838 839 self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3)) 840 self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3)) 841 842 # Make sure we can't unpack. 843 with self.assertRaisesRegex(TypeError, 'unpack'): 844 x, y, z = Point3D(4, 5, 6) 845 846 # Make sure another class with the same field names isn't 847 # equal. 848 @dataclass 849 class Point3Dv1: 850 x: int = 0 851 y: int = 0 852 z: int = 0 853 self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1()) 854 855 def test_function_annotations(self): 856 # Some dummy class and instance to use as a default. 857 class F: 858 pass 859 f = F() 860 861 def validate_class(cls): 862 # First, check __annotations__, even though they're not 863 # function annotations. 864 self.assertEqual(cls.__annotations__['i'], int) 865 self.assertEqual(cls.__annotations__['j'], str) 866 self.assertEqual(cls.__annotations__['k'], F) 867 self.assertEqual(cls.__annotations__['l'], float) 868 self.assertEqual(cls.__annotations__['z'], complex) 869 870 # Verify __init__. 871 872 signature = inspect.signature(cls.__init__) 873 # Check the return type, should be None. 874 self.assertIs(signature.return_annotation, None) 875 876 # Check each parameter. 877 params = iter(signature.parameters.values()) 878 param = next(params) 879 # This is testing an internal name, and probably shouldn't be tested. 880 self.assertEqual(param.name, 'self') 881 param = next(params) 882 self.assertEqual(param.name, 'i') 883 self.assertIs (param.annotation, int) 884 self.assertEqual(param.default, inspect.Parameter.empty) 885 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD) 886 param = next(params) 887 self.assertEqual(param.name, 'j') 888 self.assertIs (param.annotation, str) 889 self.assertEqual(param.default, inspect.Parameter.empty) 890 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD) 891 param = next(params) 892 self.assertEqual(param.name, 'k') 893 self.assertIs (param.annotation, F) 894 # Don't test for the default, since it's set to MISSING. 895 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD) 896 param = next(params) 897 self.assertEqual(param.name, 'l') 898 self.assertIs (param.annotation, float) 899 # Don't test for the default, since it's set to MISSING. 900 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD) 901 self.assertRaises(StopIteration, next, params) 902 903 904 @dataclass 905 class C: 906 i: int 907 j: str 908 k: F = f 909 l: float=field(default=None) 910 z: complex=field(default=3+4j, init=False) 911 912 validate_class(C) 913 914 # Now repeat with __hash__. 915 @dataclass(frozen=True, unsafe_hash=True) 916 class C: 917 i: int 918 j: str 919 k: F = f 920 l: float=field(default=None) 921 z: complex=field(default=3+4j, init=False) 922 923 validate_class(C) 924 925 def test_missing_default(self): 926 # Test that MISSING works the same as a default not being 927 # specified. 928 @dataclass 929 class C: 930 x: int=field(default=MISSING) 931 with self.assertRaisesRegex(TypeError, 932 r'__init__\(\) missing 1 required ' 933 'positional argument'): 934 C() 935 self.assertNotIn('x', C.__dict__) 936 937 @dataclass 938 class D: 939 x: int 940 with self.assertRaisesRegex(TypeError, 941 r'__init__\(\) missing 1 required ' 942 'positional argument'): 943 D() 944 self.assertNotIn('x', D.__dict__) 945 946 def test_missing_default_factory(self): 947 # Test that MISSING works the same as a default factory not 948 # being specified (which is really the same as a default not 949 # being specified, too). 950 @dataclass 951 class C: 952 x: int=field(default_factory=MISSING) 953 with self.assertRaisesRegex(TypeError, 954 r'__init__\(\) missing 1 required ' 955 'positional argument'): 956 C() 957 self.assertNotIn('x', C.__dict__) 958 959 @dataclass 960 class D: 961 x: int=field(default=MISSING, default_factory=MISSING) 962 with self.assertRaisesRegex(TypeError, 963 r'__init__\(\) missing 1 required ' 964 'positional argument'): 965 D() 966 self.assertNotIn('x', D.__dict__) 967 968 def test_missing_repr(self): 969 self.assertIn('MISSING_TYPE object', repr(MISSING)) 970 971 def test_dont_include_other_annotations(self): 972 @dataclass 973 class C: 974 i: int 975 def foo(self) -> int: 976 return 4 977 @property 978 def bar(self) -> int: 979 return 5 980 self.assertEqual(list(C.__annotations__), ['i']) 981 self.assertEqual(C(10).foo(), 4) 982 self.assertEqual(C(10).bar, 5) 983 self.assertEqual(C(10).i, 10) 984 985 def test_post_init(self): 986 # Just make sure it gets called 987 @dataclass 988 class C: 989 def __post_init__(self): 990 raise CustomError() 991 with self.assertRaises(CustomError): 992 C() 993 994 @dataclass 995 class C: 996 i: int = 10 997 def __post_init__(self): 998 if self.i == 10: 999 raise CustomError() 1000 with self.assertRaises(CustomError): 1001 C() 1002 # post-init gets called, but doesn't raise. This is just 1003 # checking that self is used correctly. 1004 C(5) 1005 1006 # If there's not an __init__, then post-init won't get called. 1007 @dataclass(init=False) 1008 class C: 1009 def __post_init__(self): 1010 raise CustomError() 1011 # Creating the class won't raise 1012 C() 1013 1014 @dataclass 1015 class C: 1016 x: int = 0 1017 def __post_init__(self): 1018 self.x *= 2 1019 self.assertEqual(C().x, 0) 1020 self.assertEqual(C(2).x, 4) 1021 1022 # Make sure that if we're frozen, post-init can't set 1023 # attributes. 1024 @dataclass(frozen=True) 1025 class C: 1026 x: int = 0 1027 def __post_init__(self): 1028 self.x *= 2 1029 with self.assertRaises(FrozenInstanceError): 1030 C() 1031 1032 def test_post_init_super(self): 1033 # Make sure super() post-init isn't called by default. 1034 class B: 1035 def __post_init__(self): 1036 raise CustomError() 1037 1038 @dataclass 1039 class C(B): 1040 def __post_init__(self): 1041 self.x = 5 1042 1043 self.assertEqual(C().x, 5) 1044 1045 # Now call super(), and it will raise. 1046 @dataclass 1047 class C(B): 1048 def __post_init__(self): 1049 super().__post_init__() 1050 1051 with self.assertRaises(CustomError): 1052 C() 1053 1054 # Make sure post-init is called, even if not defined in our 1055 # class. 1056 @dataclass 1057 class C(B): 1058 pass 1059 1060 with self.assertRaises(CustomError): 1061 C() 1062 1063 def test_post_init_staticmethod(self): 1064 flag = False 1065 @dataclass 1066 class C: 1067 x: int 1068 y: int 1069 @staticmethod 1070 def __post_init__(): 1071 nonlocal flag 1072 flag = True 1073 1074 self.assertFalse(flag) 1075 c = C(3, 4) 1076 self.assertEqual((c.x, c.y), (3, 4)) 1077 self.assertTrue(flag) 1078 1079 def test_post_init_classmethod(self): 1080 @dataclass 1081 class C: 1082 flag = False 1083 x: int 1084 y: int 1085 @classmethod 1086 def __post_init__(cls): 1087 cls.flag = True 1088 1089 self.assertFalse(C.flag) 1090 c = C(3, 4) 1091 self.assertEqual((c.x, c.y), (3, 4)) 1092 self.assertTrue(C.flag) 1093 1094 def test_post_init_not_auto_added(self): 1095 # See bpo-46757, which had proposed always adding __post_init__. As 1096 # Raymond Hettinger pointed out, that would be a breaking change. So, 1097 # add a test to make sure that the current behavior doesn't change. 1098 1099 @dataclass 1100 class A0: 1101 pass 1102 1103 @dataclass 1104 class B0: 1105 b_called: bool = False 1106 def __post_init__(self): 1107 self.b_called = True 1108 1109 @dataclass 1110 class C0(A0, B0): 1111 c_called: bool = False 1112 def __post_init__(self): 1113 super().__post_init__() 1114 self.c_called = True 1115 1116 # Since A0 has no __post_init__, and one wasn't automatically added 1117 # (because that's the rule: it's never added by @dataclass, it's only 1118 # the class author that can add it), then B0.__post_init__ is called. 1119 # Verify that. 1120 c = C0() 1121 self.assertTrue(c.b_called) 1122 self.assertTrue(c.c_called) 1123 1124 ###################################### 1125 # Now, the same thing, except A1 defines __post_init__. 1126 @dataclass 1127 class A1: 1128 def __post_init__(self): 1129 pass 1130 1131 @dataclass 1132 class B1: 1133 b_called: bool = False 1134 def __post_init__(self): 1135 self.b_called = True 1136 1137 @dataclass 1138 class C1(A1, B1): 1139 c_called: bool = False 1140 def __post_init__(self): 1141 super().__post_init__() 1142 self.c_called = True 1143 1144 # This time, B1.__post_init__ isn't being called. This mimics what 1145 # would happen if A1.__post_init__ had been automatically added, 1146 # instead of manually added as we see here. This test isn't really 1147 # needed, but I'm including it just to demonstrate the changed 1148 # behavior when A1 does define __post_init__. 1149 c = C1() 1150 self.assertFalse(c.b_called) 1151 self.assertTrue(c.c_called) 1152 1153 def test_class_var(self): 1154 # Make sure ClassVars are ignored in __init__, __repr__, etc. 1155 @dataclass 1156 class C: 1157 x: int 1158 y: int = 10 1159 z: ClassVar[int] = 1000 1160 w: ClassVar[int] = 2000 1161 t: ClassVar[int] = 3000 1162 s: ClassVar = 4000 1163 1164 c = C(5) 1165 self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)') 1166 self.assertEqual(len(fields(C)), 2) # We have 2 fields. 1167 self.assertEqual(len(C.__annotations__), 6) # And 4 ClassVars. 1168 self.assertEqual(c.z, 1000) 1169 self.assertEqual(c.w, 2000) 1170 self.assertEqual(c.t, 3000) 1171 self.assertEqual(c.s, 4000) 1172 C.z += 1 1173 self.assertEqual(c.z, 1001) 1174 c = C(20) 1175 self.assertEqual((c.x, c.y), (20, 10)) 1176 self.assertEqual(c.z, 1001) 1177 self.assertEqual(c.w, 2000) 1178 self.assertEqual(c.t, 3000) 1179 self.assertEqual(c.s, 4000) 1180 1181 def test_class_var_no_default(self): 1182 # If a ClassVar has no default value, it should not be set on the class. 1183 @dataclass 1184 class C: 1185 x: ClassVar[int] 1186 1187 self.assertNotIn('x', C.__dict__) 1188 1189 def test_class_var_default_factory(self): 1190 # It makes no sense for a ClassVar to have a default factory. When 1191 # would it be called? Call it yourself, since it's class-wide. 1192 with self.assertRaisesRegex(TypeError, 1193 'cannot have a default factory'): 1194 @dataclass 1195 class C: 1196 x: ClassVar[int] = field(default_factory=int) 1197 1198 self.assertNotIn('x', C.__dict__) 1199 1200 def test_class_var_with_default(self): 1201 # If a ClassVar has a default value, it should be set on the class. 1202 @dataclass 1203 class C: 1204 x: ClassVar[int] = 10 1205 self.assertEqual(C.x, 10) 1206 1207 @dataclass 1208 class C: 1209 x: ClassVar[int] = field(default=10) 1210 self.assertEqual(C.x, 10) 1211 1212 def test_class_var_frozen(self): 1213 # Make sure ClassVars work even if we're frozen. 1214 @dataclass(frozen=True) 1215 class C: 1216 x: int 1217 y: int = 10 1218 z: ClassVar[int] = 1000 1219 w: ClassVar[int] = 2000 1220 t: ClassVar[int] = 3000 1221 1222 c = C(5) 1223 self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)') 1224 self.assertEqual(len(fields(C)), 2) # We have 2 fields 1225 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars 1226 self.assertEqual(c.z, 1000) 1227 self.assertEqual(c.w, 2000) 1228 self.assertEqual(c.t, 3000) 1229 # We can still modify the ClassVar, it's only instances that are 1230 # frozen. 1231 C.z += 1 1232 self.assertEqual(c.z, 1001) 1233 c = C(20) 1234 self.assertEqual((c.x, c.y), (20, 10)) 1235 self.assertEqual(c.z, 1001) 1236 self.assertEqual(c.w, 2000) 1237 self.assertEqual(c.t, 3000) 1238 1239 def test_init_var_no_default(self): 1240 # If an InitVar has no default value, it should not be set on the class. 1241 @dataclass 1242 class C: 1243 x: InitVar[int] 1244 1245 self.assertNotIn('x', C.__dict__) 1246 1247 def test_init_var_default_factory(self): 1248 # It makes no sense for an InitVar to have a default factory. When 1249 # would it be called? Call it yourself, since it's class-wide. 1250 with self.assertRaisesRegex(TypeError, 1251 'cannot have a default factory'): 1252 @dataclass 1253 class C: 1254 x: InitVar[int] = field(default_factory=int) 1255 1256 self.assertNotIn('x', C.__dict__) 1257 1258 def test_init_var_with_default(self): 1259 # If an InitVar has a default value, it should be set on the class. 1260 @dataclass 1261 class C: 1262 x: InitVar[int] = 10 1263 self.assertEqual(C.x, 10) 1264 1265 @dataclass 1266 class C: 1267 x: InitVar[int] = field(default=10) 1268 self.assertEqual(C.x, 10) 1269 1270 def test_init_var(self): 1271 @dataclass 1272 class C: 1273 x: int = None 1274 init_param: InitVar[int] = None 1275 1276 def __post_init__(self, init_param): 1277 if self.x is None: 1278 self.x = init_param*2 1279 1280 c = C(init_param=10) 1281 self.assertEqual(c.x, 20) 1282 1283 def test_init_var_preserve_type(self): 1284 self.assertEqual(InitVar[int].type, int) 1285 1286 # Make sure the repr is correct. 1287 self.assertEqual(repr(InitVar[int]), 'dataclasses.InitVar[int]') 1288 self.assertEqual(repr(InitVar[List[int]]), 1289 'dataclasses.InitVar[typing.List[int]]') 1290 self.assertEqual(repr(InitVar[list[int]]), 1291 'dataclasses.InitVar[list[int]]') 1292 self.assertEqual(repr(InitVar[int|str]), 1293 'dataclasses.InitVar[int | str]') 1294 1295 def test_init_var_inheritance(self): 1296 # Note that this deliberately tests that a dataclass need not 1297 # have a __post_init__ function if it has an InitVar field. 1298 # It could just be used in a derived class, as shown here. 1299 @dataclass 1300 class Base: 1301 x: int 1302 init_base: InitVar[int] 1303 1304 # We can instantiate by passing the InitVar, even though 1305 # it's not used. 1306 b = Base(0, 10) 1307 self.assertEqual(vars(b), {'x': 0}) 1308 1309 @dataclass 1310 class C(Base): 1311 y: int 1312 init_derived: InitVar[int] 1313 1314 def __post_init__(self, init_base, init_derived): 1315 self.x = self.x + init_base 1316 self.y = self.y + init_derived 1317 1318 c = C(10, 11, 50, 51) 1319 self.assertEqual(vars(c), {'x': 21, 'y': 101}) 1320 1321 def test_init_var_name_shadowing(self): 1322 # Because dataclasses rely exclusively on `__annotations__` for 1323 # handling InitVar and `__annotations__` preserves shadowed definitions, 1324 # you can actually shadow an InitVar with a method or property. 1325 # 1326 # This only works when there is no default value; `dataclasses` uses the 1327 # actual name (which will be bound to the shadowing method) for default 1328 # values. 1329 @dataclass 1330 class C: 1331 shadowed: InitVar[int] 1332 _shadowed: int = field(init=False) 1333 1334 def __post_init__(self, shadowed): 1335 self._shadowed = shadowed * 2 1336 1337 @property 1338 def shadowed(self): 1339 return self._shadowed * 3 1340 1341 c = C(5) 1342 self.assertEqual(c.shadowed, 30) 1343 1344 def test_default_factory(self): 1345 # Test a factory that returns a new list. 1346 @dataclass 1347 class C: 1348 x: int 1349 y: list = field(default_factory=list) 1350 1351 c0 = C(3) 1352 c1 = C(3) 1353 self.assertEqual(c0.x, 3) 1354 self.assertEqual(c0.y, []) 1355 self.assertEqual(c0, c1) 1356 self.assertIsNot(c0.y, c1.y) 1357 self.assertEqual(astuple(C(5, [1])), (5, [1])) 1358 1359 # Test a factory that returns a shared list. 1360 l = [] 1361 @dataclass 1362 class C: 1363 x: int 1364 y: list = field(default_factory=lambda: l) 1365 1366 c0 = C(3) 1367 c1 = C(3) 1368 self.assertEqual(c0.x, 3) 1369 self.assertEqual(c0.y, []) 1370 self.assertEqual(c0, c1) 1371 self.assertIs(c0.y, c1.y) 1372 self.assertEqual(astuple(C(5, [1])), (5, [1])) 1373 1374 # Test various other field flags. 1375 # repr 1376 @dataclass 1377 class C: 1378 x: list = field(default_factory=list, repr=False) 1379 self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()') 1380 self.assertEqual(C().x, []) 1381 1382 # hash 1383 @dataclass(unsafe_hash=True) 1384 class C: 1385 x: list = field(default_factory=list, hash=False) 1386 self.assertEqual(astuple(C()), ([],)) 1387 self.assertEqual(hash(C()), hash(())) 1388 1389 # init (see also test_default_factory_with_no_init) 1390 @dataclass 1391 class C: 1392 x: list = field(default_factory=list, init=False) 1393 self.assertEqual(astuple(C()), ([],)) 1394 1395 # compare 1396 @dataclass 1397 class C: 1398 x: list = field(default_factory=list, compare=False) 1399 self.assertEqual(C(), C([1])) 1400 1401 def test_default_factory_with_no_init(self): 1402 # We need a factory with a side effect. 1403 factory = Mock() 1404 1405 @dataclass 1406 class C: 1407 x: list = field(default_factory=factory, init=False) 1408 1409 # Make sure the default factory is called for each new instance. 1410 C().x 1411 self.assertEqual(factory.call_count, 1) 1412 C().x 1413 self.assertEqual(factory.call_count, 2) 1414 1415 def test_default_factory_not_called_if_value_given(self): 1416 # We need a factory that we can test if it's been called. 1417 factory = Mock() 1418 1419 @dataclass 1420 class C: 1421 x: int = field(default_factory=factory) 1422 1423 # Make sure that if a field has a default factory function, 1424 # it's not called if a value is specified. 1425 C().x 1426 self.assertEqual(factory.call_count, 1) 1427 self.assertEqual(C(10).x, 10) 1428 self.assertEqual(factory.call_count, 1) 1429 C().x 1430 self.assertEqual(factory.call_count, 2) 1431 1432 def test_default_factory_derived(self): 1433 # See bpo-32896. 1434 @dataclass 1435 class Foo: 1436 x: dict = field(default_factory=dict) 1437 1438 @dataclass 1439 class Bar(Foo): 1440 y: int = 1 1441 1442 self.assertEqual(Foo().x, {}) 1443 self.assertEqual(Bar().x, {}) 1444 self.assertEqual(Bar().y, 1) 1445 1446 @dataclass 1447 class Baz(Foo): 1448 pass 1449 self.assertEqual(Baz().x, {}) 1450 1451 def test_intermediate_non_dataclass(self): 1452 # Test that an intermediate class that defines 1453 # annotations does not define fields. 1454 1455 @dataclass 1456 class A: 1457 x: int 1458 1459 class B(A): 1460 y: int 1461 1462 @dataclass 1463 class C(B): 1464 z: int 1465 1466 c = C(1, 3) 1467 self.assertEqual((c.x, c.z), (1, 3)) 1468 1469 # .y was not initialized. 1470 with self.assertRaisesRegex(AttributeError, 1471 'object has no attribute'): 1472 c.y 1473 1474 # And if we again derive a non-dataclass, no fields are added. 1475 class D(C): 1476 t: int 1477 d = D(4, 5) 1478 self.assertEqual((d.x, d.z), (4, 5)) 1479 1480 def test_classvar_default_factory(self): 1481 # It's an error for a ClassVar to have a factory function. 1482 with self.assertRaisesRegex(TypeError, 1483 'cannot have a default factory'): 1484 @dataclass 1485 class C: 1486 x: ClassVar[int] = field(default_factory=int) 1487 1488 def test_is_dataclass(self): 1489 class NotDataClass: 1490 pass 1491 1492 self.assertFalse(is_dataclass(0)) 1493 self.assertFalse(is_dataclass(int)) 1494 self.assertFalse(is_dataclass(NotDataClass)) 1495 self.assertFalse(is_dataclass(NotDataClass())) 1496 1497 @dataclass 1498 class C: 1499 x: int 1500 1501 @dataclass 1502 class D: 1503 d: C 1504 e: int 1505 1506 c = C(10) 1507 d = D(c, 4) 1508 1509 self.assertTrue(is_dataclass(C)) 1510 self.assertTrue(is_dataclass(c)) 1511 self.assertFalse(is_dataclass(c.x)) 1512 self.assertTrue(is_dataclass(d.d)) 1513 self.assertFalse(is_dataclass(d.e)) 1514 1515 def test_is_dataclass_when_getattr_always_returns(self): 1516 # See bpo-37868. 1517 class A: 1518 def __getattr__(self, key): 1519 return 0 1520 self.assertFalse(is_dataclass(A)) 1521 a = A() 1522 1523 # Also test for an instance attribute. 1524 class B: 1525 pass 1526 b = B() 1527 b.__dataclass_fields__ = [] 1528 1529 for obj in a, b: 1530 with self.subTest(obj=obj): 1531 self.assertFalse(is_dataclass(obj)) 1532 1533 # Indirect tests for _is_dataclass_instance(). 1534 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'): 1535 asdict(obj) 1536 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'): 1537 astuple(obj) 1538 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'): 1539 replace(obj, x=0) 1540 1541 def test_is_dataclass_genericalias(self): 1542 @dataclass 1543 class A(types.GenericAlias): 1544 origin: type 1545 args: type 1546 self.assertTrue(is_dataclass(A)) 1547 a = A(list, int) 1548 self.assertTrue(is_dataclass(type(a))) 1549 self.assertTrue(is_dataclass(a)) 1550 1551 def test_is_dataclass_inheritance(self): 1552 @dataclass 1553 class X: 1554 y: int 1555 1556 class Z(X): 1557 pass 1558 1559 self.assertTrue(is_dataclass(X), "X should be a dataclass") 1560 self.assertTrue( 1561 is_dataclass(Z), 1562 "Z should be a dataclass because it inherits from X", 1563 ) 1564 z_instance = Z(y=5) 1565 self.assertTrue( 1566 is_dataclass(z_instance), 1567 "z_instance should be a dataclass because it is an instance of Z", 1568 ) 1569 1570 def test_helper_fields_with_class_instance(self): 1571 # Check that we can call fields() on either a class or instance, 1572 # and get back the same thing. 1573 @dataclass 1574 class C: 1575 x: int 1576 y: float 1577 1578 self.assertEqual(fields(C), fields(C(0, 0.0))) 1579 1580 def test_helper_fields_exception(self): 1581 # Check that TypeError is raised if not passed a dataclass or 1582 # instance. 1583 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'): 1584 fields(0) 1585 1586 class C: pass 1587 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'): 1588 fields(C) 1589 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'): 1590 fields(C()) 1591 1592 def test_clean_traceback_from_fields_exception(self): 1593 stdout = io.StringIO() 1594 try: 1595 fields(object) 1596 except TypeError as exc: 1597 traceback.print_exception(exc, file=stdout) 1598 printed_traceback = stdout.getvalue() 1599 self.assertNotIn("AttributeError", printed_traceback) 1600 self.assertNotIn("__dataclass_fields__", printed_traceback) 1601 1602 def test_helper_asdict(self): 1603 # Basic tests for asdict(), it should return a new dictionary. 1604 @dataclass 1605 class C: 1606 x: int 1607 y: int 1608 c = C(1, 2) 1609 1610 self.assertEqual(asdict(c), {'x': 1, 'y': 2}) 1611 self.assertEqual(asdict(c), asdict(c)) 1612 self.assertIsNot(asdict(c), asdict(c)) 1613 c.x = 42 1614 self.assertEqual(asdict(c), {'x': 42, 'y': 2}) 1615 self.assertIs(type(asdict(c)), dict) 1616 1617 def test_helper_asdict_raises_on_classes(self): 1618 # asdict() should raise on a class object. 1619 @dataclass 1620 class C: 1621 x: int 1622 y: int 1623 with self.assertRaisesRegex(TypeError, 'dataclass instance'): 1624 asdict(C) 1625 with self.assertRaisesRegex(TypeError, 'dataclass instance'): 1626 asdict(int) 1627 1628 def test_helper_asdict_copy_values(self): 1629 @dataclass 1630 class C: 1631 x: int 1632 y: List[int] = field(default_factory=list) 1633 initial = [] 1634 c = C(1, initial) 1635 d = asdict(c) 1636 self.assertEqual(d['y'], initial) 1637 self.assertIsNot(d['y'], initial) 1638 c = C(1) 1639 d = asdict(c) 1640 d['y'].append(1) 1641 self.assertEqual(c.y, []) 1642 1643 def test_helper_asdict_nested(self): 1644 @dataclass 1645 class UserId: 1646 token: int 1647 group: int 1648 @dataclass 1649 class User: 1650 name: str 1651 id: UserId 1652 u = User('Joe', UserId(123, 1)) 1653 d = asdict(u) 1654 self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}}) 1655 self.assertIsNot(asdict(u), asdict(u)) 1656 u.id.group = 2 1657 self.assertEqual(asdict(u), {'name': 'Joe', 1658 'id': {'token': 123, 'group': 2}}) 1659 1660 def test_helper_asdict_builtin_containers(self): 1661 @dataclass 1662 class User: 1663 name: str 1664 id: int 1665 @dataclass 1666 class GroupList: 1667 id: int 1668 users: List[User] 1669 @dataclass 1670 class GroupTuple: 1671 id: int 1672 users: Tuple[User, ...] 1673 @dataclass 1674 class GroupDict: 1675 id: int 1676 users: Dict[str, User] 1677 a = User('Alice', 1) 1678 b = User('Bob', 2) 1679 gl = GroupList(0, [a, b]) 1680 gt = GroupTuple(0, (a, b)) 1681 gd = GroupDict(0, {'first': a, 'second': b}) 1682 self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1}, 1683 {'name': 'Bob', 'id': 2}]}) 1684 self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1}, 1685 {'name': 'Bob', 'id': 2})}) 1686 self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1}, 1687 'second': {'name': 'Bob', 'id': 2}}}) 1688 1689 def test_helper_asdict_builtin_object_containers(self): 1690 @dataclass 1691 class Child: 1692 d: object 1693 1694 @dataclass 1695 class Parent: 1696 child: Child 1697 1698 self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}}) 1699 self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}}) 1700 1701 def test_helper_asdict_factory(self): 1702 @dataclass 1703 class C: 1704 x: int 1705 y: int 1706 c = C(1, 2) 1707 d = asdict(c, dict_factory=OrderedDict) 1708 self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)])) 1709 self.assertIsNot(d, asdict(c, dict_factory=OrderedDict)) 1710 c.x = 42 1711 d = asdict(c, dict_factory=OrderedDict) 1712 self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)])) 1713 self.assertIs(type(d), OrderedDict) 1714 1715 def test_helper_asdict_namedtuple(self): 1716 T = namedtuple('T', 'a b c') 1717 @dataclass 1718 class C: 1719 x: str 1720 y: T 1721 c = C('outer', T(1, C('inner', T(11, 12, 13)), 2)) 1722 1723 d = asdict(c) 1724 self.assertEqual(d, {'x': 'outer', 1725 'y': T(1, 1726 {'x': 'inner', 1727 'y': T(11, 12, 13)}, 1728 2), 1729 } 1730 ) 1731 1732 # Now with a dict_factory. OrderedDict is convenient, but 1733 # since it compares to dicts, we also need to have separate 1734 # assertIs tests. 1735 d = asdict(c, dict_factory=OrderedDict) 1736 self.assertEqual(d, {'x': 'outer', 1737 'y': T(1, 1738 {'x': 'inner', 1739 'y': T(11, 12, 13)}, 1740 2), 1741 } 1742 ) 1743 1744 # Make sure that the returned dicts are actually OrderedDicts. 1745 self.assertIs(type(d), OrderedDict) 1746 self.assertIs(type(d['y'][1]), OrderedDict) 1747 1748 def test_helper_asdict_namedtuple_key(self): 1749 # Ensure that a field that contains a dict which has a 1750 # namedtuple as a key works with asdict(). 1751 1752 @dataclass 1753 class C: 1754 f: dict 1755 T = namedtuple('T', 'a') 1756 1757 c = C({T('an a'): 0}) 1758 1759 self.assertEqual(asdict(c), {'f': {T(a='an a'): 0}}) 1760 1761 def test_helper_asdict_namedtuple_derived(self): 1762 class T(namedtuple('Tbase', 'a')): 1763 def my_a(self): 1764 return self.a 1765 1766 @dataclass 1767 class C: 1768 f: T 1769 1770 t = T(6) 1771 c = C(t) 1772 1773 d = asdict(c) 1774 self.assertEqual(d, {'f': T(a=6)}) 1775 # Make sure that t has been copied, not used directly. 1776 self.assertIsNot(d['f'], t) 1777 self.assertEqual(d['f'].my_a(), 6) 1778 1779 def test_helper_asdict_defaultdict(self): 1780 # Ensure asdict() does not throw exceptions when a 1781 # defaultdict is a member of a dataclass 1782 @dataclass 1783 class C: 1784 mp: DefaultDict[str, List] 1785 1786 dd = defaultdict(list) 1787 dd["x"].append(12) 1788 c = C(mp=dd) 1789 d = asdict(c) 1790 1791 self.assertEqual(d, {"mp": {"x": [12]}}) 1792 self.assertTrue(d["mp"] is not c.mp) # make sure defaultdict is copied 1793 1794 def test_helper_astuple(self): 1795 # Basic tests for astuple(), it should return a new tuple. 1796 @dataclass 1797 class C: 1798 x: int 1799 y: int = 0 1800 c = C(1) 1801 1802 self.assertEqual(astuple(c), (1, 0)) 1803 self.assertEqual(astuple(c), astuple(c)) 1804 self.assertIsNot(astuple(c), astuple(c)) 1805 c.y = 42 1806 self.assertEqual(astuple(c), (1, 42)) 1807 self.assertIs(type(astuple(c)), tuple) 1808 1809 def test_helper_astuple_raises_on_classes(self): 1810 # astuple() should raise on a class object. 1811 @dataclass 1812 class C: 1813 x: int 1814 y: int 1815 with self.assertRaisesRegex(TypeError, 'dataclass instance'): 1816 astuple(C) 1817 with self.assertRaisesRegex(TypeError, 'dataclass instance'): 1818 astuple(int) 1819 1820 def test_helper_astuple_copy_values(self): 1821 @dataclass 1822 class C: 1823 x: int 1824 y: List[int] = field(default_factory=list) 1825 initial = [] 1826 c = C(1, initial) 1827 t = astuple(c) 1828 self.assertEqual(t[1], initial) 1829 self.assertIsNot(t[1], initial) 1830 c = C(1) 1831 t = astuple(c) 1832 t[1].append(1) 1833 self.assertEqual(c.y, []) 1834 1835 def test_helper_astuple_nested(self): 1836 @dataclass 1837 class UserId: 1838 token: int 1839 group: int 1840 @dataclass 1841 class User: 1842 name: str 1843 id: UserId 1844 u = User('Joe', UserId(123, 1)) 1845 t = astuple(u) 1846 self.assertEqual(t, ('Joe', (123, 1))) 1847 self.assertIsNot(astuple(u), astuple(u)) 1848 u.id.group = 2 1849 self.assertEqual(astuple(u), ('Joe', (123, 2))) 1850 1851 def test_helper_astuple_builtin_containers(self): 1852 @dataclass 1853 class User: 1854 name: str 1855 id: int 1856 @dataclass 1857 class GroupList: 1858 id: int 1859 users: List[User] 1860 @dataclass 1861 class GroupTuple: 1862 id: int 1863 users: Tuple[User, ...] 1864 @dataclass 1865 class GroupDict: 1866 id: int 1867 users: Dict[str, User] 1868 a = User('Alice', 1) 1869 b = User('Bob', 2) 1870 gl = GroupList(0, [a, b]) 1871 gt = GroupTuple(0, (a, b)) 1872 gd = GroupDict(0, {'first': a, 'second': b}) 1873 self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)])) 1874 self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2)))) 1875 self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)})) 1876 1877 def test_helper_astuple_builtin_object_containers(self): 1878 @dataclass 1879 class Child: 1880 d: object 1881 1882 @dataclass 1883 class Parent: 1884 child: Child 1885 1886 self.assertEqual(astuple(Parent(Child([1]))), (([1],),)) 1887 self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),)) 1888 1889 def test_helper_astuple_factory(self): 1890 @dataclass 1891 class C: 1892 x: int 1893 y: int 1894 NT = namedtuple('NT', 'x y') 1895 def nt(lst): 1896 return NT(*lst) 1897 c = C(1, 2) 1898 t = astuple(c, tuple_factory=nt) 1899 self.assertEqual(t, NT(1, 2)) 1900 self.assertIsNot(t, astuple(c, tuple_factory=nt)) 1901 c.x = 42 1902 t = astuple(c, tuple_factory=nt) 1903 self.assertEqual(t, NT(42, 2)) 1904 self.assertIs(type(t), NT) 1905 1906 def test_helper_astuple_namedtuple(self): 1907 T = namedtuple('T', 'a b c') 1908 @dataclass 1909 class C: 1910 x: str 1911 y: T 1912 c = C('outer', T(1, C('inner', T(11, 12, 13)), 2)) 1913 1914 t = astuple(c) 1915 self.assertEqual(t, ('outer', T(1, ('inner', (11, 12, 13)), 2))) 1916 1917 # Now, using a tuple_factory. list is convenient here. 1918 t = astuple(c, tuple_factory=list) 1919 self.assertEqual(t, ['outer', T(1, ['inner', T(11, 12, 13)], 2)]) 1920 1921 def test_helper_astuple_defaultdict(self): 1922 # Ensure astuple() does not throw exceptions when a 1923 # defaultdict is a member of a dataclass 1924 @dataclass 1925 class C: 1926 mp: DefaultDict[str, List] 1927 1928 dd = defaultdict(list) 1929 dd["x"].append(12) 1930 c = C(mp=dd) 1931 t = astuple(c) 1932 1933 self.assertEqual(t, ({"x": [12]},)) 1934 self.assertTrue(t[0] is not dd) # make sure defaultdict is copied 1935 1936 def test_dynamic_class_creation(self): 1937 cls_dict = {'__annotations__': {'x': int, 'y': int}, 1938 } 1939 1940 # Create the class. 1941 cls = type('C', (), cls_dict) 1942 1943 # Make it a dataclass. 1944 cls1 = dataclass(cls) 1945 1946 self.assertEqual(cls1, cls) 1947 self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2}) 1948 1949 def test_dynamic_class_creation_using_field(self): 1950 cls_dict = {'__annotations__': {'x': int, 'y': int}, 1951 'y': field(default=5), 1952 } 1953 1954 # Create the class. 1955 cls = type('C', (), cls_dict) 1956 1957 # Make it a dataclass. 1958 cls1 = dataclass(cls) 1959 1960 self.assertEqual(cls1, cls) 1961 self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5}) 1962 1963 def test_init_in_order(self): 1964 @dataclass 1965 class C: 1966 a: int 1967 b: int = field() 1968 c: list = field(default_factory=list, init=False) 1969 d: list = field(default_factory=list) 1970 e: int = field(default=4, init=False) 1971 f: int = 4 1972 1973 calls = [] 1974 def setattr(self, name, value): 1975 calls.append((name, value)) 1976 1977 C.__setattr__ = setattr 1978 c = C(0, 1) 1979 self.assertEqual(('a', 0), calls[0]) 1980 self.assertEqual(('b', 1), calls[1]) 1981 self.assertEqual(('c', []), calls[2]) 1982 self.assertEqual(('d', []), calls[3]) 1983 self.assertNotIn(('e', 4), calls) 1984 self.assertEqual(('f', 4), calls[4]) 1985 1986 def test_items_in_dicts(self): 1987 @dataclass 1988 class C: 1989 a: int 1990 b: list = field(default_factory=list, init=False) 1991 c: list = field(default_factory=list) 1992 d: int = field(default=4, init=False) 1993 e: int = 0 1994 1995 c = C(0) 1996 # Class dict 1997 self.assertNotIn('a', C.__dict__) 1998 self.assertNotIn('b', C.__dict__) 1999 self.assertNotIn('c', C.__dict__) 2000 self.assertIn('d', C.__dict__) 2001 self.assertEqual(C.d, 4) 2002 self.assertIn('e', C.__dict__) 2003 self.assertEqual(C.e, 0) 2004 # Instance dict 2005 self.assertIn('a', c.__dict__) 2006 self.assertEqual(c.a, 0) 2007 self.assertIn('b', c.__dict__) 2008 self.assertEqual(c.b, []) 2009 self.assertIn('c', c.__dict__) 2010 self.assertEqual(c.c, []) 2011 self.assertNotIn('d', c.__dict__) 2012 self.assertIn('e', c.__dict__) 2013 self.assertEqual(c.e, 0) 2014 2015 def test_alternate_classmethod_constructor(self): 2016 # Since __post_init__ can't take params, use a classmethod 2017 # alternate constructor. This is mostly an example to show 2018 # how to use this technique. 2019 @dataclass 2020 class C: 2021 x: int 2022 @classmethod 2023 def from_file(cls, filename): 2024 # In a real example, create a new instance 2025 # and populate 'x' from contents of a file. 2026 value_in_file = 20 2027 return cls(value_in_file) 2028 2029 self.assertEqual(C.from_file('filename').x, 20) 2030 2031 def test_field_metadata_default(self): 2032 # Make sure the default metadata is read-only and of 2033 # zero length. 2034 @dataclass 2035 class C: 2036 i: int 2037 2038 self.assertFalse(fields(C)[0].metadata) 2039 self.assertEqual(len(fields(C)[0].metadata), 0) 2040 with self.assertRaisesRegex(TypeError, 2041 'does not support item assignment'): 2042 fields(C)[0].metadata['test'] = 3 2043 2044 def test_field_metadata_mapping(self): 2045 # Make sure only a mapping can be passed as metadata 2046 # zero length. 2047 with self.assertRaises(TypeError): 2048 @dataclass 2049 class C: 2050 i: int = field(metadata=0) 2051 2052 # Make sure an empty dict works. 2053 d = {} 2054 @dataclass 2055 class C: 2056 i: int = field(metadata=d) 2057 self.assertFalse(fields(C)[0].metadata) 2058 self.assertEqual(len(fields(C)[0].metadata), 0) 2059 # Update should work (see bpo-35960). 2060 d['foo'] = 1 2061 self.assertEqual(len(fields(C)[0].metadata), 1) 2062 self.assertEqual(fields(C)[0].metadata['foo'], 1) 2063 with self.assertRaisesRegex(TypeError, 2064 'does not support item assignment'): 2065 fields(C)[0].metadata['test'] = 3 2066 2067 # Make sure a non-empty dict works. 2068 d = {'test': 10, 'bar': '42', 3: 'three'} 2069 @dataclass 2070 class C: 2071 i: int = field(metadata=d) 2072 self.assertEqual(len(fields(C)[0].metadata), 3) 2073 self.assertEqual(fields(C)[0].metadata['test'], 10) 2074 self.assertEqual(fields(C)[0].metadata['bar'], '42') 2075 self.assertEqual(fields(C)[0].metadata[3], 'three') 2076 # Update should work. 2077 d['foo'] = 1 2078 self.assertEqual(len(fields(C)[0].metadata), 4) 2079 self.assertEqual(fields(C)[0].metadata['foo'], 1) 2080 with self.assertRaises(KeyError): 2081 # Non-existent key. 2082 fields(C)[0].metadata['baz'] 2083 with self.assertRaisesRegex(TypeError, 2084 'does not support item assignment'): 2085 fields(C)[0].metadata['test'] = 3 2086 2087 def test_field_metadata_custom_mapping(self): 2088 # Try a custom mapping. 2089 class SimpleNameSpace: 2090 def __init__(self, **kw): 2091 self.__dict__.update(kw) 2092 2093 def __getitem__(self, item): 2094 if item == 'xyzzy': 2095 return 'plugh' 2096 return getattr(self, item) 2097 2098 def __len__(self): 2099 return self.__dict__.__len__() 2100 2101 @dataclass 2102 class C: 2103 i: int = field(metadata=SimpleNameSpace(a=10)) 2104 2105 self.assertEqual(len(fields(C)[0].metadata), 1) 2106 self.assertEqual(fields(C)[0].metadata['a'], 10) 2107 with self.assertRaises(AttributeError): 2108 fields(C)[0].metadata['b'] 2109 # Make sure we're still talking to our custom mapping. 2110 self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh') 2111 2112 def test_generic_dataclasses(self): 2113 T = TypeVar('T') 2114 2115 @dataclass 2116 class LabeledBox(Generic[T]): 2117 content: T 2118 label: str = '<unknown>' 2119 2120 box = LabeledBox(42) 2121 self.assertEqual(box.content, 42) 2122 self.assertEqual(box.label, '<unknown>') 2123 2124 # Subscripting the resulting class should work, etc. 2125 Alias = List[LabeledBox[int]] 2126 2127 def test_generic_extending(self): 2128 S = TypeVar('S') 2129 T = TypeVar('T') 2130 2131 @dataclass 2132 class Base(Generic[T, S]): 2133 x: T 2134 y: S 2135 2136 @dataclass 2137 class DataDerived(Base[int, T]): 2138 new_field: str 2139 Alias = DataDerived[str] 2140 c = Alias(0, 'test1', 'test2') 2141 self.assertEqual(astuple(c), (0, 'test1', 'test2')) 2142 2143 class NonDataDerived(Base[int, T]): 2144 def new_method(self): 2145 return self.y 2146 Alias = NonDataDerived[float] 2147 c = Alias(10, 1.0) 2148 self.assertEqual(c.new_method(), 1.0) 2149 2150 def test_generic_dynamic(self): 2151 T = TypeVar('T') 2152 2153 @dataclass 2154 class Parent(Generic[T]): 2155 x: T 2156 Child = make_dataclass('Child', [('y', T), ('z', Optional[T], None)], 2157 bases=(Parent[int], Generic[T]), namespace={'other': 42}) 2158 self.assertIs(Child[int](1, 2).z, None) 2159 self.assertEqual(Child[int](1, 2, 3).z, 3) 2160 self.assertEqual(Child[int](1, 2, 3).other, 42) 2161 # Check that type aliases work correctly. 2162 Alias = Child[T] 2163 self.assertEqual(Alias[int](1, 2).x, 1) 2164 # Check MRO resolution. 2165 self.assertEqual(Child.__mro__, (Child, Parent, Generic, object)) 2166 2167 def test_dataclasses_pickleable(self): 2168 global P, Q, R 2169 @dataclass 2170 class P: 2171 x: int 2172 y: int = 0 2173 @dataclass 2174 class Q: 2175 x: int 2176 y: int = field(default=0, init=False) 2177 @dataclass 2178 class R: 2179 x: int 2180 y: List[int] = field(default_factory=list) 2181 q = Q(1) 2182 q.y = 2 2183 samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])] 2184 for sample in samples: 2185 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 2186 with self.subTest(sample=sample, proto=proto): 2187 new_sample = pickle.loads(pickle.dumps(sample, proto)) 2188 self.assertEqual(sample.x, new_sample.x) 2189 self.assertEqual(sample.y, new_sample.y) 2190 self.assertIsNot(sample, new_sample) 2191 new_sample.x = 42 2192 another_new_sample = pickle.loads(pickle.dumps(new_sample, proto)) 2193 self.assertEqual(new_sample.x, another_new_sample.x) 2194 self.assertEqual(sample.y, another_new_sample.y) 2195 2196 def test_dataclasses_qualnames(self): 2197 @dataclass(order=True, unsafe_hash=True, frozen=True) 2198 class A: 2199 x: int 2200 y: int 2201 2202 self.assertEqual(A.__init__.__name__, "__init__") 2203 for function in ( 2204 '__eq__', 2205 '__lt__', 2206 '__le__', 2207 '__gt__', 2208 '__ge__', 2209 '__hash__', 2210 '__init__', 2211 '__repr__', 2212 '__setattr__', 2213 '__delattr__', 2214 ): 2215 self.assertEqual(getattr(A, function).__qualname__, f"TestCase.test_dataclasses_qualnames.<locals>.A.{function}") 2216 2217 with self.assertRaisesRegex(TypeError, r"A\.__init__\(\) missing"): 2218 A() 2219 2220 2221class TestFieldNoAnnotation(unittest.TestCase): 2222 def test_field_without_annotation(self): 2223 with self.assertRaisesRegex(TypeError, 2224 "'f' is a field but has no type annotation"): 2225 @dataclass 2226 class C: 2227 f = field() 2228 2229 def test_field_without_annotation_but_annotation_in_base(self): 2230 @dataclass 2231 class B: 2232 f: int 2233 2234 with self.assertRaisesRegex(TypeError, 2235 "'f' is a field but has no type annotation"): 2236 # This is still an error: make sure we don't pick up the 2237 # type annotation in the base class. 2238 @dataclass 2239 class C(B): 2240 f = field() 2241 2242 def test_field_without_annotation_but_annotation_in_base_not_dataclass(self): 2243 # Same test, but with the base class not a dataclass. 2244 class B: 2245 f: int 2246 2247 with self.assertRaisesRegex(TypeError, 2248 "'f' is a field but has no type annotation"): 2249 # This is still an error: make sure we don't pick up the 2250 # type annotation in the base class. 2251 @dataclass 2252 class C(B): 2253 f = field() 2254 2255 2256class TestDocString(unittest.TestCase): 2257 def assertDocStrEqual(self, a, b): 2258 # Because 3.6 and 3.7 differ in how inspect.signature work 2259 # (see bpo #32108), for the time being just compare them with 2260 # whitespace stripped. 2261 self.assertEqual(a.replace(' ', ''), b.replace(' ', '')) 2262 2263 @support.requires_docstrings 2264 def test_existing_docstring_not_overridden(self): 2265 @dataclass 2266 class C: 2267 """Lorem ipsum""" 2268 x: int 2269 2270 self.assertEqual(C.__doc__, "Lorem ipsum") 2271 2272 def test_docstring_no_fields(self): 2273 @dataclass 2274 class C: 2275 pass 2276 2277 self.assertDocStrEqual(C.__doc__, "C()") 2278 2279 def test_docstring_one_field(self): 2280 @dataclass 2281 class C: 2282 x: int 2283 2284 self.assertDocStrEqual(C.__doc__, "C(x:int)") 2285 2286 def test_docstring_two_fields(self): 2287 @dataclass 2288 class C: 2289 x: int 2290 y: int 2291 2292 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)") 2293 2294 def test_docstring_three_fields(self): 2295 @dataclass 2296 class C: 2297 x: int 2298 y: int 2299 z: str 2300 2301 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)") 2302 2303 def test_docstring_one_field_with_default(self): 2304 @dataclass 2305 class C: 2306 x: int = 3 2307 2308 self.assertDocStrEqual(C.__doc__, "C(x:int=3)") 2309 2310 def test_docstring_one_field_with_default_none(self): 2311 @dataclass 2312 class C: 2313 x: Union[int, type(None)] = None 2314 2315 self.assertDocStrEqual(C.__doc__, "C(x:Optional[int]=None)") 2316 2317 def test_docstring_list_field(self): 2318 @dataclass 2319 class C: 2320 x: List[int] 2321 2322 self.assertDocStrEqual(C.__doc__, "C(x:List[int])") 2323 2324 def test_docstring_list_field_with_default_factory(self): 2325 @dataclass 2326 class C: 2327 x: List[int] = field(default_factory=list) 2328 2329 self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)") 2330 2331 def test_docstring_deque_field(self): 2332 @dataclass 2333 class C: 2334 x: deque 2335 2336 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)") 2337 2338 def test_docstring_deque_field_with_default_factory(self): 2339 @dataclass 2340 class C: 2341 x: deque = field(default_factory=deque) 2342 2343 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)") 2344 2345 def test_docstring_with_no_signature(self): 2346 # See https://github.com/python/cpython/issues/103449 2347 class Meta(type): 2348 __call__ = dict 2349 class Base(metaclass=Meta): 2350 pass 2351 2352 @dataclass 2353 class C(Base): 2354 pass 2355 2356 self.assertDocStrEqual(C.__doc__, "C") 2357 2358 2359class TestInit(unittest.TestCase): 2360 def test_base_has_init(self): 2361 class B: 2362 def __init__(self): 2363 self.z = 100 2364 2365 # Make sure that declaring this class doesn't raise an error. 2366 # The issue is that we can't override __init__ in our class, 2367 # but it should be okay to add __init__ to us if our base has 2368 # an __init__. 2369 @dataclass 2370 class C(B): 2371 x: int = 0 2372 c = C(10) 2373 self.assertEqual(c.x, 10) 2374 self.assertNotIn('z', vars(c)) 2375 2376 # Make sure that if we don't add an init, the base __init__ 2377 # gets called. 2378 @dataclass(init=False) 2379 class C(B): 2380 x: int = 10 2381 c = C() 2382 self.assertEqual(c.x, 10) 2383 self.assertEqual(c.z, 100) 2384 2385 def test_no_init(self): 2386 @dataclass(init=False) 2387 class C: 2388 i: int = 0 2389 self.assertEqual(C().i, 0) 2390 2391 @dataclass(init=False) 2392 class C: 2393 i: int = 2 2394 def __init__(self): 2395 self.i = 3 2396 self.assertEqual(C().i, 3) 2397 2398 def test_overwriting_init(self): 2399 # If the class has __init__, use it no matter the value of 2400 # init=. 2401 2402 @dataclass 2403 class C: 2404 x: int 2405 def __init__(self, x): 2406 self.x = 2 * x 2407 self.assertEqual(C(3).x, 6) 2408 2409 @dataclass(init=True) 2410 class C: 2411 x: int 2412 def __init__(self, x): 2413 self.x = 2 * x 2414 self.assertEqual(C(4).x, 8) 2415 2416 @dataclass(init=False) 2417 class C: 2418 x: int 2419 def __init__(self, x): 2420 self.x = 2 * x 2421 self.assertEqual(C(5).x, 10) 2422 2423 def test_inherit_from_protocol(self): 2424 # Dataclasses inheriting from protocol should preserve their own `__init__`. 2425 # See bpo-45081. 2426 2427 class P(Protocol): 2428 a: int 2429 2430 @dataclass 2431 class C(P): 2432 a: int 2433 2434 self.assertEqual(C(5).a, 5) 2435 2436 @dataclass 2437 class D(P): 2438 def __init__(self, a): 2439 self.a = a * 2 2440 2441 self.assertEqual(D(5).a, 10) 2442 2443 2444class TestRepr(unittest.TestCase): 2445 def test_repr(self): 2446 @dataclass 2447 class B: 2448 x: int 2449 2450 @dataclass 2451 class C(B): 2452 y: int = 10 2453 2454 o = C(4) 2455 self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)') 2456 2457 @dataclass 2458 class D(C): 2459 x: int = 20 2460 self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)') 2461 2462 @dataclass 2463 class C: 2464 @dataclass 2465 class D: 2466 i: int 2467 @dataclass 2468 class E: 2469 pass 2470 self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)') 2471 self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()') 2472 2473 def test_no_repr(self): 2474 # Test a class with no __repr__ and repr=False. 2475 @dataclass(repr=False) 2476 class C: 2477 x: int 2478 self.assertIn(f'{__name__}.TestRepr.test_no_repr.<locals>.C object at', 2479 repr(C(3))) 2480 2481 # Test a class with a __repr__ and repr=False. 2482 @dataclass(repr=False) 2483 class C: 2484 x: int 2485 def __repr__(self): 2486 return 'C-class' 2487 self.assertEqual(repr(C(3)), 'C-class') 2488 2489 def test_overwriting_repr(self): 2490 # If the class has __repr__, use it no matter the value of 2491 # repr=. 2492 2493 @dataclass 2494 class C: 2495 x: int 2496 def __repr__(self): 2497 return 'x' 2498 self.assertEqual(repr(C(0)), 'x') 2499 2500 @dataclass(repr=True) 2501 class C: 2502 x: int 2503 def __repr__(self): 2504 return 'x' 2505 self.assertEqual(repr(C(0)), 'x') 2506 2507 @dataclass(repr=False) 2508 class C: 2509 x: int 2510 def __repr__(self): 2511 return 'x' 2512 self.assertEqual(repr(C(0)), 'x') 2513 2514 2515class TestEq(unittest.TestCase): 2516 def test_recursive_eq(self): 2517 # Test a class with recursive child 2518 @dataclass 2519 class C: 2520 recursive: object = ... 2521 c = C() 2522 c.recursive = c 2523 self.assertEqual(c, c) 2524 2525 def test_no_eq(self): 2526 # Test a class with no __eq__ and eq=False. 2527 @dataclass(eq=False) 2528 class C: 2529 x: int 2530 self.assertNotEqual(C(0), C(0)) 2531 c = C(3) 2532 self.assertEqual(c, c) 2533 2534 # Test a class with an __eq__ and eq=False. 2535 @dataclass(eq=False) 2536 class C: 2537 x: int 2538 def __eq__(self, other): 2539 return other == 10 2540 self.assertEqual(C(3), 10) 2541 2542 def test_overwriting_eq(self): 2543 # If the class has __eq__, use it no matter the value of 2544 # eq=. 2545 2546 @dataclass 2547 class C: 2548 x: int 2549 def __eq__(self, other): 2550 return other == 3 2551 self.assertEqual(C(1), 3) 2552 self.assertNotEqual(C(1), 1) 2553 2554 @dataclass(eq=True) 2555 class C: 2556 x: int 2557 def __eq__(self, other): 2558 return other == 4 2559 self.assertEqual(C(1), 4) 2560 self.assertNotEqual(C(1), 1) 2561 2562 @dataclass(eq=False) 2563 class C: 2564 x: int 2565 def __eq__(self, other): 2566 return other == 5 2567 self.assertEqual(C(1), 5) 2568 self.assertNotEqual(C(1), 1) 2569 2570 2571class TestOrdering(unittest.TestCase): 2572 def test_functools_total_ordering(self): 2573 # Test that functools.total_ordering works with this class. 2574 @total_ordering 2575 @dataclass 2576 class C: 2577 x: int 2578 def __lt__(self, other): 2579 # Perform the test "backward", just to make 2580 # sure this is being called. 2581 return self.x >= other 2582 2583 self.assertLess(C(0), -1) 2584 self.assertLessEqual(C(0), -1) 2585 self.assertGreater(C(0), 1) 2586 self.assertGreaterEqual(C(0), 1) 2587 2588 def test_no_order(self): 2589 # Test that no ordering functions are added by default. 2590 @dataclass(order=False) 2591 class C: 2592 x: int 2593 # Make sure no order methods are added. 2594 self.assertNotIn('__le__', C.__dict__) 2595 self.assertNotIn('__lt__', C.__dict__) 2596 self.assertNotIn('__ge__', C.__dict__) 2597 self.assertNotIn('__gt__', C.__dict__) 2598 2599 # Test that __lt__ is still called 2600 @dataclass(order=False) 2601 class C: 2602 x: int 2603 def __lt__(self, other): 2604 return False 2605 # Make sure other methods aren't added. 2606 self.assertNotIn('__le__', C.__dict__) 2607 self.assertNotIn('__ge__', C.__dict__) 2608 self.assertNotIn('__gt__', C.__dict__) 2609 2610 def test_overwriting_order(self): 2611 with self.assertRaisesRegex(TypeError, 2612 'Cannot overwrite attribute __lt__' 2613 '.*using functools.total_ordering'): 2614 @dataclass(order=True) 2615 class C: 2616 x: int 2617 def __lt__(self): 2618 pass 2619 2620 with self.assertRaisesRegex(TypeError, 2621 'Cannot overwrite attribute __le__' 2622 '.*using functools.total_ordering'): 2623 @dataclass(order=True) 2624 class C: 2625 x: int 2626 def __le__(self): 2627 pass 2628 2629 with self.assertRaisesRegex(TypeError, 2630 'Cannot overwrite attribute __gt__' 2631 '.*using functools.total_ordering'): 2632 @dataclass(order=True) 2633 class C: 2634 x: int 2635 def __gt__(self): 2636 pass 2637 2638 with self.assertRaisesRegex(TypeError, 2639 'Cannot overwrite attribute __ge__' 2640 '.*using functools.total_ordering'): 2641 @dataclass(order=True) 2642 class C: 2643 x: int 2644 def __ge__(self): 2645 pass 2646 2647class TestHash(unittest.TestCase): 2648 def test_unsafe_hash(self): 2649 @dataclass(unsafe_hash=True) 2650 class C: 2651 x: int 2652 y: str 2653 self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo'))) 2654 2655 def test_hash_rules(self): 2656 def non_bool(value): 2657 # Map to something else that's True, but not a bool. 2658 if value is None: 2659 return None 2660 if value: 2661 return (3,) 2662 return 0 2663 2664 def test(case, unsafe_hash, eq, frozen, with_hash, result): 2665 with self.subTest(case=case, unsafe_hash=unsafe_hash, eq=eq, 2666 frozen=frozen): 2667 if result != 'exception': 2668 if with_hash: 2669 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen) 2670 class C: 2671 def __hash__(self): 2672 return 0 2673 else: 2674 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen) 2675 class C: 2676 pass 2677 2678 # See if the result matches what's expected. 2679 if result == 'fn': 2680 # __hash__ contains the function we generated. 2681 self.assertIn('__hash__', C.__dict__) 2682 self.assertIsNotNone(C.__dict__['__hash__']) 2683 2684 elif result == '': 2685 # __hash__ is not present in our class. 2686 if not with_hash: 2687 self.assertNotIn('__hash__', C.__dict__) 2688 2689 elif result == 'none': 2690 # __hash__ is set to None. 2691 self.assertIn('__hash__', C.__dict__) 2692 self.assertIsNone(C.__dict__['__hash__']) 2693 2694 elif result == 'exception': 2695 # Creating the class should cause an exception. 2696 # This only happens with with_hash==True. 2697 assert(with_hash) 2698 with self.assertRaisesRegex(TypeError, 'Cannot overwrite attribute __hash__'): 2699 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen) 2700 class C: 2701 def __hash__(self): 2702 return 0 2703 2704 else: 2705 assert False, f'unknown result {result!r}' 2706 2707 # There are 8 cases of: 2708 # unsafe_hash=True/False 2709 # eq=True/False 2710 # frozen=True/False 2711 # And for each of these, a different result if 2712 # __hash__ is defined or not. 2713 for case, (unsafe_hash, eq, frozen, res_no_defined_hash, res_defined_hash) in enumerate([ 2714 (False, False, False, '', ''), 2715 (False, False, True, '', ''), 2716 (False, True, False, 'none', ''), 2717 (False, True, True, 'fn', ''), 2718 (True, False, False, 'fn', 'exception'), 2719 (True, False, True, 'fn', 'exception'), 2720 (True, True, False, 'fn', 'exception'), 2721 (True, True, True, 'fn', 'exception'), 2722 ], 1): 2723 test(case, unsafe_hash, eq, frozen, False, res_no_defined_hash) 2724 test(case, unsafe_hash, eq, frozen, True, res_defined_hash) 2725 2726 # Test non-bool truth values, too. This is just to 2727 # make sure the data-driven table in the decorator 2728 # handles non-bool values. 2729 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), False, res_no_defined_hash) 2730 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), True, res_defined_hash) 2731 2732 2733 def test_eq_only(self): 2734 # If a class defines __eq__, __hash__ is automatically added 2735 # and set to None. This is normal Python behavior, not 2736 # related to dataclasses. Make sure we don't interfere with 2737 # that (see bpo=32546). 2738 2739 @dataclass 2740 class C: 2741 i: int 2742 def __eq__(self, other): 2743 return self.i == other.i 2744 self.assertEqual(C(1), C(1)) 2745 self.assertNotEqual(C(1), C(4)) 2746 2747 # And make sure things work in this case if we specify 2748 # unsafe_hash=True. 2749 @dataclass(unsafe_hash=True) 2750 class C: 2751 i: int 2752 def __eq__(self, other): 2753 return self.i == other.i 2754 self.assertEqual(C(1), C(1.0)) 2755 self.assertEqual(hash(C(1)), hash(C(1.0))) 2756 2757 # And check that the classes __eq__ is being used, despite 2758 # specifying eq=True. 2759 @dataclass(unsafe_hash=True, eq=True) 2760 class C: 2761 i: int 2762 def __eq__(self, other): 2763 return self.i == 3 and self.i == other.i 2764 self.assertEqual(C(3), C(3)) 2765 self.assertNotEqual(C(1), C(1)) 2766 self.assertEqual(hash(C(1)), hash(C(1.0))) 2767 2768 def test_0_field_hash(self): 2769 @dataclass(frozen=True) 2770 class C: 2771 pass 2772 self.assertEqual(hash(C()), hash(())) 2773 2774 @dataclass(unsafe_hash=True) 2775 class C: 2776 pass 2777 self.assertEqual(hash(C()), hash(())) 2778 2779 def test_1_field_hash(self): 2780 @dataclass(frozen=True) 2781 class C: 2782 x: int 2783 self.assertEqual(hash(C(4)), hash((4,))) 2784 self.assertEqual(hash(C(42)), hash((42,))) 2785 2786 @dataclass(unsafe_hash=True) 2787 class C: 2788 x: int 2789 self.assertEqual(hash(C(4)), hash((4,))) 2790 self.assertEqual(hash(C(42)), hash((42,))) 2791 2792 def test_hash_no_args(self): 2793 # Test dataclasses with no hash= argument. This exists to 2794 # make sure that if the @dataclass parameter name is changed 2795 # or the non-default hashing behavior changes, the default 2796 # hashability keeps working the same way. 2797 2798 class Base: 2799 def __hash__(self): 2800 return 301 2801 2802 # If frozen or eq is None, then use the default value (do not 2803 # specify any value in the decorator). 2804 for frozen, eq, base, expected in [ 2805 (None, None, object, 'unhashable'), 2806 (None, None, Base, 'unhashable'), 2807 (None, False, object, 'object'), 2808 (None, False, Base, 'base'), 2809 (None, True, object, 'unhashable'), 2810 (None, True, Base, 'unhashable'), 2811 (False, None, object, 'unhashable'), 2812 (False, None, Base, 'unhashable'), 2813 (False, False, object, 'object'), 2814 (False, False, Base, 'base'), 2815 (False, True, object, 'unhashable'), 2816 (False, True, Base, 'unhashable'), 2817 (True, None, object, 'tuple'), 2818 (True, None, Base, 'tuple'), 2819 (True, False, object, 'object'), 2820 (True, False, Base, 'base'), 2821 (True, True, object, 'tuple'), 2822 (True, True, Base, 'tuple'), 2823 ]: 2824 2825 with self.subTest(frozen=frozen, eq=eq, base=base, expected=expected): 2826 # First, create the class. 2827 if frozen is None and eq is None: 2828 @dataclass 2829 class C(base): 2830 i: int 2831 elif frozen is None: 2832 @dataclass(eq=eq) 2833 class C(base): 2834 i: int 2835 elif eq is None: 2836 @dataclass(frozen=frozen) 2837 class C(base): 2838 i: int 2839 else: 2840 @dataclass(frozen=frozen, eq=eq) 2841 class C(base): 2842 i: int 2843 2844 # Now, make sure it hashes as expected. 2845 if expected == 'unhashable': 2846 c = C(10) 2847 with self.assertRaisesRegex(TypeError, 'unhashable type'): 2848 hash(c) 2849 2850 elif expected == 'base': 2851 self.assertEqual(hash(C(10)), 301) 2852 2853 elif expected == 'object': 2854 # I'm not sure what test to use here. object's 2855 # hash isn't based on id(), so calling hash() 2856 # won't tell us much. So, just check the 2857 # function used is object's. 2858 self.assertIs(C.__hash__, object.__hash__) 2859 2860 elif expected == 'tuple': 2861 self.assertEqual(hash(C(42)), hash((42,))) 2862 2863 else: 2864 assert False, f'unknown value for expected={expected!r}' 2865 2866 2867class TestFrozen(unittest.TestCase): 2868 def test_frozen(self): 2869 @dataclass(frozen=True) 2870 class C: 2871 i: int 2872 2873 c = C(10) 2874 self.assertEqual(c.i, 10) 2875 with self.assertRaises(FrozenInstanceError): 2876 c.i = 5 2877 self.assertEqual(c.i, 10) 2878 2879 def test_frozen_empty(self): 2880 @dataclass(frozen=True) 2881 class C: 2882 pass 2883 2884 c = C() 2885 self.assertFalse(hasattr(c, 'i')) 2886 with self.assertRaises(FrozenInstanceError): 2887 c.i = 5 2888 self.assertFalse(hasattr(c, 'i')) 2889 with self.assertRaises(FrozenInstanceError): 2890 del c.i 2891 2892 def test_inherit(self): 2893 @dataclass(frozen=True) 2894 class C: 2895 i: int 2896 2897 @dataclass(frozen=True) 2898 class D(C): 2899 j: int 2900 2901 d = D(0, 10) 2902 with self.assertRaises(FrozenInstanceError): 2903 d.i = 5 2904 with self.assertRaises(FrozenInstanceError): 2905 d.j = 6 2906 self.assertEqual(d.i, 0) 2907 self.assertEqual(d.j, 10) 2908 2909 def test_inherit_nonfrozen_from_empty_frozen(self): 2910 @dataclass(frozen=True) 2911 class C: 2912 pass 2913 2914 with self.assertRaisesRegex(TypeError, 2915 'cannot inherit non-frozen dataclass from a frozen one'): 2916 @dataclass 2917 class D(C): 2918 j: int 2919 2920 def test_inherit_frozen_mutliple_inheritance(self): 2921 @dataclass 2922 class NotFrozen: 2923 pass 2924 2925 @dataclass(frozen=True) 2926 class Frozen: 2927 pass 2928 2929 class NotDataclass: 2930 pass 2931 2932 for bases in ( 2933 (NotFrozen, Frozen), 2934 (Frozen, NotFrozen), 2935 (Frozen, NotDataclass), 2936 (NotDataclass, Frozen), 2937 ): 2938 with self.subTest(bases=bases): 2939 with self.assertRaisesRegex( 2940 TypeError, 2941 'cannot inherit non-frozen dataclass from a frozen one', 2942 ): 2943 @dataclass 2944 class NotFrozenChild(*bases): 2945 pass 2946 2947 for bases in ( 2948 (NotFrozen, Frozen), 2949 (Frozen, NotFrozen), 2950 (NotFrozen, NotDataclass), 2951 (NotDataclass, NotFrozen), 2952 ): 2953 with self.subTest(bases=bases): 2954 with self.assertRaisesRegex( 2955 TypeError, 2956 'cannot inherit frozen dataclass from a non-frozen one', 2957 ): 2958 @dataclass(frozen=True) 2959 class FrozenChild(*bases): 2960 pass 2961 2962 def test_inherit_frozen_mutliple_inheritance_regular_mixins(self): 2963 @dataclass(frozen=True) 2964 class Frozen: 2965 pass 2966 2967 class NotDataclass: 2968 pass 2969 2970 class C1(Frozen, NotDataclass): 2971 pass 2972 self.assertEqual(C1.__mro__, (C1, Frozen, NotDataclass, object)) 2973 2974 class C2(NotDataclass, Frozen): 2975 pass 2976 self.assertEqual(C2.__mro__, (C2, NotDataclass, Frozen, object)) 2977 2978 @dataclass(frozen=True) 2979 class C3(Frozen, NotDataclass): 2980 pass 2981 self.assertEqual(C3.__mro__, (C3, Frozen, NotDataclass, object)) 2982 2983 @dataclass(frozen=True) 2984 class C4(NotDataclass, Frozen): 2985 pass 2986 self.assertEqual(C4.__mro__, (C4, NotDataclass, Frozen, object)) 2987 2988 def test_multiple_frozen_dataclasses_inheritance(self): 2989 @dataclass(frozen=True) 2990 class FrozenA: 2991 pass 2992 2993 @dataclass(frozen=True) 2994 class FrozenB: 2995 pass 2996 2997 class C1(FrozenA, FrozenB): 2998 pass 2999 self.assertEqual(C1.__mro__, (C1, FrozenA, FrozenB, object)) 3000 3001 class C2(FrozenB, FrozenA): 3002 pass 3003 self.assertEqual(C2.__mro__, (C2, FrozenB, FrozenA, object)) 3004 3005 @dataclass(frozen=True) 3006 class C3(FrozenA, FrozenB): 3007 pass 3008 self.assertEqual(C3.__mro__, (C3, FrozenA, FrozenB, object)) 3009 3010 @dataclass(frozen=True) 3011 class C4(FrozenB, FrozenA): 3012 pass 3013 self.assertEqual(C4.__mro__, (C4, FrozenB, FrozenA, object)) 3014 3015 def test_inherit_nonfrozen_from_empty(self): 3016 @dataclass 3017 class C: 3018 pass 3019 3020 @dataclass 3021 class D(C): 3022 j: int 3023 3024 d = D(3) 3025 self.assertEqual(d.j, 3) 3026 self.assertIsInstance(d, C) 3027 3028 # Test both ways: with an intermediate normal (non-dataclass) 3029 # class and without an intermediate class. 3030 def test_inherit_nonfrozen_from_frozen(self): 3031 for intermediate_class in [True, False]: 3032 with self.subTest(intermediate_class=intermediate_class): 3033 @dataclass(frozen=True) 3034 class C: 3035 i: int 3036 3037 if intermediate_class: 3038 class I(C): pass 3039 else: 3040 I = C 3041 3042 with self.assertRaisesRegex(TypeError, 3043 'cannot inherit non-frozen dataclass from a frozen one'): 3044 @dataclass 3045 class D(I): 3046 pass 3047 3048 def test_inherit_frozen_from_nonfrozen(self): 3049 for intermediate_class in [True, False]: 3050 with self.subTest(intermediate_class=intermediate_class): 3051 @dataclass 3052 class C: 3053 i: int 3054 3055 if intermediate_class: 3056 class I(C): pass 3057 else: 3058 I = C 3059 3060 with self.assertRaisesRegex(TypeError, 3061 'cannot inherit frozen dataclass from a non-frozen one'): 3062 @dataclass(frozen=True) 3063 class D(I): 3064 pass 3065 3066 def test_inherit_from_normal_class(self): 3067 for intermediate_class in [True, False]: 3068 with self.subTest(intermediate_class=intermediate_class): 3069 class C: 3070 pass 3071 3072 if intermediate_class: 3073 class I(C): pass 3074 else: 3075 I = C 3076 3077 @dataclass(frozen=True) 3078 class D(I): 3079 i: int 3080 3081 d = D(10) 3082 with self.assertRaises(FrozenInstanceError): 3083 d.i = 5 3084 3085 def test_non_frozen_normal_derived(self): 3086 # See bpo-32953. 3087 3088 @dataclass(frozen=True) 3089 class D: 3090 x: int 3091 y: int = 10 3092 3093 class S(D): 3094 pass 3095 3096 s = S(3) 3097 self.assertEqual(s.x, 3) 3098 self.assertEqual(s.y, 10) 3099 s.cached = True 3100 3101 # But can't change the frozen attributes. 3102 with self.assertRaises(FrozenInstanceError): 3103 s.x = 5 3104 with self.assertRaises(FrozenInstanceError): 3105 s.y = 5 3106 self.assertEqual(s.x, 3) 3107 self.assertEqual(s.y, 10) 3108 self.assertEqual(s.cached, True) 3109 3110 with self.assertRaises(FrozenInstanceError): 3111 del s.x 3112 self.assertEqual(s.x, 3) 3113 with self.assertRaises(FrozenInstanceError): 3114 del s.y 3115 self.assertEqual(s.y, 10) 3116 del s.cached 3117 self.assertFalse(hasattr(s, 'cached')) 3118 with self.assertRaises(AttributeError) as cm: 3119 del s.cached 3120 self.assertNotIsInstance(cm.exception, FrozenInstanceError) 3121 3122 def test_non_frozen_normal_derived_from_empty_frozen(self): 3123 @dataclass(frozen=True) 3124 class D: 3125 pass 3126 3127 class S(D): 3128 pass 3129 3130 s = S() 3131 self.assertFalse(hasattr(s, 'x')) 3132 s.x = 5 3133 self.assertEqual(s.x, 5) 3134 3135 del s.x 3136 self.assertFalse(hasattr(s, 'x')) 3137 with self.assertRaises(AttributeError) as cm: 3138 del s.x 3139 self.assertNotIsInstance(cm.exception, FrozenInstanceError) 3140 3141 def test_overwriting_frozen(self): 3142 # frozen uses __setattr__ and __delattr__. 3143 with self.assertRaisesRegex(TypeError, 3144 'Cannot overwrite attribute __setattr__'): 3145 @dataclass(frozen=True) 3146 class C: 3147 x: int 3148 def __setattr__(self): 3149 pass 3150 3151 with self.assertRaisesRegex(TypeError, 3152 'Cannot overwrite attribute __delattr__'): 3153 @dataclass(frozen=True) 3154 class C: 3155 x: int 3156 def __delattr__(self): 3157 pass 3158 3159 @dataclass(frozen=False) 3160 class C: 3161 x: int 3162 def __setattr__(self, name, value): 3163 self.__dict__['x'] = value * 2 3164 self.assertEqual(C(10).x, 20) 3165 3166 def test_frozen_hash(self): 3167 @dataclass(frozen=True) 3168 class C: 3169 x: Any 3170 3171 # If x is immutable, we can compute the hash. No exception is 3172 # raised. 3173 hash(C(3)) 3174 3175 # If x is mutable, computing the hash is an error. 3176 with self.assertRaisesRegex(TypeError, 'unhashable type'): 3177 hash(C({})) 3178 3179 def test_frozen_deepcopy_without_slots(self): 3180 # see: https://github.com/python/cpython/issues/89683 3181 @dataclass(frozen=True, slots=False) 3182 class C: 3183 s: str 3184 3185 c = C('hello') 3186 self.assertEqual(deepcopy(c), c) 3187 3188 def test_frozen_deepcopy_with_slots(self): 3189 # see: https://github.com/python/cpython/issues/89683 3190 with self.subTest('generated __slots__'): 3191 @dataclass(frozen=True, slots=True) 3192 class C: 3193 s: str 3194 3195 c = C('hello') 3196 self.assertEqual(deepcopy(c), c) 3197 3198 with self.subTest('user-defined __slots__ and no __{get,set}state__'): 3199 @dataclass(frozen=True, slots=False) 3200 class C: 3201 __slots__ = ('s',) 3202 s: str 3203 3204 # with user-defined slots, __getstate__ and __setstate__ are not 3205 # automatically added, hence the error 3206 err = r"^cannot\ assign\ to\ field\ 's'$" 3207 self.assertRaisesRegex(FrozenInstanceError, err, deepcopy, C('')) 3208 3209 with self.subTest('user-defined __slots__ and __{get,set}state__'): 3210 @dataclass(frozen=True, slots=False) 3211 class C: 3212 __slots__ = ('s',) 3213 __getstate__ = dataclasses._dataclass_getstate 3214 __setstate__ = dataclasses._dataclass_setstate 3215 3216 s: str 3217 3218 c = C('hello') 3219 self.assertEqual(deepcopy(c), c) 3220 3221 3222class TestSlots(unittest.TestCase): 3223 def test_simple(self): 3224 @dataclass 3225 class C: 3226 __slots__ = ('x',) 3227 x: Any 3228 3229 # There was a bug where a variable in a slot was assumed to 3230 # also have a default value (of type 3231 # types.MemberDescriptorType). 3232 with self.assertRaisesRegex(TypeError, 3233 r"__init__\(\) missing 1 required positional argument: 'x'"): 3234 C() 3235 3236 # We can create an instance, and assign to x. 3237 c = C(10) 3238 self.assertEqual(c.x, 10) 3239 c.x = 5 3240 self.assertEqual(c.x, 5) 3241 3242 # We can't assign to anything else. 3243 with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'y'"): 3244 c.y = 5 3245 3246 def test_derived_added_field(self): 3247 # See bpo-33100. 3248 @dataclass 3249 class Base: 3250 __slots__ = ('x',) 3251 x: Any 3252 3253 @dataclass 3254 class Derived(Base): 3255 x: int 3256 y: int 3257 3258 d = Derived(1, 2) 3259 self.assertEqual((d.x, d.y), (1, 2)) 3260 3261 # We can add a new field to the derived instance. 3262 d.z = 10 3263 3264 def test_generated_slots(self): 3265 @dataclass(slots=True) 3266 class C: 3267 x: int 3268 y: int 3269 3270 c = C(1, 2) 3271 self.assertEqual((c.x, c.y), (1, 2)) 3272 3273 c.x = 3 3274 c.y = 4 3275 self.assertEqual((c.x, c.y), (3, 4)) 3276 3277 with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'z'"): 3278 c.z = 5 3279 3280 def test_add_slots_when_slots_exists(self): 3281 with self.assertRaisesRegex(TypeError, '^C already specifies __slots__$'): 3282 @dataclass(slots=True) 3283 class C: 3284 __slots__ = ('x',) 3285 x: int 3286 3287 def test_generated_slots_value(self): 3288 3289 class Root: 3290 __slots__ = {'x'} 3291 3292 class Root2(Root): 3293 __slots__ = {'k': '...', 'j': ''} 3294 3295 class Root3(Root2): 3296 __slots__ = ['h'] 3297 3298 class Root4(Root3): 3299 __slots__ = 'aa' 3300 3301 @dataclass(slots=True) 3302 class Base(Root4): 3303 y: int 3304 j: str 3305 h: str 3306 3307 self.assertEqual(Base.__slots__, ('y', )) 3308 3309 @dataclass(slots=True) 3310 class Derived(Base): 3311 aa: float 3312 x: str 3313 z: int 3314 k: str 3315 h: str 3316 3317 self.assertEqual(Derived.__slots__, ('z', )) 3318 3319 @dataclass 3320 class AnotherDerived(Base): 3321 z: int 3322 3323 self.assertNotIn('__slots__', AnotherDerived.__dict__) 3324 3325 def test_cant_inherit_from_iterator_slots(self): 3326 3327 class Root: 3328 __slots__ = iter(['a']) 3329 3330 class Root2(Root): 3331 __slots__ = ('b', ) 3332 3333 with self.assertRaisesRegex( 3334 TypeError, 3335 "^Slots of 'Root' cannot be determined" 3336 ): 3337 @dataclass(slots=True) 3338 class C(Root2): 3339 x: int 3340 3341 def test_returns_new_class(self): 3342 class A: 3343 x: int 3344 3345 B = dataclass(A, slots=True) 3346 self.assertIsNot(A, B) 3347 3348 self.assertFalse(hasattr(A, "__slots__")) 3349 self.assertTrue(hasattr(B, "__slots__")) 3350 3351 # Can't be local to test_frozen_pickle. 3352 @dataclass(frozen=True, slots=True) 3353 class FrozenSlotsClass: 3354 foo: str 3355 bar: int 3356 3357 @dataclass(frozen=True) 3358 class FrozenWithoutSlotsClass: 3359 foo: str 3360 bar: int 3361 3362 def test_frozen_pickle(self): 3363 # bpo-43999 3364 3365 self.assertEqual(self.FrozenSlotsClass.__slots__, ("foo", "bar")) 3366 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 3367 with self.subTest(proto=proto): 3368 obj = self.FrozenSlotsClass("a", 1) 3369 p = pickle.loads(pickle.dumps(obj, protocol=proto)) 3370 self.assertIsNot(obj, p) 3371 self.assertEqual(obj, p) 3372 3373 obj = self.FrozenWithoutSlotsClass("a", 1) 3374 p = pickle.loads(pickle.dumps(obj, protocol=proto)) 3375 self.assertIsNot(obj, p) 3376 self.assertEqual(obj, p) 3377 3378 @dataclass(frozen=True, slots=True) 3379 class FrozenSlotsGetStateClass: 3380 foo: str 3381 bar: int 3382 3383 getstate_called: bool = field(default=False, compare=False) 3384 3385 def __getstate__(self): 3386 object.__setattr__(self, 'getstate_called', True) 3387 return [self.foo, self.bar] 3388 3389 @dataclass(frozen=True, slots=True) 3390 class FrozenSlotsSetStateClass: 3391 foo: str 3392 bar: int 3393 3394 setstate_called: bool = field(default=False, compare=False) 3395 3396 def __setstate__(self, state): 3397 object.__setattr__(self, 'setstate_called', True) 3398 object.__setattr__(self, 'foo', state[0]) 3399 object.__setattr__(self, 'bar', state[1]) 3400 3401 @dataclass(frozen=True, slots=True) 3402 class FrozenSlotsAllStateClass: 3403 foo: str 3404 bar: int 3405 3406 getstate_called: bool = field(default=False, compare=False) 3407 setstate_called: bool = field(default=False, compare=False) 3408 3409 def __getstate__(self): 3410 object.__setattr__(self, 'getstate_called', True) 3411 return [self.foo, self.bar] 3412 3413 def __setstate__(self, state): 3414 object.__setattr__(self, 'setstate_called', True) 3415 object.__setattr__(self, 'foo', state[0]) 3416 object.__setattr__(self, 'bar', state[1]) 3417 3418 def test_frozen_slots_pickle_custom_state(self): 3419 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 3420 with self.subTest(proto=proto): 3421 obj = self.FrozenSlotsGetStateClass('a', 1) 3422 dumped = pickle.dumps(obj, protocol=proto) 3423 3424 self.assertTrue(obj.getstate_called) 3425 self.assertEqual(obj, pickle.loads(dumped)) 3426 3427 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 3428 with self.subTest(proto=proto): 3429 obj = self.FrozenSlotsSetStateClass('a', 1) 3430 obj2 = pickle.loads(pickle.dumps(obj, protocol=proto)) 3431 3432 self.assertTrue(obj2.setstate_called) 3433 self.assertEqual(obj, obj2) 3434 3435 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 3436 with self.subTest(proto=proto): 3437 obj = self.FrozenSlotsAllStateClass('a', 1) 3438 dumped = pickle.dumps(obj, protocol=proto) 3439 3440 self.assertTrue(obj.getstate_called) 3441 3442 obj2 = pickle.loads(dumped) 3443 self.assertTrue(obj2.setstate_called) 3444 self.assertEqual(obj, obj2) 3445 3446 def test_slots_with_default_no_init(self): 3447 # Originally reported in bpo-44649. 3448 @dataclass(slots=True) 3449 class A: 3450 a: str 3451 b: str = field(default='b', init=False) 3452 3453 obj = A("a") 3454 self.assertEqual(obj.a, 'a') 3455 self.assertEqual(obj.b, 'b') 3456 3457 def test_slots_with_default_factory_no_init(self): 3458 # Originally reported in bpo-44649. 3459 @dataclass(slots=True) 3460 class A: 3461 a: str 3462 b: str = field(default_factory=lambda:'b', init=False) 3463 3464 obj = A("a") 3465 self.assertEqual(obj.a, 'a') 3466 self.assertEqual(obj.b, 'b') 3467 3468 def test_slots_no_weakref(self): 3469 @dataclass(slots=True) 3470 class A: 3471 # No weakref. 3472 pass 3473 3474 self.assertNotIn("__weakref__", A.__slots__) 3475 a = A() 3476 with self.assertRaisesRegex(TypeError, 3477 "cannot create weak reference"): 3478 weakref.ref(a) 3479 with self.assertRaises(AttributeError): 3480 a.__weakref__ 3481 3482 def test_slots_weakref(self): 3483 @dataclass(slots=True, weakref_slot=True) 3484 class A: 3485 a: int 3486 3487 self.assertIn("__weakref__", A.__slots__) 3488 a = A(1) 3489 a_ref = weakref.ref(a) 3490 3491 self.assertIs(a.__weakref__, a_ref) 3492 3493 def test_slots_weakref_base_str(self): 3494 class Base: 3495 __slots__ = '__weakref__' 3496 3497 @dataclass(slots=True) 3498 class A(Base): 3499 a: int 3500 3501 # __weakref__ is in the base class, not A. But an A is still weakref-able. 3502 self.assertIn("__weakref__", Base.__slots__) 3503 self.assertNotIn("__weakref__", A.__slots__) 3504 a = A(1) 3505 weakref.ref(a) 3506 3507 def test_slots_weakref_base_tuple(self): 3508 # Same as test_slots_weakref_base, but use a tuple instead of a string 3509 # in the base class. 3510 class Base: 3511 __slots__ = ('__weakref__',) 3512 3513 @dataclass(slots=True) 3514 class A(Base): 3515 a: int 3516 3517 # __weakref__ is in the base class, not A. But an A is still 3518 # weakref-able. 3519 self.assertIn("__weakref__", Base.__slots__) 3520 self.assertNotIn("__weakref__", A.__slots__) 3521 a = A(1) 3522 weakref.ref(a) 3523 3524 def test_weakref_slot_without_slot(self): 3525 with self.assertRaisesRegex(TypeError, 3526 "weakref_slot is True but slots is False"): 3527 @dataclass(weakref_slot=True) 3528 class A: 3529 a: int 3530 3531 def test_weakref_slot_make_dataclass(self): 3532 A = make_dataclass('A', [('a', int),], slots=True, weakref_slot=True) 3533 self.assertIn("__weakref__", A.__slots__) 3534 a = A(1) 3535 weakref.ref(a) 3536 3537 # And make sure if raises if slots=True is not given. 3538 with self.assertRaisesRegex(TypeError, 3539 "weakref_slot is True but slots is False"): 3540 B = make_dataclass('B', [('a', int),], weakref_slot=True) 3541 3542 def test_weakref_slot_subclass_weakref_slot(self): 3543 @dataclass(slots=True, weakref_slot=True) 3544 class Base: 3545 field: int 3546 3547 # A *can* also specify weakref_slot=True if it wants to (gh-93521) 3548 @dataclass(slots=True, weakref_slot=True) 3549 class A(Base): 3550 ... 3551 3552 # __weakref__ is in the base class, not A. But an instance of A 3553 # is still weakref-able. 3554 self.assertIn("__weakref__", Base.__slots__) 3555 self.assertNotIn("__weakref__", A.__slots__) 3556 a = A(1) 3557 a_ref = weakref.ref(a) 3558 self.assertIs(a.__weakref__, a_ref) 3559 3560 def test_weakref_slot_subclass_no_weakref_slot(self): 3561 @dataclass(slots=True, weakref_slot=True) 3562 class Base: 3563 field: int 3564 3565 @dataclass(slots=True) 3566 class A(Base): 3567 ... 3568 3569 # __weakref__ is in the base class, not A. Even though A doesn't 3570 # specify weakref_slot, it should still be weakref-able. 3571 self.assertIn("__weakref__", Base.__slots__) 3572 self.assertNotIn("__weakref__", A.__slots__) 3573 a = A(1) 3574 a_ref = weakref.ref(a) 3575 self.assertIs(a.__weakref__, a_ref) 3576 3577 def test_weakref_slot_normal_base_weakref_slot(self): 3578 class Base: 3579 __slots__ = ('__weakref__',) 3580 3581 @dataclass(slots=True, weakref_slot=True) 3582 class A(Base): 3583 field: int 3584 3585 # __weakref__ is in the base class, not A. But an instance of 3586 # A is still weakref-able. 3587 self.assertIn("__weakref__", Base.__slots__) 3588 self.assertNotIn("__weakref__", A.__slots__) 3589 a = A(1) 3590 a_ref = weakref.ref(a) 3591 self.assertIs(a.__weakref__, a_ref) 3592 3593 3594 def test_dataclass_derived_weakref_slot(self): 3595 class A: 3596 pass 3597 3598 @dataclass(slots=True, weakref_slot=True) 3599 class B(A): 3600 pass 3601 3602 self.assertEqual(B.__slots__, ()) 3603 B() 3604 3605 def test_dataclass_derived_generic(self): 3606 T = typing.TypeVar('T') 3607 3608 @dataclass(slots=True, weakref_slot=True) 3609 class A(typing.Generic[T]): 3610 pass 3611 self.assertEqual(A.__slots__, ('__weakref__',)) 3612 self.assertTrue(A.__weakref__) 3613 A() 3614 3615 @dataclass(slots=True, weakref_slot=True) 3616 class B[T2]: 3617 pass 3618 self.assertEqual(B.__slots__, ('__weakref__',)) 3619 self.assertTrue(B.__weakref__) 3620 B() 3621 3622 def test_dataclass_derived_generic_from_base(self): 3623 T = typing.TypeVar('T') 3624 3625 class RawBase: ... 3626 3627 @dataclass(slots=True, weakref_slot=True) 3628 class C1(typing.Generic[T], RawBase): 3629 pass 3630 self.assertEqual(C1.__slots__, ()) 3631 self.assertTrue(C1.__weakref__) 3632 C1() 3633 @dataclass(slots=True, weakref_slot=True) 3634 class C2(RawBase, typing.Generic[T]): 3635 pass 3636 self.assertEqual(C2.__slots__, ()) 3637 self.assertTrue(C2.__weakref__) 3638 C2() 3639 3640 @dataclass(slots=True, weakref_slot=True) 3641 class D[T2](RawBase): 3642 pass 3643 self.assertEqual(D.__slots__, ()) 3644 self.assertTrue(D.__weakref__) 3645 D() 3646 3647 def test_dataclass_derived_generic_from_slotted_base(self): 3648 T = typing.TypeVar('T') 3649 3650 class WithSlots: 3651 __slots__ = ('a', 'b') 3652 3653 @dataclass(slots=True, weakref_slot=True) 3654 class E1(WithSlots, Generic[T]): 3655 pass 3656 self.assertEqual(E1.__slots__, ('__weakref__',)) 3657 self.assertTrue(E1.__weakref__) 3658 E1() 3659 @dataclass(slots=True, weakref_slot=True) 3660 class E2(Generic[T], WithSlots): 3661 pass 3662 self.assertEqual(E2.__slots__, ('__weakref__',)) 3663 self.assertTrue(E2.__weakref__) 3664 E2() 3665 3666 @dataclass(slots=True, weakref_slot=True) 3667 class F[T2](WithSlots): 3668 pass 3669 self.assertEqual(F.__slots__, ('__weakref__',)) 3670 self.assertTrue(F.__weakref__) 3671 F() 3672 3673 def test_dataclass_derived_generic_from_slotted_base(self): 3674 T = typing.TypeVar('T') 3675 3676 class WithWeakrefSlot: 3677 __slots__ = ('__weakref__',) 3678 3679 @dataclass(slots=True, weakref_slot=True) 3680 class G1(WithWeakrefSlot, Generic[T]): 3681 pass 3682 self.assertEqual(G1.__slots__, ()) 3683 self.assertTrue(G1.__weakref__) 3684 G1() 3685 @dataclass(slots=True, weakref_slot=True) 3686 class G2(Generic[T], WithWeakrefSlot): 3687 pass 3688 self.assertEqual(G2.__slots__, ()) 3689 self.assertTrue(G2.__weakref__) 3690 G2() 3691 3692 @dataclass(slots=True, weakref_slot=True) 3693 class H[T2](WithWeakrefSlot): 3694 pass 3695 self.assertEqual(H.__slots__, ()) 3696 self.assertTrue(H.__weakref__) 3697 H() 3698 3699 def test_dataclass_slot_dict(self): 3700 class WithDictSlot: 3701 __slots__ = ('__dict__',) 3702 3703 @dataclass(slots=True) 3704 class A(WithDictSlot): ... 3705 3706 self.assertEqual(A.__slots__, ()) 3707 self.assertEqual(A().__dict__, {}) 3708 A() 3709 3710 @support.cpython_only 3711 def test_dataclass_slot_dict_ctype(self): 3712 # https://github.com/python/cpython/issues/123935 3713 from test.support import import_helper 3714 # Skips test if `_testcapi` is not present: 3715 _testcapi = import_helper.import_module('_testcapi') 3716 3717 @dataclass(slots=True) 3718 class HasDictOffset(_testcapi.HeapCTypeWithDict): 3719 __dict__: dict = {} 3720 self.assertNotEqual(_testcapi.HeapCTypeWithDict.__dictoffset__, 0) 3721 self.assertEqual(HasDictOffset.__slots__, ()) 3722 3723 @dataclass(slots=True) 3724 class DoesNotHaveDictOffset(_testcapi.HeapCTypeWithWeakref): 3725 __dict__: dict = {} 3726 self.assertEqual(_testcapi.HeapCTypeWithWeakref.__dictoffset__, 0) 3727 self.assertEqual(DoesNotHaveDictOffset.__slots__, ('__dict__',)) 3728 3729 @support.cpython_only 3730 def test_slots_with_wrong_init_subclass(self): 3731 # TODO: This test is for a kinda-buggy behavior. 3732 # Ideally, it should be fixed and `__init_subclass__` 3733 # should be fully supported in the future versions. 3734 # See https://github.com/python/cpython/issues/91126 3735 class WrongSuper: 3736 def __init_subclass__(cls, arg): 3737 pass 3738 3739 with self.assertRaisesRegex( 3740 TypeError, 3741 "missing 1 required positional argument: 'arg'", 3742 ): 3743 @dataclass(slots=True) 3744 class WithWrongSuper(WrongSuper, arg=1): 3745 pass 3746 3747 class CorrectSuper: 3748 args = [] 3749 def __init_subclass__(cls, arg="default"): 3750 cls.args.append(arg) 3751 3752 @dataclass(slots=True) 3753 class WithCorrectSuper(CorrectSuper): 3754 pass 3755 3756 # __init_subclass__ is called twice: once for `WithCorrectSuper` 3757 # and once for `WithCorrectSuper__slots__` new class 3758 # that we create internally. 3759 self.assertEqual(CorrectSuper.args, ["default", "default"]) 3760 3761 3762class TestDescriptors(unittest.TestCase): 3763 def test_set_name(self): 3764 # See bpo-33141. 3765 3766 # Create a descriptor. 3767 class D: 3768 def __set_name__(self, owner, name): 3769 self.name = name + 'x' 3770 def __get__(self, instance, owner): 3771 if instance is not None: 3772 return 1 3773 return self 3774 3775 # This is the case of just normal descriptor behavior, no 3776 # dataclass code is involved in initializing the descriptor. 3777 @dataclass 3778 class C: 3779 c: int=D() 3780 self.assertEqual(C.c.name, 'cx') 3781 3782 # Now test with a default value and init=False, which is the 3783 # only time this is really meaningful. If not using 3784 # init=False, then the descriptor will be overwritten, anyway. 3785 @dataclass 3786 class C: 3787 c: int=field(default=D(), init=False) 3788 self.assertEqual(C.c.name, 'cx') 3789 self.assertEqual(C().c, 1) 3790 3791 def test_non_descriptor(self): 3792 # PEP 487 says __set_name__ should work on non-descriptors. 3793 # Create a descriptor. 3794 3795 class D: 3796 def __set_name__(self, owner, name): 3797 self.name = name + 'x' 3798 3799 @dataclass 3800 class C: 3801 c: int=field(default=D(), init=False) 3802 self.assertEqual(C.c.name, 'cx') 3803 3804 def test_lookup_on_instance(self): 3805 # See bpo-33175. 3806 class D: 3807 pass 3808 3809 d = D() 3810 # Create an attribute on the instance, not type. 3811 d.__set_name__ = Mock() 3812 3813 # Make sure d.__set_name__ is not called. 3814 @dataclass 3815 class C: 3816 i: int=field(default=d, init=False) 3817 3818 self.assertEqual(d.__set_name__.call_count, 0) 3819 3820 def test_lookup_on_class(self): 3821 # See bpo-33175. 3822 class D: 3823 pass 3824 D.__set_name__ = Mock() 3825 3826 # Make sure D.__set_name__ is called. 3827 @dataclass 3828 class C: 3829 i: int=field(default=D(), init=False) 3830 3831 self.assertEqual(D.__set_name__.call_count, 1) 3832 3833 def test_init_calls_set(self): 3834 class D: 3835 pass 3836 3837 D.__set__ = Mock() 3838 3839 @dataclass 3840 class C: 3841 i: D = D() 3842 3843 # Make sure D.__set__ is called. 3844 D.__set__.reset_mock() 3845 c = C(5) 3846 self.assertEqual(D.__set__.call_count, 1) 3847 3848 def test_getting_field_calls_get(self): 3849 class D: 3850 pass 3851 3852 D.__set__ = Mock() 3853 D.__get__ = Mock() 3854 3855 @dataclass 3856 class C: 3857 i: D = D() 3858 3859 c = C(5) 3860 3861 # Make sure D.__get__ is called. 3862 D.__get__.reset_mock() 3863 value = c.i 3864 self.assertEqual(D.__get__.call_count, 1) 3865 3866 def test_setting_field_calls_set(self): 3867 class D: 3868 pass 3869 3870 D.__set__ = Mock() 3871 3872 @dataclass 3873 class C: 3874 i: D = D() 3875 3876 c = C(5) 3877 3878 # Make sure D.__set__ is called. 3879 D.__set__.reset_mock() 3880 c.i = 10 3881 self.assertEqual(D.__set__.call_count, 1) 3882 3883 def test_setting_uninitialized_descriptor_field(self): 3884 class D: 3885 pass 3886 3887 D.__set__ = Mock() 3888 3889 @dataclass 3890 class C: 3891 i: D 3892 3893 # D.__set__ is not called because there's no D instance to call it on 3894 D.__set__.reset_mock() 3895 c = C(5) 3896 self.assertEqual(D.__set__.call_count, 0) 3897 3898 # D.__set__ still isn't called after setting i to an instance of D 3899 # because descriptors don't behave like that when stored as instance vars 3900 c.i = D() 3901 c.i = 5 3902 self.assertEqual(D.__set__.call_count, 0) 3903 3904 def test_default_value(self): 3905 class D: 3906 def __get__(self, instance: Any, owner: object) -> int: 3907 if instance is None: 3908 return 100 3909 3910 return instance._x 3911 3912 def __set__(self, instance: Any, value: int) -> None: 3913 instance._x = value 3914 3915 @dataclass 3916 class C: 3917 i: D = D() 3918 3919 c = C() 3920 self.assertEqual(c.i, 100) 3921 3922 c = C(5) 3923 self.assertEqual(c.i, 5) 3924 3925 def test_no_default_value(self): 3926 class D: 3927 def __get__(self, instance: Any, owner: object) -> int: 3928 if instance is None: 3929 raise AttributeError() 3930 3931 return instance._x 3932 3933 def __set__(self, instance: Any, value: int) -> None: 3934 instance._x = value 3935 3936 @dataclass 3937 class C: 3938 i: D = D() 3939 3940 with self.assertRaisesRegex(TypeError, 'missing 1 required positional argument'): 3941 c = C() 3942 3943class TestStringAnnotations(unittest.TestCase): 3944 def test_classvar(self): 3945 # Some expressions recognized as ClassVar really aren't. But 3946 # if you're using string annotations, it's not an exact 3947 # science. 3948 # These tests assume that both "import typing" and "from 3949 # typing import *" have been run in this file. 3950 for typestr in ('ClassVar[int]', 3951 'ClassVar [int]', 3952 ' ClassVar [int]', 3953 'ClassVar', 3954 ' ClassVar ', 3955 'typing.ClassVar[int]', 3956 'typing.ClassVar[str]', 3957 ' typing.ClassVar[str]', 3958 'typing .ClassVar[str]', 3959 'typing. ClassVar[str]', 3960 'typing.ClassVar [str]', 3961 'typing.ClassVar [ str]', 3962 3963 # Not syntactically valid, but these will 3964 # be treated as ClassVars. 3965 'typing.ClassVar.[int]', 3966 'typing.ClassVar+', 3967 ): 3968 with self.subTest(typestr=typestr): 3969 @dataclass 3970 class C: 3971 x: typestr 3972 3973 # x is a ClassVar, so C() takes no args. 3974 C() 3975 3976 # And it won't appear in the class's dict because it doesn't 3977 # have a default. 3978 self.assertNotIn('x', C.__dict__) 3979 3980 def test_isnt_classvar(self): 3981 for typestr in ('CV', 3982 't.ClassVar', 3983 't.ClassVar[int]', 3984 'typing..ClassVar[int]', 3985 'Classvar', 3986 'Classvar[int]', 3987 'typing.ClassVarx[int]', 3988 'typong.ClassVar[int]', 3989 'dataclasses.ClassVar[int]', 3990 'typingxClassVar[str]', 3991 ): 3992 with self.subTest(typestr=typestr): 3993 @dataclass 3994 class C: 3995 x: typestr 3996 3997 # x is not a ClassVar, so C() takes one arg. 3998 self.assertEqual(C(10).x, 10) 3999 4000 def test_initvar(self): 4001 # These tests assume that both "import dataclasses" and "from 4002 # dataclasses import *" have been run in this file. 4003 for typestr in ('InitVar[int]', 4004 'InitVar [int]' 4005 ' InitVar [int]', 4006 'InitVar', 4007 ' InitVar ', 4008 'dataclasses.InitVar[int]', 4009 'dataclasses.InitVar[str]', 4010 ' dataclasses.InitVar[str]', 4011 'dataclasses .InitVar[str]', 4012 'dataclasses. InitVar[str]', 4013 'dataclasses.InitVar [str]', 4014 'dataclasses.InitVar [ str]', 4015 4016 # Not syntactically valid, but these will 4017 # be treated as InitVars. 4018 'dataclasses.InitVar.[int]', 4019 'dataclasses.InitVar+', 4020 ): 4021 with self.subTest(typestr=typestr): 4022 @dataclass 4023 class C: 4024 x: typestr 4025 4026 # x is an InitVar, so doesn't create a member. 4027 with self.assertRaisesRegex(AttributeError, 4028 "object has no attribute 'x'"): 4029 C(1).x 4030 4031 def test_isnt_initvar(self): 4032 for typestr in ('IV', 4033 'dc.InitVar', 4034 'xdataclasses.xInitVar', 4035 'typing.xInitVar[int]', 4036 ): 4037 with self.subTest(typestr=typestr): 4038 @dataclass 4039 class C: 4040 x: typestr 4041 4042 # x is not an InitVar, so there will be a member x. 4043 self.assertEqual(C(10).x, 10) 4044 4045 def test_classvar_module_level_import(self): 4046 from test.test_dataclasses import dataclass_module_1 4047 from test.test_dataclasses import dataclass_module_1_str 4048 from test.test_dataclasses import dataclass_module_2 4049 from test.test_dataclasses import dataclass_module_2_str 4050 4051 for m in (dataclass_module_1, dataclass_module_1_str, 4052 dataclass_module_2, dataclass_module_2_str, 4053 ): 4054 with self.subTest(m=m): 4055 # There's a difference in how the ClassVars are 4056 # interpreted when using string annotations or 4057 # not. See the imported modules for details. 4058 if m.USING_STRINGS: 4059 c = m.CV(10) 4060 else: 4061 c = m.CV() 4062 self.assertEqual(c.cv0, 20) 4063 4064 4065 # There's a difference in how the InitVars are 4066 # interpreted when using string annotations or 4067 # not. See the imported modules for details. 4068 c = m.IV(0, 1, 2, 3, 4) 4069 4070 for field_name in ('iv0', 'iv1', 'iv2', 'iv3'): 4071 with self.subTest(field_name=field_name): 4072 with self.assertRaisesRegex(AttributeError, f"object has no attribute '{field_name}'"): 4073 # Since field_name is an InitVar, it's 4074 # not an instance field. 4075 getattr(c, field_name) 4076 4077 if m.USING_STRINGS: 4078 # iv4 is interpreted as a normal field. 4079 self.assertIn('not_iv4', c.__dict__) 4080 self.assertEqual(c.not_iv4, 4) 4081 else: 4082 # iv4 is interpreted as an InitVar, so it 4083 # won't exist on the instance. 4084 self.assertNotIn('not_iv4', c.__dict__) 4085 4086 def test_text_annotations(self): 4087 from test.test_dataclasses import dataclass_textanno 4088 4089 self.assertEqual( 4090 get_type_hints(dataclass_textanno.Bar), 4091 {'foo': dataclass_textanno.Foo}) 4092 self.assertEqual( 4093 get_type_hints(dataclass_textanno.Bar.__init__), 4094 {'foo': dataclass_textanno.Foo, 4095 'return': type(None)}) 4096 4097 4098ByMakeDataClass = make_dataclass('ByMakeDataClass', [('x', int)]) 4099ManualModuleMakeDataClass = make_dataclass('ManualModuleMakeDataClass', 4100 [('x', int)], 4101 module=__name__) 4102WrongNameMakeDataclass = make_dataclass('Wrong', [('x', int)]) 4103WrongModuleMakeDataclass = make_dataclass('WrongModuleMakeDataclass', 4104 [('x', int)], 4105 module='custom') 4106 4107class TestMakeDataclass(unittest.TestCase): 4108 def test_simple(self): 4109 C = make_dataclass('C', 4110 [('x', int), 4111 ('y', int, field(default=5))], 4112 namespace={'add_one': lambda self: self.x + 1}) 4113 c = C(10) 4114 self.assertEqual((c.x, c.y), (10, 5)) 4115 self.assertEqual(c.add_one(), 11) 4116 4117 4118 def test_no_mutate_namespace(self): 4119 # Make sure a provided namespace isn't mutated. 4120 ns = {} 4121 C = make_dataclass('C', 4122 [('x', int), 4123 ('y', int, field(default=5))], 4124 namespace=ns) 4125 self.assertEqual(ns, {}) 4126 4127 def test_base(self): 4128 class Base1: 4129 pass 4130 class Base2: 4131 pass 4132 C = make_dataclass('C', 4133 [('x', int)], 4134 bases=(Base1, Base2)) 4135 c = C(2) 4136 self.assertIsInstance(c, C) 4137 self.assertIsInstance(c, Base1) 4138 self.assertIsInstance(c, Base2) 4139 4140 def test_base_dataclass(self): 4141 @dataclass 4142 class Base1: 4143 x: int 4144 class Base2: 4145 pass 4146 C = make_dataclass('C', 4147 [('y', int)], 4148 bases=(Base1, Base2)) 4149 with self.assertRaisesRegex(TypeError, 'required positional'): 4150 c = C(2) 4151 c = C(1, 2) 4152 self.assertIsInstance(c, C) 4153 self.assertIsInstance(c, Base1) 4154 self.assertIsInstance(c, Base2) 4155 4156 self.assertEqual((c.x, c.y), (1, 2)) 4157 4158 def test_init_var(self): 4159 def post_init(self, y): 4160 self.x *= y 4161 4162 C = make_dataclass('C', 4163 [('x', int), 4164 ('y', InitVar[int]), 4165 ], 4166 namespace={'__post_init__': post_init}, 4167 ) 4168 c = C(2, 3) 4169 self.assertEqual(vars(c), {'x': 6}) 4170 self.assertEqual(len(fields(c)), 1) 4171 4172 def test_class_var(self): 4173 C = make_dataclass('C', 4174 [('x', int), 4175 ('y', ClassVar[int], 10), 4176 ('z', ClassVar[int], field(default=20)), 4177 ]) 4178 c = C(1) 4179 self.assertEqual(vars(c), {'x': 1}) 4180 self.assertEqual(len(fields(c)), 1) 4181 self.assertEqual(C.y, 10) 4182 self.assertEqual(C.z, 20) 4183 4184 def test_other_params(self): 4185 C = make_dataclass('C', 4186 [('x', int), 4187 ('y', ClassVar[int], 10), 4188 ('z', ClassVar[int], field(default=20)), 4189 ], 4190 init=False) 4191 # Make sure we have a repr, but no init. 4192 self.assertNotIn('__init__', vars(C)) 4193 self.assertIn('__repr__', vars(C)) 4194 4195 # Make sure random other params don't work. 4196 with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'): 4197 C = make_dataclass('C', 4198 [], 4199 xxinit=False) 4200 4201 def test_no_types(self): 4202 C = make_dataclass('Point', ['x', 'y', 'z']) 4203 c = C(1, 2, 3) 4204 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3}) 4205 self.assertEqual(C.__annotations__, {'x': 'typing.Any', 4206 'y': 'typing.Any', 4207 'z': 'typing.Any'}) 4208 4209 C = make_dataclass('Point', ['x', ('y', int), 'z']) 4210 c = C(1, 2, 3) 4211 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3}) 4212 self.assertEqual(C.__annotations__, {'x': 'typing.Any', 4213 'y': int, 4214 'z': 'typing.Any'}) 4215 4216 def test_module_attr(self): 4217 self.assertEqual(ByMakeDataClass.__module__, __name__) 4218 self.assertEqual(ByMakeDataClass(1).__module__, __name__) 4219 self.assertEqual(WrongModuleMakeDataclass.__module__, "custom") 4220 Nested = make_dataclass('Nested', []) 4221 self.assertEqual(Nested.__module__, __name__) 4222 self.assertEqual(Nested().__module__, __name__) 4223 4224 def test_pickle_support(self): 4225 for klass in [ByMakeDataClass, ManualModuleMakeDataClass]: 4226 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 4227 with self.subTest(proto=proto): 4228 self.assertEqual( 4229 pickle.loads(pickle.dumps(klass, proto)), 4230 klass, 4231 ) 4232 self.assertEqual( 4233 pickle.loads(pickle.dumps(klass(1), proto)), 4234 klass(1), 4235 ) 4236 4237 def test_cannot_be_pickled(self): 4238 for klass in [WrongNameMakeDataclass, WrongModuleMakeDataclass]: 4239 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 4240 with self.subTest(proto=proto): 4241 with self.assertRaises(pickle.PickleError): 4242 pickle.dumps(klass, proto) 4243 with self.assertRaises(pickle.PickleError): 4244 pickle.dumps(klass(1), proto) 4245 4246 def test_invalid_type_specification(self): 4247 for bad_field in [(), 4248 (1, 2, 3, 4), 4249 ]: 4250 with self.subTest(bad_field=bad_field): 4251 with self.assertRaisesRegex(TypeError, r'Invalid field: '): 4252 make_dataclass('C', ['a', bad_field]) 4253 4254 # And test for things with no len(). 4255 for bad_field in [float, 4256 lambda x:x, 4257 ]: 4258 with self.subTest(bad_field=bad_field): 4259 with self.assertRaisesRegex(TypeError, r'has no len\(\)'): 4260 make_dataclass('C', ['a', bad_field]) 4261 4262 def test_duplicate_field_names(self): 4263 for field in ['a', 'ab']: 4264 with self.subTest(field=field): 4265 with self.assertRaisesRegex(TypeError, 'Field name duplicated'): 4266 make_dataclass('C', [field, 'a', field]) 4267 4268 def test_keyword_field_names(self): 4269 for field in ['for', 'async', 'await', 'as']: 4270 with self.subTest(field=field): 4271 with self.assertRaisesRegex(TypeError, 'must not be keywords'): 4272 make_dataclass('C', ['a', field]) 4273 with self.assertRaisesRegex(TypeError, 'must not be keywords'): 4274 make_dataclass('C', [field]) 4275 with self.assertRaisesRegex(TypeError, 'must not be keywords'): 4276 make_dataclass('C', [field, 'a']) 4277 4278 def test_non_identifier_field_names(self): 4279 for field in ['()', 'x,y', '*', '2@3', '', 'little johnny tables']: 4280 with self.subTest(field=field): 4281 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'): 4282 make_dataclass('C', ['a', field]) 4283 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'): 4284 make_dataclass('C', [field]) 4285 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'): 4286 make_dataclass('C', [field, 'a']) 4287 4288 def test_underscore_field_names(self): 4289 # Unlike namedtuple, it's okay if dataclass field names have 4290 # an underscore. 4291 make_dataclass('C', ['_', '_a', 'a_a', 'a_']) 4292 4293 def test_funny_class_names_names(self): 4294 # No reason to prevent weird class names, since 4295 # types.new_class allows them. 4296 for classname in ['()', 'x,y', '*', '2@3', '']: 4297 with self.subTest(classname=classname): 4298 C = make_dataclass(classname, ['a', 'b']) 4299 self.assertEqual(C.__name__, classname) 4300 4301class TestReplace(unittest.TestCase): 4302 def test(self): 4303 @dataclass(frozen=True) 4304 class C: 4305 x: int 4306 y: int 4307 4308 c = C(1, 2) 4309 c1 = replace(c, x=3) 4310 self.assertEqual(c1.x, 3) 4311 self.assertEqual(c1.y, 2) 4312 4313 def test_frozen(self): 4314 @dataclass(frozen=True) 4315 class C: 4316 x: int 4317 y: int 4318 z: int = field(init=False, default=10) 4319 t: int = field(init=False, default=100) 4320 4321 c = C(1, 2) 4322 c1 = replace(c, x=3) 4323 self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100)) 4324 self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100)) 4325 4326 4327 with self.assertRaisesRegex(TypeError, 'init=False'): 4328 replace(c, x=3, z=20, t=50) 4329 with self.assertRaisesRegex(TypeError, 'init=False'): 4330 replace(c, z=20) 4331 replace(c, x=3, z=20, t=50) 4332 4333 # Make sure the result is still frozen. 4334 with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"): 4335 c1.x = 3 4336 4337 # Make sure we can't replace an attribute that doesn't exist, 4338 # if we're also replacing one that does exist. Test this 4339 # here, because setting attributes on frozen instances is 4340 # handled slightly differently from non-frozen ones. 4341 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected " 4342 "keyword argument 'a'"): 4343 c1 = replace(c, x=20, a=5) 4344 4345 def test_invalid_field_name(self): 4346 @dataclass(frozen=True) 4347 class C: 4348 x: int 4349 y: int 4350 4351 c = C(1, 2) 4352 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected " 4353 "keyword argument 'z'"): 4354 c1 = replace(c, z=3) 4355 4356 def test_invalid_object(self): 4357 @dataclass(frozen=True) 4358 class C: 4359 x: int 4360 y: int 4361 4362 with self.assertRaisesRegex(TypeError, 'dataclass instance'): 4363 replace(C, x=3) 4364 4365 with self.assertRaisesRegex(TypeError, 'dataclass instance'): 4366 replace(0, x=3) 4367 4368 def test_no_init(self): 4369 @dataclass 4370 class C: 4371 x: int 4372 y: int = field(init=False, default=10) 4373 4374 c = C(1) 4375 c.y = 20 4376 4377 # Make sure y gets the default value. 4378 c1 = replace(c, x=5) 4379 self.assertEqual((c1.x, c1.y), (5, 10)) 4380 4381 # Trying to replace y is an error. 4382 with self.assertRaisesRegex(TypeError, 'init=False'): 4383 replace(c, x=2, y=30) 4384 4385 with self.assertRaisesRegex(TypeError, 'init=False'): 4386 replace(c, y=30) 4387 4388 def test_classvar(self): 4389 @dataclass 4390 class C: 4391 x: int 4392 y: ClassVar[int] = 1000 4393 4394 c = C(1) 4395 d = C(2) 4396 4397 self.assertIs(c.y, d.y) 4398 self.assertEqual(c.y, 1000) 4399 4400 # Trying to replace y is an error: can't replace ClassVars. 4401 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an " 4402 "unexpected keyword argument 'y'"): 4403 replace(c, y=30) 4404 4405 replace(c, x=5) 4406 4407 def test_initvar_is_specified(self): 4408 @dataclass 4409 class C: 4410 x: int 4411 y: InitVar[int] 4412 4413 def __post_init__(self, y): 4414 self.x *= y 4415 4416 c = C(1, 10) 4417 self.assertEqual(c.x, 10) 4418 with self.assertRaisesRegex(TypeError, r"InitVar 'y' must be " 4419 r"specified with replace\(\)"): 4420 replace(c, x=3) 4421 c = replace(c, x=3, y=5) 4422 self.assertEqual(c.x, 15) 4423 4424 def test_initvar_with_default_value(self): 4425 @dataclass 4426 class C: 4427 x: int 4428 y: InitVar[int] = None 4429 z: InitVar[int] = 42 4430 4431 def __post_init__(self, y, z): 4432 if y is not None: 4433 self.x += y 4434 if z is not None: 4435 self.x += z 4436 4437 c = C(x=1, y=10, z=1) 4438 self.assertEqual(replace(c), C(x=12)) 4439 self.assertEqual(replace(c, y=4), C(x=12, y=4, z=42)) 4440 self.assertEqual(replace(c, y=4, z=1), C(x=12, y=4, z=1)) 4441 4442 def test_recursive_repr(self): 4443 @dataclass 4444 class C: 4445 f: "C" 4446 4447 c = C(None) 4448 c.f = c 4449 self.assertEqual(repr(c), "TestReplace.test_recursive_repr.<locals>.C(f=...)") 4450 4451 def test_recursive_repr_two_attrs(self): 4452 @dataclass 4453 class C: 4454 f: "C" 4455 g: "C" 4456 4457 c = C(None, None) 4458 c.f = c 4459 c.g = c 4460 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_two_attrs" 4461 ".<locals>.C(f=..., g=...)") 4462 4463 def test_recursive_repr_indirection(self): 4464 @dataclass 4465 class C: 4466 f: "D" 4467 4468 @dataclass 4469 class D: 4470 f: "C" 4471 4472 c = C(None) 4473 d = D(None) 4474 c.f = d 4475 d.f = c 4476 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection" 4477 ".<locals>.C(f=TestReplace.test_recursive_repr_indirection" 4478 ".<locals>.D(f=...))") 4479 4480 def test_recursive_repr_indirection_two(self): 4481 @dataclass 4482 class C: 4483 f: "D" 4484 4485 @dataclass 4486 class D: 4487 f: "E" 4488 4489 @dataclass 4490 class E: 4491 f: "C" 4492 4493 c = C(None) 4494 d = D(None) 4495 e = E(None) 4496 c.f = d 4497 d.f = e 4498 e.f = c 4499 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection_two" 4500 ".<locals>.C(f=TestReplace.test_recursive_repr_indirection_two" 4501 ".<locals>.D(f=TestReplace.test_recursive_repr_indirection_two" 4502 ".<locals>.E(f=...)))") 4503 4504 def test_recursive_repr_misc_attrs(self): 4505 @dataclass 4506 class C: 4507 f: "C" 4508 g: int 4509 4510 c = C(None, 1) 4511 c.f = c 4512 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_misc_attrs" 4513 ".<locals>.C(f=..., g=1)") 4514 4515 ## def test_initvar(self): 4516 ## @dataclass 4517 ## class C: 4518 ## x: int 4519 ## y: InitVar[int] 4520 4521 ## c = C(1, 10) 4522 ## d = C(2, 20) 4523 4524 ## # In our case, replacing an InitVar is a no-op 4525 ## self.assertEqual(c, replace(c, y=5)) 4526 4527 ## replace(c, x=5) 4528 4529class TestAbstract(unittest.TestCase): 4530 def test_abc_implementation(self): 4531 class Ordered(abc.ABC): 4532 @abc.abstractmethod 4533 def __lt__(self, other): 4534 pass 4535 4536 @abc.abstractmethod 4537 def __le__(self, other): 4538 pass 4539 4540 @dataclass(order=True) 4541 class Date(Ordered): 4542 year: int 4543 month: 'Month' 4544 day: 'int' 4545 4546 self.assertFalse(inspect.isabstract(Date)) 4547 self.assertGreater(Date(2020,12,25), Date(2020,8,31)) 4548 4549 def test_maintain_abc(self): 4550 class A(abc.ABC): 4551 @abc.abstractmethod 4552 def foo(self): 4553 pass 4554 4555 @dataclass 4556 class Date(A): 4557 year: int 4558 month: 'Month' 4559 day: 'int' 4560 4561 self.assertTrue(inspect.isabstract(Date)) 4562 msg = "class Date without an implementation for abstract method 'foo'" 4563 self.assertRaisesRegex(TypeError, msg, Date) 4564 4565 4566class TestMatchArgs(unittest.TestCase): 4567 def test_match_args(self): 4568 @dataclass 4569 class C: 4570 a: int 4571 self.assertEqual(C(42).__match_args__, ('a',)) 4572 4573 def test_explicit_match_args(self): 4574 ma = () 4575 @dataclass 4576 class C: 4577 a: int 4578 __match_args__ = ma 4579 self.assertIs(C(42).__match_args__, ma) 4580 4581 def test_bpo_43764(self): 4582 @dataclass(repr=False, eq=False, init=False) 4583 class X: 4584 a: int 4585 b: int 4586 c: int 4587 self.assertEqual(X.__match_args__, ("a", "b", "c")) 4588 4589 def test_match_args_argument(self): 4590 @dataclass(match_args=False) 4591 class X: 4592 a: int 4593 self.assertNotIn('__match_args__', X.__dict__) 4594 4595 @dataclass(match_args=False) 4596 class Y: 4597 a: int 4598 __match_args__ = ('b',) 4599 self.assertEqual(Y.__match_args__, ('b',)) 4600 4601 @dataclass(match_args=False) 4602 class Z(Y): 4603 z: int 4604 self.assertEqual(Z.__match_args__, ('b',)) 4605 4606 # Ensure parent dataclass __match_args__ is seen, if child class 4607 # specifies match_args=False. 4608 @dataclass 4609 class A: 4610 a: int 4611 z: int 4612 @dataclass(match_args=False) 4613 class B(A): 4614 b: int 4615 self.assertEqual(B.__match_args__, ('a', 'z')) 4616 4617 def test_make_dataclasses(self): 4618 C = make_dataclass('C', [('x', int), ('y', int)]) 4619 self.assertEqual(C.__match_args__, ('x', 'y')) 4620 4621 C = make_dataclass('C', [('x', int), ('y', int)], match_args=True) 4622 self.assertEqual(C.__match_args__, ('x', 'y')) 4623 4624 C = make_dataclass('C', [('x', int), ('y', int)], match_args=False) 4625 self.assertNotIn('__match__args__', C.__dict__) 4626 4627 C = make_dataclass('C', [('x', int), ('y', int)], namespace={'__match_args__': ('z',)}) 4628 self.assertEqual(C.__match_args__, ('z',)) 4629 4630 4631class TestKeywordArgs(unittest.TestCase): 4632 def test_no_classvar_kwarg(self): 4633 msg = 'field a is a ClassVar but specifies kw_only' 4634 with self.assertRaisesRegex(TypeError, msg): 4635 @dataclass 4636 class A: 4637 a: ClassVar[int] = field(kw_only=True) 4638 4639 with self.assertRaisesRegex(TypeError, msg): 4640 @dataclass 4641 class A: 4642 a: ClassVar[int] = field(kw_only=False) 4643 4644 with self.assertRaisesRegex(TypeError, msg): 4645 @dataclass(kw_only=True) 4646 class A: 4647 a: ClassVar[int] = field(kw_only=False) 4648 4649 def test_field_marked_as_kwonly(self): 4650 ####################### 4651 # Using dataclass(kw_only=True) 4652 @dataclass(kw_only=True) 4653 class A: 4654 a: int 4655 self.assertTrue(fields(A)[0].kw_only) 4656 4657 @dataclass(kw_only=True) 4658 class A: 4659 a: int = field(kw_only=True) 4660 self.assertTrue(fields(A)[0].kw_only) 4661 4662 @dataclass(kw_only=True) 4663 class A: 4664 a: int = field(kw_only=False) 4665 self.assertFalse(fields(A)[0].kw_only) 4666 4667 ####################### 4668 # Using dataclass(kw_only=False) 4669 @dataclass(kw_only=False) 4670 class A: 4671 a: int 4672 self.assertFalse(fields(A)[0].kw_only) 4673 4674 @dataclass(kw_only=False) 4675 class A: 4676 a: int = field(kw_only=True) 4677 self.assertTrue(fields(A)[0].kw_only) 4678 4679 @dataclass(kw_only=False) 4680 class A: 4681 a: int = field(kw_only=False) 4682 self.assertFalse(fields(A)[0].kw_only) 4683 4684 ####################### 4685 # Not specifying dataclass(kw_only) 4686 @dataclass 4687 class A: 4688 a: int 4689 self.assertFalse(fields(A)[0].kw_only) 4690 4691 @dataclass 4692 class A: 4693 a: int = field(kw_only=True) 4694 self.assertTrue(fields(A)[0].kw_only) 4695 4696 @dataclass 4697 class A: 4698 a: int = field(kw_only=False) 4699 self.assertFalse(fields(A)[0].kw_only) 4700 4701 def test_match_args(self): 4702 # kw fields don't show up in __match_args__. 4703 @dataclass(kw_only=True) 4704 class C: 4705 a: int 4706 self.assertEqual(C(a=42).__match_args__, ()) 4707 4708 @dataclass 4709 class C: 4710 a: int 4711 b: int = field(kw_only=True) 4712 self.assertEqual(C(42, b=10).__match_args__, ('a',)) 4713 4714 def test_KW_ONLY(self): 4715 @dataclass 4716 class A: 4717 a: int 4718 _: KW_ONLY 4719 b: int 4720 c: int 4721 A(3, c=5, b=4) 4722 msg = "takes 2 positional arguments but 4 were given" 4723 with self.assertRaisesRegex(TypeError, msg): 4724 A(3, 4, 5) 4725 4726 4727 @dataclass(kw_only=True) 4728 class B: 4729 a: int 4730 _: KW_ONLY 4731 b: int 4732 c: int 4733 B(a=3, b=4, c=5) 4734 msg = "takes 1 positional argument but 4 were given" 4735 with self.assertRaisesRegex(TypeError, msg): 4736 B(3, 4, 5) 4737 4738 # Explicitly make a field that follows KW_ONLY be non-keyword-only. 4739 @dataclass 4740 class C: 4741 a: int 4742 _: KW_ONLY 4743 b: int 4744 c: int = field(kw_only=False) 4745 c = C(1, 2, b=3) 4746 self.assertEqual(c.a, 1) 4747 self.assertEqual(c.b, 3) 4748 self.assertEqual(c.c, 2) 4749 c = C(1, b=3, c=2) 4750 self.assertEqual(c.a, 1) 4751 self.assertEqual(c.b, 3) 4752 self.assertEqual(c.c, 2) 4753 c = C(1, b=3, c=2) 4754 self.assertEqual(c.a, 1) 4755 self.assertEqual(c.b, 3) 4756 self.assertEqual(c.c, 2) 4757 c = C(c=2, b=3, a=1) 4758 self.assertEqual(c.a, 1) 4759 self.assertEqual(c.b, 3) 4760 self.assertEqual(c.c, 2) 4761 4762 def test_KW_ONLY_as_string(self): 4763 @dataclass 4764 class A: 4765 a: int 4766 _: 'dataclasses.KW_ONLY' 4767 b: int 4768 c: int 4769 A(3, c=5, b=4) 4770 msg = "takes 2 positional arguments but 4 were given" 4771 with self.assertRaisesRegex(TypeError, msg): 4772 A(3, 4, 5) 4773 4774 def test_KW_ONLY_twice(self): 4775 msg = "'Y' is KW_ONLY, but KW_ONLY has already been specified" 4776 4777 with self.assertRaisesRegex(TypeError, msg): 4778 @dataclass 4779 class A: 4780 a: int 4781 X: KW_ONLY 4782 Y: KW_ONLY 4783 b: int 4784 c: int 4785 4786 with self.assertRaisesRegex(TypeError, msg): 4787 @dataclass 4788 class A: 4789 a: int 4790 X: KW_ONLY 4791 b: int 4792 Y: KW_ONLY 4793 c: int 4794 4795 with self.assertRaisesRegex(TypeError, msg): 4796 @dataclass 4797 class A: 4798 a: int 4799 X: KW_ONLY 4800 b: int 4801 c: int 4802 Y: KW_ONLY 4803 4804 # But this usage is okay, since it's not using KW_ONLY. 4805 @dataclass 4806 class A: 4807 a: int 4808 _: KW_ONLY 4809 b: int 4810 c: int = field(kw_only=True) 4811 4812 # And if inheriting, it's okay. 4813 @dataclass 4814 class A: 4815 a: int 4816 _: KW_ONLY 4817 b: int 4818 c: int 4819 @dataclass 4820 class B(A): 4821 _: KW_ONLY 4822 d: int 4823 4824 # Make sure the error is raised in a derived class. 4825 with self.assertRaisesRegex(TypeError, msg): 4826 @dataclass 4827 class A: 4828 a: int 4829 _: KW_ONLY 4830 b: int 4831 c: int 4832 @dataclass 4833 class B(A): 4834 X: KW_ONLY 4835 d: int 4836 Y: KW_ONLY 4837 4838 4839 def test_post_init(self): 4840 @dataclass 4841 class A: 4842 a: int 4843 _: KW_ONLY 4844 b: InitVar[int] 4845 c: int 4846 d: InitVar[int] 4847 def __post_init__(self, b, d): 4848 raise CustomError(f'{b=} {d=}') 4849 with self.assertRaisesRegex(CustomError, 'b=3 d=4'): 4850 A(1, c=2, b=3, d=4) 4851 4852 @dataclass 4853 class B: 4854 a: int 4855 _: KW_ONLY 4856 b: InitVar[int] 4857 c: int 4858 d: InitVar[int] 4859 def __post_init__(self, b, d): 4860 self.a = b 4861 self.c = d 4862 b = B(1, c=2, b=3, d=4) 4863 self.assertEqual(asdict(b), {'a': 3, 'c': 4}) 4864 4865 def test_defaults(self): 4866 # For kwargs, make sure we can have defaults after non-defaults. 4867 @dataclass 4868 class A: 4869 a: int = 0 4870 _: KW_ONLY 4871 b: int 4872 c: int = 1 4873 d: int 4874 4875 a = A(d=4, b=3) 4876 self.assertEqual(a.a, 0) 4877 self.assertEqual(a.b, 3) 4878 self.assertEqual(a.c, 1) 4879 self.assertEqual(a.d, 4) 4880 4881 # Make sure we still check for non-kwarg non-defaults not following 4882 # defaults. 4883 err_regex = "non-default argument 'z' follows default argument 'a'" 4884 with self.assertRaisesRegex(TypeError, err_regex): 4885 @dataclass 4886 class A: 4887 a: int = 0 4888 z: int 4889 _: KW_ONLY 4890 b: int 4891 c: int = 1 4892 d: int 4893 4894 def test_make_dataclass(self): 4895 A = make_dataclass("A", ['a'], kw_only=True) 4896 self.assertTrue(fields(A)[0].kw_only) 4897 4898 B = make_dataclass("B", 4899 ['a', ('b', int, field(kw_only=False))], 4900 kw_only=True) 4901 self.assertTrue(fields(B)[0].kw_only) 4902 self.assertFalse(fields(B)[1].kw_only) 4903 4904 4905if __name__ == '__main__': 4906 unittest.main() 4907