File: annotations.py

package info (click to toggle)
python-cyclopts 3.12.0-3
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 3,288 kB
  • sloc: python: 11,445; makefile: 24
file content (182 lines) | stat: -rw-r--r-- 5,493 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import inspect
import sys
from inspect import isclass
from typing import Annotated, Any, Optional, Union, get_args, get_origin

import attrs

_IS_PYTHON_3_8 = sys.version_info[:2] == (3, 8)

if sys.version_info >= (3, 10):  # pragma: no cover
    from types import UnionType
else:
    UnionType = object()

if sys.version_info < (3, 11):  # pragma: no cover
    from typing_extensions import NotRequired, Required
else:  # pragma: no cover
    from typing import NotRequired, Required

# from types import NoneType is available >=3.10
NoneType = type(None)
AnnotatedType = type(Annotated[int, 0])


def is_nonetype(hint):
    return hint is NoneType


def is_union(type_: Optional[type]) -> bool:
    """Checks if a type is a union."""
    # Direct checks are faster than checking if the type is in a set that contains the union-types.
    if type_ is Union or type_ is UnionType:
        return True

    # The ``get_origin`` call is relatively expensive, so we'll check common types
    # that are passed in here to see if we can avoid calling ``get_origin``.
    if type_ is str or type_ is int or type_ is float or type_ is bool or is_annotated(type_):
        return False
    origin = get_origin(type_)
    return origin is Union or origin is UnionType


def is_pydantic(hint) -> bool:
    return hasattr(hint, "__pydantic_core_schema__")


def is_dataclass(hint) -> bool:
    return hasattr(hint, "__dataclass_fields__")


def is_namedtuple(hint) -> bool:
    return isclass(hint) and issubclass(hint, tuple) and hasattr(hint, "_fields")


def is_attrs(hint) -> bool:
    return attrs.has(hint)


def is_annotated(hint) -> bool:
    return type(hint) is AnnotatedType


def contains_hint(hint, target_type) -> bool:
    """Indicates if ``target_type`` is in a possibly annotated/unioned ``hint``.

    E.g. ``contains_hint(Union[int, str], str) == True``
    """
    hint = resolve(hint)
    if is_union(hint):
        return any(contains_hint(x, target_type) for x in get_args(hint))
    else:
        return isclass(hint) and issubclass(hint, target_type)


def is_typeddict(hint) -> bool:
    """Determine if a type annotation is a TypedDict.

    This is surprisingly hard! Modified from Beartype's implementation:

        https://github.com/beartype/beartype/blob/main/beartype/_util/hint/pep/proposal/utilpep589.py
    """
    hint = resolve(hint)
    if is_union(hint):
        return any(is_typeddict(x) for x in get_args(hint))

    if not (isclass(hint) and issubclass(hint, dict)):
        return False

    return (
        # This "dict" subclass defines these "TypedDict" attributes *AND*...
        hasattr(hint, "__annotations__")
        and hasattr(hint, "__total__")
        and
        # Either...
        (
            # The active Python interpreter targets exactly Python 3.8 and
            # thus fails to unconditionally define the remaining attributes
            # *OR*...
            _IS_PYTHON_3_8
            or
            # The active Python interpreter targets any other Python version
            # and thus unconditionally defines the remaining attributes.
            (hasattr(hint, "__required_keys__") and hasattr(hint, "__optional_keys__"))
        )
    )


def resolve(type_: Any) -> type:
    """Perform all simplifying resolutions."""
    if type_ is inspect.Parameter.empty:
        return str

    type_prev = None
    while type_ != type_prev:
        type_prev = type_
        type_ = resolve_annotated(type_)
        type_ = resolve_optional(type_)
        type_ = resolve_required(type_)
        type_ = resolve_new_type(type_)
    return type_


def resolve_optional(type_: Any) -> Any:
    """Only resolves Union's of None + one other type (i.e. Optional)."""
    # Python will automatically flatten out nested unions when possible.
    # So we don't need to loop over resolution.
    if not is_union(type_):
        return type_

    non_none_types = [t for t in get_args(type_) if t is not NoneType]
    if not non_none_types:  # pragma: no cover
        # This should never happen; python simplifies:
        #    ``Union[None, None] -> NoneType``
        raise ValueError("Union type cannot be all NoneType")
    elif len(non_none_types) == 1:
        type_ = non_none_types[0]
    elif len(non_none_types) > 1:
        return Union[tuple(resolve_optional(x) for x in non_none_types)]  # pyright: ignore
    else:
        raise NotImplementedError

    return type_


def resolve_annotated(type_: Any) -> type:
    if type(type_) is AnnotatedType:
        type_ = get_args(type_)[0]
    return type_


def resolve_required(type_: Any) -> type:
    if get_origin(type_) in (Required, NotRequired):
        type_ = get_args(type_)[0]
    return type_


def resolve_new_type(type_: Any) -> type:
    try:
        return resolve_new_type(type_.__supertype__)
    except AttributeError:
        return type_


def get_hint_name(hint) -> str:
    if isinstance(hint, str):
        return hint
    if is_nonetype(hint):
        return "None"
    if hint is Any:
        return "Any"
    if is_union(hint):
        return "|".join(get_hint_name(arg) for arg in get_args(hint))
    if origin := get_origin(hint):
        out = get_hint_name(origin)
        if args := get_args(hint):
            out += "[" + ", ".join(get_hint_name(arg) for arg in args) + "]"
        return out
    if hasattr(hint, "__name__"):
        return hint.__name__
    if getattr(hint, "_name", None) is not None:
        return hint._name
    return str(hint)