1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
|
"""
Testing for mean shift clustering methods
"""
import warnings
import numpy as np
import pytest
from sklearn.cluster import MeanShift, estimate_bandwidth, get_bin_seeds, mean_shift
from sklearn.datasets import make_blobs
from sklearn.metrics import v_measure_score
from sklearn.utils._testing import assert_allclose, assert_array_equal
n_clusters = 3
centers = np.array([[1, 1], [-1, -1], [1, -1]]) + 10
X, _ = make_blobs(
n_samples=300,
n_features=2,
centers=centers,
cluster_std=0.4,
shuffle=True,
random_state=11,
)
def test_convergence_of_1d_constant_data():
# Test convergence using 1D constant data
# Non-regression test for:
# https://github.com/scikit-learn/scikit-learn/issues/28926
model = MeanShift()
n_iter = model.fit(np.ones(10).reshape(-1, 1)).n_iter_
assert n_iter < model.max_iter
def test_estimate_bandwidth():
# Test estimate_bandwidth
bandwidth = estimate_bandwidth(X, n_samples=200)
assert 0.9 <= bandwidth <= 1.5
def test_estimate_bandwidth_1sample(global_dtype):
# Test estimate_bandwidth when n_samples=1 and quantile<1, so that
# n_neighbors is set to 1.
bandwidth = estimate_bandwidth(
X.astype(global_dtype, copy=False), n_samples=1, quantile=0.3
)
assert bandwidth.dtype == X.dtype
assert bandwidth == pytest.approx(0.0, abs=1e-5)
@pytest.mark.parametrize(
"bandwidth, cluster_all, expected, first_cluster_label",
[(1.2, True, 3, 0), (1.2, False, 4, -1)],
)
def test_mean_shift(
global_dtype, bandwidth, cluster_all, expected, first_cluster_label
):
# Test MeanShift algorithm
X_with_global_dtype = X.astype(global_dtype, copy=False)
ms = MeanShift(bandwidth=bandwidth, cluster_all=cluster_all)
labels = ms.fit(X_with_global_dtype).labels_
labels_unique = np.unique(labels)
n_clusters_ = len(labels_unique)
assert n_clusters_ == expected
assert labels_unique[0] == first_cluster_label
assert ms.cluster_centers_.dtype == global_dtype
cluster_centers, labels_mean_shift = mean_shift(
X_with_global_dtype, cluster_all=cluster_all
)
labels_mean_shift_unique = np.unique(labels_mean_shift)
n_clusters_mean_shift = len(labels_mean_shift_unique)
assert n_clusters_mean_shift == expected
assert labels_mean_shift_unique[0] == first_cluster_label
assert cluster_centers.dtype == global_dtype
def test_parallel(global_dtype, global_random_seed):
centers = np.array([[1, 1], [-1, -1], [1, -1]]) + 10
X, _ = make_blobs(
n_samples=50,
n_features=2,
centers=centers,
cluster_std=0.4,
shuffle=True,
random_state=global_random_seed,
)
X = X.astype(global_dtype, copy=False)
ms1 = MeanShift(n_jobs=2)
ms1.fit(X)
ms2 = MeanShift()
ms2.fit(X)
assert_allclose(ms1.cluster_centers_, ms2.cluster_centers_)
assert ms1.cluster_centers_.dtype == ms2.cluster_centers_.dtype
assert_array_equal(ms1.labels_, ms2.labels_)
def test_meanshift_predict(global_dtype):
# Test MeanShift.predict
ms = MeanShift(bandwidth=1.2)
X_with_global_dtype = X.astype(global_dtype, copy=False)
labels = ms.fit_predict(X_with_global_dtype)
labels2 = ms.predict(X_with_global_dtype)
assert_array_equal(labels, labels2)
def test_meanshift_all_orphans():
# init away from the data, crash with a sensible warning
ms = MeanShift(bandwidth=0.1, seeds=[[-9, -9], [-10, -10]])
msg = "No point was within bandwidth=0.1"
with pytest.raises(ValueError, match=msg):
ms.fit(
X,
)
def test_unfitted():
# Non-regression: before fit, there should be not fitted attributes.
ms = MeanShift()
assert not hasattr(ms, "cluster_centers_")
assert not hasattr(ms, "labels_")
def test_cluster_intensity_tie(global_dtype):
X = np.array([[1, 1], [2, 1], [1, 0], [4, 7], [3, 5], [3, 6]], dtype=global_dtype)
c1 = MeanShift(bandwidth=2).fit(X)
X = np.array([[4, 7], [3, 5], [3, 6], [1, 1], [2, 1], [1, 0]], dtype=global_dtype)
c2 = MeanShift(bandwidth=2).fit(X)
assert_array_equal(c1.labels_, [1, 1, 1, 0, 0, 0])
assert_array_equal(c2.labels_, [0, 0, 0, 1, 1, 1])
def test_bin_seeds(global_dtype):
# Test the bin seeding technique which can be used in the mean shift
# algorithm
# Data is just 6 points in the plane
X = np.array(
[[1.0, 1.0], [1.4, 1.4], [1.8, 1.2], [2.0, 1.0], [2.1, 1.1], [0.0, 0.0]],
dtype=global_dtype,
)
# With a bin coarseness of 1.0 and min_bin_freq of 1, 3 bins should be
# found
ground_truth = {(1.0, 1.0), (2.0, 1.0), (0.0, 0.0)}
test_bins = get_bin_seeds(X, 1, 1)
test_result = set(tuple(p) for p in test_bins)
assert len(ground_truth.symmetric_difference(test_result)) == 0
# With a bin coarseness of 1.0 and min_bin_freq of 2, 2 bins should be
# found
ground_truth = {(1.0, 1.0), (2.0, 1.0)}
test_bins = get_bin_seeds(X, 1, 2)
test_result = set(tuple(p) for p in test_bins)
assert len(ground_truth.symmetric_difference(test_result)) == 0
# With a bin size of 0.01 and min_bin_freq of 1, 6 bins should be found
# we bail and use the whole data here.
with warnings.catch_warnings(record=True):
test_bins = get_bin_seeds(X, 0.01, 1)
assert_allclose(test_bins, X)
# tight clusters around [0, 0] and [1, 1], only get two bins
X, _ = make_blobs(
n_samples=100,
n_features=2,
centers=[[0, 0], [1, 1]],
cluster_std=0.1,
random_state=0,
)
X = X.astype(global_dtype, copy=False)
test_bins = get_bin_seeds(X, 1)
assert_array_equal(test_bins, [[0, 0], [1, 1]])
@pytest.mark.parametrize("max_iter", [1, 100])
def test_max_iter(max_iter):
clusters1, _ = mean_shift(X, max_iter=max_iter)
ms = MeanShift(max_iter=max_iter).fit(X)
clusters2 = ms.cluster_centers_
assert ms.n_iter_ <= ms.max_iter
assert len(clusters1) == len(clusters2)
for c1, c2 in zip(clusters1, clusters2):
assert np.allclose(c1, c2)
def test_mean_shift_zero_bandwidth(global_dtype):
# Check that mean shift works when the estimated bandwidth is 0.
X = np.array([1, 1, 1, 2, 2, 2, 3, 3], dtype=global_dtype).reshape(-1, 1)
# estimate_bandwidth with default args returns 0 on this dataset
bandwidth = estimate_bandwidth(X)
assert bandwidth == 0
# get_bin_seeds with a 0 bin_size should return the dataset itself
assert get_bin_seeds(X, bin_size=bandwidth) is X
# MeanShift with binning and a 0 estimated bandwidth should be equivalent
# to no binning.
ms_binning = MeanShift(bin_seeding=True, bandwidth=None).fit(X)
ms_nobinning = MeanShift(bin_seeding=False).fit(X)
expected_labels = np.array([0, 0, 0, 1, 1, 1, 2, 2])
assert v_measure_score(ms_binning.labels_, expected_labels) == pytest.approx(1)
assert v_measure_score(ms_nobinning.labels_, expected_labels) == pytest.approx(1)
assert_allclose(ms_binning.cluster_centers_, ms_nobinning.cluster_centers_)
|