1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
|
import math
import torch
import torch.jit
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import broadcast_all, lazy_property
__all__ = ['VonMises']
def _eval_poly(y, coef):
coef = list(coef)
result = coef.pop()
while coef:
result = coef.pop() + y * result
return result
_I0_COEF_SMALL = [1.0, 3.5156229, 3.0899424, 1.2067492, 0.2659732, 0.360768e-1, 0.45813e-2]
_I0_COEF_LARGE = [0.39894228, 0.1328592e-1, 0.225319e-2, -0.157565e-2, 0.916281e-2,
-0.2057706e-1, 0.2635537e-1, -0.1647633e-1, 0.392377e-2]
_I1_COEF_SMALL = [0.5, 0.87890594, 0.51498869, 0.15084934, 0.2658733e-1, 0.301532e-2, 0.32411e-3]
_I1_COEF_LARGE = [0.39894228, -0.3988024e-1, -0.362018e-2, 0.163801e-2, -0.1031555e-1,
0.2282967e-1, -0.2895312e-1, 0.1787654e-1, -0.420059e-2]
_COEF_SMALL = [_I0_COEF_SMALL, _I1_COEF_SMALL]
_COEF_LARGE = [_I0_COEF_LARGE, _I1_COEF_LARGE]
def _log_modified_bessel_fn(x, order=0):
"""
Returns ``log(I_order(x))`` for ``x > 0``,
where `order` is either 0 or 1.
"""
assert order == 0 or order == 1
# compute small solution
y = (x / 3.75)
y = y * y
small = _eval_poly(y, _COEF_SMALL[order])
if order == 1:
small = x.abs() * small
small = small.log()
# compute large solution
y = 3.75 / x
large = x - 0.5 * x.log() + _eval_poly(y, _COEF_LARGE[order]).log()
result = torch.where(x < 3.75, small, large)
return result
@torch.jit.script_if_tracing
def _rejection_sample(loc, concentration, proposal_r, x):
done = torch.zeros(x.shape, dtype=torch.bool, device=loc.device)
while not done.all():
u = torch.rand((3,) + x.shape, dtype=loc.dtype, device=loc.device)
u1, u2, u3 = u.unbind()
z = torch.cos(math.pi * u1)
f = (1 + proposal_r * z) / (proposal_r + z)
c = concentration * (proposal_r - f)
accept = ((c * (2 - c) - u2) > 0) | ((c / u2).log() + 1 - c >= 0)
if accept.any():
x = torch.where(accept, (u3 - 0.5).sign() * f.acos(), x)
done = done | accept
return (x + math.pi + loc) % (2 * math.pi) - math.pi
class VonMises(Distribution):
"""
A circular von Mises distribution.
This implementation uses polar coordinates. The ``loc`` and ``value`` args
can be any real number (to facilitate unconstrained optimization), but are
interpreted as angles modulo 2 pi.
Example::
>>> # xdoctest: +IGNORE_WANT("non-deterinistic")
>>> m = VonMises(torch.tensor([1.0]), torch.tensor([1.0]))
>>> m.sample() # von Mises distributed with loc=1 and concentration=1
tensor([1.9777])
:param torch.Tensor loc: an angle in radians.
:param torch.Tensor concentration: concentration parameter
"""
arg_constraints = {'loc': constraints.real, 'concentration': constraints.positive}
support = constraints.real
has_rsample = False
def __init__(self, loc, concentration, validate_args=None):
self.loc, self.concentration = broadcast_all(loc, concentration)
batch_shape = self.loc.shape
event_shape = torch.Size()
# Parameters for sampling
tau = 1 + (1 + 4 * self.concentration ** 2).sqrt()
rho = (tau - (2 * tau).sqrt()) / (2 * self.concentration)
self._proposal_r = (1 + rho ** 2) / (2 * rho)
super(VonMises, self).__init__(batch_shape, event_shape, validate_args)
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
log_prob = self.concentration * torch.cos(value - self.loc)
log_prob = log_prob - math.log(2 * math.pi) - _log_modified_bessel_fn(self.concentration, order=0)
return log_prob
@torch.no_grad()
def sample(self, sample_shape=torch.Size()):
"""
The sampling algorithm for the von Mises distribution is based on the following paper:
Best, D. J., and Nicholas I. Fisher.
"Efficient simulation of the von Mises distribution." Applied Statistics (1979): 152-157.
"""
shape = self._extended_shape(sample_shape)
x = torch.empty(shape, dtype=self.loc.dtype, device=self.loc.device)
return _rejection_sample(self.loc, self.concentration, self._proposal_r, x)
def expand(self, batch_shape):
try:
return super(VonMises, self).expand(batch_shape)
except NotImplementedError:
validate_args = self.__dict__.get('_validate_args')
loc = self.loc.expand(batch_shape)
concentration = self.concentration.expand(batch_shape)
return type(self)(loc, concentration, validate_args=validate_args)
@property
def mean(self):
"""
The provided mean is the circular one.
"""
return self.loc
@property
def mode(self):
return self.loc
@lazy_property
def variance(self):
"""
The provided variance is the circular one.
"""
return 1 - (_log_modified_bessel_fn(self.concentration, order=1) -
_log_modified_bessel_fn(self.concentration, order=0)).exp()
|