1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
|
################################################################################
# Copyright (C) 2015 Jaakko Luttinen
#
# This file is licensed under the MIT License.
################################################################################
import numpy as np
from .deterministic import Deterministic
from .node import Moments
from bayespy.utils import misc
class Take(Deterministic):
"""
Choose elements/sub-arrays along a plate axis
Basically, applies `np.take` on a plate axis. Allows advanced mapping of
plates.
Parameters
----------
node : Node
A node to apply the take operation on.
indices : array of integers
Plate elements to pick along a plate axis.
plate_axis : int (negative)
The plate axis to pick elements from (default: -1).
See also
--------
numpy.take
Examples
--------
>>> from bayespy.nodes import Gamma, Take
>>> alpha = Gamma([1, 2, 3], [1, 1, 1])
>>> x = Take(alpha, [1, 1, 2, 2, 1, 0])
>>> x.get_moments()[0]
array([2., 2., 3., 3., 2., 1.])
"""
def __init__(self, node, indices, plate_axis=-1, **kwargs):
self._moments = node._moments
self._parent_moments = (node._moments,)
self._indices = np.array(indices)
self._plate_axis = plate_axis
self._original_length = node.plates[plate_axis]
# Validate arguments
if not misc.is_scalar_integer(plate_axis):
raise ValueError("Plate axis must be integer")
if plate_axis >= 0:
raise ValueError("plate_axis must be negative index")
if plate_axis < -len(node.plates):
raise ValueError("plate_axis out of bounds")
if not issubclass(self._indices.dtype.type, np.integer):
raise ValueError("Indices must be integers")
if (np.any(self._indices < -self._original_length) or
np.any(self._indices >= self._original_length)):
raise ValueError("Index out of bounds")
super().__init__(node, dims=node.dims, **kwargs)
def _compute_moments(self, u_parent):
u = []
for (ui, dimi) in zip(u_parent, self.dims):
axis = self._plate_axis - len(dimi)
# Just in case the taken axis is using broadcasting and has unit
# length in u_parent, force it to have the correct length along the
# axis in order to avoid errors in np.take.
broadcaster = np.ones((self._original_length,) + (-axis-1)*(1,))
u.append(np.take(ui*broadcaster, self._indices, axis=axis))
return u
def _compute_message_to_parent(self, index, m_child, u_parent):
m = [
misc.put_simple(
mi,
self._indices,
axis=self._plate_axis-len(dimi),
length=self._original_length,
)
for (mi, dimi) in zip(m_child, self.dims)
]
return m
def _compute_weights_to_parent(self, index, weights):
return misc.put_simple(
weights,
self._indices,
axis=self._plate_axis,
length=self._original_length,
)
def _compute_plates_to_parent(self, index, plates):
# Number of axes created by take operation
N = np.ndim(self._indices)
if self._plate_axis >= 0:
raise RuntimeError("Plate axis should be negative")
end_before = self._plate_axis - N + 1
start_after = self._plate_axis + 1
if end_before == 0:
return plates + (self._original_length,)
elif start_after == 0:
return plates[:end_before] + (self._original_length,)
return (plates[:end_before]
+ (self._original_length,)
+ plates[start_after:])
def _compute_plates_from_parent(self, index, parent_plates):
plates = parent_plates[:self._plate_axis] + np.shape(self._indices)
if self._plate_axis != -1:
plates = plates + parent_plates[(self._plate_axis+1):]
return plates
def _compute_plates_multiplier_from_parent(self, index, parent_multiplier):
if any(p != 1 for p in parent_multiplier):
raise NotImplementedError(
"Take node doesn't yet support plate multipliers {0}"
.format(parent_multiplier)
)
return parent_multiplier
|