1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
|
# mypy: ignore-errors
"""Export torch work functions for binary ufuncs, rename/tweak to match numpy.
This listing is further exported to public symbols in the `torch._numpy/_ufuncs.py` module.
"""
import torch
from torch import ( # noqa: F401
add,
arctan2,
bitwise_and,
bitwise_left_shift as left_shift,
bitwise_or,
bitwise_right_shift as right_shift,
bitwise_xor,
copysign,
divide,
eq as equal,
float_power,
floor_divide,
fmax,
fmin,
fmod,
gcd,
greater,
greater_equal,
heaviside,
hypot,
lcm,
ldexp,
less,
less_equal,
logaddexp,
logaddexp2,
logical_and,
logical_or,
logical_xor,
maximum,
minimum,
multiply,
nextafter,
not_equal,
pow as power,
remainder,
remainder as mod,
subtract,
true_divide,
)
from . import _dtypes_impl, _util
# work around torch limitations w.r.t. numpy
def matmul(x, y):
# work around:
# - RuntimeError: expected scalar type Int but found Double
# - RuntimeError: "addmm_impl_cpu_" not implemented for 'Bool'
# - RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
dtype = _dtypes_impl.result_type_impl(x, y)
is_bool = dtype == torch.bool
is_half = (x.dtype == torch.float16 or y.dtype == torch.float16) and (
x.is_cpu or y.is_cpu
)
work_dtype = dtype
if is_bool:
work_dtype = torch.uint8
if is_half:
work_dtype = torch.float32
x = _util.cast_if_needed(x, work_dtype)
y = _util.cast_if_needed(y, work_dtype)
result = torch.matmul(x, y)
if work_dtype != dtype:
result = result.to(dtype)
return result
# a stub implementation of divmod, should be improved after
# https://github.com/pytorch/pytorch/issues/90820 is fixed in pytorch
def divmod(x, y):
return x // y, x % y
|