1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338
|
import torch
from torch._prims_common import (
Number,
NumberType,
TensorLike,
TensorLikeType,
ELEMENTWISE_TYPE_PROMOTION_KIND,
)
import torch._prims_common as utils
from torch.utils._pytree import tree_flatten, tree_unflatten
from typing import Callable, Sequence, Union, Tuple, NamedTuple
import inspect
from functools import wraps, reduce
import operator
import warnings
from itertools import chain
# TODO: implement ref.cast with an option to enforce safe casting
def _maybe_convert_to_dtype(
a: Union[TensorLikeType, NumberType, Sequence, None], dtype: torch.dtype
) -> Union[TensorLikeType, NumberType, Sequence, None]:
import torch._prims as prims
if isinstance(a, TensorLike):
if a.dtype != dtype:
# NOTE: this is incorrect on the CPU
# See https://github.com/pytorch/pytorch/issues/77553
return prims.convert_element_type(a, dtype)
return a
if isinstance(a, Number):
return utils.dtype_to_type_ctor(dtype)(a)
if isinstance(a, Sequence):
return tuple(_maybe_convert_to_dtype(x, dtype) for x in a)
# Passthrough None because some functions wrapped with type promotion
# wrapper might have optional args
if a is None:
return None
raise ValueError(
"Received type {0} that is neither a tensor or a number!".format(type(a))
)
def _maybe_convert_to_type(a: NumberType, typ: type) -> NumberType:
if not isinstance(a, Number):
msg = "Found unknown type {0} when trying to convert scalars!".format(type(a))
raise ValueError(msg)
if not utils.is_weakly_lesser_type(type(a), typ):
msg = "Scalar {0} of type {1} cannot be safely cast to type {2}!".format(
a, type(a), typ
)
raise ValueError(msg)
return typ(a)
def _annotation_has_type(*, typ, annotation):
if hasattr(annotation, "__args__"):
for a in annotation.__args__:
if _annotation_has_type(typ=typ, annotation=a):
return True
return False
return typ is annotation
class elementwise_type_promotion_wrapper(object):
"""
Adds elementwise type promotion to a Python reference implementation.
Takes two kwargs, type_promoting_args and type_promotion_kind.
type_promoting_args must be a string Sequence specifiying the argument names of all
arguments that participate in type promotion (and should be type promoted). If the
arg specifies a Sequence-type then every element of the Sequence will participate in
type promotion.
type_promotion_kind must be one of the kinds specified by ELEMENTWISE_TYPE_PROMOTION_KIND.
See its documentation for details.
Other type promotion behavior, like validating the Python type of scalar arguments, must
be handled separately.
"""
def __init__(
self,
*,
type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND,
type_promoting_args: Sequence[str] = None,
):
self.type_promoting_arg_names = type_promoting_args
self.type_promotion_kind = type_promotion_kind
def __call__(self, fn: Callable) -> Callable:
sig = inspect.signature(fn)
@wraps(fn)
def _fn(*args, **kwargs):
bound = sig.bind(*args, **kwargs)
type_promoting_args = tuple(
bound.arguments[x]
for x in self.type_promoting_arg_names # type: ignore[union-attr]
if x in bound.arguments.keys()
)
flattened_type_promoting_args = tree_flatten(type_promoting_args)[0]
compute_dtype, result_dtype = utils.elementwise_dtypes(
*flattened_type_promoting_args,
type_promotion_kind=self.type_promotion_kind,
)
promoted_args = {
x: _maybe_convert_to_dtype(bound.arguments[x], compute_dtype)
for x in self.type_promoting_arg_names # type: ignore[union-attr]
if x in bound.arguments.keys()
}
bound.arguments.update(promoted_args)
result = fn(**bound.arguments)
if isinstance(result, TensorLike):
return _maybe_convert_to_dtype(result, result_dtype)
if isinstance(result, Sequence):
return tuple(_maybe_convert_to_dtype(x, result_dtype) for x in result)
raise AssertionError(f"Unhandled result type: {type(result)}")
_fn.__signature__ = sig # type: ignore[attr-defined]
return _fn
# TODO: handle tuples of tensors
def _maybe_resize_out(out: TensorLikeType, shape):
if out.numel() == 0:
return out.resize_(shape)
if out.numel() != reduce(operator.mul, shape, 1):
msg = (
"An output with one or more elements was resized since it had shape {0} "
"which does not match the required output shape {1}. "
"This behavior is deprecated, and in a future PyTorch release outputs will not "
"be resized unless they have zero elements. "
"You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0).".format(
str(out.shape), str(shape)
)
)
warnings.warn(msg)
return out.resize_(shape)
return out
def _safe_copy_out(
*, copy_from: TensorLikeType, copy_to: TensorLikeType, exact_dtype: bool = False
):
# Checks same device
if copy_from.device != copy_to.device:
msg = "Attempting to copy from device {0} to device {1}, but cross-device copies are not allowed!".format(
copy_from.device, copy_to.device
)
raise RuntimeError(msg)
# Checks safe cast
if exact_dtype:
utils.check(
copy_from.dtype == copy_to.dtype,
lambda: f"Expected out tensor to have dtype {copy_from.dtype} "
"but got {copy_to.dtype} instead",
)
else:
utils.check(
utils.can_safe_cast_to(cast_from=copy_from.dtype, cast_to=copy_to.dtype),
lambda: f"Attempting to cast from {copy_from.dtype} to out tensor with dtype {copy_to.dtype}, "
"but this can't be cast because it is not safe!",
)
return copy_to.copy_(copy_from)
def out_wrapper(*out_names: str, exact_dtype: bool = False):
is_tensor = len(out_names) == 0
assert is_tensor or len(out_names) >= 2
def _out_wrapper(fn: Callable) -> Callable:
"""
Adds the out parameter to a Python reference.
"""
out_type = (
TensorLikeType
if is_tensor
else Tuple[tuple(TensorLikeType for _ in range(len(out_names)))]
)
return_type = (
TensorLikeType
if is_tensor
else NamedTuple(
f"return_types_{fn.__name__}", [(o, TensorLikeType) for o in out_names]
)
)
sig = inspect.signature(fn)
factory_kwargs = ("device", "dtype")
is_factory_fn = all(p in sig.parameters for p in factory_kwargs)
@wraps(fn)
def _fn(*args, out=None, **kwargs):
if is_factory_fn and out is not None:
for k in factory_kwargs:
out_attr = getattr(out, k)
if k not in kwargs:
kwargs[k] = out_attr
result = fn(*args, **kwargs)
assert (
isinstance(result, TensorLike)
and is_tensor
or isinstance(result, Tuple) # type: ignore[arg-type]
and len(result) == len(out_names)
)
if out is not None:
# Naively you might expect this assert to be true, but
# it's not:
#
# assert type(out) == type(result)
#
# The reason is that functions under this wrapper can
# get registered to the Meta dispatch key, and that
# means they can be executed in a context where tensor
# subclasses are disabled (with no_dispatch), which is a
# handy way for an is-a tensor subclass (e.g.,
# FakeTensor) to have the normal meta backend create a
# meta tensor, to be wrapped once it gets returned.
# In this situation, you will get a FakeTensor as
# the output tensor, but not the result--which will
# be a normal meta tensor, but this is perfectly
# harmless.
if is_tensor:
assert isinstance(out, TensorLike)
# These two operations are done in-place
_maybe_resize_out(out, result.shape)
_safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype) # type: ignore[arg-type]
else:
assert isinstance(out, Tuple) # type: ignore[arg-type]
utils.check(
len(out) == len(result),
lambda: f"expected tuple of {len(result)} elements but got {len(out)}",
TypeError,
)
for r, o in zip(result, out):
# These two operations are done in-place
_maybe_resize_out(o, r.shape)
_safe_copy_out(copy_from=r, copy_to=o, exact_dtype=exact_dtype) # type: ignore[arg-type]
else:
out = result
# mypy does not see through the definition of out_type given that it's in a different scope
return out if is_tensor else return_type(*out) # type: ignore[operator]
out_param = inspect.Parameter(
"out",
kind=inspect.Parameter.KEYWORD_ONLY,
default=None,
annotation=out_type,
)
# Mark that the function now returns a tuple
assert sig.return_annotation in (sig.empty, out_type)
params = chain(sig.parameters.values(), (out_param,))
_fn.__signature__ = inspect.Signature( # type: ignore[attr-defined]
parameters=params, return_annotation=return_type # type: ignore[arg-type]
)
_fn.__annotations__ = fn.__annotations__
_fn.__annotations__["out"] = out_type
_fn.__annotations__["return"] = return_type
return _fn
return _out_wrapper
def backwards_not_supported(prim):
def redispatch_prim(args, kwargs):
g = torch._C._AutoDispatchBelowAutograd()
try:
return prim(*args, **kwargs)
finally:
del g
class BackwardsNotSupported(torch.autograd.Function):
@staticmethod
def forward(ctx, args_spec, *flat_args):
args, kwargs = tree_unflatten(flat_args, args_spec) # type: ignore[arg-type]
return redispatch_prim(args, kwargs)
@staticmethod
def backward(ctx, *args):
raise RuntimeError("backwards not supported on prim")
@wraps(prim)
def _autograd_impl(*args, **kwargs):
flat_args, args_spec = tree_flatten((args, kwargs))
if torch.is_grad_enabled() and any(a.requires_grad for a in flat_args if isinstance(a, torch.Tensor)):
# TODO: There is a subtle bug here: prims like copy_to
# return their input argument after mutating it; and custom
# autograd function will incorrectly turn the result into
# a view which will fail test_python_ref_executor tests.
# At the moment, we sidestep this by observing that the
# unit tests don't ever try to run the executor with
# autograd, so we don't exercise the buggy case, but if
# you ever want to feed autograd through this, be aware
# of it! We need a way of properly implementing autograd
# for mutating operations in Python to do this.
return BackwardsNotSupported.apply(args_spec, *flat_args)
else:
return redispatch_prim(args, kwargs)
return _autograd_impl
# TODO: when tracing this will add torch tensors and not TensorMeta objects
# to the trace -- we should fix this by adding a tracing context and NumberMeta classes
# TODO: this wrapper is currently untested
def elementwise_unary_scalar_wrapper(fn: Callable) -> Callable:
"""
Allows unary operators that accept tensors to work with Python numbers.
"""
sig = inspect.signature(fn)
@wraps(fn)
def _fn(*args, **kwargs):
if len(args) > 0 and isinstance(args[0], Number):
dtype = utils.type_to_dtype(type(args[0]))
args_ = list(args)
args_[0] = torch.tensor(args[0], dtype=dtype)
result = fn(*args_, **kwargs)
assert isinstance(result, torch.Tensor)
return result.item()
return fn(*args, **kwargs)
_fn.__signature__ = sig # type: ignore[attr-defined]
return _fn
|