File: test_gmm_bayes.py

package info (click to toggle)
astroml 1.0.2-6
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 932 kB
  • sloc: python: 5,731; makefile: 3
file content (74 lines) | stat: -rw-r--r-- 1,749 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
"""Tests of the GMM Bayes classifier"""
import numpy as np
from numpy.testing import assert_allclose
import pytest
from astroML.classification import GMMBayes


def test_gmm1d():
    x1 = np.random.normal(0, 1, size=100)
    x2 = np.random.normal(10, 1, size=100)
    X = np.concatenate((x1, x2)).reshape((200, 1))
    y = np.zeros(200)
    y[100:] = 1

    ncm = 1
    clf = GMMBayes(ncm)
    clf.fit(X, y)

    predicted = clf.predict(X)
    assert_allclose(y, predicted)


def test_gmm2d():
    x1 = np.random.normal(0, 1, size=(100, 2))
    x2 = np.random.normal(10, 1, size=(100, 2))
    X = np.vstack((x1, x2))
    y = np.zeros(200)
    y[100:] = 1

    for ncm in (1, 2, 3):
        clf = GMMBayes(ncm)
        clf.fit(X, y)

        predicted = clf.predict(X)
        assert_allclose(y, predicted)


def test_incompatible_shapes_exception():
    X = np.random.normal(0, 1, size=(100, 2))
    y = np.zeros(99)

    ncm = 1
    clf = GMMBayes(ncm)

    with pytest.raises(Exception) as e:
        assert clf.fit(X, y)

    assert str(e.value) == "X and y have incompatible shapes"


def test_incompatible_number_of_components_exception():
    X = np.random.normal(0, 1, size=(100, 2))
    y = np.zeros(100)

    ncm = [1, 2, 3]
    clf = GMMBayes(ncm)

    with pytest.raises(Exception) as e:
        assert clf.fit(X, y)

    assert str(e.value) == ("n_components must be compatible with "
                            "the number of classes")


def test_too_many_components_warning():
    X = np.random.normal(0, 1, size=(3, 2))
    y = np.zeros(3)

    ncm = 5
    clf = GMMBayes(ncm)

    with pytest.warns(UserWarning, match="Expected n_samples >= "
                                         "n_components but got "):
        clf.fit(X, y)