• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# This code is backported from python 3.10 dataclasses. Once 3.10 becomes the
2# minimum supported we should use dataclass(slots=True) instead.
3
4from __future__ import annotations
5
6import dataclasses
7import itertools
8from typing import Generator, List, Type, TYPE_CHECKING, TypeVar
9
10
11if TYPE_CHECKING:
12    from _typeshed import DataclassInstance
13
14
15__all__ = ["dataclass_slots"]
16
17_T = TypeVar("_T", bound="DataclassInstance")
18
19
20def dataclass_slots(cls: Type[_T]) -> Type[DataclassInstance]:
21    assert dataclasses.is_dataclass(cls), "Can only be used on dataclasses."
22
23    def _get_slots(cls: Type[DataclassInstance]) -> Generator[str, None, None]:
24        slots = cls.__dict__.get("__slots__")
25        # `__dictoffset__` and `__weakrefoffset__` can tell us whether
26        # the base type has dict/weakref slots, in a way that works correctly
27        # for both Python classes and C extension types. Extension types
28        # don't use `__slots__` for slot creation
29        if slots is None:
30            slots = []
31            if getattr(cls, "__weakrefoffset__", -1) != 0:
32                slots.append("__weakref__")
33            if getattr(cls, "__dictrefoffset__", -1) != 0:
34                slots.append("__dict__")
35            yield from slots
36        elif isinstance(slots, str):
37            yield slots
38        # Slots may be any iterable, but we cannot handle an iterator
39        # because it will already be (partially) consumed.
40        elif not hasattr(cls, "__next__"):
41            yield from slots
42        else:
43            raise TypeError(f"Slots of '{cls.__name__}' cannot be determined")
44
45    def _add_slots(
46        cls: Type[DataclassInstance], is_frozen: bool, weakref_slot: bool
47    ) -> Type[DataclassInstance]:
48        # Need to create a new class, since we can't set __slots__
49        #  after a class has been created.
50
51        # Make sure __slots__ isn't already set.
52        if "__slots__" in cls.__dict__:
53            raise TypeError(f"{cls.__name__} already specifies __slots__")
54
55        # Create a new dict for our new class.
56        cls_dict = dict(cls.__dict__)
57        field_names = tuple(f.name for f in dataclasses.fields(cls))
58        # Make sure slots don't overlap with those in base classes.
59        inherited_slots = set(
60            itertools.chain.from_iterable(map(_get_slots, cls.__mro__[1:-1]))
61        )
62        # The slots for our class.  Remove slots from our base classes.  Add
63        # '__weakref__' if weakref_slot was given, unless it is already present.
64        cls_dict["__slots__"] = tuple(
65            itertools.filterfalse(
66                inherited_slots.__contains__,
67                itertools.chain(
68                    # gh-93521: '__weakref__' also needs to be filtered out if
69                    # already present in inherited_slots
70                    field_names,
71                    ("__weakref__",) if weakref_slot else (),
72                ),
73            ),
74        )
75
76        for field_name in field_names:
77            # Remove our attributes, if present. They'll still be
78            #  available in _MARKER.
79            cls_dict.pop(field_name, None)
80
81        # Remove __dict__ itself.
82        cls_dict.pop("__dict__", None)
83
84        # Clear existing `__weakref__` descriptor, it belongs to a previous type:
85        cls_dict.pop("__weakref__", None)  # gh-102069
86
87        # And finally create the class.
88        qualname = getattr(cls, "__qualname__", None)
89        cls = type(cls.__name__, cls.__bases__, cls_dict)
90        if qualname is not None:
91            cls.__qualname__ = qualname
92
93        def _dataclass_getstate(self: _T) -> object:
94            fields = dataclasses.fields(self)
95            return [getattr(self, f.name) for f in fields]
96
97        def _dataclass_setstate(self: _T, state: List[object]) -> None:
98            fields = dataclasses.fields(self)
99            for field, value in zip(fields, state):
100                # use setattr because dataclass may be frozen
101                object.__setattr__(self, field.name, value)
102
103        if is_frozen:
104            # Need this for pickling frozen classes with slots.
105            if "__getstate__" not in cls_dict:
106                cls.__getstate__ = _dataclass_getstate  # type: ignore[method-assign, assignment]
107            if "__setstate__" not in cls_dict:
108                cls.__setstate__ = _dataclass_setstate  # type: ignore[attr-defined]
109
110        return cls
111
112    params = getattr(cls, dataclasses._PARAMS)  # type: ignore[attr-defined]
113    weakref_slot = getattr(params, "weakref_slot", False)
114    return _add_slots(cls, params.frozen, weakref_slot)
115