1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
|
import math
from typing import Optional
import torch
import torch._prims as prims
import torch._prims_common as utils
import torch._refs as refs
from torch import Tensor
from torch._decomp import register_decomposition
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND, TensorLikeType
from torch._prims_common.wrappers import elementwise_type_promotion_wrapper, out_wrapper
from torch._refs import (
_make_elementwise_binary_reference,
_make_elementwise_unary_reference,
)
__all__ = [
"bessel_j0",
"bessel_j1",
"i0e",
"i1",
"i1e",
"logit",
"multigammaln",
"spherical_bessel_j0",
"zeta",
]
@_make_elementwise_unary_reference(
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
aten_op=torch.ops.aten.special_bessel_j0,
)
def bessel_j0(a: TensorLikeType) -> TensorLikeType:
return prims.bessel_j0(a)
@_make_elementwise_unary_reference(
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
aten_op=torch.ops.aten.special_bessel_j1,
)
def bessel_j1(a: TensorLikeType) -> TensorLikeType:
return prims.bessel_j1(a)
@_make_elementwise_unary_reference(
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, aten_op=torch.ops.aten.special_i0e
)
def i0e(a: TensorLikeType) -> TensorLikeType:
return prims.bessel_i0e(a)
@_make_elementwise_unary_reference(
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, aten_op=torch.ops.aten.special_i1
)
def i1(a: TensorLikeType) -> TensorLikeType:
return prims.bessel_i1(a)
@_make_elementwise_unary_reference(
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, aten_op=torch.ops.aten.special_i1e
)
def i1e(a: TensorLikeType) -> TensorLikeType:
return prims.bessel_i1e(a)
@register_decomposition(torch.ops.aten.logit)
@out_wrapper()
@elementwise_type_promotion_wrapper(
type_promoting_args=("self",),
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
)
def logit(self: TensorLikeType, eps: Optional[float] = None) -> TensorLikeType:
if eps is None:
eps = -1.0
lo = eps
hi = 1 - eps
self = torch.clamp(self, lo, hi)
return torch.log(torch.true_divide(self, torch.sub(1, self)))
@register_decomposition(torch.ops.aten.mvlgamma)
@out_wrapper()
@elementwise_type_promotion_wrapper(
type_promoting_args=("a",),
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
)
def multigammaln(a: TensorLikeType, p: int) -> TensorLikeType:
c = 0.25 * p * (p - 1) * math.log(math.pi)
b = 0.5 * torch.arange(start=(1 - p), end=1, step=1, dtype=a.dtype, device=a.device)
return torch.sum(torch.lgamma(a.unsqueeze(-1) + b), dim=-1) + c
@_make_elementwise_unary_reference(
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
aten_op=torch.ops.aten.special_spherical_bessel_j0,
)
def spherical_bessel_j0(a: TensorLikeType) -> TensorLikeType:
return prims.spherical_bessel_j0(a)
zeta = _make_elementwise_binary_reference(
prims.zeta, # type: ignore[has-type]
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
aten_op=torch.ops.aten.special_zeta,
)
|