File: take.py

package info (click to toggle)
python-bayespy 0.6.2-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,132 kB
  • sloc: python: 22,402; makefile: 156
file content (140 lines) | stat: -rw-r--r-- 4,420 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
################################################################################
# Copyright (C) 2015 Jaakko Luttinen
#
# This file is licensed under the MIT License.
################################################################################

import numpy as np

from .deterministic import Deterministic
from .node import Moments
from bayespy.utils import misc


class Take(Deterministic):
    """
    Choose elements/sub-arrays along a plate axis

    Basically, applies `np.take` on a plate axis. Allows advanced mapping of
    plates.

    Parameters
    ----------
    node : Node
        A node to apply the take operation on.
    indices : array of integers
        Plate elements to pick along a plate axis.
    plate_axis : int (negative)
        The plate axis to pick elements from (default: -1).

    See also
    --------
    numpy.take

    Examples
    --------

    >>> from bayespy.nodes import Gamma, Take
    >>> alpha = Gamma([1, 2, 3], [1, 1, 1])
    >>> x = Take(alpha, [1, 1, 2, 2, 1, 0])
    >>> x.get_moments()[0]
    array([2., 2., 3., 3., 2., 1.])
    """


    def __init__(self, node, indices, plate_axis=-1, **kwargs):
        self._moments = node._moments
        self._parent_moments = (node._moments,)
        self._indices = np.array(indices)
        self._plate_axis = plate_axis
        self._original_length = node.plates[plate_axis]

        # Validate arguments
        if not misc.is_scalar_integer(plate_axis):
            raise ValueError("Plate axis must be integer")
        if plate_axis >= 0:
            raise ValueError("plate_axis must be negative index")
        if plate_axis < -len(node.plates):
            raise ValueError("plate_axis out of bounds")
        if not issubclass(self._indices.dtype.type, np.integer):
            raise ValueError("Indices must be integers")
        if (np.any(self._indices < -self._original_length) or
            np.any(self._indices >= self._original_length)):
            raise ValueError("Index out of bounds")

        super().__init__(node, dims=node.dims, **kwargs)


    def _compute_moments(self, u_parent):
        u = []
        for (ui, dimi) in zip(u_parent, self.dims):
            axis = self._plate_axis - len(dimi)
            # Just in case the taken axis is using broadcasting and has unit
            # length in u_parent, force it to have the correct length along the
            # axis in order to avoid errors in np.take.
            broadcaster = np.ones((self._original_length,) + (-axis-1)*(1,))
            u.append(np.take(ui*broadcaster, self._indices, axis=axis))
        return u


    def _compute_message_to_parent(self, index, m_child, u_parent):

        m = [
            misc.put_simple(
                mi,
                self._indices,
                axis=self._plate_axis-len(dimi),
                length=self._original_length,
            )
            for (mi, dimi) in zip(m_child, self.dims)
        ]
        return m


    def _compute_weights_to_parent(self, index, weights):

        return misc.put_simple(
            weights,
            self._indices,
            axis=self._plate_axis,
            length=self._original_length,
        )


    def _compute_plates_to_parent(self, index, plates):

        # Number of axes created by take operation
        N = np.ndim(self._indices)

        if self._plate_axis >= 0:
            raise RuntimeError("Plate axis should be negative")

        end_before = self._plate_axis - N + 1
        start_after = self._plate_axis + 1

        if end_before == 0:
            return plates + (self._original_length,)
        elif start_after == 0:
            return plates[:end_before] + (self._original_length,)

        return (plates[:end_before]
                + (self._original_length,)
                + plates[start_after:])


    def _compute_plates_from_parent(self, index, parent_plates):

        plates = parent_plates[:self._plate_axis] + np.shape(self._indices)
        if self._plate_axis != -1:
            plates = plates + parent_plates[(self._plate_axis+1):]
        return plates


    def _compute_plates_multiplier_from_parent(self, index, parent_multiplier):
        if any(p != 1 for p in parent_multiplier):
            raise NotImplementedError(
                "Take node doesn't yet support plate multipliers {0}"
                .format(parent_multiplier)
            )

        return parent_multiplier