File: add.py

package info (click to toggle)
python-bayespy 0.6.2-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,132 kB
  • sloc: python: 22,402; makefile: 156
file content (154 lines) | stat: -rw-r--r-- 4,331 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
################################################################################
# Copyright (C) 2015 Jaakko Luttinen
#
# This file is licensed under the MIT License.
################################################################################


import numpy as np
import functools

from .deterministic import Deterministic
from .gaussian import Gaussian, GaussianMoments

from bayespy.utils import linalg


class Add(Deterministic):
    r"""
    Node for computing sums of Gaussian nodes: :math:`X+Y+Z`.

    Examples
    --------

    >>> import numpy as np
    >>> from bayespy import nodes
    >>> X = nodes.Gaussian(np.zeros(2), np.identity(2), plates=(3,))
    >>> Y = nodes.Gaussian(np.ones(2), np.identity(2))
    >>> Z = nodes.Add(X, Y)
    >>> print("Mean:\n", Z.get_moments()[0])
    Mean:
     [[1. 1.]]
    >>> print("Second moment:\n", Z.get_moments()[1])
    Second moment:
     [[[3. 1.]
      [1. 3.]]]

    Notes
    -----

    Shapes of the nodes must be identical. Plates are broadcasted.

    This node sums nodes that are independent in the posterior
    approximation. However, summing variables puts a strong coupling among the
    variables, which is lost in this construction.  Thus, it is usually better
    to use a single Gaussian node to represent the set of the summed variables
    and use SumMultiply node to compute the sum.  In that way, the correlation
    between the variables is not lost. However, in some cases it is necessary or
    useful to use Add node.

    See also
    --------

    Dot, SumMultiply
    """

    def __init__(self, *nodes, **kwargs):
        """
        Add(X1, X2, ...)
        """

        ndim = None
        for node in nodes:
            try:
                node = self._ensure_moments(node, GaussianMoments, ndim=None)
            except ValueError:
                pass
            else:
                ndim = node._moments.ndim
                break

        nodes = [self._ensure_moments(node, GaussianMoments, ndim=ndim)
                 for node in nodes]

        N = len(nodes)
        if N < 2:
            raise ValueError("Give at least two parents")

        nodes = list(nodes)

        for n in range(N-1):
            if nodes[n].dims != nodes[n+1].dims:
                raise ValueError("Nodes do not have identical shapes")

        ndim = len(nodes[0].dims[0])
        dims = tuple(nodes[0].dims)
        shape = dims[0]

        self._moments = GaussianMoments(shape)
        self._parent_moments = N * [GaussianMoments(shape)]

        self.ndim = ndim
        self.N = N

        super().__init__(*nodes, dims=dims, **kwargs)


    def _compute_moments(self, *u_parents):
        """
        Compute the moments of the sum
        """

        u0 = functools.reduce(np.add,
                              (u_parent[0] for u_parent in u_parents))
        u1 = functools.reduce(np.add,
                              (u_parent[1] for u_parent in u_parents))

        for i in range(self.N):
            for j in range(i+1, self.N):
                xi_xj = linalg.outer(u_parents[i][0], u_parents[j][0], ndim=self.ndim)
                xj_xi = linalg.transpose(xi_xj, ndim=self.ndim)
                u1 = u1 + xi_xj + xj_xi

        return [u0, u1]


    def _compute_message_to_parent(self, index, m, *u_parents):
        """
        Compute the message to a parent node.

        .. math::

           (\sum_i \mathbf{x}_i)^T \mathbf{M}_2 (\sum_j \mathbf{x}_j)
           + (\sum_i \mathbf{x}_i)^T \mathbf{m}_1

        Moments of the parents are

        .. math::

           u_1^{(i)} = \langle \mathbf{x}_i \rangle
           \\
           u_2^{(i)} = \langle \mathbf{x}_i \mathbf{x}_i^T \rangle

        Thus, the message for :math:`i`-th parent is

        .. math::

           \phi_{x_i}^{(1)} = \mathbf{m}_1 + 2 \mathbf{M}_2 \sum_{j\neq i} \mathbf{x}_j
           \\
           \phi_{x_i}^{(2)} = \mathbf{M}_2
        """

        # Remove the moments of the parent that receives the message
        u_parents = u_parents[:index] + u_parents[(index+1):]

        m0 = (m[0] +
              linalg.mvdot(
                  2*m[1],
                  functools.reduce(np.add,
                                   (u_parent[0] for u_parent in u_parents)),
                  ndim=self.ndim))

        m1 = m[1]

        return [m0, m1]