1# Deliberately use "from dataclasses import *". Every name in __all__ 2# is tested, so they all must be present. This is a way to catch 3# missing ones. 4 5from dataclasses import * 6 7import pickle 8import inspect 9import builtins 10import unittest 11from unittest.mock import Mock 12from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional 13from typing import get_type_hints 14from collections import deque, OrderedDict, namedtuple 15from functools import total_ordering 16 17import typing # Needed for the string "typing.ClassVar[int]" to work as an annotation. 18import dataclasses # Needed for the string "dataclasses.InitVar[int]" to work as an annotation. 19 20# Just any custom exception we can catch. 21class CustomError(Exception): pass 22 23class TestCase(unittest.TestCase): 24 def test_no_fields(self): 25 @dataclass 26 class C: 27 pass 28 29 o = C() 30 self.assertEqual(len(fields(C)), 0) 31 32 def test_no_fields_but_member_variable(self): 33 @dataclass 34 class C: 35 i = 0 36 37 o = C() 38 self.assertEqual(len(fields(C)), 0) 39 40 def test_one_field_no_default(self): 41 @dataclass 42 class C: 43 x: int 44 45 o = C(42) 46 self.assertEqual(o.x, 42) 47 48 def test_field_default_default_factory_error(self): 49 msg = "cannot specify both default and default_factory" 50 with self.assertRaisesRegex(ValueError, msg): 51 @dataclass 52 class C: 53 x: int = field(default=1, default_factory=int) 54 55 def test_field_repr(self): 56 int_field = field(default=1, init=True, repr=False) 57 int_field.name = "id" 58 repr_output = repr(int_field) 59 expected_output = "Field(name='id',type=None," \ 60 f"default=1,default_factory={MISSING!r}," \ 61 "init=True,repr=False,hash=None," \ 62 "compare=True,metadata=mappingproxy({})," \ 63 "_field_type=None)" 64 65 self.assertEqual(repr_output, expected_output) 66 67 def test_named_init_params(self): 68 @dataclass 69 class C: 70 x: int 71 72 o = C(x=32) 73 self.assertEqual(o.x, 32) 74 75 def test_two_fields_one_default(self): 76 @dataclass 77 class C: 78 x: int 79 y: int = 0 80 81 o = C(3) 82 self.assertEqual((o.x, o.y), (3, 0)) 83 84 # Non-defaults following defaults. 85 with self.assertRaisesRegex(TypeError, 86 "non-default argument 'y' follows " 87 "default argument"): 88 @dataclass 89 class C: 90 x: int = 0 91 y: int 92 93 # A derived class adds a non-default field after a default one. 94 with self.assertRaisesRegex(TypeError, 95 "non-default argument 'y' follows " 96 "default argument"): 97 @dataclass 98 class B: 99 x: int = 0 100 101 @dataclass 102 class C(B): 103 y: int 104 105 # Override a base class field and add a default to 106 # a field which didn't use to have a default. 107 with self.assertRaisesRegex(TypeError, 108 "non-default argument 'y' follows " 109 "default argument"): 110 @dataclass 111 class B: 112 x: int 113 y: int 114 115 @dataclass 116 class C(B): 117 x: int = 0 118 119 def test_overwrite_hash(self): 120 # Test that declaring this class isn't an error. It should 121 # use the user-provided __hash__. 122 @dataclass(frozen=True) 123 class C: 124 x: int 125 def __hash__(self): 126 return 301 127 self.assertEqual(hash(C(100)), 301) 128 129 # Test that declaring this class isn't an error. It should 130 # use the generated __hash__. 131 @dataclass(frozen=True) 132 class C: 133 x: int 134 def __eq__(self, other): 135 return False 136 self.assertEqual(hash(C(100)), hash((100,))) 137 138 # But this one should generate an exception, because with 139 # unsafe_hash=True, it's an error to have a __hash__ defined. 140 with self.assertRaisesRegex(TypeError, 141 'Cannot overwrite attribute __hash__'): 142 @dataclass(unsafe_hash=True) 143 class C: 144 def __hash__(self): 145 pass 146 147 # Creating this class should not generate an exception, 148 # because even though __hash__ exists before @dataclass is 149 # called, (due to __eq__ being defined), since it's None 150 # that's okay. 151 @dataclass(unsafe_hash=True) 152 class C: 153 x: int 154 def __eq__(self): 155 pass 156 # The generated hash function works as we'd expect. 157 self.assertEqual(hash(C(10)), hash((10,))) 158 159 # Creating this class should generate an exception, because 160 # __hash__ exists and is not None, which it would be if it 161 # had been auto-generated due to __eq__ being defined. 162 with self.assertRaisesRegex(TypeError, 163 'Cannot overwrite attribute __hash__'): 164 @dataclass(unsafe_hash=True) 165 class C: 166 x: int 167 def __eq__(self): 168 pass 169 def __hash__(self): 170 pass 171 172 def test_overwrite_fields_in_derived_class(self): 173 # Note that x from C1 replaces x in Base, but the order remains 174 # the same as defined in Base. 175 @dataclass 176 class Base: 177 x: Any = 15.0 178 y: int = 0 179 180 @dataclass 181 class C1(Base): 182 z: int = 10 183 x: int = 15 184 185 o = Base() 186 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)') 187 188 o = C1() 189 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)') 190 191 o = C1(x=5) 192 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)') 193 194 def test_field_named_self(self): 195 @dataclass 196 class C: 197 self: str 198 c=C('foo') 199 self.assertEqual(c.self, 'foo') 200 201 # Make sure the first parameter is not named 'self'. 202 sig = inspect.signature(C.__init__) 203 first = next(iter(sig.parameters)) 204 self.assertNotEqual('self', first) 205 206 # But we do use 'self' if no field named self. 207 @dataclass 208 class C: 209 selfx: str 210 211 # Make sure the first parameter is named 'self'. 212 sig = inspect.signature(C.__init__) 213 first = next(iter(sig.parameters)) 214 self.assertEqual('self', first) 215 216 def test_field_named_object(self): 217 @dataclass 218 class C: 219 object: str 220 c = C('foo') 221 self.assertEqual(c.object, 'foo') 222 223 def test_field_named_object_frozen(self): 224 @dataclass(frozen=True) 225 class C: 226 object: str 227 c = C('foo') 228 self.assertEqual(c.object, 'foo') 229 230 def test_field_named_like_builtin(self): 231 # Attribute names can shadow built-in names 232 # since code generation is used. 233 # Ensure that this is not happening. 234 exclusions = {'None', 'True', 'False'} 235 builtins_names = sorted( 236 b for b in builtins.__dict__.keys() 237 if not b.startswith('__') and b not in exclusions 238 ) 239 attributes = [(name, str) for name in builtins_names] 240 C = make_dataclass('C', attributes) 241 242 c = C(*[name for name in builtins_names]) 243 244 for name in builtins_names: 245 self.assertEqual(getattr(c, name), name) 246 247 def test_field_named_like_builtin_frozen(self): 248 # Attribute names can shadow built-in names 249 # since code generation is used. 250 # Ensure that this is not happening 251 # for frozen data classes. 252 exclusions = {'None', 'True', 'False'} 253 builtins_names = sorted( 254 b for b in builtins.__dict__.keys() 255 if not b.startswith('__') and b not in exclusions 256 ) 257 attributes = [(name, str) for name in builtins_names] 258 C = make_dataclass('C', attributes, frozen=True) 259 260 c = C(*[name for name in builtins_names]) 261 262 for name in builtins_names: 263 self.assertEqual(getattr(c, name), name) 264 265 def test_0_field_compare(self): 266 # Ensure that order=False is the default. 267 @dataclass 268 class C0: 269 pass 270 271 @dataclass(order=False) 272 class C1: 273 pass 274 275 for cls in [C0, C1]: 276 with self.subTest(cls=cls): 277 self.assertEqual(cls(), cls()) 278 for idx, fn in enumerate([lambda a, b: a < b, 279 lambda a, b: a <= b, 280 lambda a, b: a > b, 281 lambda a, b: a >= b]): 282 with self.subTest(idx=idx): 283 with self.assertRaisesRegex(TypeError, 284 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"): 285 fn(cls(), cls()) 286 287 @dataclass(order=True) 288 class C: 289 pass 290 self.assertLessEqual(C(), C()) 291 self.assertGreaterEqual(C(), C()) 292 293 def test_1_field_compare(self): 294 # Ensure that order=False is the default. 295 @dataclass 296 class C0: 297 x: int 298 299 @dataclass(order=False) 300 class C1: 301 x: int 302 303 for cls in [C0, C1]: 304 with self.subTest(cls=cls): 305 self.assertEqual(cls(1), cls(1)) 306 self.assertNotEqual(cls(0), cls(1)) 307 for idx, fn in enumerate([lambda a, b: a < b, 308 lambda a, b: a <= b, 309 lambda a, b: a > b, 310 lambda a, b: a >= b]): 311 with self.subTest(idx=idx): 312 with self.assertRaisesRegex(TypeError, 313 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"): 314 fn(cls(0), cls(0)) 315 316 @dataclass(order=True) 317 class C: 318 x: int 319 self.assertLess(C(0), C(1)) 320 self.assertLessEqual(C(0), C(1)) 321 self.assertLessEqual(C(1), C(1)) 322 self.assertGreater(C(1), C(0)) 323 self.assertGreaterEqual(C(1), C(0)) 324 self.assertGreaterEqual(C(1), C(1)) 325 326 def test_simple_compare(self): 327 # Ensure that order=False is the default. 328 @dataclass 329 class C0: 330 x: int 331 y: int 332 333 @dataclass(order=False) 334 class C1: 335 x: int 336 y: int 337 338 for cls in [C0, C1]: 339 with self.subTest(cls=cls): 340 self.assertEqual(cls(0, 0), cls(0, 0)) 341 self.assertEqual(cls(1, 2), cls(1, 2)) 342 self.assertNotEqual(cls(1, 0), cls(0, 0)) 343 self.assertNotEqual(cls(1, 0), cls(1, 1)) 344 for idx, fn in enumerate([lambda a, b: a < b, 345 lambda a, b: a <= b, 346 lambda a, b: a > b, 347 lambda a, b: a >= b]): 348 with self.subTest(idx=idx): 349 with self.assertRaisesRegex(TypeError, 350 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"): 351 fn(cls(0, 0), cls(0, 0)) 352 353 @dataclass(order=True) 354 class C: 355 x: int 356 y: int 357 358 for idx, fn in enumerate([lambda a, b: a == b, 359 lambda a, b: a <= b, 360 lambda a, b: a >= b]): 361 with self.subTest(idx=idx): 362 self.assertTrue(fn(C(0, 0), C(0, 0))) 363 364 for idx, fn in enumerate([lambda a, b: a < b, 365 lambda a, b: a <= b, 366 lambda a, b: a != b]): 367 with self.subTest(idx=idx): 368 self.assertTrue(fn(C(0, 0), C(0, 1))) 369 self.assertTrue(fn(C(0, 1), C(1, 0))) 370 self.assertTrue(fn(C(1, 0), C(1, 1))) 371 372 for idx, fn in enumerate([lambda a, b: a > b, 373 lambda a, b: a >= b, 374 lambda a, b: a != b]): 375 with self.subTest(idx=idx): 376 self.assertTrue(fn(C(0, 1), C(0, 0))) 377 self.assertTrue(fn(C(1, 0), C(0, 1))) 378 self.assertTrue(fn(C(1, 1), C(1, 0))) 379 380 def test_compare_subclasses(self): 381 # Comparisons fail for subclasses, even if no fields 382 # are added. 383 @dataclass 384 class B: 385 i: int 386 387 @dataclass 388 class C(B): 389 pass 390 391 for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False), 392 (lambda a, b: a != b, True)]): 393 with self.subTest(idx=idx): 394 self.assertEqual(fn(B(0), C(0)), expected) 395 396 for idx, fn in enumerate([lambda a, b: a < b, 397 lambda a, b: a <= b, 398 lambda a, b: a > b, 399 lambda a, b: a >= b]): 400 with self.subTest(idx=idx): 401 with self.assertRaisesRegex(TypeError, 402 "not supported between instances of 'B' and 'C'"): 403 fn(B(0), C(0)) 404 405 def test_eq_order(self): 406 # Test combining eq and order. 407 for (eq, order, result ) in [ 408 (False, False, 'neither'), 409 (False, True, 'exception'), 410 (True, False, 'eq_only'), 411 (True, True, 'both'), 412 ]: 413 with self.subTest(eq=eq, order=order): 414 if result == 'exception': 415 with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'): 416 @dataclass(eq=eq, order=order) 417 class C: 418 pass 419 else: 420 @dataclass(eq=eq, order=order) 421 class C: 422 pass 423 424 if result == 'neither': 425 self.assertNotIn('__eq__', C.__dict__) 426 self.assertNotIn('__lt__', C.__dict__) 427 self.assertNotIn('__le__', C.__dict__) 428 self.assertNotIn('__gt__', C.__dict__) 429 self.assertNotIn('__ge__', C.__dict__) 430 elif result == 'both': 431 self.assertIn('__eq__', C.__dict__) 432 self.assertIn('__lt__', C.__dict__) 433 self.assertIn('__le__', C.__dict__) 434 self.assertIn('__gt__', C.__dict__) 435 self.assertIn('__ge__', C.__dict__) 436 elif result == 'eq_only': 437 self.assertIn('__eq__', C.__dict__) 438 self.assertNotIn('__lt__', C.__dict__) 439 self.assertNotIn('__le__', C.__dict__) 440 self.assertNotIn('__gt__', C.__dict__) 441 self.assertNotIn('__ge__', C.__dict__) 442 else: 443 assert False, f'unknown result {result!r}' 444 445 def test_field_no_default(self): 446 @dataclass 447 class C: 448 x: int = field() 449 450 self.assertEqual(C(5).x, 5) 451 452 with self.assertRaisesRegex(TypeError, 453 r"__init__\(\) missing 1 required " 454 "positional argument: 'x'"): 455 C() 456 457 def test_field_default(self): 458 default = object() 459 @dataclass 460 class C: 461 x: object = field(default=default) 462 463 self.assertIs(C.x, default) 464 c = C(10) 465 self.assertEqual(c.x, 10) 466 467 # If we delete the instance attribute, we should then see the 468 # class attribute. 469 del c.x 470 self.assertIs(c.x, default) 471 472 self.assertIs(C().x, default) 473 474 def test_not_in_repr(self): 475 @dataclass 476 class C: 477 x: int = field(repr=False) 478 with self.assertRaises(TypeError): 479 C() 480 c = C(10) 481 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()') 482 483 @dataclass 484 class C: 485 x: int = field(repr=False) 486 y: int 487 c = C(10, 20) 488 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)') 489 490 def test_not_in_compare(self): 491 @dataclass 492 class C: 493 x: int = 0 494 y: int = field(compare=False, default=4) 495 496 self.assertEqual(C(), C(0, 20)) 497 self.assertEqual(C(1, 10), C(1, 20)) 498 self.assertNotEqual(C(3), C(4, 10)) 499 self.assertNotEqual(C(3, 10), C(4, 10)) 500 501 def test_hash_field_rules(self): 502 # Test all 6 cases of: 503 # hash=True/False/None 504 # compare=True/False 505 for (hash_, compare, result ) in [ 506 (True, False, 'field' ), 507 (True, True, 'field' ), 508 (False, False, 'absent'), 509 (False, True, 'absent'), 510 (None, False, 'absent'), 511 (None, True, 'field' ), 512 ]: 513 with self.subTest(hash=hash_, compare=compare): 514 @dataclass(unsafe_hash=True) 515 class C: 516 x: int = field(compare=compare, hash=hash_, default=5) 517 518 if result == 'field': 519 # __hash__ contains the field. 520 self.assertEqual(hash(C(5)), hash((5,))) 521 elif result == 'absent': 522 # The field is not present in the hash. 523 self.assertEqual(hash(C(5)), hash(())) 524 else: 525 assert False, f'unknown result {result!r}' 526 527 def test_init_false_no_default(self): 528 # If init=False and no default value, then the field won't be 529 # present in the instance. 530 @dataclass 531 class C: 532 x: int = field(init=False) 533 534 self.assertNotIn('x', C().__dict__) 535 536 @dataclass 537 class C: 538 x: int 539 y: int = 0 540 z: int = field(init=False) 541 t: int = 10 542 543 self.assertNotIn('z', C(0).__dict__) 544 self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0}) 545 546 def test_class_marker(self): 547 @dataclass 548 class C: 549 x: int 550 y: str = field(init=False, default=None) 551 z: str = field(repr=False) 552 553 the_fields = fields(C) 554 # the_fields is a tuple of 3 items, each value 555 # is in __annotations__. 556 self.assertIsInstance(the_fields, tuple) 557 for f in the_fields: 558 self.assertIs(type(f), Field) 559 self.assertIn(f.name, C.__annotations__) 560 561 self.assertEqual(len(the_fields), 3) 562 563 self.assertEqual(the_fields[0].name, 'x') 564 self.assertEqual(the_fields[0].type, int) 565 self.assertFalse(hasattr(C, 'x')) 566 self.assertTrue (the_fields[0].init) 567 self.assertTrue (the_fields[0].repr) 568 self.assertEqual(the_fields[1].name, 'y') 569 self.assertEqual(the_fields[1].type, str) 570 self.assertIsNone(getattr(C, 'y')) 571 self.assertFalse(the_fields[1].init) 572 self.assertTrue (the_fields[1].repr) 573 self.assertEqual(the_fields[2].name, 'z') 574 self.assertEqual(the_fields[2].type, str) 575 self.assertFalse(hasattr(C, 'z')) 576 self.assertTrue (the_fields[2].init) 577 self.assertFalse(the_fields[2].repr) 578 579 def test_field_order(self): 580 @dataclass 581 class B: 582 a: str = 'B:a' 583 b: str = 'B:b' 584 c: str = 'B:c' 585 586 @dataclass 587 class C(B): 588 b: str = 'C:b' 589 590 self.assertEqual([(f.name, f.default) for f in fields(C)], 591 [('a', 'B:a'), 592 ('b', 'C:b'), 593 ('c', 'B:c')]) 594 595 @dataclass 596 class D(B): 597 c: str = 'D:c' 598 599 self.assertEqual([(f.name, f.default) for f in fields(D)], 600 [('a', 'B:a'), 601 ('b', 'B:b'), 602 ('c', 'D:c')]) 603 604 @dataclass 605 class E(D): 606 a: str = 'E:a' 607 d: str = 'E:d' 608 609 self.assertEqual([(f.name, f.default) for f in fields(E)], 610 [('a', 'E:a'), 611 ('b', 'B:b'), 612 ('c', 'D:c'), 613 ('d', 'E:d')]) 614 615 def test_class_attrs(self): 616 # We only have a class attribute if a default value is 617 # specified, either directly or via a field with a default. 618 default = object() 619 @dataclass 620 class C: 621 x: int 622 y: int = field(repr=False) 623 z: object = default 624 t: int = field(default=100) 625 626 self.assertFalse(hasattr(C, 'x')) 627 self.assertFalse(hasattr(C, 'y')) 628 self.assertIs (C.z, default) 629 self.assertEqual(C.t, 100) 630 631 def test_disallowed_mutable_defaults(self): 632 # For the known types, don't allow mutable default values. 633 for typ, empty, non_empty in [(list, [], [1]), 634 (dict, {}, {0:1}), 635 (set, set(), set([1])), 636 ]: 637 with self.subTest(typ=typ): 638 # Can't use a zero-length value. 639 with self.assertRaisesRegex(ValueError, 640 f'mutable default {typ} for field ' 641 'x is not allowed'): 642 @dataclass 643 class Point: 644 x: typ = empty 645 646 647 # Nor a non-zero-length value 648 with self.assertRaisesRegex(ValueError, 649 f'mutable default {typ} for field ' 650 'y is not allowed'): 651 @dataclass 652 class Point: 653 y: typ = non_empty 654 655 # Check subtypes also fail. 656 class Subclass(typ): pass 657 658 with self.assertRaisesRegex(ValueError, 659 f"mutable default .*Subclass'>" 660 ' for field z is not allowed' 661 ): 662 @dataclass 663 class Point: 664 z: typ = Subclass() 665 666 # Because this is a ClassVar, it can be mutable. 667 @dataclass 668 class C: 669 z: ClassVar[typ] = typ() 670 671 # Because this is a ClassVar, it can be mutable. 672 @dataclass 673 class C: 674 x: ClassVar[typ] = Subclass() 675 676 def test_deliberately_mutable_defaults(self): 677 # If a mutable default isn't in the known list of 678 # (list, dict, set), then it's okay. 679 class Mutable: 680 def __init__(self): 681 self.l = [] 682 683 @dataclass 684 class C: 685 x: Mutable 686 687 # These 2 instances will share this value of x. 688 lst = Mutable() 689 o1 = C(lst) 690 o2 = C(lst) 691 self.assertEqual(o1, o2) 692 o1.x.l.extend([1, 2]) 693 self.assertEqual(o1, o2) 694 self.assertEqual(o1.x.l, [1, 2]) 695 self.assertIs(o1.x, o2.x) 696 697 def test_no_options(self): 698 # Call with dataclass(). 699 @dataclass() 700 class C: 701 x: int 702 703 self.assertEqual(C(42).x, 42) 704 705 def test_not_tuple(self): 706 # Make sure we can't be compared to a tuple. 707 @dataclass 708 class Point: 709 x: int 710 y: int 711 self.assertNotEqual(Point(1, 2), (1, 2)) 712 713 # And that we can't compare to another unrelated dataclass. 714 @dataclass 715 class C: 716 x: int 717 y: int 718 self.assertNotEqual(Point(1, 3), C(1, 3)) 719 720 def test_not_other_dataclass(self): 721 # Test that some of the problems with namedtuple don't happen 722 # here. 723 @dataclass 724 class Point3D: 725 x: int 726 y: int 727 z: int 728 729 @dataclass 730 class Date: 731 year: int 732 month: int 733 day: int 734 735 self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3)) 736 self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3)) 737 738 # Make sure we can't unpack. 739 with self.assertRaisesRegex(TypeError, 'unpack'): 740 x, y, z = Point3D(4, 5, 6) 741 742 # Make sure another class with the same field names isn't 743 # equal. 744 @dataclass 745 class Point3Dv1: 746 x: int = 0 747 y: int = 0 748 z: int = 0 749 self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1()) 750 751 def test_function_annotations(self): 752 # Some dummy class and instance to use as a default. 753 class F: 754 pass 755 f = F() 756 757 def validate_class(cls): 758 # First, check __annotations__, even though they're not 759 # function annotations. 760 self.assertEqual(cls.__annotations__['i'], int) 761 self.assertEqual(cls.__annotations__['j'], str) 762 self.assertEqual(cls.__annotations__['k'], F) 763 self.assertEqual(cls.__annotations__['l'], float) 764 self.assertEqual(cls.__annotations__['z'], complex) 765 766 # Verify __init__. 767 768 signature = inspect.signature(cls.__init__) 769 # Check the return type, should be None. 770 self.assertIs(signature.return_annotation, None) 771 772 # Check each parameter. 773 params = iter(signature.parameters.values()) 774 param = next(params) 775 # This is testing an internal name, and probably shouldn't be tested. 776 self.assertEqual(param.name, 'self') 777 param = next(params) 778 self.assertEqual(param.name, 'i') 779 self.assertIs (param.annotation, int) 780 self.assertEqual(param.default, inspect.Parameter.empty) 781 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD) 782 param = next(params) 783 self.assertEqual(param.name, 'j') 784 self.assertIs (param.annotation, str) 785 self.assertEqual(param.default, inspect.Parameter.empty) 786 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD) 787 param = next(params) 788 self.assertEqual(param.name, 'k') 789 self.assertIs (param.annotation, F) 790 # Don't test for the default, since it's set to MISSING. 791 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD) 792 param = next(params) 793 self.assertEqual(param.name, 'l') 794 self.assertIs (param.annotation, float) 795 # Don't test for the default, since it's set to MISSING. 796 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD) 797 self.assertRaises(StopIteration, next, params) 798 799 800 @dataclass 801 class C: 802 i: int 803 j: str 804 k: F = f 805 l: float=field(default=None) 806 z: complex=field(default=3+4j, init=False) 807 808 validate_class(C) 809 810 # Now repeat with __hash__. 811 @dataclass(frozen=True, unsafe_hash=True) 812 class C: 813 i: int 814 j: str 815 k: F = f 816 l: float=field(default=None) 817 z: complex=field(default=3+4j, init=False) 818 819 validate_class(C) 820 821 def test_missing_default(self): 822 # Test that MISSING works the same as a default not being 823 # specified. 824 @dataclass 825 class C: 826 x: int=field(default=MISSING) 827 with self.assertRaisesRegex(TypeError, 828 r'__init__\(\) missing 1 required ' 829 'positional argument'): 830 C() 831 self.assertNotIn('x', C.__dict__) 832 833 @dataclass 834 class D: 835 x: int 836 with self.assertRaisesRegex(TypeError, 837 r'__init__\(\) missing 1 required ' 838 'positional argument'): 839 D() 840 self.assertNotIn('x', D.__dict__) 841 842 def test_missing_default_factory(self): 843 # Test that MISSING works the same as a default factory not 844 # being specified (which is really the same as a default not 845 # being specified, too). 846 @dataclass 847 class C: 848 x: int=field(default_factory=MISSING) 849 with self.assertRaisesRegex(TypeError, 850 r'__init__\(\) missing 1 required ' 851 'positional argument'): 852 C() 853 self.assertNotIn('x', C.__dict__) 854 855 @dataclass 856 class D: 857 x: int=field(default=MISSING, default_factory=MISSING) 858 with self.assertRaisesRegex(TypeError, 859 r'__init__\(\) missing 1 required ' 860 'positional argument'): 861 D() 862 self.assertNotIn('x', D.__dict__) 863 864 def test_missing_repr(self): 865 self.assertIn('MISSING_TYPE object', repr(MISSING)) 866 867 def test_dont_include_other_annotations(self): 868 @dataclass 869 class C: 870 i: int 871 def foo(self) -> int: 872 return 4 873 @property 874 def bar(self) -> int: 875 return 5 876 self.assertEqual(list(C.__annotations__), ['i']) 877 self.assertEqual(C(10).foo(), 4) 878 self.assertEqual(C(10).bar, 5) 879 self.assertEqual(C(10).i, 10) 880 881 def test_post_init(self): 882 # Just make sure it gets called 883 @dataclass 884 class C: 885 def __post_init__(self): 886 raise CustomError() 887 with self.assertRaises(CustomError): 888 C() 889 890 @dataclass 891 class C: 892 i: int = 10 893 def __post_init__(self): 894 if self.i == 10: 895 raise CustomError() 896 with self.assertRaises(CustomError): 897 C() 898 # post-init gets called, but doesn't raise. This is just 899 # checking that self is used correctly. 900 C(5) 901 902 # If there's not an __init__, then post-init won't get called. 903 @dataclass(init=False) 904 class C: 905 def __post_init__(self): 906 raise CustomError() 907 # Creating the class won't raise 908 C() 909 910 @dataclass 911 class C: 912 x: int = 0 913 def __post_init__(self): 914 self.x *= 2 915 self.assertEqual(C().x, 0) 916 self.assertEqual(C(2).x, 4) 917 918 # Make sure that if we're frozen, post-init can't set 919 # attributes. 920 @dataclass(frozen=True) 921 class C: 922 x: int = 0 923 def __post_init__(self): 924 self.x *= 2 925 with self.assertRaises(FrozenInstanceError): 926 C() 927 928 def test_post_init_super(self): 929 # Make sure super() post-init isn't called by default. 930 class B: 931 def __post_init__(self): 932 raise CustomError() 933 934 @dataclass 935 class C(B): 936 def __post_init__(self): 937 self.x = 5 938 939 self.assertEqual(C().x, 5) 940 941 # Now call super(), and it will raise. 942 @dataclass 943 class C(B): 944 def __post_init__(self): 945 super().__post_init__() 946 947 with self.assertRaises(CustomError): 948 C() 949 950 # Make sure post-init is called, even if not defined in our 951 # class. 952 @dataclass 953 class C(B): 954 pass 955 956 with self.assertRaises(CustomError): 957 C() 958 959 def test_post_init_staticmethod(self): 960 flag = False 961 @dataclass 962 class C: 963 x: int 964 y: int 965 @staticmethod 966 def __post_init__(): 967 nonlocal flag 968 flag = True 969 970 self.assertFalse(flag) 971 c = C(3, 4) 972 self.assertEqual((c.x, c.y), (3, 4)) 973 self.assertTrue(flag) 974 975 def test_post_init_classmethod(self): 976 @dataclass 977 class C: 978 flag = False 979 x: int 980 y: int 981 @classmethod 982 def __post_init__(cls): 983 cls.flag = True 984 985 self.assertFalse(C.flag) 986 c = C(3, 4) 987 self.assertEqual((c.x, c.y), (3, 4)) 988 self.assertTrue(C.flag) 989 990 def test_class_var(self): 991 # Make sure ClassVars are ignored in __init__, __repr__, etc. 992 @dataclass 993 class C: 994 x: int 995 y: int = 10 996 z: ClassVar[int] = 1000 997 w: ClassVar[int] = 2000 998 t: ClassVar[int] = 3000 999 s: ClassVar = 4000 1000 1001 c = C(5) 1002 self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)') 1003 self.assertEqual(len(fields(C)), 2) # We have 2 fields. 1004 self.assertEqual(len(C.__annotations__), 6) # And 4 ClassVars. 1005 self.assertEqual(c.z, 1000) 1006 self.assertEqual(c.w, 2000) 1007 self.assertEqual(c.t, 3000) 1008 self.assertEqual(c.s, 4000) 1009 C.z += 1 1010 self.assertEqual(c.z, 1001) 1011 c = C(20) 1012 self.assertEqual((c.x, c.y), (20, 10)) 1013 self.assertEqual(c.z, 1001) 1014 self.assertEqual(c.w, 2000) 1015 self.assertEqual(c.t, 3000) 1016 self.assertEqual(c.s, 4000) 1017 1018 def test_class_var_no_default(self): 1019 # If a ClassVar has no default value, it should not be set on the class. 1020 @dataclass 1021 class C: 1022 x: ClassVar[int] 1023 1024 self.assertNotIn('x', C.__dict__) 1025 1026 def test_class_var_default_factory(self): 1027 # It makes no sense for a ClassVar to have a default factory. When 1028 # would it be called? Call it yourself, since it's class-wide. 1029 with self.assertRaisesRegex(TypeError, 1030 'cannot have a default factory'): 1031 @dataclass 1032 class C: 1033 x: ClassVar[int] = field(default_factory=int) 1034 1035 self.assertNotIn('x', C.__dict__) 1036 1037 def test_class_var_with_default(self): 1038 # If a ClassVar has a default value, it should be set on the class. 1039 @dataclass 1040 class C: 1041 x: ClassVar[int] = 10 1042 self.assertEqual(C.x, 10) 1043 1044 @dataclass 1045 class C: 1046 x: ClassVar[int] = field(default=10) 1047 self.assertEqual(C.x, 10) 1048 1049 def test_class_var_frozen(self): 1050 # Make sure ClassVars work even if we're frozen. 1051 @dataclass(frozen=True) 1052 class C: 1053 x: int 1054 y: int = 10 1055 z: ClassVar[int] = 1000 1056 w: ClassVar[int] = 2000 1057 t: ClassVar[int] = 3000 1058 1059 c = C(5) 1060 self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)') 1061 self.assertEqual(len(fields(C)), 2) # We have 2 fields 1062 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars 1063 self.assertEqual(c.z, 1000) 1064 self.assertEqual(c.w, 2000) 1065 self.assertEqual(c.t, 3000) 1066 # We can still modify the ClassVar, it's only instances that are 1067 # frozen. 1068 C.z += 1 1069 self.assertEqual(c.z, 1001) 1070 c = C(20) 1071 self.assertEqual((c.x, c.y), (20, 10)) 1072 self.assertEqual(c.z, 1001) 1073 self.assertEqual(c.w, 2000) 1074 self.assertEqual(c.t, 3000) 1075 1076 def test_init_var_no_default(self): 1077 # If an InitVar has no default value, it should not be set on the class. 1078 @dataclass 1079 class C: 1080 x: InitVar[int] 1081 1082 self.assertNotIn('x', C.__dict__) 1083 1084 def test_init_var_default_factory(self): 1085 # It makes no sense for an InitVar to have a default factory. When 1086 # would it be called? Call it yourself, since it's class-wide. 1087 with self.assertRaisesRegex(TypeError, 1088 'cannot have a default factory'): 1089 @dataclass 1090 class C: 1091 x: InitVar[int] = field(default_factory=int) 1092 1093 self.assertNotIn('x', C.__dict__) 1094 1095 def test_init_var_with_default(self): 1096 # If an InitVar has a default value, it should be set on the class. 1097 @dataclass 1098 class C: 1099 x: InitVar[int] = 10 1100 self.assertEqual(C.x, 10) 1101 1102 @dataclass 1103 class C: 1104 x: InitVar[int] = field(default=10) 1105 self.assertEqual(C.x, 10) 1106 1107 def test_init_var(self): 1108 @dataclass 1109 class C: 1110 x: int = None 1111 init_param: InitVar[int] = None 1112 1113 def __post_init__(self, init_param): 1114 if self.x is None: 1115 self.x = init_param*2 1116 1117 c = C(init_param=10) 1118 self.assertEqual(c.x, 20) 1119 1120 def test_init_var_preserve_type(self): 1121 self.assertEqual(InitVar[int].type, int) 1122 1123 # Make sure the repr is correct. 1124 self.assertEqual(repr(InitVar[int]), 'dataclasses.InitVar[int]') 1125 self.assertEqual(repr(InitVar[List[int]]), 1126 'dataclasses.InitVar[typing.List[int]]') 1127 1128 def test_init_var_inheritance(self): 1129 # Note that this deliberately tests that a dataclass need not 1130 # have a __post_init__ function if it has an InitVar field. 1131 # It could just be used in a derived class, as shown here. 1132 @dataclass 1133 class Base: 1134 x: int 1135 init_base: InitVar[int] 1136 1137 # We can instantiate by passing the InitVar, even though 1138 # it's not used. 1139 b = Base(0, 10) 1140 self.assertEqual(vars(b), {'x': 0}) 1141 1142 @dataclass 1143 class C(Base): 1144 y: int 1145 init_derived: InitVar[int] 1146 1147 def __post_init__(self, init_base, init_derived): 1148 self.x = self.x + init_base 1149 self.y = self.y + init_derived 1150 1151 c = C(10, 11, 50, 51) 1152 self.assertEqual(vars(c), {'x': 21, 'y': 101}) 1153 1154 def test_default_factory(self): 1155 # Test a factory that returns a new list. 1156 @dataclass 1157 class C: 1158 x: int 1159 y: list = field(default_factory=list) 1160 1161 c0 = C(3) 1162 c1 = C(3) 1163 self.assertEqual(c0.x, 3) 1164 self.assertEqual(c0.y, []) 1165 self.assertEqual(c0, c1) 1166 self.assertIsNot(c0.y, c1.y) 1167 self.assertEqual(astuple(C(5, [1])), (5, [1])) 1168 1169 # Test a factory that returns a shared list. 1170 l = [] 1171 @dataclass 1172 class C: 1173 x: int 1174 y: list = field(default_factory=lambda: l) 1175 1176 c0 = C(3) 1177 c1 = C(3) 1178 self.assertEqual(c0.x, 3) 1179 self.assertEqual(c0.y, []) 1180 self.assertEqual(c0, c1) 1181 self.assertIs(c0.y, c1.y) 1182 self.assertEqual(astuple(C(5, [1])), (5, [1])) 1183 1184 # Test various other field flags. 1185 # repr 1186 @dataclass 1187 class C: 1188 x: list = field(default_factory=list, repr=False) 1189 self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()') 1190 self.assertEqual(C().x, []) 1191 1192 # hash 1193 @dataclass(unsafe_hash=True) 1194 class C: 1195 x: list = field(default_factory=list, hash=False) 1196 self.assertEqual(astuple(C()), ([],)) 1197 self.assertEqual(hash(C()), hash(())) 1198 1199 # init (see also test_default_factory_with_no_init) 1200 @dataclass 1201 class C: 1202 x: list = field(default_factory=list, init=False) 1203 self.assertEqual(astuple(C()), ([],)) 1204 1205 # compare 1206 @dataclass 1207 class C: 1208 x: list = field(default_factory=list, compare=False) 1209 self.assertEqual(C(), C([1])) 1210 1211 def test_default_factory_with_no_init(self): 1212 # We need a factory with a side effect. 1213 factory = Mock() 1214 1215 @dataclass 1216 class C: 1217 x: list = field(default_factory=factory, init=False) 1218 1219 # Make sure the default factory is called for each new instance. 1220 C().x 1221 self.assertEqual(factory.call_count, 1) 1222 C().x 1223 self.assertEqual(factory.call_count, 2) 1224 1225 def test_default_factory_not_called_if_value_given(self): 1226 # We need a factory that we can test if it's been called. 1227 factory = Mock() 1228 1229 @dataclass 1230 class C: 1231 x: int = field(default_factory=factory) 1232 1233 # Make sure that if a field has a default factory function, 1234 # it's not called if a value is specified. 1235 C().x 1236 self.assertEqual(factory.call_count, 1) 1237 self.assertEqual(C(10).x, 10) 1238 self.assertEqual(factory.call_count, 1) 1239 C().x 1240 self.assertEqual(factory.call_count, 2) 1241 1242 def test_default_factory_derived(self): 1243 # See bpo-32896. 1244 @dataclass 1245 class Foo: 1246 x: dict = field(default_factory=dict) 1247 1248 @dataclass 1249 class Bar(Foo): 1250 y: int = 1 1251 1252 self.assertEqual(Foo().x, {}) 1253 self.assertEqual(Bar().x, {}) 1254 self.assertEqual(Bar().y, 1) 1255 1256 @dataclass 1257 class Baz(Foo): 1258 pass 1259 self.assertEqual(Baz().x, {}) 1260 1261 def test_intermediate_non_dataclass(self): 1262 # Test that an intermediate class that defines 1263 # annotations does not define fields. 1264 1265 @dataclass 1266 class A: 1267 x: int 1268 1269 class B(A): 1270 y: int 1271 1272 @dataclass 1273 class C(B): 1274 z: int 1275 1276 c = C(1, 3) 1277 self.assertEqual((c.x, c.z), (1, 3)) 1278 1279 # .y was not initialized. 1280 with self.assertRaisesRegex(AttributeError, 1281 'object has no attribute'): 1282 c.y 1283 1284 # And if we again derive a non-dataclass, no fields are added. 1285 class D(C): 1286 t: int 1287 d = D(4, 5) 1288 self.assertEqual((d.x, d.z), (4, 5)) 1289 1290 def test_classvar_default_factory(self): 1291 # It's an error for a ClassVar to have a factory function. 1292 with self.assertRaisesRegex(TypeError, 1293 'cannot have a default factory'): 1294 @dataclass 1295 class C: 1296 x: ClassVar[int] = field(default_factory=int) 1297 1298 def test_is_dataclass(self): 1299 class NotDataClass: 1300 pass 1301 1302 self.assertFalse(is_dataclass(0)) 1303 self.assertFalse(is_dataclass(int)) 1304 self.assertFalse(is_dataclass(NotDataClass)) 1305 self.assertFalse(is_dataclass(NotDataClass())) 1306 1307 @dataclass 1308 class C: 1309 x: int 1310 1311 @dataclass 1312 class D: 1313 d: C 1314 e: int 1315 1316 c = C(10) 1317 d = D(c, 4) 1318 1319 self.assertTrue(is_dataclass(C)) 1320 self.assertTrue(is_dataclass(c)) 1321 self.assertFalse(is_dataclass(c.x)) 1322 self.assertTrue(is_dataclass(d.d)) 1323 self.assertFalse(is_dataclass(d.e)) 1324 1325 def test_is_dataclass_when_getattr_always_returns(self): 1326 # See bpo-37868. 1327 class A: 1328 def __getattr__(self, key): 1329 return 0 1330 self.assertFalse(is_dataclass(A)) 1331 a = A() 1332 1333 # Also test for an instance attribute. 1334 class B: 1335 pass 1336 b = B() 1337 b.__dataclass_fields__ = [] 1338 1339 for obj in a, b: 1340 with self.subTest(obj=obj): 1341 self.assertFalse(is_dataclass(obj)) 1342 1343 # Indirect tests for _is_dataclass_instance(). 1344 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'): 1345 asdict(obj) 1346 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'): 1347 astuple(obj) 1348 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'): 1349 replace(obj, x=0) 1350 1351 def test_helper_fields_with_class_instance(self): 1352 # Check that we can call fields() on either a class or instance, 1353 # and get back the same thing. 1354 @dataclass 1355 class C: 1356 x: int 1357 y: float 1358 1359 self.assertEqual(fields(C), fields(C(0, 0.0))) 1360 1361 def test_helper_fields_exception(self): 1362 # Check that TypeError is raised if not passed a dataclass or 1363 # instance. 1364 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'): 1365 fields(0) 1366 1367 class C: pass 1368 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'): 1369 fields(C) 1370 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'): 1371 fields(C()) 1372 1373 def test_helper_asdict(self): 1374 # Basic tests for asdict(), it should return a new dictionary. 1375 @dataclass 1376 class C: 1377 x: int 1378 y: int 1379 c = C(1, 2) 1380 1381 self.assertEqual(asdict(c), {'x': 1, 'y': 2}) 1382 self.assertEqual(asdict(c), asdict(c)) 1383 self.assertIsNot(asdict(c), asdict(c)) 1384 c.x = 42 1385 self.assertEqual(asdict(c), {'x': 42, 'y': 2}) 1386 self.assertIs(type(asdict(c)), dict) 1387 1388 def test_helper_asdict_raises_on_classes(self): 1389 # asdict() should raise on a class object. 1390 @dataclass 1391 class C: 1392 x: int 1393 y: int 1394 with self.assertRaisesRegex(TypeError, 'dataclass instance'): 1395 asdict(C) 1396 with self.assertRaisesRegex(TypeError, 'dataclass instance'): 1397 asdict(int) 1398 1399 def test_helper_asdict_copy_values(self): 1400 @dataclass 1401 class C: 1402 x: int 1403 y: List[int] = field(default_factory=list) 1404 initial = [] 1405 c = C(1, initial) 1406 d = asdict(c) 1407 self.assertEqual(d['y'], initial) 1408 self.assertIsNot(d['y'], initial) 1409 c = C(1) 1410 d = asdict(c) 1411 d['y'].append(1) 1412 self.assertEqual(c.y, []) 1413 1414 def test_helper_asdict_nested(self): 1415 @dataclass 1416 class UserId: 1417 token: int 1418 group: int 1419 @dataclass 1420 class User: 1421 name: str 1422 id: UserId 1423 u = User('Joe', UserId(123, 1)) 1424 d = asdict(u) 1425 self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}}) 1426 self.assertIsNot(asdict(u), asdict(u)) 1427 u.id.group = 2 1428 self.assertEqual(asdict(u), {'name': 'Joe', 1429 'id': {'token': 123, 'group': 2}}) 1430 1431 def test_helper_asdict_builtin_containers(self): 1432 @dataclass 1433 class User: 1434 name: str 1435 id: int 1436 @dataclass 1437 class GroupList: 1438 id: int 1439 users: List[User] 1440 @dataclass 1441 class GroupTuple: 1442 id: int 1443 users: Tuple[User, ...] 1444 @dataclass 1445 class GroupDict: 1446 id: int 1447 users: Dict[str, User] 1448 a = User('Alice', 1) 1449 b = User('Bob', 2) 1450 gl = GroupList(0, [a, b]) 1451 gt = GroupTuple(0, (a, b)) 1452 gd = GroupDict(0, {'first': a, 'second': b}) 1453 self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1}, 1454 {'name': 'Bob', 'id': 2}]}) 1455 self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1}, 1456 {'name': 'Bob', 'id': 2})}) 1457 self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1}, 1458 'second': {'name': 'Bob', 'id': 2}}}) 1459 1460 def test_helper_asdict_builtin_object_containers(self): 1461 @dataclass 1462 class Child: 1463 d: object 1464 1465 @dataclass 1466 class Parent: 1467 child: Child 1468 1469 self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}}) 1470 self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}}) 1471 1472 def test_helper_asdict_factory(self): 1473 @dataclass 1474 class C: 1475 x: int 1476 y: int 1477 c = C(1, 2) 1478 d = asdict(c, dict_factory=OrderedDict) 1479 self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)])) 1480 self.assertIsNot(d, asdict(c, dict_factory=OrderedDict)) 1481 c.x = 42 1482 d = asdict(c, dict_factory=OrderedDict) 1483 self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)])) 1484 self.assertIs(type(d), OrderedDict) 1485 1486 def test_helper_asdict_namedtuple(self): 1487 T = namedtuple('T', 'a b c') 1488 @dataclass 1489 class C: 1490 x: str 1491 y: T 1492 c = C('outer', T(1, C('inner', T(11, 12, 13)), 2)) 1493 1494 d = asdict(c) 1495 self.assertEqual(d, {'x': 'outer', 1496 'y': T(1, 1497 {'x': 'inner', 1498 'y': T(11, 12, 13)}, 1499 2), 1500 } 1501 ) 1502 1503 # Now with a dict_factory. OrderedDict is convenient, but 1504 # since it compares to dicts, we also need to have separate 1505 # assertIs tests. 1506 d = asdict(c, dict_factory=OrderedDict) 1507 self.assertEqual(d, {'x': 'outer', 1508 'y': T(1, 1509 {'x': 'inner', 1510 'y': T(11, 12, 13)}, 1511 2), 1512 } 1513 ) 1514 1515 # Make sure that the returned dicts are actually OrderedDicts. 1516 self.assertIs(type(d), OrderedDict) 1517 self.assertIs(type(d['y'][1]), OrderedDict) 1518 1519 def test_helper_asdict_namedtuple_key(self): 1520 # Ensure that a field that contains a dict which has a 1521 # namedtuple as a key works with asdict(). 1522 1523 @dataclass 1524 class C: 1525 f: dict 1526 T = namedtuple('T', 'a') 1527 1528 c = C({T('an a'): 0}) 1529 1530 self.assertEqual(asdict(c), {'f': {T(a='an a'): 0}}) 1531 1532 def test_helper_asdict_namedtuple_derived(self): 1533 class T(namedtuple('Tbase', 'a')): 1534 def my_a(self): 1535 return self.a 1536 1537 @dataclass 1538 class C: 1539 f: T 1540 1541 t = T(6) 1542 c = C(t) 1543 1544 d = asdict(c) 1545 self.assertEqual(d, {'f': T(a=6)}) 1546 # Make sure that t has been copied, not used directly. 1547 self.assertIsNot(d['f'], t) 1548 self.assertEqual(d['f'].my_a(), 6) 1549 1550 def test_helper_astuple(self): 1551 # Basic tests for astuple(), it should return a new tuple. 1552 @dataclass 1553 class C: 1554 x: int 1555 y: int = 0 1556 c = C(1) 1557 1558 self.assertEqual(astuple(c), (1, 0)) 1559 self.assertEqual(astuple(c), astuple(c)) 1560 self.assertIsNot(astuple(c), astuple(c)) 1561 c.y = 42 1562 self.assertEqual(astuple(c), (1, 42)) 1563 self.assertIs(type(astuple(c)), tuple) 1564 1565 def test_helper_astuple_raises_on_classes(self): 1566 # astuple() should raise on a class object. 1567 @dataclass 1568 class C: 1569 x: int 1570 y: int 1571 with self.assertRaisesRegex(TypeError, 'dataclass instance'): 1572 astuple(C) 1573 with self.assertRaisesRegex(TypeError, 'dataclass instance'): 1574 astuple(int) 1575 1576 def test_helper_astuple_copy_values(self): 1577 @dataclass 1578 class C: 1579 x: int 1580 y: List[int] = field(default_factory=list) 1581 initial = [] 1582 c = C(1, initial) 1583 t = astuple(c) 1584 self.assertEqual(t[1], initial) 1585 self.assertIsNot(t[1], initial) 1586 c = C(1) 1587 t = astuple(c) 1588 t[1].append(1) 1589 self.assertEqual(c.y, []) 1590 1591 def test_helper_astuple_nested(self): 1592 @dataclass 1593 class UserId: 1594 token: int 1595 group: int 1596 @dataclass 1597 class User: 1598 name: str 1599 id: UserId 1600 u = User('Joe', UserId(123, 1)) 1601 t = astuple(u) 1602 self.assertEqual(t, ('Joe', (123, 1))) 1603 self.assertIsNot(astuple(u), astuple(u)) 1604 u.id.group = 2 1605 self.assertEqual(astuple(u), ('Joe', (123, 2))) 1606 1607 def test_helper_astuple_builtin_containers(self): 1608 @dataclass 1609 class User: 1610 name: str 1611 id: int 1612 @dataclass 1613 class GroupList: 1614 id: int 1615 users: List[User] 1616 @dataclass 1617 class GroupTuple: 1618 id: int 1619 users: Tuple[User, ...] 1620 @dataclass 1621 class GroupDict: 1622 id: int 1623 users: Dict[str, User] 1624 a = User('Alice', 1) 1625 b = User('Bob', 2) 1626 gl = GroupList(0, [a, b]) 1627 gt = GroupTuple(0, (a, b)) 1628 gd = GroupDict(0, {'first': a, 'second': b}) 1629 self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)])) 1630 self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2)))) 1631 self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)})) 1632 1633 def test_helper_astuple_builtin_object_containers(self): 1634 @dataclass 1635 class Child: 1636 d: object 1637 1638 @dataclass 1639 class Parent: 1640 child: Child 1641 1642 self.assertEqual(astuple(Parent(Child([1]))), (([1],),)) 1643 self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),)) 1644 1645 def test_helper_astuple_factory(self): 1646 @dataclass 1647 class C: 1648 x: int 1649 y: int 1650 NT = namedtuple('NT', 'x y') 1651 def nt(lst): 1652 return NT(*lst) 1653 c = C(1, 2) 1654 t = astuple(c, tuple_factory=nt) 1655 self.assertEqual(t, NT(1, 2)) 1656 self.assertIsNot(t, astuple(c, tuple_factory=nt)) 1657 c.x = 42 1658 t = astuple(c, tuple_factory=nt) 1659 self.assertEqual(t, NT(42, 2)) 1660 self.assertIs(type(t), NT) 1661 1662 def test_helper_astuple_namedtuple(self): 1663 T = namedtuple('T', 'a b c') 1664 @dataclass 1665 class C: 1666 x: str 1667 y: T 1668 c = C('outer', T(1, C('inner', T(11, 12, 13)), 2)) 1669 1670 t = astuple(c) 1671 self.assertEqual(t, ('outer', T(1, ('inner', (11, 12, 13)), 2))) 1672 1673 # Now, using a tuple_factory. list is convenient here. 1674 t = astuple(c, tuple_factory=list) 1675 self.assertEqual(t, ['outer', T(1, ['inner', T(11, 12, 13)], 2)]) 1676 1677 def test_dynamic_class_creation(self): 1678 cls_dict = {'__annotations__': {'x': int, 'y': int}, 1679 } 1680 1681 # Create the class. 1682 cls = type('C', (), cls_dict) 1683 1684 # Make it a dataclass. 1685 cls1 = dataclass(cls) 1686 1687 self.assertEqual(cls1, cls) 1688 self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2}) 1689 1690 def test_dynamic_class_creation_using_field(self): 1691 cls_dict = {'__annotations__': {'x': int, 'y': int}, 1692 'y': field(default=5), 1693 } 1694 1695 # Create the class. 1696 cls = type('C', (), cls_dict) 1697 1698 # Make it a dataclass. 1699 cls1 = dataclass(cls) 1700 1701 self.assertEqual(cls1, cls) 1702 self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5}) 1703 1704 def test_init_in_order(self): 1705 @dataclass 1706 class C: 1707 a: int 1708 b: int = field() 1709 c: list = field(default_factory=list, init=False) 1710 d: list = field(default_factory=list) 1711 e: int = field(default=4, init=False) 1712 f: int = 4 1713 1714 calls = [] 1715 def setattr(self, name, value): 1716 calls.append((name, value)) 1717 1718 C.__setattr__ = setattr 1719 c = C(0, 1) 1720 self.assertEqual(('a', 0), calls[0]) 1721 self.assertEqual(('b', 1), calls[1]) 1722 self.assertEqual(('c', []), calls[2]) 1723 self.assertEqual(('d', []), calls[3]) 1724 self.assertNotIn(('e', 4), calls) 1725 self.assertEqual(('f', 4), calls[4]) 1726 1727 def test_items_in_dicts(self): 1728 @dataclass 1729 class C: 1730 a: int 1731 b: list = field(default_factory=list, init=False) 1732 c: list = field(default_factory=list) 1733 d: int = field(default=4, init=False) 1734 e: int = 0 1735 1736 c = C(0) 1737 # Class dict 1738 self.assertNotIn('a', C.__dict__) 1739 self.assertNotIn('b', C.__dict__) 1740 self.assertNotIn('c', C.__dict__) 1741 self.assertIn('d', C.__dict__) 1742 self.assertEqual(C.d, 4) 1743 self.assertIn('e', C.__dict__) 1744 self.assertEqual(C.e, 0) 1745 # Instance dict 1746 self.assertIn('a', c.__dict__) 1747 self.assertEqual(c.a, 0) 1748 self.assertIn('b', c.__dict__) 1749 self.assertEqual(c.b, []) 1750 self.assertIn('c', c.__dict__) 1751 self.assertEqual(c.c, []) 1752 self.assertNotIn('d', c.__dict__) 1753 self.assertIn('e', c.__dict__) 1754 self.assertEqual(c.e, 0) 1755 1756 def test_alternate_classmethod_constructor(self): 1757 # Since __post_init__ can't take params, use a classmethod 1758 # alternate constructor. This is mostly an example to show 1759 # how to use this technique. 1760 @dataclass 1761 class C: 1762 x: int 1763 @classmethod 1764 def from_file(cls, filename): 1765 # In a real example, create a new instance 1766 # and populate 'x' from contents of a file. 1767 value_in_file = 20 1768 return cls(value_in_file) 1769 1770 self.assertEqual(C.from_file('filename').x, 20) 1771 1772 def test_field_metadata_default(self): 1773 # Make sure the default metadata is read-only and of 1774 # zero length. 1775 @dataclass 1776 class C: 1777 i: int 1778 1779 self.assertFalse(fields(C)[0].metadata) 1780 self.assertEqual(len(fields(C)[0].metadata), 0) 1781 with self.assertRaisesRegex(TypeError, 1782 'does not support item assignment'): 1783 fields(C)[0].metadata['test'] = 3 1784 1785 def test_field_metadata_mapping(self): 1786 # Make sure only a mapping can be passed as metadata 1787 # zero length. 1788 with self.assertRaises(TypeError): 1789 @dataclass 1790 class C: 1791 i: int = field(metadata=0) 1792 1793 # Make sure an empty dict works. 1794 d = {} 1795 @dataclass 1796 class C: 1797 i: int = field(metadata=d) 1798 self.assertFalse(fields(C)[0].metadata) 1799 self.assertEqual(len(fields(C)[0].metadata), 0) 1800 # Update should work (see bpo-35960). 1801 d['foo'] = 1 1802 self.assertEqual(len(fields(C)[0].metadata), 1) 1803 self.assertEqual(fields(C)[0].metadata['foo'], 1) 1804 with self.assertRaisesRegex(TypeError, 1805 'does not support item assignment'): 1806 fields(C)[0].metadata['test'] = 3 1807 1808 # Make sure a non-empty dict works. 1809 d = {'test': 10, 'bar': '42', 3: 'three'} 1810 @dataclass 1811 class C: 1812 i: int = field(metadata=d) 1813 self.assertEqual(len(fields(C)[0].metadata), 3) 1814 self.assertEqual(fields(C)[0].metadata['test'], 10) 1815 self.assertEqual(fields(C)[0].metadata['bar'], '42') 1816 self.assertEqual(fields(C)[0].metadata[3], 'three') 1817 # Update should work. 1818 d['foo'] = 1 1819 self.assertEqual(len(fields(C)[0].metadata), 4) 1820 self.assertEqual(fields(C)[0].metadata['foo'], 1) 1821 with self.assertRaises(KeyError): 1822 # Non-existent key. 1823 fields(C)[0].metadata['baz'] 1824 with self.assertRaisesRegex(TypeError, 1825 'does not support item assignment'): 1826 fields(C)[0].metadata['test'] = 3 1827 1828 def test_field_metadata_custom_mapping(self): 1829 # Try a custom mapping. 1830 class SimpleNameSpace: 1831 def __init__(self, **kw): 1832 self.__dict__.update(kw) 1833 1834 def __getitem__(self, item): 1835 if item == 'xyzzy': 1836 return 'plugh' 1837 return getattr(self, item) 1838 1839 def __len__(self): 1840 return self.__dict__.__len__() 1841 1842 @dataclass 1843 class C: 1844 i: int = field(metadata=SimpleNameSpace(a=10)) 1845 1846 self.assertEqual(len(fields(C)[0].metadata), 1) 1847 self.assertEqual(fields(C)[0].metadata['a'], 10) 1848 with self.assertRaises(AttributeError): 1849 fields(C)[0].metadata['b'] 1850 # Make sure we're still talking to our custom mapping. 1851 self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh') 1852 1853 def test_generic_dataclasses(self): 1854 T = TypeVar('T') 1855 1856 @dataclass 1857 class LabeledBox(Generic[T]): 1858 content: T 1859 label: str = '<unknown>' 1860 1861 box = LabeledBox(42) 1862 self.assertEqual(box.content, 42) 1863 self.assertEqual(box.label, '<unknown>') 1864 1865 # Subscripting the resulting class should work, etc. 1866 Alias = List[LabeledBox[int]] 1867 1868 def test_generic_extending(self): 1869 S = TypeVar('S') 1870 T = TypeVar('T') 1871 1872 @dataclass 1873 class Base(Generic[T, S]): 1874 x: T 1875 y: S 1876 1877 @dataclass 1878 class DataDerived(Base[int, T]): 1879 new_field: str 1880 Alias = DataDerived[str] 1881 c = Alias(0, 'test1', 'test2') 1882 self.assertEqual(astuple(c), (0, 'test1', 'test2')) 1883 1884 class NonDataDerived(Base[int, T]): 1885 def new_method(self): 1886 return self.y 1887 Alias = NonDataDerived[float] 1888 c = Alias(10, 1.0) 1889 self.assertEqual(c.new_method(), 1.0) 1890 1891 def test_generic_dynamic(self): 1892 T = TypeVar('T') 1893 1894 @dataclass 1895 class Parent(Generic[T]): 1896 x: T 1897 Child = make_dataclass('Child', [('y', T), ('z', Optional[T], None)], 1898 bases=(Parent[int], Generic[T]), namespace={'other': 42}) 1899 self.assertIs(Child[int](1, 2).z, None) 1900 self.assertEqual(Child[int](1, 2, 3).z, 3) 1901 self.assertEqual(Child[int](1, 2, 3).other, 42) 1902 # Check that type aliases work correctly. 1903 Alias = Child[T] 1904 self.assertEqual(Alias[int](1, 2).x, 1) 1905 # Check MRO resolution. 1906 self.assertEqual(Child.__mro__, (Child, Parent, Generic, object)) 1907 1908 def test_dataclassses_pickleable(self): 1909 global P, Q, R 1910 @dataclass 1911 class P: 1912 x: int 1913 y: int = 0 1914 @dataclass 1915 class Q: 1916 x: int 1917 y: int = field(default=0, init=False) 1918 @dataclass 1919 class R: 1920 x: int 1921 y: List[int] = field(default_factory=list) 1922 q = Q(1) 1923 q.y = 2 1924 samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])] 1925 for sample in samples: 1926 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 1927 with self.subTest(sample=sample, proto=proto): 1928 new_sample = pickle.loads(pickle.dumps(sample, proto)) 1929 self.assertEqual(sample.x, new_sample.x) 1930 self.assertEqual(sample.y, new_sample.y) 1931 self.assertIsNot(sample, new_sample) 1932 new_sample.x = 42 1933 another_new_sample = pickle.loads(pickle.dumps(new_sample, proto)) 1934 self.assertEqual(new_sample.x, another_new_sample.x) 1935 self.assertEqual(sample.y, another_new_sample.y) 1936 1937 1938class TestFieldNoAnnotation(unittest.TestCase): 1939 def test_field_without_annotation(self): 1940 with self.assertRaisesRegex(TypeError, 1941 "'f' is a field but has no type annotation"): 1942 @dataclass 1943 class C: 1944 f = field() 1945 1946 def test_field_without_annotation_but_annotation_in_base(self): 1947 @dataclass 1948 class B: 1949 f: int 1950 1951 with self.assertRaisesRegex(TypeError, 1952 "'f' is a field but has no type annotation"): 1953 # This is still an error: make sure we don't pick up the 1954 # type annotation in the base class. 1955 @dataclass 1956 class C(B): 1957 f = field() 1958 1959 def test_field_without_annotation_but_annotation_in_base_not_dataclass(self): 1960 # Same test, but with the base class not a dataclass. 1961 class B: 1962 f: int 1963 1964 with self.assertRaisesRegex(TypeError, 1965 "'f' is a field but has no type annotation"): 1966 # This is still an error: make sure we don't pick up the 1967 # type annotation in the base class. 1968 @dataclass 1969 class C(B): 1970 f = field() 1971 1972 1973class TestDocString(unittest.TestCase): 1974 def assertDocStrEqual(self, a, b): 1975 # Because 3.6 and 3.7 differ in how inspect.signature work 1976 # (see bpo #32108), for the time being just compare them with 1977 # whitespace stripped. 1978 self.assertEqual(a.replace(' ', ''), b.replace(' ', '')) 1979 1980 def test_existing_docstring_not_overridden(self): 1981 @dataclass 1982 class C: 1983 """Lorem ipsum""" 1984 x: int 1985 1986 self.assertEqual(C.__doc__, "Lorem ipsum") 1987 1988 def test_docstring_no_fields(self): 1989 @dataclass 1990 class C: 1991 pass 1992 1993 self.assertDocStrEqual(C.__doc__, "C()") 1994 1995 def test_docstring_one_field(self): 1996 @dataclass 1997 class C: 1998 x: int 1999 2000 self.assertDocStrEqual(C.__doc__, "C(x:int)") 2001 2002 def test_docstring_two_fields(self): 2003 @dataclass 2004 class C: 2005 x: int 2006 y: int 2007 2008 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)") 2009 2010 def test_docstring_three_fields(self): 2011 @dataclass 2012 class C: 2013 x: int 2014 y: int 2015 z: str 2016 2017 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)") 2018 2019 def test_docstring_one_field_with_default(self): 2020 @dataclass 2021 class C: 2022 x: int = 3 2023 2024 self.assertDocStrEqual(C.__doc__, "C(x:int=3)") 2025 2026 def test_docstring_one_field_with_default_none(self): 2027 @dataclass 2028 class C: 2029 x: Union[int, type(None)] = None 2030 2031 self.assertDocStrEqual(C.__doc__, "C(x:Optional[int]=None)") 2032 2033 def test_docstring_list_field(self): 2034 @dataclass 2035 class C: 2036 x: List[int] 2037 2038 self.assertDocStrEqual(C.__doc__, "C(x:List[int])") 2039 2040 def test_docstring_list_field_with_default_factory(self): 2041 @dataclass 2042 class C: 2043 x: List[int] = field(default_factory=list) 2044 2045 self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)") 2046 2047 def test_docstring_deque_field(self): 2048 @dataclass 2049 class C: 2050 x: deque 2051 2052 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)") 2053 2054 def test_docstring_deque_field_with_default_factory(self): 2055 @dataclass 2056 class C: 2057 x: deque = field(default_factory=deque) 2058 2059 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)") 2060 2061 2062class TestInit(unittest.TestCase): 2063 def test_base_has_init(self): 2064 class B: 2065 def __init__(self): 2066 self.z = 100 2067 pass 2068 2069 # Make sure that declaring this class doesn't raise an error. 2070 # The issue is that we can't override __init__ in our class, 2071 # but it should be okay to add __init__ to us if our base has 2072 # an __init__. 2073 @dataclass 2074 class C(B): 2075 x: int = 0 2076 c = C(10) 2077 self.assertEqual(c.x, 10) 2078 self.assertNotIn('z', vars(c)) 2079 2080 # Make sure that if we don't add an init, the base __init__ 2081 # gets called. 2082 @dataclass(init=False) 2083 class C(B): 2084 x: int = 10 2085 c = C() 2086 self.assertEqual(c.x, 10) 2087 self.assertEqual(c.z, 100) 2088 2089 def test_no_init(self): 2090 dataclass(init=False) 2091 class C: 2092 i: int = 0 2093 self.assertEqual(C().i, 0) 2094 2095 dataclass(init=False) 2096 class C: 2097 i: int = 2 2098 def __init__(self): 2099 self.i = 3 2100 self.assertEqual(C().i, 3) 2101 2102 def test_overwriting_init(self): 2103 # If the class has __init__, use it no matter the value of 2104 # init=. 2105 2106 @dataclass 2107 class C: 2108 x: int 2109 def __init__(self, x): 2110 self.x = 2 * x 2111 self.assertEqual(C(3).x, 6) 2112 2113 @dataclass(init=True) 2114 class C: 2115 x: int 2116 def __init__(self, x): 2117 self.x = 2 * x 2118 self.assertEqual(C(4).x, 8) 2119 2120 @dataclass(init=False) 2121 class C: 2122 x: int 2123 def __init__(self, x): 2124 self.x = 2 * x 2125 self.assertEqual(C(5).x, 10) 2126 2127 2128class TestRepr(unittest.TestCase): 2129 def test_repr(self): 2130 @dataclass 2131 class B: 2132 x: int 2133 2134 @dataclass 2135 class C(B): 2136 y: int = 10 2137 2138 o = C(4) 2139 self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)') 2140 2141 @dataclass 2142 class D(C): 2143 x: int = 20 2144 self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)') 2145 2146 @dataclass 2147 class C: 2148 @dataclass 2149 class D: 2150 i: int 2151 @dataclass 2152 class E: 2153 pass 2154 self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)') 2155 self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()') 2156 2157 def test_no_repr(self): 2158 # Test a class with no __repr__ and repr=False. 2159 @dataclass(repr=False) 2160 class C: 2161 x: int 2162 self.assertIn(f'{__name__}.TestRepr.test_no_repr.<locals>.C object at', 2163 repr(C(3))) 2164 2165 # Test a class with a __repr__ and repr=False. 2166 @dataclass(repr=False) 2167 class C: 2168 x: int 2169 def __repr__(self): 2170 return 'C-class' 2171 self.assertEqual(repr(C(3)), 'C-class') 2172 2173 def test_overwriting_repr(self): 2174 # If the class has __repr__, use it no matter the value of 2175 # repr=. 2176 2177 @dataclass 2178 class C: 2179 x: int 2180 def __repr__(self): 2181 return 'x' 2182 self.assertEqual(repr(C(0)), 'x') 2183 2184 @dataclass(repr=True) 2185 class C: 2186 x: int 2187 def __repr__(self): 2188 return 'x' 2189 self.assertEqual(repr(C(0)), 'x') 2190 2191 @dataclass(repr=False) 2192 class C: 2193 x: int 2194 def __repr__(self): 2195 return 'x' 2196 self.assertEqual(repr(C(0)), 'x') 2197 2198 2199class TestEq(unittest.TestCase): 2200 def test_no_eq(self): 2201 # Test a class with no __eq__ and eq=False. 2202 @dataclass(eq=False) 2203 class C: 2204 x: int 2205 self.assertNotEqual(C(0), C(0)) 2206 c = C(3) 2207 self.assertEqual(c, c) 2208 2209 # Test a class with an __eq__ and eq=False. 2210 @dataclass(eq=False) 2211 class C: 2212 x: int 2213 def __eq__(self, other): 2214 return other == 10 2215 self.assertEqual(C(3), 10) 2216 2217 def test_overwriting_eq(self): 2218 # If the class has __eq__, use it no matter the value of 2219 # eq=. 2220 2221 @dataclass 2222 class C: 2223 x: int 2224 def __eq__(self, other): 2225 return other == 3 2226 self.assertEqual(C(1), 3) 2227 self.assertNotEqual(C(1), 1) 2228 2229 @dataclass(eq=True) 2230 class C: 2231 x: int 2232 def __eq__(self, other): 2233 return other == 4 2234 self.assertEqual(C(1), 4) 2235 self.assertNotEqual(C(1), 1) 2236 2237 @dataclass(eq=False) 2238 class C: 2239 x: int 2240 def __eq__(self, other): 2241 return other == 5 2242 self.assertEqual(C(1), 5) 2243 self.assertNotEqual(C(1), 1) 2244 2245 2246class TestOrdering(unittest.TestCase): 2247 def test_functools_total_ordering(self): 2248 # Test that functools.total_ordering works with this class. 2249 @total_ordering 2250 @dataclass 2251 class C: 2252 x: int 2253 def __lt__(self, other): 2254 # Perform the test "backward", just to make 2255 # sure this is being called. 2256 return self.x >= other 2257 2258 self.assertLess(C(0), -1) 2259 self.assertLessEqual(C(0), -1) 2260 self.assertGreater(C(0), 1) 2261 self.assertGreaterEqual(C(0), 1) 2262 2263 def test_no_order(self): 2264 # Test that no ordering functions are added by default. 2265 @dataclass(order=False) 2266 class C: 2267 x: int 2268 # Make sure no order methods are added. 2269 self.assertNotIn('__le__', C.__dict__) 2270 self.assertNotIn('__lt__', C.__dict__) 2271 self.assertNotIn('__ge__', C.__dict__) 2272 self.assertNotIn('__gt__', C.__dict__) 2273 2274 # Test that __lt__ is still called 2275 @dataclass(order=False) 2276 class C: 2277 x: int 2278 def __lt__(self, other): 2279 return False 2280 # Make sure other methods aren't added. 2281 self.assertNotIn('__le__', C.__dict__) 2282 self.assertNotIn('__ge__', C.__dict__) 2283 self.assertNotIn('__gt__', C.__dict__) 2284 2285 def test_overwriting_order(self): 2286 with self.assertRaisesRegex(TypeError, 2287 'Cannot overwrite attribute __lt__' 2288 '.*using functools.total_ordering'): 2289 @dataclass(order=True) 2290 class C: 2291 x: int 2292 def __lt__(self): 2293 pass 2294 2295 with self.assertRaisesRegex(TypeError, 2296 'Cannot overwrite attribute __le__' 2297 '.*using functools.total_ordering'): 2298 @dataclass(order=True) 2299 class C: 2300 x: int 2301 def __le__(self): 2302 pass 2303 2304 with self.assertRaisesRegex(TypeError, 2305 'Cannot overwrite attribute __gt__' 2306 '.*using functools.total_ordering'): 2307 @dataclass(order=True) 2308 class C: 2309 x: int 2310 def __gt__(self): 2311 pass 2312 2313 with self.assertRaisesRegex(TypeError, 2314 'Cannot overwrite attribute __ge__' 2315 '.*using functools.total_ordering'): 2316 @dataclass(order=True) 2317 class C: 2318 x: int 2319 def __ge__(self): 2320 pass 2321 2322class TestHash(unittest.TestCase): 2323 def test_unsafe_hash(self): 2324 @dataclass(unsafe_hash=True) 2325 class C: 2326 x: int 2327 y: str 2328 self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo'))) 2329 2330 def test_hash_rules(self): 2331 def non_bool(value): 2332 # Map to something else that's True, but not a bool. 2333 if value is None: 2334 return None 2335 if value: 2336 return (3,) 2337 return 0 2338 2339 def test(case, unsafe_hash, eq, frozen, with_hash, result): 2340 with self.subTest(case=case, unsafe_hash=unsafe_hash, eq=eq, 2341 frozen=frozen): 2342 if result != 'exception': 2343 if with_hash: 2344 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen) 2345 class C: 2346 def __hash__(self): 2347 return 0 2348 else: 2349 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen) 2350 class C: 2351 pass 2352 2353 # See if the result matches what's expected. 2354 if result == 'fn': 2355 # __hash__ contains the function we generated. 2356 self.assertIn('__hash__', C.__dict__) 2357 self.assertIsNotNone(C.__dict__['__hash__']) 2358 2359 elif result == '': 2360 # __hash__ is not present in our class. 2361 if not with_hash: 2362 self.assertNotIn('__hash__', C.__dict__) 2363 2364 elif result == 'none': 2365 # __hash__ is set to None. 2366 self.assertIn('__hash__', C.__dict__) 2367 self.assertIsNone(C.__dict__['__hash__']) 2368 2369 elif result == 'exception': 2370 # Creating the class should cause an exception. 2371 # This only happens with with_hash==True. 2372 assert(with_hash) 2373 with self.assertRaisesRegex(TypeError, 'Cannot overwrite attribute __hash__'): 2374 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen) 2375 class C: 2376 def __hash__(self): 2377 return 0 2378 2379 else: 2380 assert False, f'unknown result {result!r}' 2381 2382 # There are 8 cases of: 2383 # unsafe_hash=True/False 2384 # eq=True/False 2385 # frozen=True/False 2386 # And for each of these, a different result if 2387 # __hash__ is defined or not. 2388 for case, (unsafe_hash, eq, frozen, res_no_defined_hash, res_defined_hash) in enumerate([ 2389 (False, False, False, '', ''), 2390 (False, False, True, '', ''), 2391 (False, True, False, 'none', ''), 2392 (False, True, True, 'fn', ''), 2393 (True, False, False, 'fn', 'exception'), 2394 (True, False, True, 'fn', 'exception'), 2395 (True, True, False, 'fn', 'exception'), 2396 (True, True, True, 'fn', 'exception'), 2397 ], 1): 2398 test(case, unsafe_hash, eq, frozen, False, res_no_defined_hash) 2399 test(case, unsafe_hash, eq, frozen, True, res_defined_hash) 2400 2401 # Test non-bool truth values, too. This is just to 2402 # make sure the data-driven table in the decorator 2403 # handles non-bool values. 2404 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), False, res_no_defined_hash) 2405 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), True, res_defined_hash) 2406 2407 2408 def test_eq_only(self): 2409 # If a class defines __eq__, __hash__ is automatically added 2410 # and set to None. This is normal Python behavior, not 2411 # related to dataclasses. Make sure we don't interfere with 2412 # that (see bpo=32546). 2413 2414 @dataclass 2415 class C: 2416 i: int 2417 def __eq__(self, other): 2418 return self.i == other.i 2419 self.assertEqual(C(1), C(1)) 2420 self.assertNotEqual(C(1), C(4)) 2421 2422 # And make sure things work in this case if we specify 2423 # unsafe_hash=True. 2424 @dataclass(unsafe_hash=True) 2425 class C: 2426 i: int 2427 def __eq__(self, other): 2428 return self.i == other.i 2429 self.assertEqual(C(1), C(1.0)) 2430 self.assertEqual(hash(C(1)), hash(C(1.0))) 2431 2432 # And check that the classes __eq__ is being used, despite 2433 # specifying eq=True. 2434 @dataclass(unsafe_hash=True, eq=True) 2435 class C: 2436 i: int 2437 def __eq__(self, other): 2438 return self.i == 3 and self.i == other.i 2439 self.assertEqual(C(3), C(3)) 2440 self.assertNotEqual(C(1), C(1)) 2441 self.assertEqual(hash(C(1)), hash(C(1.0))) 2442 2443 def test_0_field_hash(self): 2444 @dataclass(frozen=True) 2445 class C: 2446 pass 2447 self.assertEqual(hash(C()), hash(())) 2448 2449 @dataclass(unsafe_hash=True) 2450 class C: 2451 pass 2452 self.assertEqual(hash(C()), hash(())) 2453 2454 def test_1_field_hash(self): 2455 @dataclass(frozen=True) 2456 class C: 2457 x: int 2458 self.assertEqual(hash(C(4)), hash((4,))) 2459 self.assertEqual(hash(C(42)), hash((42,))) 2460 2461 @dataclass(unsafe_hash=True) 2462 class C: 2463 x: int 2464 self.assertEqual(hash(C(4)), hash((4,))) 2465 self.assertEqual(hash(C(42)), hash((42,))) 2466 2467 def test_hash_no_args(self): 2468 # Test dataclasses with no hash= argument. This exists to 2469 # make sure that if the @dataclass parameter name is changed 2470 # or the non-default hashing behavior changes, the default 2471 # hashability keeps working the same way. 2472 2473 class Base: 2474 def __hash__(self): 2475 return 301 2476 2477 # If frozen or eq is None, then use the default value (do not 2478 # specify any value in the decorator). 2479 for frozen, eq, base, expected in [ 2480 (None, None, object, 'unhashable'), 2481 (None, None, Base, 'unhashable'), 2482 (None, False, object, 'object'), 2483 (None, False, Base, 'base'), 2484 (None, True, object, 'unhashable'), 2485 (None, True, Base, 'unhashable'), 2486 (False, None, object, 'unhashable'), 2487 (False, None, Base, 'unhashable'), 2488 (False, False, object, 'object'), 2489 (False, False, Base, 'base'), 2490 (False, True, object, 'unhashable'), 2491 (False, True, Base, 'unhashable'), 2492 (True, None, object, 'tuple'), 2493 (True, None, Base, 'tuple'), 2494 (True, False, object, 'object'), 2495 (True, False, Base, 'base'), 2496 (True, True, object, 'tuple'), 2497 (True, True, Base, 'tuple'), 2498 ]: 2499 2500 with self.subTest(frozen=frozen, eq=eq, base=base, expected=expected): 2501 # First, create the class. 2502 if frozen is None and eq is None: 2503 @dataclass 2504 class C(base): 2505 i: int 2506 elif frozen is None: 2507 @dataclass(eq=eq) 2508 class C(base): 2509 i: int 2510 elif eq is None: 2511 @dataclass(frozen=frozen) 2512 class C(base): 2513 i: int 2514 else: 2515 @dataclass(frozen=frozen, eq=eq) 2516 class C(base): 2517 i: int 2518 2519 # Now, make sure it hashes as expected. 2520 if expected == 'unhashable': 2521 c = C(10) 2522 with self.assertRaisesRegex(TypeError, 'unhashable type'): 2523 hash(c) 2524 2525 elif expected == 'base': 2526 self.assertEqual(hash(C(10)), 301) 2527 2528 elif expected == 'object': 2529 # I'm not sure what test to use here. object's 2530 # hash isn't based on id(), so calling hash() 2531 # won't tell us much. So, just check the 2532 # function used is object's. 2533 self.assertIs(C.__hash__, object.__hash__) 2534 2535 elif expected == 'tuple': 2536 self.assertEqual(hash(C(42)), hash((42,))) 2537 2538 else: 2539 assert False, f'unknown value for expected={expected!r}' 2540 2541 2542class TestFrozen(unittest.TestCase): 2543 def test_frozen(self): 2544 @dataclass(frozen=True) 2545 class C: 2546 i: int 2547 2548 c = C(10) 2549 self.assertEqual(c.i, 10) 2550 with self.assertRaises(FrozenInstanceError): 2551 c.i = 5 2552 self.assertEqual(c.i, 10) 2553 2554 def test_inherit(self): 2555 @dataclass(frozen=True) 2556 class C: 2557 i: int 2558 2559 @dataclass(frozen=True) 2560 class D(C): 2561 j: int 2562 2563 d = D(0, 10) 2564 with self.assertRaises(FrozenInstanceError): 2565 d.i = 5 2566 with self.assertRaises(FrozenInstanceError): 2567 d.j = 6 2568 self.assertEqual(d.i, 0) 2569 self.assertEqual(d.j, 10) 2570 2571 # Test both ways: with an intermediate normal (non-dataclass) 2572 # class and without an intermediate class. 2573 def test_inherit_nonfrozen_from_frozen(self): 2574 for intermediate_class in [True, False]: 2575 with self.subTest(intermediate_class=intermediate_class): 2576 @dataclass(frozen=True) 2577 class C: 2578 i: int 2579 2580 if intermediate_class: 2581 class I(C): pass 2582 else: 2583 I = C 2584 2585 with self.assertRaisesRegex(TypeError, 2586 'cannot inherit non-frozen dataclass from a frozen one'): 2587 @dataclass 2588 class D(I): 2589 pass 2590 2591 def test_inherit_frozen_from_nonfrozen(self): 2592 for intermediate_class in [True, False]: 2593 with self.subTest(intermediate_class=intermediate_class): 2594 @dataclass 2595 class C: 2596 i: int 2597 2598 if intermediate_class: 2599 class I(C): pass 2600 else: 2601 I = C 2602 2603 with self.assertRaisesRegex(TypeError, 2604 'cannot inherit frozen dataclass from a non-frozen one'): 2605 @dataclass(frozen=True) 2606 class D(I): 2607 pass 2608 2609 def test_inherit_from_normal_class(self): 2610 for intermediate_class in [True, False]: 2611 with self.subTest(intermediate_class=intermediate_class): 2612 class C: 2613 pass 2614 2615 if intermediate_class: 2616 class I(C): pass 2617 else: 2618 I = C 2619 2620 @dataclass(frozen=True) 2621 class D(I): 2622 i: int 2623 2624 d = D(10) 2625 with self.assertRaises(FrozenInstanceError): 2626 d.i = 5 2627 2628 def test_non_frozen_normal_derived(self): 2629 # See bpo-32953. 2630 2631 @dataclass(frozen=True) 2632 class D: 2633 x: int 2634 y: int = 10 2635 2636 class S(D): 2637 pass 2638 2639 s = S(3) 2640 self.assertEqual(s.x, 3) 2641 self.assertEqual(s.y, 10) 2642 s.cached = True 2643 2644 # But can't change the frozen attributes. 2645 with self.assertRaises(FrozenInstanceError): 2646 s.x = 5 2647 with self.assertRaises(FrozenInstanceError): 2648 s.y = 5 2649 self.assertEqual(s.x, 3) 2650 self.assertEqual(s.y, 10) 2651 self.assertEqual(s.cached, True) 2652 2653 def test_overwriting_frozen(self): 2654 # frozen uses __setattr__ and __delattr__. 2655 with self.assertRaisesRegex(TypeError, 2656 'Cannot overwrite attribute __setattr__'): 2657 @dataclass(frozen=True) 2658 class C: 2659 x: int 2660 def __setattr__(self): 2661 pass 2662 2663 with self.assertRaisesRegex(TypeError, 2664 'Cannot overwrite attribute __delattr__'): 2665 @dataclass(frozen=True) 2666 class C: 2667 x: int 2668 def __delattr__(self): 2669 pass 2670 2671 @dataclass(frozen=False) 2672 class C: 2673 x: int 2674 def __setattr__(self, name, value): 2675 self.__dict__['x'] = value * 2 2676 self.assertEqual(C(10).x, 20) 2677 2678 def test_frozen_hash(self): 2679 @dataclass(frozen=True) 2680 class C: 2681 x: Any 2682 2683 # If x is immutable, we can compute the hash. No exception is 2684 # raised. 2685 hash(C(3)) 2686 2687 # If x is mutable, computing the hash is an error. 2688 with self.assertRaisesRegex(TypeError, 'unhashable type'): 2689 hash(C({})) 2690 2691 2692class TestSlots(unittest.TestCase): 2693 def test_simple(self): 2694 @dataclass 2695 class C: 2696 __slots__ = ('x',) 2697 x: Any 2698 2699 # There was a bug where a variable in a slot was assumed to 2700 # also have a default value (of type 2701 # types.MemberDescriptorType). 2702 with self.assertRaisesRegex(TypeError, 2703 r"__init__\(\) missing 1 required positional argument: 'x'"): 2704 C() 2705 2706 # We can create an instance, and assign to x. 2707 c = C(10) 2708 self.assertEqual(c.x, 10) 2709 c.x = 5 2710 self.assertEqual(c.x, 5) 2711 2712 # We can't assign to anything else. 2713 with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'y'"): 2714 c.y = 5 2715 2716 def test_derived_added_field(self): 2717 # See bpo-33100. 2718 @dataclass 2719 class Base: 2720 __slots__ = ('x',) 2721 x: Any 2722 2723 @dataclass 2724 class Derived(Base): 2725 x: int 2726 y: int 2727 2728 d = Derived(1, 2) 2729 self.assertEqual((d.x, d.y), (1, 2)) 2730 2731 # We can add a new field to the derived instance. 2732 d.z = 10 2733 2734class TestDescriptors(unittest.TestCase): 2735 def test_set_name(self): 2736 # See bpo-33141. 2737 2738 # Create a descriptor. 2739 class D: 2740 def __set_name__(self, owner, name): 2741 self.name = name + 'x' 2742 def __get__(self, instance, owner): 2743 if instance is not None: 2744 return 1 2745 return self 2746 2747 # This is the case of just normal descriptor behavior, no 2748 # dataclass code is involved in initializing the descriptor. 2749 @dataclass 2750 class C: 2751 c: int=D() 2752 self.assertEqual(C.c.name, 'cx') 2753 2754 # Now test with a default value and init=False, which is the 2755 # only time this is really meaningful. If not using 2756 # init=False, then the descriptor will be overwritten, anyway. 2757 @dataclass 2758 class C: 2759 c: int=field(default=D(), init=False) 2760 self.assertEqual(C.c.name, 'cx') 2761 self.assertEqual(C().c, 1) 2762 2763 def test_non_descriptor(self): 2764 # PEP 487 says __set_name__ should work on non-descriptors. 2765 # Create a descriptor. 2766 2767 class D: 2768 def __set_name__(self, owner, name): 2769 self.name = name + 'x' 2770 2771 @dataclass 2772 class C: 2773 c: int=field(default=D(), init=False) 2774 self.assertEqual(C.c.name, 'cx') 2775 2776 def test_lookup_on_instance(self): 2777 # See bpo-33175. 2778 class D: 2779 pass 2780 2781 d = D() 2782 # Create an attribute on the instance, not type. 2783 d.__set_name__ = Mock() 2784 2785 # Make sure d.__set_name__ is not called. 2786 @dataclass 2787 class C: 2788 i: int=field(default=d, init=False) 2789 2790 self.assertEqual(d.__set_name__.call_count, 0) 2791 2792 def test_lookup_on_class(self): 2793 # See bpo-33175. 2794 class D: 2795 pass 2796 D.__set_name__ = Mock() 2797 2798 # Make sure D.__set_name__ is called. 2799 @dataclass 2800 class C: 2801 i: int=field(default=D(), init=False) 2802 2803 self.assertEqual(D.__set_name__.call_count, 1) 2804 2805 2806class TestStringAnnotations(unittest.TestCase): 2807 def test_classvar(self): 2808 # Some expressions recognized as ClassVar really aren't. But 2809 # if you're using string annotations, it's not an exact 2810 # science. 2811 # These tests assume that both "import typing" and "from 2812 # typing import *" have been run in this file. 2813 for typestr in ('ClassVar[int]', 2814 'ClassVar [int]' 2815 ' ClassVar [int]', 2816 'ClassVar', 2817 ' ClassVar ', 2818 'typing.ClassVar[int]', 2819 'typing.ClassVar[str]', 2820 ' typing.ClassVar[str]', 2821 'typing .ClassVar[str]', 2822 'typing. ClassVar[str]', 2823 'typing.ClassVar [str]', 2824 'typing.ClassVar [ str]', 2825 2826 # Not syntactically valid, but these will 2827 # be treated as ClassVars. 2828 'typing.ClassVar.[int]', 2829 'typing.ClassVar+', 2830 ): 2831 with self.subTest(typestr=typestr): 2832 @dataclass 2833 class C: 2834 x: typestr 2835 2836 # x is a ClassVar, so C() takes no args. 2837 C() 2838 2839 # And it won't appear in the class's dict because it doesn't 2840 # have a default. 2841 self.assertNotIn('x', C.__dict__) 2842 2843 def test_isnt_classvar(self): 2844 for typestr in ('CV', 2845 't.ClassVar', 2846 't.ClassVar[int]', 2847 'typing..ClassVar[int]', 2848 'Classvar', 2849 'Classvar[int]', 2850 'typing.ClassVarx[int]', 2851 'typong.ClassVar[int]', 2852 'dataclasses.ClassVar[int]', 2853 'typingxClassVar[str]', 2854 ): 2855 with self.subTest(typestr=typestr): 2856 @dataclass 2857 class C: 2858 x: typestr 2859 2860 # x is not a ClassVar, so C() takes one arg. 2861 self.assertEqual(C(10).x, 10) 2862 2863 def test_initvar(self): 2864 # These tests assume that both "import dataclasses" and "from 2865 # dataclasses import *" have been run in this file. 2866 for typestr in ('InitVar[int]', 2867 'InitVar [int]' 2868 ' InitVar [int]', 2869 'InitVar', 2870 ' InitVar ', 2871 'dataclasses.InitVar[int]', 2872 'dataclasses.InitVar[str]', 2873 ' dataclasses.InitVar[str]', 2874 'dataclasses .InitVar[str]', 2875 'dataclasses. InitVar[str]', 2876 'dataclasses.InitVar [str]', 2877 'dataclasses.InitVar [ str]', 2878 2879 # Not syntactically valid, but these will 2880 # be treated as InitVars. 2881 'dataclasses.InitVar.[int]', 2882 'dataclasses.InitVar+', 2883 ): 2884 with self.subTest(typestr=typestr): 2885 @dataclass 2886 class C: 2887 x: typestr 2888 2889 # x is an InitVar, so doesn't create a member. 2890 with self.assertRaisesRegex(AttributeError, 2891 "object has no attribute 'x'"): 2892 C(1).x 2893 2894 def test_isnt_initvar(self): 2895 for typestr in ('IV', 2896 'dc.InitVar', 2897 'xdataclasses.xInitVar', 2898 'typing.xInitVar[int]', 2899 ): 2900 with self.subTest(typestr=typestr): 2901 @dataclass 2902 class C: 2903 x: typestr 2904 2905 # x is not an InitVar, so there will be a member x. 2906 self.assertEqual(C(10).x, 10) 2907 2908 def test_classvar_module_level_import(self): 2909 from test import dataclass_module_1 2910 from test import dataclass_module_1_str 2911 from test import dataclass_module_2 2912 from test import dataclass_module_2_str 2913 2914 for m in (dataclass_module_1, dataclass_module_1_str, 2915 dataclass_module_2, dataclass_module_2_str, 2916 ): 2917 with self.subTest(m=m): 2918 # There's a difference in how the ClassVars are 2919 # interpreted when using string annotations or 2920 # not. See the imported modules for details. 2921 if m.USING_STRINGS: 2922 c = m.CV(10) 2923 else: 2924 c = m.CV() 2925 self.assertEqual(c.cv0, 20) 2926 2927 2928 # There's a difference in how the InitVars are 2929 # interpreted when using string annotations or 2930 # not. See the imported modules for details. 2931 c = m.IV(0, 1, 2, 3, 4) 2932 2933 for field_name in ('iv0', 'iv1', 'iv2', 'iv3'): 2934 with self.subTest(field_name=field_name): 2935 with self.assertRaisesRegex(AttributeError, f"object has no attribute '{field_name}'"): 2936 # Since field_name is an InitVar, it's 2937 # not an instance field. 2938 getattr(c, field_name) 2939 2940 if m.USING_STRINGS: 2941 # iv4 is interpreted as a normal field. 2942 self.assertIn('not_iv4', c.__dict__) 2943 self.assertEqual(c.not_iv4, 4) 2944 else: 2945 # iv4 is interpreted as an InitVar, so it 2946 # won't exist on the instance. 2947 self.assertNotIn('not_iv4', c.__dict__) 2948 2949 def test_text_annotations(self): 2950 from test import dataclass_textanno 2951 2952 self.assertEqual( 2953 get_type_hints(dataclass_textanno.Bar), 2954 {'foo': dataclass_textanno.Foo}) 2955 self.assertEqual( 2956 get_type_hints(dataclass_textanno.Bar.__init__), 2957 {'foo': dataclass_textanno.Foo, 2958 'return': type(None)}) 2959 2960 2961class TestMakeDataclass(unittest.TestCase): 2962 def test_simple(self): 2963 C = make_dataclass('C', 2964 [('x', int), 2965 ('y', int, field(default=5))], 2966 namespace={'add_one': lambda self: self.x + 1}) 2967 c = C(10) 2968 self.assertEqual((c.x, c.y), (10, 5)) 2969 self.assertEqual(c.add_one(), 11) 2970 2971 2972 def test_no_mutate_namespace(self): 2973 # Make sure a provided namespace isn't mutated. 2974 ns = {} 2975 C = make_dataclass('C', 2976 [('x', int), 2977 ('y', int, field(default=5))], 2978 namespace=ns) 2979 self.assertEqual(ns, {}) 2980 2981 def test_base(self): 2982 class Base1: 2983 pass 2984 class Base2: 2985 pass 2986 C = make_dataclass('C', 2987 [('x', int)], 2988 bases=(Base1, Base2)) 2989 c = C(2) 2990 self.assertIsInstance(c, C) 2991 self.assertIsInstance(c, Base1) 2992 self.assertIsInstance(c, Base2) 2993 2994 def test_base_dataclass(self): 2995 @dataclass 2996 class Base1: 2997 x: int 2998 class Base2: 2999 pass 3000 C = make_dataclass('C', 3001 [('y', int)], 3002 bases=(Base1, Base2)) 3003 with self.assertRaisesRegex(TypeError, 'required positional'): 3004 c = C(2) 3005 c = C(1, 2) 3006 self.assertIsInstance(c, C) 3007 self.assertIsInstance(c, Base1) 3008 self.assertIsInstance(c, Base2) 3009 3010 self.assertEqual((c.x, c.y), (1, 2)) 3011 3012 def test_init_var(self): 3013 def post_init(self, y): 3014 self.x *= y 3015 3016 C = make_dataclass('C', 3017 [('x', int), 3018 ('y', InitVar[int]), 3019 ], 3020 namespace={'__post_init__': post_init}, 3021 ) 3022 c = C(2, 3) 3023 self.assertEqual(vars(c), {'x': 6}) 3024 self.assertEqual(len(fields(c)), 1) 3025 3026 def test_class_var(self): 3027 C = make_dataclass('C', 3028 [('x', int), 3029 ('y', ClassVar[int], 10), 3030 ('z', ClassVar[int], field(default=20)), 3031 ]) 3032 c = C(1) 3033 self.assertEqual(vars(c), {'x': 1}) 3034 self.assertEqual(len(fields(c)), 1) 3035 self.assertEqual(C.y, 10) 3036 self.assertEqual(C.z, 20) 3037 3038 def test_other_params(self): 3039 C = make_dataclass('C', 3040 [('x', int), 3041 ('y', ClassVar[int], 10), 3042 ('z', ClassVar[int], field(default=20)), 3043 ], 3044 init=False) 3045 # Make sure we have a repr, but no init. 3046 self.assertNotIn('__init__', vars(C)) 3047 self.assertIn('__repr__', vars(C)) 3048 3049 # Make sure random other params don't work. 3050 with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'): 3051 C = make_dataclass('C', 3052 [], 3053 xxinit=False) 3054 3055 def test_no_types(self): 3056 C = make_dataclass('Point', ['x', 'y', 'z']) 3057 c = C(1, 2, 3) 3058 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3}) 3059 self.assertEqual(C.__annotations__, {'x': 'typing.Any', 3060 'y': 'typing.Any', 3061 'z': 'typing.Any'}) 3062 3063 C = make_dataclass('Point', ['x', ('y', int), 'z']) 3064 c = C(1, 2, 3) 3065 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3}) 3066 self.assertEqual(C.__annotations__, {'x': 'typing.Any', 3067 'y': int, 3068 'z': 'typing.Any'}) 3069 3070 def test_invalid_type_specification(self): 3071 for bad_field in [(), 3072 (1, 2, 3, 4), 3073 ]: 3074 with self.subTest(bad_field=bad_field): 3075 with self.assertRaisesRegex(TypeError, r'Invalid field: '): 3076 make_dataclass('C', ['a', bad_field]) 3077 3078 # And test for things with no len(). 3079 for bad_field in [float, 3080 lambda x:x, 3081 ]: 3082 with self.subTest(bad_field=bad_field): 3083 with self.assertRaisesRegex(TypeError, r'has no len\(\)'): 3084 make_dataclass('C', ['a', bad_field]) 3085 3086 def test_duplicate_field_names(self): 3087 for field in ['a', 'ab']: 3088 with self.subTest(field=field): 3089 with self.assertRaisesRegex(TypeError, 'Field name duplicated'): 3090 make_dataclass('C', [field, 'a', field]) 3091 3092 def test_keyword_field_names(self): 3093 for field in ['for', 'async', 'await', 'as']: 3094 with self.subTest(field=field): 3095 with self.assertRaisesRegex(TypeError, 'must not be keywords'): 3096 make_dataclass('C', ['a', field]) 3097 with self.assertRaisesRegex(TypeError, 'must not be keywords'): 3098 make_dataclass('C', [field]) 3099 with self.assertRaisesRegex(TypeError, 'must not be keywords'): 3100 make_dataclass('C', [field, 'a']) 3101 3102 def test_non_identifier_field_names(self): 3103 for field in ['()', 'x,y', '*', '2@3', '', 'little johnny tables']: 3104 with self.subTest(field=field): 3105 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'): 3106 make_dataclass('C', ['a', field]) 3107 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'): 3108 make_dataclass('C', [field]) 3109 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'): 3110 make_dataclass('C', [field, 'a']) 3111 3112 def test_underscore_field_names(self): 3113 # Unlike namedtuple, it's okay if dataclass field names have 3114 # an underscore. 3115 make_dataclass('C', ['_', '_a', 'a_a', 'a_']) 3116 3117 def test_funny_class_names_names(self): 3118 # No reason to prevent weird class names, since 3119 # types.new_class allows them. 3120 for classname in ['()', 'x,y', '*', '2@3', '']: 3121 with self.subTest(classname=classname): 3122 C = make_dataclass(classname, ['a', 'b']) 3123 self.assertEqual(C.__name__, classname) 3124 3125class TestReplace(unittest.TestCase): 3126 def test(self): 3127 @dataclass(frozen=True) 3128 class C: 3129 x: int 3130 y: int 3131 3132 c = C(1, 2) 3133 c1 = replace(c, x=3) 3134 self.assertEqual(c1.x, 3) 3135 self.assertEqual(c1.y, 2) 3136 3137 def test_frozen(self): 3138 @dataclass(frozen=True) 3139 class C: 3140 x: int 3141 y: int 3142 z: int = field(init=False, default=10) 3143 t: int = field(init=False, default=100) 3144 3145 c = C(1, 2) 3146 c1 = replace(c, x=3) 3147 self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100)) 3148 self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100)) 3149 3150 3151 with self.assertRaisesRegex(ValueError, 'init=False'): 3152 replace(c, x=3, z=20, t=50) 3153 with self.assertRaisesRegex(ValueError, 'init=False'): 3154 replace(c, z=20) 3155 replace(c, x=3, z=20, t=50) 3156 3157 # Make sure the result is still frozen. 3158 with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"): 3159 c1.x = 3 3160 3161 # Make sure we can't replace an attribute that doesn't exist, 3162 # if we're also replacing one that does exist. Test this 3163 # here, because setting attributes on frozen instances is 3164 # handled slightly differently from non-frozen ones. 3165 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected " 3166 "keyword argument 'a'"): 3167 c1 = replace(c, x=20, a=5) 3168 3169 def test_invalid_field_name(self): 3170 @dataclass(frozen=True) 3171 class C: 3172 x: int 3173 y: int 3174 3175 c = C(1, 2) 3176 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected " 3177 "keyword argument 'z'"): 3178 c1 = replace(c, z=3) 3179 3180 def test_invalid_object(self): 3181 @dataclass(frozen=True) 3182 class C: 3183 x: int 3184 y: int 3185 3186 with self.assertRaisesRegex(TypeError, 'dataclass instance'): 3187 replace(C, x=3) 3188 3189 with self.assertRaisesRegex(TypeError, 'dataclass instance'): 3190 replace(0, x=3) 3191 3192 def test_no_init(self): 3193 @dataclass 3194 class C: 3195 x: int 3196 y: int = field(init=False, default=10) 3197 3198 c = C(1) 3199 c.y = 20 3200 3201 # Make sure y gets the default value. 3202 c1 = replace(c, x=5) 3203 self.assertEqual((c1.x, c1.y), (5, 10)) 3204 3205 # Trying to replace y is an error. 3206 with self.assertRaisesRegex(ValueError, 'init=False'): 3207 replace(c, x=2, y=30) 3208 3209 with self.assertRaisesRegex(ValueError, 'init=False'): 3210 replace(c, y=30) 3211 3212 def test_classvar(self): 3213 @dataclass 3214 class C: 3215 x: int 3216 y: ClassVar[int] = 1000 3217 3218 c = C(1) 3219 d = C(2) 3220 3221 self.assertIs(c.y, d.y) 3222 self.assertEqual(c.y, 1000) 3223 3224 # Trying to replace y is an error: can't replace ClassVars. 3225 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an " 3226 "unexpected keyword argument 'y'"): 3227 replace(c, y=30) 3228 3229 replace(c, x=5) 3230 3231 def test_initvar_is_specified(self): 3232 @dataclass 3233 class C: 3234 x: int 3235 y: InitVar[int] 3236 3237 def __post_init__(self, y): 3238 self.x *= y 3239 3240 c = C(1, 10) 3241 self.assertEqual(c.x, 10) 3242 with self.assertRaisesRegex(ValueError, r"InitVar 'y' must be " 3243 "specified with replace()"): 3244 replace(c, x=3) 3245 c = replace(c, x=3, y=5) 3246 self.assertEqual(c.x, 15) 3247 3248 def test_recursive_repr(self): 3249 @dataclass 3250 class C: 3251 f: "C" 3252 3253 c = C(None) 3254 c.f = c 3255 self.assertEqual(repr(c), "TestReplace.test_recursive_repr.<locals>.C(f=...)") 3256 3257 def test_recursive_repr_two_attrs(self): 3258 @dataclass 3259 class C: 3260 f: "C" 3261 g: "C" 3262 3263 c = C(None, None) 3264 c.f = c 3265 c.g = c 3266 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_two_attrs" 3267 ".<locals>.C(f=..., g=...)") 3268 3269 def test_recursive_repr_indirection(self): 3270 @dataclass 3271 class C: 3272 f: "D" 3273 3274 @dataclass 3275 class D: 3276 f: "C" 3277 3278 c = C(None) 3279 d = D(None) 3280 c.f = d 3281 d.f = c 3282 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection" 3283 ".<locals>.C(f=TestReplace.test_recursive_repr_indirection" 3284 ".<locals>.D(f=...))") 3285 3286 def test_recursive_repr_indirection_two(self): 3287 @dataclass 3288 class C: 3289 f: "D" 3290 3291 @dataclass 3292 class D: 3293 f: "E" 3294 3295 @dataclass 3296 class E: 3297 f: "C" 3298 3299 c = C(None) 3300 d = D(None) 3301 e = E(None) 3302 c.f = d 3303 d.f = e 3304 e.f = c 3305 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection_two" 3306 ".<locals>.C(f=TestReplace.test_recursive_repr_indirection_two" 3307 ".<locals>.D(f=TestReplace.test_recursive_repr_indirection_two" 3308 ".<locals>.E(f=...)))") 3309 3310 def test_recursive_repr_misc_attrs(self): 3311 @dataclass 3312 class C: 3313 f: "C" 3314 g: int 3315 3316 c = C(None, 1) 3317 c.f = c 3318 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_misc_attrs" 3319 ".<locals>.C(f=..., g=1)") 3320 3321 ## def test_initvar(self): 3322 ## @dataclass 3323 ## class C: 3324 ## x: int 3325 ## y: InitVar[int] 3326 3327 ## c = C(1, 10) 3328 ## d = C(2, 20) 3329 3330 ## # In our case, replacing an InitVar is a no-op 3331 ## self.assertEqual(c, replace(c, y=5)) 3332 3333 ## replace(c, x=5) 3334 3335 3336if __name__ == '__main__': 3337 unittest.main() 3338