1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259
|
# mypy: ignore-errors
""" "Normalize" arguments: convert array_likes to tensors, dtypes to torch dtypes and so on.
"""
from __future__ import annotations
import functools
import inspect
import operator
import typing
import torch
from . import _dtypes, _dtypes_impl, _util
ArrayLike = typing.TypeVar("ArrayLike")
Scalar = typing.Union[int, float, complex, bool]
ArrayLikeOrScalar = typing.Union[ArrayLike, Scalar]
DTypeLike = typing.TypeVar("DTypeLike")
AxisLike = typing.TypeVar("AxisLike")
NDArray = typing.TypeVar("NDArray")
CastingModes = typing.TypeVar("CastingModes")
KeepDims = typing.TypeVar("KeepDims")
# OutArray is to annotate the out= array argument.
#
# This one is special is several respects:
# First, It needs to be an NDArray, and we need to preserve the `result is out`
# semantics. Therefore, we cannot just extract the Tensor from the out array.
# So we never pass the out array to implementer functions and handle it in the
# `normalizer` below.
# Second, the out= argument can be either keyword or positional argument, and
# as a positional arg, it can be anywhere in the signature.
# To handle all this, we define a special `OutArray` annotation and dispatch on it.
#
OutArray = typing.TypeVar("OutArray")
try:
from typing import NotImplementedType
except ImportError:
NotImplementedType = typing.TypeVar("NotImplementedType")
def normalize_array_like(x, parm=None):
from ._ndarray import asarray
return asarray(x).tensor
def normalize_array_like_or_scalar(x, parm=None):
if _dtypes_impl.is_scalar_or_symbolic(x):
return x
return normalize_array_like(x, parm)
def normalize_optional_array_like_or_scalar(x, parm=None):
if x is None:
return None
return normalize_array_like_or_scalar(x, parm)
def normalize_optional_array_like(x, parm=None):
# This explicit normalizer is needed because otherwise normalize_array_like
# does not run for a parameter annotated as Optional[ArrayLike]
return None if x is None else normalize_array_like(x, parm)
def normalize_seq_array_like(x, parm=None):
return tuple(normalize_array_like(value) for value in x)
def normalize_dtype(dtype, parm=None):
# cf _decorators.dtype_to_torch
torch_dtype = None
if dtype is not None:
dtype = _dtypes.dtype(dtype)
torch_dtype = dtype.torch_dtype
return torch_dtype
def normalize_not_implemented(arg, parm):
if arg != parm.default:
raise NotImplementedError(f"'{parm.name}' parameter is not supported.")
def normalize_axis_like(arg, parm=None):
from ._ndarray import ndarray
if isinstance(arg, ndarray):
arg = operator.index(arg)
return arg
def normalize_ndarray(arg, parm=None):
# check the arg is an ndarray, extract its tensor attribute
if arg is None:
return arg
from ._ndarray import ndarray
if not isinstance(arg, ndarray):
raise TypeError(f"'{parm.name}' must be an array")
return arg.tensor
def normalize_outarray(arg, parm=None):
# almost normalize_ndarray, only return the array, not its tensor
if arg is None:
return arg
from ._ndarray import ndarray
# Dynamo can pass torch tensors as out arguments,
# wrap it in an ndarray before processing
if isinstance(arg, torch.Tensor):
arg = ndarray(arg)
if not isinstance(arg, ndarray):
raise TypeError(f"'{parm.name}' must be an array")
return arg
def normalize_casting(arg, parm=None):
if arg not in ["no", "equiv", "safe", "same_kind", "unsafe"]:
raise ValueError(
f"casting must be one of 'no', 'equiv', 'safe', 'same_kind', or 'unsafe' (got '{arg}')"
)
return arg
normalizers = {
"ArrayLike": normalize_array_like,
"ArrayLikeOrScalar": normalize_array_like_or_scalar,
"Optional[ArrayLike]": normalize_optional_array_like,
"Sequence[ArrayLike]": normalize_seq_array_like,
"Optional[ArrayLikeOrScalar]": normalize_optional_array_like_or_scalar,
"Optional[NDArray]": normalize_ndarray,
"Optional[OutArray]": normalize_outarray,
"NDArray": normalize_ndarray,
"Optional[DTypeLike]": normalize_dtype,
"AxisLike": normalize_axis_like,
"NotImplementedType": normalize_not_implemented,
"Optional[CastingModes]": normalize_casting,
}
def maybe_normalize(arg, parm):
"""Normalize arg if a normalizer is registered."""
normalizer = normalizers.get(parm.annotation, None)
return normalizer(arg, parm) if normalizer else arg
# ### Return value helpers ###
def maybe_copy_to(out, result, promote_scalar_result=False):
# NB: here out is either an ndarray or None
if out is None:
return result
elif isinstance(result, torch.Tensor):
if result.shape != out.shape:
can_fit = result.numel() == 1 and out.ndim == 0
if promote_scalar_result and can_fit:
result = result.squeeze()
else:
raise ValueError(
f"Bad size of the out array: out.shape = {out.shape}"
f" while result.shape = {result.shape}."
)
out.tensor.copy_(result)
return out
elif isinstance(result, (tuple, list)):
return type(result)(
maybe_copy_to(o, r, promote_scalar_result) for o, r in zip(out, result)
)
else:
raise AssertionError # We should never hit this path
def wrap_tensors(result):
from ._ndarray import ndarray
if isinstance(result, torch.Tensor):
return ndarray(result)
elif isinstance(result, (tuple, list)):
result = type(result)(wrap_tensors(x) for x in result)
return result
def array_or_scalar(values, py_type=float, return_scalar=False):
if return_scalar:
return py_type(values.item())
else:
from ._ndarray import ndarray
return ndarray(values)
# ### The main decorator to normalize arguments / postprocess the output ###
def normalizer(_func=None, *, promote_scalar_result=False):
def normalizer_inner(func):
@functools.wraps(func)
def wrapped(*args, **kwds):
sig = inspect.signature(func)
params = sig.parameters
first_param = next(iter(params.values()))
# NumPy's API does not have positional args before variadic positional args
if first_param.kind == inspect.Parameter.VAR_POSITIONAL:
args = [maybe_normalize(arg, first_param) for arg in args]
else:
# NB: extra unknown arguments: pass through, will raise in func(*args) below
args = (
tuple(
maybe_normalize(arg, parm)
for arg, parm in zip(args, params.values())
)
+ args[len(params.values()) :]
)
kwds = {
name: maybe_normalize(arg, params[name]) if name in params else arg
for name, arg in kwds.items()
}
result = func(*args, **kwds)
# keepdims
bound_args = None
if "keepdims" in params and params["keepdims"].annotation == "KeepDims":
# keepdims can be in any position so we need sig.bind
bound_args = sig.bind(*args, **kwds).arguments
if bound_args.get("keepdims", False):
# In this case the first arg is the initial tensor and
# the second arg is (optionally) the axis
tensor = args[0]
axis = bound_args.get("axis")
result = _util.apply_keepdims(result, axis, tensor.ndim)
# out
if "out" in params:
# out can be in any position so we need sig.bind
if bound_args is None:
bound_args = sig.bind(*args, **kwds).arguments
out = bound_args.get("out")
result = maybe_copy_to(out, result, promote_scalar_result)
result = wrap_tensors(result)
return result
return wrapped
if _func is None:
return normalizer_inner
else:
return normalizer_inner(_func)
|