1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
|
# pyright: reportPrivateUsage=false
from __future__ import annotations
from builtins import bool as py_bool
from typing import TYPE_CHECKING, Any, Literal, TypeAlias, cast
import numpy as np
from .._internal import get_xp
from ..common import _aliases, _helpers
from ..common._typing import NestedSequence, SupportsBufferProtocol
from ._info import __array_namespace_info__
from ._typing import Array, Device, DType
if TYPE_CHECKING:
from typing_extensions import Buffer, TypeIs
# The values of the `_CopyMode` enum can be either `False`, `True`, or `2`:
# https://github.com/numpy/numpy/blob/5a8a6a79d9c2fff8f07dcab5d41e14f8508d673f/numpy/_globals.pyi#L7-L10
_Copy: TypeAlias = py_bool | Literal[2] | np._CopyMode
bool = np.bool_
# Basic renames
acos = np.arccos
acosh = np.arccosh
asin = np.arcsin
asinh = np.arcsinh
atan = np.arctan
atan2 = np.arctan2
atanh = np.arctanh
bitwise_left_shift = np.left_shift
bitwise_invert = np.invert
bitwise_right_shift = np.right_shift
concat = np.concatenate
pow = np.power
arange = get_xp(np)(_aliases.arange)
empty = get_xp(np)(_aliases.empty)
empty_like = get_xp(np)(_aliases.empty_like)
eye = get_xp(np)(_aliases.eye)
full = get_xp(np)(_aliases.full)
full_like = get_xp(np)(_aliases.full_like)
linspace = get_xp(np)(_aliases.linspace)
ones = get_xp(np)(_aliases.ones)
ones_like = get_xp(np)(_aliases.ones_like)
zeros = get_xp(np)(_aliases.zeros)
zeros_like = get_xp(np)(_aliases.zeros_like)
UniqueAllResult = get_xp(np)(_aliases.UniqueAllResult)
UniqueCountsResult = get_xp(np)(_aliases.UniqueCountsResult)
UniqueInverseResult = get_xp(np)(_aliases.UniqueInverseResult)
unique_all = get_xp(np)(_aliases.unique_all)
unique_counts = get_xp(np)(_aliases.unique_counts)
unique_inverse = get_xp(np)(_aliases.unique_inverse)
unique_values = get_xp(np)(_aliases.unique_values)
std = get_xp(np)(_aliases.std)
var = get_xp(np)(_aliases.var)
cumulative_sum = get_xp(np)(_aliases.cumulative_sum)
cumulative_prod = get_xp(np)(_aliases.cumulative_prod)
clip = get_xp(np)(_aliases.clip)
permute_dims = get_xp(np)(_aliases.permute_dims)
reshape = get_xp(np)(_aliases.reshape)
argsort = get_xp(np)(_aliases.argsort)
sort = get_xp(np)(_aliases.sort)
nonzero = get_xp(np)(_aliases.nonzero)
ceil = get_xp(np)(_aliases.ceil)
floor = get_xp(np)(_aliases.floor)
trunc = get_xp(np)(_aliases.trunc)
matmul = get_xp(np)(_aliases.matmul)
matrix_transpose = get_xp(np)(_aliases.matrix_transpose)
tensordot = get_xp(np)(_aliases.tensordot)
sign = get_xp(np)(_aliases.sign)
finfo = get_xp(np)(_aliases.finfo)
iinfo = get_xp(np)(_aliases.iinfo)
def _supports_buffer_protocol(obj: object) -> TypeIs[Buffer]: # pyright: ignore[reportUnusedFunction]
try:
memoryview(obj) # pyright: ignore[reportArgumentType]
except TypeError:
return False
return True
# asarray also adds the copy keyword, which is not present in numpy 1.0.
# asarray() is different enough between numpy, cupy, and dask, the logic
# complicated enough that it's easier to define it separately for each module
# rather than trying to combine everything into one function in common/
def asarray(
obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol,
/,
*,
dtype: DType | None = None,
device: Device | None = None,
copy: _Copy | None = None,
**kwargs: Any,
) -> Array:
"""
Array API compatibility wrapper for asarray().
See the corresponding documentation in the array library and/or the array API
specification for more details.
"""
_helpers._check_device(np, device)
if copy is None:
copy = np._CopyMode.IF_NEEDED
elif copy is False:
copy = np._CopyMode.NEVER
elif copy is True:
copy = np._CopyMode.ALWAYS
return np.array(obj, copy=copy, dtype=dtype, **kwargs) # pyright: ignore
def astype(
x: Array,
dtype: DType,
/,
*,
copy: py_bool = True,
device: Device | None = None,
) -> Array:
_helpers._check_device(np, device)
return x.astype(dtype=dtype, copy=copy)
# count_nonzero returns a python int for axis=None and keepdims=False
# https://github.com/numpy/numpy/issues/17562
def count_nonzero(
x: Array,
axis: int | tuple[int, ...] | None = None,
keepdims: py_bool = False,
) -> Array:
# NOTE: this is currently incorrectly typed in numpy, but will be fixed in
# numpy 2.2.5 and 2.3.0: https://github.com/numpy/numpy/pull/28750
result = cast("Any", np.count_nonzero(x, axis=axis, keepdims=keepdims)) # pyright: ignore[reportArgumentType, reportCallIssue]
if axis is None and not keepdims:
return np.asarray(result)
return result
# take_along_axis: axis defaults to -1 but in numpy axis is a required arg
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
return np.take_along_axis(x, indices, axis=axis)
# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(np, "vecdot"):
vecdot = np.vecdot
else:
vecdot = get_xp(np)(_aliases.vecdot)
if hasattr(np, "isdtype"):
isdtype = np.isdtype
else:
isdtype = get_xp(np)(_aliases.isdtype)
if hasattr(np, "unstack"):
unstack = np.unstack
else:
unstack = get_xp(np)(_aliases.unstack)
__all__ = [
"__array_namespace_info__",
"asarray",
"astype",
"acos",
"acosh",
"asin",
"asinh",
"atan",
"atan2",
"atanh",
"bitwise_left_shift",
"bitwise_invert",
"bitwise_right_shift",
"bool",
"concat",
"count_nonzero",
"pow",
"take_along_axis"
]
__all__ += _aliases.__all__
_all_ignore = ["np", "get_xp"]
def __dir__() -> list[str]:
return __all__
|