1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
|
from __future__ import annotations
from collections.abc import Mapping
from types import ModuleType as Namespace
from typing import (
TYPE_CHECKING,
Literal,
Protocol,
TypeAlias,
TypedDict,
TypeVar,
final,
)
if TYPE_CHECKING:
from _typeshed import Incomplete
SupportsBufferProtocol: TypeAlias = Incomplete
Array: TypeAlias = Incomplete
Device: TypeAlias = Incomplete
DType: TypeAlias = Incomplete
else:
SupportsBufferProtocol = object
Array = object
Device = object
DType = object
_T_co = TypeVar("_T_co", covariant=True)
# These "Just" types are equivalent to the `Just` type from the `optype` library,
# apart from them not being `@runtime_checkable`.
# - docs: https://github.com/jorenham/optype/blob/master/README.md#just
# - code: https://github.com/jorenham/optype/blob/master/optype/_core/_just.py
@final
class JustInt(Protocol):
@property
def __class__(self, /) -> type[int]: ...
@__class__.setter
def __class__(self, value: type[int], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride]
@final
class JustFloat(Protocol):
@property
def __class__(self, /) -> type[float]: ...
@__class__.setter
def __class__(self, value: type[float], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride]
@final
class JustComplex(Protocol):
@property
def __class__(self, /) -> type[complex]: ...
@__class__.setter
def __class__(self, value: type[complex], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride]
#
class NestedSequence(Protocol[_T_co]):
def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
def __len__(self, /) -> int: ...
class SupportsArrayNamespace(Protocol[_T_co]):
def __array_namespace__(self, /, *, api_version: str | None) -> _T_co: ...
class HasShape(Protocol[_T_co]):
@property
def shape(self, /) -> _T_co: ...
# Return type of `__array_namespace_info__.default_dtypes`
Capabilities = TypedDict(
"Capabilities",
{
"boolean indexing": bool,
"data-dependent shapes": bool,
"max dimensions": int,
},
)
# Return type of `__array_namespace_info__.default_dtypes`
DefaultDTypes = TypedDict(
"DefaultDTypes",
{
"real floating": DType,
"complex floating": DType,
"integral": DType,
"indexing": DType,
},
)
_DTypeKind: TypeAlias = Literal[
"bool",
"signed integer",
"unsigned integer",
"integral",
"real floating",
"complex floating",
"numeric",
]
# Type of the `kind` parameter in `__array_namespace_info__.dtypes`
DTypeKind: TypeAlias = _DTypeKind | tuple[_DTypeKind, ...]
# `__array_namespace_info__.dtypes(kind="bool")`
class DTypesBool(TypedDict):
bool: DType
# `__array_namespace_info__.dtypes(kind="signed integer")`
class DTypesSigned(TypedDict):
int8: DType
int16: DType
int32: DType
int64: DType
# `__array_namespace_info__.dtypes(kind="unsigned integer")`
class DTypesUnsigned(TypedDict):
uint8: DType
uint16: DType
uint32: DType
uint64: DType
# `__array_namespace_info__.dtypes(kind="integral")`
class DTypesIntegral(DTypesSigned, DTypesUnsigned):
pass
# `__array_namespace_info__.dtypes(kind="real floating")`
class DTypesReal(TypedDict):
float32: DType
float64: DType
# `__array_namespace_info__.dtypes(kind="complex floating")`
class DTypesComplex(TypedDict):
complex64: DType
complex128: DType
# `__array_namespace_info__.dtypes(kind="numeric")`
class DTypesNumeric(DTypesIntegral, DTypesReal, DTypesComplex):
pass
# `__array_namespace_info__.dtypes(kind=None)` (default)
class DTypesAll(DTypesBool, DTypesNumeric):
pass
# `__array_namespace_info__.dtypes(kind=?)` (fallback)
DTypesAny: TypeAlias = Mapping[str, DType]
__all__ = [
"Array",
"Capabilities",
"DType",
"DTypeKind",
"DTypesAny",
"DTypesAll",
"DTypesBool",
"DTypesNumeric",
"DTypesIntegral",
"DTypesSigned",
"DTypesUnsigned",
"DTypesReal",
"DTypesComplex",
"DefaultDTypes",
"Device",
"HasShape",
"Namespace",
"JustInt",
"JustFloat",
"JustComplex",
"NestedSequence",
"SupportsArrayNamespace",
"SupportsBufferProtocol",
]
def __dir__() -> list[str]:
return __all__
|