1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
|
import math
import torch
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.multivariate_normal import _batch_mahalanobis, _batch_mv
from torch.distributions.utils import _standard_normal, lazy_property
__all__ = ['LowRankMultivariateNormal']
def _batch_capacitance_tril(W, D):
r"""
Computes Cholesky of :math:`I + W.T @ inv(D) @ W` for a batch of matrices :math:`W`
and a batch of vectors :math:`D`.
"""
m = W.size(-1)
Wt_Dinv = W.mT / D.unsqueeze(-2)
K = torch.matmul(Wt_Dinv, W).contiguous()
K.view(-1, m * m)[:, ::m + 1] += 1 # add identity matrix to K
return torch.linalg.cholesky(K)
def _batch_lowrank_logdet(W, D, capacitance_tril):
r"""
Uses "matrix determinant lemma"::
log|W @ W.T + D| = log|C| + log|D|,
where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute
the log determinant.
"""
return 2 * capacitance_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + D.log().sum(-1)
def _batch_lowrank_mahalanobis(W, D, x, capacitance_tril):
r"""
Uses "Woodbury matrix identity"::
inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D),
where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute the squared
Mahalanobis distance :math:`x.T @ inv(W @ W.T + D) @ x`.
"""
Wt_Dinv = W.mT / D.unsqueeze(-2)
Wt_Dinv_x = _batch_mv(Wt_Dinv, x)
mahalanobis_term1 = (x.pow(2) / D).sum(-1)
mahalanobis_term2 = _batch_mahalanobis(capacitance_tril, Wt_Dinv_x)
return mahalanobis_term1 - mahalanobis_term2
class LowRankMultivariateNormal(Distribution):
r"""
Creates a multivariate normal distribution with covariance matrix having a low-rank form
parameterized by :attr:`cov_factor` and :attr:`cov_diag`::
covariance_matrix = cov_factor @ cov_factor.T + cov_diag
Example:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
>>> # xdoctest: +IGNORE_WANT("non-determenistic")
>>> m = LowRankMultivariateNormal(torch.zeros(2), torch.tensor([[1.], [0.]]), torch.ones(2))
>>> m.sample() # normally distributed with mean=`[0,0]`, cov_factor=`[[1],[0]]`, cov_diag=`[1,1]`
tensor([-0.2102, -0.5429])
Args:
loc (Tensor): mean of the distribution with shape `batch_shape + event_shape`
cov_factor (Tensor): factor part of low-rank form of covariance matrix with shape
`batch_shape + event_shape + (rank,)`
cov_diag (Tensor): diagonal part of low-rank form of covariance matrix with shape
`batch_shape + event_shape`
Note:
The computation for determinant and inverse of covariance matrix is avoided when
`cov_factor.shape[1] << cov_factor.shape[0]` thanks to `Woodbury matrix identity
<https://en.wikipedia.org/wiki/Woodbury_matrix_identity>`_ and
`matrix determinant lemma <https://en.wikipedia.org/wiki/Matrix_determinant_lemma>`_.
Thanks to these formulas, we just need to compute the determinant and inverse of
the small size "capacitance" matrix::
capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor
"""
arg_constraints = {"loc": constraints.real_vector,
"cov_factor": constraints.independent(constraints.real, 2),
"cov_diag": constraints.independent(constraints.positive, 1)}
support = constraints.real_vector
has_rsample = True
def __init__(self, loc, cov_factor, cov_diag, validate_args=None):
if loc.dim() < 1:
raise ValueError("loc must be at least one-dimensional.")
event_shape = loc.shape[-1:]
if cov_factor.dim() < 2:
raise ValueError("cov_factor must be at least two-dimensional, "
"with optional leading batch dimensions")
if cov_factor.shape[-2:-1] != event_shape:
raise ValueError("cov_factor must be a batch of matrices with shape {} x m"
.format(event_shape[0]))
if cov_diag.shape[-1:] != event_shape:
raise ValueError("cov_diag must be a batch of vectors with shape {}".format(event_shape))
loc_ = loc.unsqueeze(-1)
cov_diag_ = cov_diag.unsqueeze(-1)
try:
loc_, self.cov_factor, cov_diag_ = torch.broadcast_tensors(loc_, cov_factor, cov_diag_)
except RuntimeError as e:
raise ValueError("Incompatible batch shapes: loc {}, cov_factor {}, cov_diag {}"
.format(loc.shape, cov_factor.shape, cov_diag.shape)) from e
self.loc = loc_[..., 0]
self.cov_diag = cov_diag_[..., 0]
batch_shape = self.loc.shape[:-1]
self._unbroadcasted_cov_factor = cov_factor
self._unbroadcasted_cov_diag = cov_diag
self._capacitance_tril = _batch_capacitance_tril(cov_factor, cov_diag)
super(LowRankMultivariateNormal, self).__init__(batch_shape, event_shape,
validate_args=validate_args)
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(LowRankMultivariateNormal, _instance)
batch_shape = torch.Size(batch_shape)
loc_shape = batch_shape + self.event_shape
new.loc = self.loc.expand(loc_shape)
new.cov_diag = self.cov_diag.expand(loc_shape)
new.cov_factor = self.cov_factor.expand(loc_shape + self.cov_factor.shape[-1:])
new._unbroadcasted_cov_factor = self._unbroadcasted_cov_factor
new._unbroadcasted_cov_diag = self._unbroadcasted_cov_diag
new._capacitance_tril = self._capacitance_tril
super(LowRankMultivariateNormal, new).__init__(batch_shape,
self.event_shape,
validate_args=False)
new._validate_args = self._validate_args
return new
@property
def mean(self):
return self.loc
@property
def mode(self):
return self.loc
@lazy_property
def variance(self):
return (self._unbroadcasted_cov_factor.pow(2).sum(-1)
+ self._unbroadcasted_cov_diag).expand(self._batch_shape + self._event_shape)
@lazy_property
def scale_tril(self):
# The following identity is used to increase the numerically computation stability
# for Cholesky decomposition (see http://www.gaussianprocess.org/gpml/, Section 3.4.3):
# W @ W.T + D = D1/2 @ (I + D-1/2 @ W @ W.T @ D-1/2) @ D1/2
# The matrix "I + D-1/2 @ W @ W.T @ D-1/2" has eigenvalues bounded from below by 1,
# hence it is well-conditioned and safe to take Cholesky decomposition.
n = self._event_shape[0]
cov_diag_sqrt_unsqueeze = self._unbroadcasted_cov_diag.sqrt().unsqueeze(-1)
Dinvsqrt_W = self._unbroadcasted_cov_factor / cov_diag_sqrt_unsqueeze
K = torch.matmul(Dinvsqrt_W, Dinvsqrt_W.mT).contiguous()
K.view(-1, n * n)[:, ::n + 1] += 1 # add identity matrix to K
scale_tril = cov_diag_sqrt_unsqueeze * torch.linalg.cholesky(K)
return scale_tril.expand(self._batch_shape + self._event_shape + self._event_shape)
@lazy_property
def covariance_matrix(self):
covariance_matrix = (torch.matmul(self._unbroadcasted_cov_factor,
self._unbroadcasted_cov_factor.mT)
+ torch.diag_embed(self._unbroadcasted_cov_diag))
return covariance_matrix.expand(self._batch_shape + self._event_shape +
self._event_shape)
@lazy_property
def precision_matrix(self):
# We use "Woodbury matrix identity" to take advantage of low rank form::
# inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D)
# where :math:`C` is the capacitance matrix.
Wt_Dinv = (self._unbroadcasted_cov_factor.mT
/ self._unbroadcasted_cov_diag.unsqueeze(-2))
A = torch.linalg.solve_triangular(self._capacitance_tril, Wt_Dinv, upper=False)
precision_matrix = torch.diag_embed(self._unbroadcasted_cov_diag.reciprocal()) - A.mT @ A
return precision_matrix.expand(self._batch_shape + self._event_shape +
self._event_shape)
def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
W_shape = shape[:-1] + self.cov_factor.shape[-1:]
eps_W = _standard_normal(W_shape, dtype=self.loc.dtype, device=self.loc.device)
eps_D = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
return (self.loc + _batch_mv(self._unbroadcasted_cov_factor, eps_W)
+ self._unbroadcasted_cov_diag.sqrt() * eps_D)
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
diff = value - self.loc
M = _batch_lowrank_mahalanobis(self._unbroadcasted_cov_factor,
self._unbroadcasted_cov_diag,
diff,
self._capacitance_tril)
log_det = _batch_lowrank_logdet(self._unbroadcasted_cov_factor,
self._unbroadcasted_cov_diag,
self._capacitance_tril)
return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + log_det + M)
def entropy(self):
log_det = _batch_lowrank_logdet(self._unbroadcasted_cov_factor,
self._unbroadcasted_cov_diag,
self._capacitance_tril)
H = 0.5 * (self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + log_det)
if len(self._batch_shape) == 0:
return H
else:
return H.expand(self._batch_shape)
|