File: concatenate.py

package info (click to toggle)
python-bayespy 0.6.2-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,132 kB
  • sloc: python: 22,402; makefile: 156
file content (167 lines) | stat: -rw-r--r-- 5,957 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
################################################################################
# Copyright (C) 2015 Jaakko Luttinen
#
# This file is licensed under the MIT License.
################################################################################


import numpy as np

from bayespy.utils import misc

from .deterministic import Deterministic
from .node import Moments

class Concatenate(Deterministic):
    """
    Concatenate similar nodes along a plate axis.

    Nodes must be of same type and dimensionality. Also, plates must be
    identical except for the plate axis along which the concatenation is
    performed.

    See also
    --------
    numpy.concatenate
    """


    def __init__(self, *nodes, axis=-1, **kwargs):
        if axis >= 0:
            raise ValueError("Currently, only negative axis indeces "
                             "are allowed.")
        self._axis = axis
        parent_moments = None
        for node in nodes:
            try:
                parent_moments = node._moments
            except:
                pass
            else:
                break
        if parent_moments is None:
            raise ValueError("Couldn't determine parent moments")
        # All parents must have same moments
        self._parent_moments = (parent_moments,) * len(nodes)
        self._moments = parent_moments
        # Convert nodes
        try:
            nodes = [
                self._ensure_moments(
                    node,
                    parent_moments.__class__,
                    **parent_moments.get_instance_conversion_kwargs()
                )
                for node in nodes
            ]
        except Moments.NoConverterError:
            raise ValueError("Parents have different moments")
        # Dimensionality of the node
        dims = tuple([dim for dim in nodes[0].dims])
        for node in nodes:
            if node.dims != dims:
                raise ValueError("Parents have different dimensionalities")

        super().__init__(
            *nodes,
            dims=dims,
            allow_dependent_parents=True, # because parent plates are kept separate
            **kwargs
        )

        # Compute start indices for each parent on the concatenated plate axis
        self._indices = np.zeros(len(nodes)+1, dtype=np.int64)
        self._indices[1:] = np.cumsum([int(parent.plates[axis])
                                       for parent in self.parents])
        self._lengths = [parent.plates[axis] for parent in self.parents]
        return


    def _get_id_list(self):
        """
        Parents don't need to be independent for this node so remove duplicates
        """
        return list(set(super()._get_id_list()))


    def _compute_plates_to_parent(self, index, plates):
        plates = list(plates)
        plates[self._axis] = self.parents[index].plates[self._axis]
        return tuple(plates)


    def _compute_plates_from_parent(self, index, plates):
        plates = list(plates)
        plates[self._axis] = 0
        for parent in self.parents:
            plates[self._axis] += parent.plates[self._axis]
        return tuple(plates)


    def _plates_multiplier_from_parent(self, index):
        multipliers = [parent.plates_multiplier for parent in self.parents]
        for m in multipliers:
            if np.any(np.array(m) != 1):
                raise ValueError("Concatenation node does not support plate "
                                 "multipliers.")
        return ()


    def _compute_weights_to_parent(self, index, weights):
        axis = self._axis
        indices = self._indices[index:(index+1)]
        if np.ndim(weights) >= abs(axis) and np.shape(weights)[axis] > 1:
            # Take the middle one of the returned three arrays
            return np.split(weights, indices, axis=axis)[1]
        else:
            return weights


    def _compute_message_to_parent(self, index, m, *u_parents):
        msg = []
        indices = self._indices[index:(index+2)]
        for i in range(len(m)):
            # Fix plate axis to array axis
            axis = self._axis - len(self.dims[i])
            # Find the slice from the message
            if np.ndim(m[i]) >= abs(axis) and np.shape(m[i])[axis] > 1:
                mi = np.split(m[i], indices, axis=axis)[1]
            else:
                mi = m[i]
            msg.append(mi)
        return msg


    def _compute_moments(self, *u_parents):
        # TODO/FIXME: Unfortunately, np.concatenate doesn't support
        # broadcasting but moment messages may use broadcasting.
        #
        # WORKAROUND: Broadcast the arrays explcitly to have same shape
        # except for the concatenated axis.
        u = []
        for i in range(len(self.dims)):
            # Fix plate axis to array axis
            axis = self._axis - len(self.dims[i])
            # Find broadcasted shape
            ui_parents = [u_parent[i] for u_parent in u_parents]
            shapes = [list(np.shape(uip)) for uip in ui_parents]
            for i in range(len(shapes)):
                if len(shapes[i]) >= abs(axis):
                    shapes[i][axis] = 1
            ## shapes = [np.shape(uip[:axis]) + (1,) + np.shape(uip[(axis+1)])
            ##           if np.ndim(uip) >= abs(self._axis) else
            ##           np.shape(uip)
            ##           for uip in ui_parents]
            bc_shape = misc.broadcasted_shape(*shapes)
            # Concatenated axis must be broadcasted explicitly
            bc_shapes = [misc.broadcasted_shape(bc_shape,
                                                (length,) + (1,)*(abs(axis)-1))
                         for length in self._lengths]
            # Broadcast explicitly
            ui_parents = [uip * np.ones(shape)
                          for (uip, shape) in zip(ui_parents, bc_shapes)]
            # Concatenate
            ui = np.concatenate(ui_parents, axis=axis)
            u.append(ui)

        return u