File: test_dpgmm.py

package info (click to toggle)
scikit-learn 0.18-5
  • links: PTS, VCS
  • area: main
  • in suites: stretch
  • size: 71,040 kB
  • ctags: 91,142
  • sloc: python: 97,257; ansic: 8,360; cpp: 5,649; makefile: 242; sh: 238
file content (237 lines) | stat: -rw-r--r-- 7,866 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
# Important note for the deprecation cleaning of 0.20 :
# All the function and classes of this file have been deprecated in 0.18.
# When you remove this file please also remove the related files
# - 'sklearn/mixture/dpgmm.py'
# - 'sklearn/mixture/gmm.py'
# - 'sklearn/mixture/test_gmm.py'
import unittest
import sys

import numpy as np

from sklearn.mixture import DPGMM, VBGMM
from sklearn.mixture.dpgmm import log_normalize
from sklearn.datasets import make_blobs
from sklearn.utils.testing import assert_array_less, assert_equal
from sklearn.utils.testing import assert_warns_message, ignore_warnings
from sklearn.mixture.tests.test_gmm import GMMTester
from sklearn.externals.six.moves import cStringIO as StringIO
from sklearn.mixture.dpgmm import digamma, gammaln
from sklearn.mixture.dpgmm import wishart_log_det, wishart_logz


np.seterr(all='warn')


@ignore_warnings(category=DeprecationWarning)
def test_class_weights():
    # check that the class weights are updated
    # simple 3 cluster dataset
    X, y = make_blobs(random_state=1)
    for Model in [DPGMM, VBGMM]:
        dpgmm = Model(n_components=10, random_state=1, alpha=20, n_iter=50)
        dpgmm.fit(X)
        # get indices of components that are used:
        indices = np.unique(dpgmm.predict(X))
        active = np.zeros(10, dtype=np.bool)
        active[indices] = True
        # used components are important
        assert_array_less(.1, dpgmm.weights_[active])
        # others are not
        assert_array_less(dpgmm.weights_[~active], .05)


@ignore_warnings(category=DeprecationWarning)
def test_verbose_boolean():
    # checks that the output for the verbose output is the same
    # for the flag values '1' and 'True'
    # simple 3 cluster dataset
    X, y = make_blobs(random_state=1)
    for Model in [DPGMM, VBGMM]:
        dpgmm_bool = Model(n_components=10, random_state=1, alpha=20,
                           n_iter=50, verbose=True)
        dpgmm_int = Model(n_components=10, random_state=1, alpha=20,
                          n_iter=50, verbose=1)

        old_stdout = sys.stdout
        sys.stdout = StringIO()
        try:
            # generate output with the boolean flag
            dpgmm_bool.fit(X)
            verbose_output = sys.stdout
            verbose_output.seek(0)
            bool_output = verbose_output.readline()
            # generate output with the int flag
            dpgmm_int.fit(X)
            verbose_output = sys.stdout
            verbose_output.seek(0)
            int_output = verbose_output.readline()
            assert_equal(bool_output, int_output)
        finally:
            sys.stdout = old_stdout


@ignore_warnings(category=DeprecationWarning)
def test_verbose_first_level():
    # simple 3 cluster dataset
    X, y = make_blobs(random_state=1)
    for Model in [DPGMM, VBGMM]:
        dpgmm = Model(n_components=10, random_state=1, alpha=20, n_iter=50,
                      verbose=1)

        old_stdout = sys.stdout
        sys.stdout = StringIO()
        try:
            dpgmm.fit(X)
        finally:
            sys.stdout = old_stdout


@ignore_warnings(category=DeprecationWarning)
def test_verbose_second_level():
    # simple 3 cluster dataset
    X, y = make_blobs(random_state=1)
    for Model in [DPGMM, VBGMM]:
        dpgmm = Model(n_components=10, random_state=1, alpha=20, n_iter=50,
                      verbose=2)

        old_stdout = sys.stdout
        sys.stdout = StringIO()
        try:
            dpgmm.fit(X)
        finally:
            sys.stdout = old_stdout


@ignore_warnings(category=DeprecationWarning)
def test_digamma():
    assert_warns_message(DeprecationWarning, "The function digamma is"
                         " deprecated in 0.18 and will be removed in 0.20. "
                         "Use scipy.special.digamma instead.", digamma, 3)


