1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
|
################################################################################
# Copyright (C) 2014 Jaakko Luttinen
#
# This file is licensed under the MIT License.
################################################################################
"""
Unit tests for `bernoulli` module.
"""
import numpy as np
import scipy
from bayespy.nodes import (Bernoulli,
Beta,
Mixture)
from bayespy.utils import random
from bayespy.utils.misc import TestCase
class TestBernoulli(TestCase):
"""
Unit tests for Bernoulli node
"""
def test_init(self):
"""
Test the creation of Bernoulli nodes.
"""
# Some simple initializations
X = Bernoulli(0.5)
X = Bernoulli(Beta([2,3]))
# Check that plates are correct
X = Bernoulli(0.7, plates=(4,3))
self.assertEqual(X.plates,
(4,3))
X = Bernoulli(0.7*np.ones((4,3)))
self.assertEqual(X.plates,
(4,3))
X = Bernoulli(Beta([4,3], plates=(4,3)))
self.assertEqual(X.plates,
(4,3))
# Invalid probability
self.assertRaises(ValueError,
Bernoulli,
-0.5)
self.assertRaises(ValueError,
Bernoulli,
1.5)
# Inconsistent plates
self.assertRaises(ValueError,
Bernoulli,
0.5*np.ones(4),
plates=(3,))
# Explicit plates too small
self.assertRaises(ValueError,
Bernoulli,
0.5*np.ones(4),
plates=(1,))
pass
def test_moments(self):
"""
Test the moments of Bernoulli nodes.
"""
# Simple test
X = Bernoulli(0.7)
u = X._message_to_child()
self.assertEqual(len(u), 1)
self.assertAllClose(u[0],
0.7)
# Test plates in p
p = np.random.rand(3)
X = Bernoulli(p)
u = X._message_to_child()
self.assertAllClose(u[0],
p)
# Test with beta prior
P = Beta([7, 3])
logp = P._message_to_child()[0]
p0 = np.exp(logp[0]) / (np.exp(logp[0]) + np.exp(logp[1]))
X = Bernoulli(P)
u = X._message_to_child()
self.assertAllClose(u[0],
p0)
# Test with broadcasted plates
P = Beta([7, 3], plates=(10,))
X = Bernoulli(P)
u = X._message_to_child()
self.assertAllClose(u[0] * np.ones(X.get_shape(0)),
p0*np.ones(10))
pass
def test_mixture(self):
"""
Test mixture of Bernoulli
"""
P = Mixture([2,0,0], Bernoulli, [0.1, 0.2, 0.3])
u = P._message_to_child()
self.assertEqual(len(u), 1)
self.assertAllClose(u[0], [0.3, 0.1, 0.1])
pass
def test_observed(self):
"""
Test observation of Bernoulli node
"""
Z = Bernoulli(0.3)
Z.observe(2 < 3)
pass
def test_random(self):
"""
Test random sampling in Bernoulli node
"""
p = [1.0, 0.0]
with np.errstate(divide='ignore'):
Z = Bernoulli(p, plates=(3,2)).random()
self.assertArrayEqual(Z, np.ones((3,2))*p)
|