File: concat_gaussian.py

package info (click to toggle)
python-bayespy 0.6.2-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,132 kB
  • sloc: python: 22,402; makefile: 156
file content (116 lines) | stat: -rw-r--r-- 3,889 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import numpy as np

from bayespy.utils import misc
from bayespy.utils import linalg
from .gaussian import GaussianMoments
from .deterministic import Deterministic


class ConcatGaussian(Deterministic):
    """Concatenate Gaussian vectors along the variable axis (not plate axis)

    NOTE: This concatenates on the variable axis! That is, the dimensionality
    of the resulting Gaussian vector is the sum of the dimensionalities of the
    input Gaussian vectors.

    TODO: Add support for Gaussian arrays and arbitrary concatenation axis.
    """


    def __init__(self, *nodes, **kwargs):

        # Number of nodes to concatenate
        N = len(nodes)

        # This is stuff that will be useful when implementing arbitrary
        # concatenation. That is, first determine ndim.
        #
        # # Convert nodes to Gaussians (if they are not nodes, don't worry)
        # nodes_gaussian = []
        # for node in nodes:
        #     try:
        #         node_gaussian = node._convert(GaussianMoments)
        #     except AttributeError: # Moments.NoConverterError:
        #         nodes_gaussian.append(node)
        #     else:
        #         nodes_gaussian.append(node_gaussian)
        # nodes = nodes_gaussian
        #
        # # Determine shape from the first Gaussian node
        # shape = None
        # for node in nodes:
        #     try:
        #         shape = node.dims[0]
        #     except AttibuteError:
        #         pass
        #     else:
        #         break
        # if shape is None:
        #     raise ValueError("Couldn't determine shape from the input nodes")
        #
        # ndim = len(shape)

        nodes = [self._ensure_moments(node, GaussianMoments, ndim=1)
                 for node in nodes]

        D = sum(node.dims[0][0] for node in nodes)

        shape = (D,)

        self._moments = GaussianMoments(shape)

        self._parent_moments = [node._moments for node in nodes]

        # Make sure all parents are Gaussian vectors
        if any(len(node.dims[0]) != 1 for node in nodes):
            raise ValueError("Input nodes must be (Gaussian) vectors")

        self.slices = tuple(np.cumsum([0] + [node.dims[0][0] for node in nodes]))
        D = self.slices[-1]

        return super().__init__(*nodes, dims=((D,), (D, D)), **kwargs)


    def _compute_moments(self, *u_nodes):
        x = misc.concatenate(*[u[0] for u in u_nodes], axis=-1)
        xx = misc.block_diag(*[u[1] for u in u_nodes])

        # Explicitly broadcast xx to plates of x
        x_plates = np.shape(x)[:-1]
        xx = np.ones(x_plates)[...,None,None] * xx

        # Compute the cross-covariance terms using the means of each variable
        # (because covariances are zero for factorized nodes in the VB
        # approximation)
        i_start = 0
        for m in range(len(u_nodes)):
            i_end = i_start + np.shape(u_nodes[m][0])[-1]
            j_start = 0
            for n in range(m):
                j_end = j_start + np.shape(u_nodes[n][0])[-1]
                xm_xn = linalg.outer(u_nodes[m][0], u_nodes[n][0], ndim=1)
                xx[...,i_start:i_end,j_start:j_end] = xm_xn
                xx[...,j_start:j_end,i_start:i_end] = misc.T(xm_xn)
                j_start = j_end
            i_start = i_end

        return [x, xx]


    def _compute_message_to_parent(self, i, m, *u_nodes):
        r = self.slices

        # Pick the proper parts from the message array
        m0 = m[0][...,r[i]:r[i+1]]
        m1 = m[1][...,r[i]:r[i+1],r[i]:r[i+1]]

        # Handle cross-covariance terms by using the mean of the covariate node
        for (j, u) in enumerate(u_nodes):
            if j != i:
                m0 = m0 + 2 * np.einsum(
                    '...ij,...j->...i',
                    m[1][...,r[i]:r[i+1],r[j]:r[j+1]],
                    u[0]
                )

        return [m0, m1]