@ignore_warnings(category=DeprecationWarning)
def test_gammaln():
    assert_warns_message(DeprecationWarning, "The function gammaln"
                         " is deprecated in 0.18 and will be removed"
                         " in 0.20. Use scipy.special.gammaln instead.",
                         gammaln, 3)


@ignore_warnings(category=DeprecationWarning)
def test_log_normalize():
    v = np.array([0.1, 0.8, 0.01, 0.09])
    a = np.log(2 * v)
    result = assert_warns_message(DeprecationWarning, "The function "
                                  "log_normalize is deprecated in 0.18 and"
                                  " will be removed in 0.20.",
                                  log_normalize, a)
    assert np.allclose(v, result, rtol=0.01)


@ignore_warnings(category=DeprecationWarning)
def test_wishart_log_det():
    a = np.array([0.1, 0.8, 0.01, 0.09])
    b = np.array([0.2, 0.7, 0.05, 0.1])
    assert_warns_message(DeprecationWarning, "The function "
                         "wishart_log_det is deprecated in 0.18 and"
                         " will be removed in 0.20.",
                         wishart_log_det, a, b, 2, 4)


@ignore_warnings(category=DeprecationWarning)
def test_wishart_logz():
    assert_warns_message(DeprecationWarning, "The function "
                         "wishart_logz is deprecated in 0.18 and "
                         "will be removed in 0.20.", wishart_logz,
                         3, np.identity(3), 1, 3)


@ignore_warnings(category=DeprecationWarning)
def test_DPGMM_deprecation():
    assert_warns_message(
      DeprecationWarning, "The `DPGMM` class is not working correctly and "
      "it's better to use `sklearn.mixture.BayesianGaussianMixture` class "
      "with parameter `weight_concentration_prior_type='dirichlet_process'` "
      "instead. DPGMM is deprecated in 0.18 and will be removed in 0.20.",
      DPGMM)


def do_model(self, **kwds):
    return VBGMM(verbose=False, **kwds)


class DPGMMTester(GMMTester):
    model = DPGMM
    do_test_eval = False

    def score(self, g, train_obs):
        _, z = g.score_samples(train_obs)
        return g.lower_bound(train_obs, z)


class TestDPGMMWithSphericalCovars(unittest.TestCase, DPGMMTester):
    covariance_type = 'spherical'
    setUp = GMMTester._setUp


class TestDPGMMWithDiagCovars(unittest.TestCase, DPGMMTester):
    covariance_type = 'diag'
    setUp = GMMTester._setUp


class TestDPGMMWithTiedCovars(unittest.TestCase, DPGMMTester):
    covariance_type = 'tied'
    setUp = GMMTester._setUp


class TestDPGMMWithFullCovars(unittest.TestCase, DPGMMTester):
    covariance_type = 'full'
    setUp = GMMTester._setUp


def test_VBGMM_deprecation():
    assert_warns_message(
        DeprecationWarning, "The `VBGMM` class is not working correctly and "
        "it's better to use `sklearn.mixture.BayesianGaussianMixture` class "
        "with parameter `weight_concentration_prior_type="
        "'dirichlet_distribution'` instead. VBGMM is deprecated "
        "in 0.18 and will be removed in 0.20.", VBGMM)


class VBGMMTester(GMMTester):
    model = do_model
    do_test_eval = False

    def score(self, g, train_obs):
        _, z = g.score_samples(train_obs)
        return g.lower_bound(train_obs, z)


class TestVBGMMWithSphericalCovars(unittest.TestCase, VBGMMTester):
    covariance_type = 'spherical'
    setUp = GMMTester._setUp


class TestVBGMMWithDiagCovars(unittest.TestCase, VBGMMTester):
    covariance_type = 'diag'
    setUp = GMMTester._setUp


class TestVBGMMWithTiedCovars(unittest.TestCase, VBGMMTester):
    covariance_type = 'tied'
    setUp = GMMTester._setUp


class TestVBGMMWithFullCovars(unittest.TestCase, VBGMMTester):
    covariance_type = 'full'
    setUp = GMMTester._setUp


def test_vbgmm_no_modify_alpha():
    alpha = 2.
    n_components = 3
    X, y = make_blobs(random_state=1)
    vbgmm = VBGMM(n_components=n_components, alpha=alpha, n_iter=1)
    assert_equal(vbgmm.alpha, alpha)
    assert_equal(vbgmm.fit(X).alpha_, float(alpha) / n_components)