File: test_mean_shift.py

package info (click to toggle)
scikit-learn 0.11.0-2%2Bdeb7u1
  • links: PTS, VCS
  • area: main
  • in suites: wheezy
  • size: 13,900 kB
  • sloc: python: 34,740; ansic: 8,860; cpp: 8,849; pascal: 230; makefile: 211; sh: 14
file content (66 lines) | stat: -rw-r--r-- 2,154 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""
Testing for mean shift clustering methods

"""

import numpy as np
from numpy.testing import assert_equal
from nose.tools import assert_true

from .. import MeanShift, mean_shift, estimate_bandwidth, get_bin_seeds
from ...datasets.samples_generator import make_blobs

n_clusters = 3
centers = np.array([[1, 1], [-1, -1], [1, -1]]) + 10
X, _ = make_blobs(n_samples=500, n_features=2, centers=centers,
                  cluster_std=0.4, shuffle=True, random_state=0)


def test_mean_shift():
    """ Test MeanShift algorithm
    """
    bandwidth = 1.2

    bandwidth_ = estimate_bandwidth(X, n_samples=300)
    assert_true(0.9 <= bandwidth_ <= 1.5)

    ms = MeanShift(bandwidth=bandwidth)
    labels = ms.fit(X).labels_
    cluster_centers = ms.cluster_centers_
    labels_unique = np.unique(labels)
    n_clusters_ = len(labels_unique)
    assert_equal(n_clusters_, n_clusters)

    cluster_centers, labels = mean_shift(X, bandwidth=bandwidth)
    labels_unique = np.unique(labels)
    n_clusters_ = len(labels_unique)
    assert_equal(n_clusters_, n_clusters)


def test_bin_seeds():
    """
    Test the bin seeding technique which can be used in the mean shift
    algorithm
    """
    # Data is just 6 points in the plane
    X = np.array([[1., 1.], [1.5, 1.5], [1.8, 1.2],
                  [2., 1.], [2.1, 1.1], [0., 0.]])

    # With a bin coarseness of 1.0 and min_bin_freq of 1, 3 bins should be
    # found
    ground_truth = set([(1., 1.), (2., 1.), (0., 0.)])
    test_bins = get_bin_seeds(X, 1, 1)
    test_result = set([tuple(p) for p in test_bins])
    assert_true(len(ground_truth.symmetric_difference(test_result)) == 0)

    # With a bin coarseness of 1.0 and min_bin_freq of 2, 2 bins should be
    # found
    ground_truth = set([(1., 1.), (2., 1.)])
    test_bins = get_bin_seeds(X, 1, 2)
    test_result = set([tuple(p) for p in test_bins])
    assert_true(len(ground_truth.symmetric_difference(test_result)) == 0)

    # With a bin size of 0.01 and min_bin_freq of 1, 6 bins should be found
    test_bins = get_bin_seeds(X, 0.01, 1)
    test_result = set([tuple(p) for p in test_bins])
    assert_true(len(test_result) == 6)