File: _add_slots.py

package info (click to toggle)
python-libcst 1.4.0-1.2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 5,928 kB
  • sloc: python: 76,235; makefile: 10; sh: 2
file content (65 lines) | stat: -rw-r--r-- 2,364 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# This file is derived from github.com/ericvsmith/dataclasses, and is Apache 2 licensed.
# https://github.com/ericvsmith/dataclasses/blob/ae712dd993420d43444f188f452/LICENSE.txt
# https://github.com/ericvsmith/dataclasses/blob/ae712dd993420d43444f/dataclass_tools.py
# Changed: takes slots in base classes into account when creating slots

import dataclasses
from itertools import chain, filterfalse
from typing import Any, Mapping, Type, TypeVar

_T = TypeVar("_T")


def add_slots(cls: Type[_T]) -> Type[_T]:
    # Need to create a new class, since we can't set __slots__
    #  after a class has been created.

    # Make sure __slots__ isn't already set.
    if "__slots__" in cls.__dict__:
        raise TypeError(f"{cls.__name__} already specifies __slots__")

    # Create a new dict for our new class.
    cls_dict = dict(cls.__dict__)
    field_names = tuple(f.name for f in dataclasses.fields(cls))
    inherited_slots = set(
        chain.from_iterable(
            superclass.__dict__.get("__slots__", ()) for superclass in cls.mro()
        )
    )
    cls_dict["__slots__"] = tuple(
        filterfalse(inherited_slots.__contains__, field_names)
    )
    for field_name in field_names:
        # Remove our attributes, if present. They'll still be
        #  available in _MARKER.
        cls_dict.pop(field_name, None)
    # Remove __dict__ itself.
    cls_dict.pop("__dict__", None)

    # Create the class.
    qualname = getattr(cls, "__qualname__", None)

    # pyre-fixme[9]: cls has type `Type[Variable[_T]]`; used as `_T`.
    # pyre-fixme[19]: Expected 0 positional arguments.
    cls = type(cls)(cls.__name__, cls.__bases__, cls_dict)
    if qualname is not None:
        cls.__qualname__ = qualname

    # Set __getstate__ and __setstate__ to workaround a bug with pickling frozen
    # dataclasses with slots. See https://bugs.python.org/issue36424

    def __getstate__(self: object) -> Mapping[str, Any]:
        return {
            field.name: getattr(self, field.name)
            for field in dataclasses.fields(self)
            if hasattr(self, field.name)
        }

    def __setstate__(self: object, state: Mapping[str, Any]) -> None:
        for fieldname, value in state.items():
            object.__setattr__(self, fieldname, value)

    cls.__getstate__ = __getstate__
    cls.__setstate__ = __setstate__

    return cls