File: test_wishart.py

package info (click to toggle)
python-bayespy 0.6.2-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,132 kB
  • sloc: python: 22,402; makefile: 156
file content (126 lines) | stat: -rw-r--r-- 3,781 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
################################################################################
# Copyright (C) 2015 Jaakko Luttinen
#
# This file is licensed under the MIT License.
################################################################################


"""
Unit tests for `wishart` module.
"""

import numpy as np

from scipy import special

from .. import gaussian
from bayespy.nodes import (Gaussian, 
                           Wishart)

from ...vmp import VB

from bayespy.utils import misc
from bayespy.utils import linalg
from bayespy.utils import random

from bayespy.utils.misc import TestCase

def _student_logpdf(y, mu, Cov, nu):
    D = np.shape(y)[-1]
    return (special.gammaln((nu+D)/2)
            - special.gammaln(nu/2)
            - 0.5 * D * np.log(nu)
            - 0.5 * D * np.log(np.pi)
            - 0.5 * np.linalg.slogdet(Cov)[1]
            - 0.5 * (nu+D) * np.log(1+1/nu*np.einsum('...i,...ij,...j->...',
                                                     y-mu,
                                                     np.linalg.inv(Cov),
                                                     y-mu)))


class TestWishart(TestCase):

    def test_lower_bound(self):
        """
        Test the Wishart VB lower bound
        """

        #
        # By having the Wishart node as the only latent node, VB will give exact
        # results, thus the VB lower bound is the true marginal log likelihood.
        # Thus, check that they are equal. The true marginal likelihood is the
        # multivariate Student-t distribution.
        #

        np.random.seed(42)

        D = 3
        n = (D-1) + np.random.uniform(0.1, 0.5)
        V = random.covariance(D)
        Lambda = Wishart(n, V)
        mu = np.random.randn(D)
        Y = Gaussian(mu, Lambda)
        y = np.random.randn(D)
        Y.observe(y)
        Lambda.update()
        L = Y.lower_bound_contribution() + Lambda.lower_bound_contribution()
        mu = mu
        nu = n + 1 - D
        Cov = V / nu
        self.assertAllClose(L,
                            _student_logpdf(y,
                                            mu,
                                            Cov,
                                            nu))

        pass


    def test_moments(self):
        """
        Test the moments of Wishart node
        """

        np.random.seed(42)

        # Test prior moments
        D = 3
        n = (D-1) + np.random.uniform(0.1,2)
        V = random.covariance(D)
        Lambda = Wishart(n, V)
        Lambda.update()
        u = Lambda.get_moments()
        self.assertAllClose(u[0],
                            n*np.linalg.inv(V),
                            msg='Mean incorrect')
        self.assertAllClose(u[1],
                            (np.sum(special.digamma((n - np.arange(D))/2))
                             + D*np.log(2)
                             - np.linalg.slogdet(V)[1]),
                             msg='Log determinant incorrect')

        # Test posterior moments
        D = 3
        n = (D-1) + np.random.uniform(0.1,2)
        V = random.covariance(D)
        Lambda = Wishart(n, V)
        mu = np.random.randn(D)
        Y = Gaussian(mu, Lambda)
        y = np.random.randn(D)
        Y.observe(y)
        Lambda.update()
        u = Lambda.get_moments()
        n = n + 1
        V = V + np.outer(y-mu, y-mu) 
        self.assertAllClose(u[0],
                            n*np.linalg.inv(V),
                            msg='Mean incorrect')
        self.assertAllClose(u[1],
                            (np.sum(special.digamma((n - np.arange(D))/2))
                             + D*np.log(2)
                             - np.linalg.slogdet(V)[1]),
                             msg='Log determinant incorrect')

        

        pass