1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
|
# mypy: ignore-errors
"""Export torch work functions for unary ufuncs, rename/tweak to match numpy.
This listing is further exported to public symbols in the `_numpy/_ufuncs.py` module.
"""
import torch
from torch import ( # noqa: F401
absolute as fabs,
arccos,
arccosh,
arcsin,
arcsinh,
arctan,
arctanh,
bitwise_not,
bitwise_not as invert,
ceil,
conj_physical as conjugate,
cos,
cosh,
deg2rad,
deg2rad as radians,
exp,
exp2,
expm1,
floor,
isfinite,
isinf,
isnan,
log,
log10,
log1p,
log2,
logical_not,
negative,
rad2deg,
rad2deg as degrees,
reciprocal,
round as fix,
round as rint,
sign,
signbit,
sin,
sinh,
sqrt,
square,
tan,
tanh,
trunc,
)
# special cases: torch does not export these names
def cbrt(x):
return torch.pow(x, 1 / 3)
def positive(x):
return +x
def absolute(x):
# work around torch.absolute not impl for bools
if x.dtype == torch.bool:
return x
return torch.absolute(x)
# TODO set __name__ and __qualname__
abs = absolute
conj = conjugate
|