1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
|
import inspect
import sys
from inspect import isclass
from typing import Annotated, Any, Optional, Union, get_args, get_origin
import attrs
_IS_PYTHON_3_8 = sys.version_info[:2] == (3, 8)
if sys.version_info >= (3, 10): # pragma: no cover
from types import UnionType
else:
UnionType = object()
if sys.version_info < (3, 11): # pragma: no cover
from typing_extensions import NotRequired, Required
else: # pragma: no cover
from typing import NotRequired, Required
# from types import NoneType is available >=3.10
NoneType = type(None)
AnnotatedType = type(Annotated[int, 0])
def is_nonetype(hint):
return hint is NoneType
def is_union(type_: Optional[type]) -> bool:
"""Checks if a type is a union."""
# Direct checks are faster than checking if the type is in a set that contains the union-types.
if type_ is Union or type_ is UnionType:
return True
# The ``get_origin`` call is relatively expensive, so we'll check common types
# that are passed in here to see if we can avoid calling ``get_origin``.
if type_ is str or type_ is int or type_ is float or type_ is bool or is_annotated(type_):
return False
origin = get_origin(type_)
return origin is Union or origin is UnionType
def is_pydantic(hint) -> bool:
return hasattr(hint, "__pydantic_core_schema__")
def is_dataclass(hint) -> bool:
return hasattr(hint, "__dataclass_fields__")
def is_namedtuple(hint) -> bool:
return isclass(hint) and issubclass(hint, tuple) and hasattr(hint, "_fields")
def is_attrs(hint) -> bool:
return attrs.has(hint)
def is_annotated(hint) -> bool:
return type(hint) is AnnotatedType
def contains_hint(hint, target_type) -> bool:
"""Indicates if ``target_type`` is in a possibly annotated/unioned ``hint``.
E.g. ``contains_hint(Union[int, str], str) == True``
"""
hint = resolve(hint)
if is_union(hint):
return any(contains_hint(x, target_type) for x in get_args(hint))
else:
return isclass(hint) and issubclass(hint, target_type)
def is_typeddict(hint) -> bool:
"""Determine if a type annotation is a TypedDict.
This is surprisingly hard! Modified from Beartype's implementation:
https://github.com/beartype/beartype/blob/main/beartype/_util/hint/pep/proposal/utilpep589.py
"""
hint = resolve(hint)
if is_union(hint):
return any(is_typeddict(x) for x in get_args(hint))
if not (isclass(hint) and issubclass(hint, dict)):
return False
return (
# This "dict" subclass defines these "TypedDict" attributes *AND*...
hasattr(hint, "__annotations__")
and hasattr(hint, "__total__")
and
# Either...
(
# The active Python interpreter targets exactly Python 3.8 and
# thus fails to unconditionally define the remaining attributes
# *OR*...
_IS_PYTHON_3_8
or
# The active Python interpreter targets any other Python version
# and thus unconditionally defines the remaining attributes.
(hasattr(hint, "__required_keys__") and hasattr(hint, "__optional_keys__"))
)
)
def resolve(type_: Any) -> type:
"""Perform all simplifying resolutions."""
if type_ is inspect.Parameter.empty:
return str
type_prev = None
while type_ != type_prev:
type_prev = type_
type_ = resolve_annotated(type_)
type_ = resolve_optional(type_)
type_ = resolve_required(type_)
type_ = resolve_new_type(type_)
return type_
def resolve_optional(type_: Any) -> Any:
"""Only resolves Union's of None + one other type (i.e. Optional)."""
# Python will automatically flatten out nested unions when possible.
# So we don't need to loop over resolution.
if not is_union(type_):
return type_
non_none_types = [t for t in get_args(type_) if t is not NoneType]
if not non_none_types: # pragma: no cover
# This should never happen; python simplifies:
# ``Union[None, None] -> NoneType``
raise ValueError("Union type cannot be all NoneType")
elif len(non_none_types) == 1:
type_ = non_none_types[0]
elif len(non_none_types) > 1:
return Union[tuple(resolve_optional(x) for x in non_none_types)] # pyright: ignore
else:
raise NotImplementedError
return type_
def resolve_annotated(type_: Any) -> type:
if type(type_) is AnnotatedType:
type_ = get_args(type_)[0]
return type_
def resolve_required(type_: Any) -> type:
if get_origin(type_) in (Required, NotRequired):
type_ = get_args(type_)[0]
return type_
def resolve_new_type(type_: Any) -> type:
try:
return resolve_new_type(type_.__supertype__)
except AttributeError:
return type_
def get_hint_name(hint) -> str:
if isinstance(hint, str):
return hint
if is_nonetype(hint):
return "None"
if hint is Any:
return "Any"
if is_union(hint):
return "|".join(get_hint_name(arg) for arg in get_args(hint))
if origin := get_origin(hint):
out = get_hint_name(origin)
if args := get_args(hint):
out += "[" + ", ".join(get_hint_name(arg) for arg in args) + "]"
return out
if hasattr(hint, "__name__"):
return hint.__name__
if getattr(hint, "_name", None) is not None:
return hint._name
return str(hint)
|