1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341
|
"""
Testing for Isolation Forest algorithm (sklearn.ensemble.iforest).
"""
# Authors: Nicolas Goix <nicolas.goix@telecom-paristech.fr>
# Alexandre Gramfort <alexandre.gramfort@telecom-paristech.fr>
# License: BSD 3 clause
import pytest
import warnings
import numpy as np
from sklearn.utils._testing import assert_array_equal
from sklearn.utils._testing import assert_array_almost_equal
from sklearn.utils._testing import ignore_warnings
from sklearn.utils._testing import assert_allclose
from sklearn.model_selection import ParameterGrid
from sklearn.ensemble import IsolationForest
from sklearn.ensemble._iforest import _average_path_length
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_diabetes, load_iris, make_classification
from sklearn.utils import check_random_state
from sklearn.metrics import roc_auc_score
from scipy.sparse import csc_matrix, csr_matrix
from unittest.mock import Mock, patch
# load iris & diabetes dataset
iris = load_iris()
diabetes = load_diabetes()
def test_iforest(global_random_seed):
"""Check Isolation Forest for various parameter settings."""
X_train = np.array([[0, 1], [1, 2]])
X_test = np.array([[2, 1], [1, 1]])
grid = ParameterGrid(
{"n_estimators": [3], "max_samples": [0.5, 1.0, 3], "bootstrap": [True, False]}
)
with ignore_warnings():
for params in grid:
IsolationForest(random_state=global_random_seed, **params).fit(
X_train
).predict(X_test)
def test_iforest_sparse(global_random_seed):
"""Check IForest for various parameter settings on sparse input."""
rng = check_random_state(global_random_seed)
X_train, X_test = train_test_split(diabetes.data[:50], random_state=rng)
grid = ParameterGrid({"max_samples": [0.5, 1.0], "bootstrap": [True, False]})
for sparse_format in [csc_matrix, csr_matrix]:
X_train_sparse = sparse_format(X_train)
X_test_sparse = sparse_format(X_test)
for params in grid:
# Trained on sparse format
sparse_classifier = IsolationForest(
n_estimators=10, random_state=global_random_seed, **params
).fit(X_train_sparse)
sparse_results = sparse_classifier.predict(X_test_sparse)
# Trained on dense format
dense_classifier = IsolationForest(
n_estimators=10, random_state=global_random_seed, **params
).fit(X_train)
dense_results = dense_classifier.predict(X_test)
assert_array_equal(sparse_results, dense_results)
def test_iforest_error():
"""Test that it gives proper exception on deficient input."""
X = iris.data
# The dataset has less than 256 samples, explicitly setting
# max_samples > n_samples should result in a warning. If not set
# explicitly there should be no warning
warn_msg = "max_samples will be set to n_samples for estimation"
with pytest.warns(UserWarning, match=warn_msg):
IsolationForest(max_samples=1000).fit(X)
with warnings.catch_warnings():
warnings.simplefilter("error", UserWarning)
IsolationForest(max_samples="auto").fit(X)
with warnings.catch_warnings():
warnings.simplefilter("error", UserWarning)
IsolationForest(max_samples=np.int64(2)).fit(X)
# test X_test n_features match X_train one:
with pytest.raises(ValueError):
IsolationForest().fit(X).predict(X[:, 1:])
def test_recalculate_max_depth():
"""Check max_depth recalculation when max_samples is reset to n_samples"""
X = iris.data
clf = IsolationForest().fit(X)
for est in clf.estimators_:
assert est.max_depth == int(np.ceil(np.log2(X.shape[0])))
def test_max_samples_attribute():
X = iris.data
clf = IsolationForest().fit(X)
assert clf.max_samples_ == X.shape[0]
clf = IsolationForest(max_samples=500)
warn_msg = "max_samples will be set to n_samples for estimation"
with pytest.warns(UserWarning, match=warn_msg):
clf.fit(X)
assert clf.max_samples_ == X.shape[0]
clf = IsolationForest(max_samples=0.4).fit(X)
assert clf.max_samples_ == 0.4 * X.shape[0]
def test_iforest_parallel_regression(global_random_seed):
"""Check parallel regression."""
rng = check_random_state(global_random_seed)
X_train, X_test = train_test_split(diabetes.data, random_state=rng)
ensemble = IsolationForest(n_jobs=3, random_state=global_random_seed).fit(X_train)
ensemble.set_params(n_jobs=1)
y1 = ensemble.predict(X_test)
ensemble.set_params(n_jobs=2)
y2 = ensemble.predict(X_test)
assert_array_almost_equal(y1, y2)
ensemble = IsolationForest(n_jobs=1, random_state=global_random_seed).fit(X_train)
y3 = ensemble.predict(X_test)
assert_array_almost_equal(y1, y3)
def test_iforest_performance(global_random_seed):
"""Test Isolation Forest performs well"""
# Generate train/test data
rng = check_random_state(global_random_seed)
X = 0.3 * rng.randn(600, 2)
X = rng.permutation(np.vstack((X + 2, X - 2)))
X_train = X[:1000]
# Generate some abnormal novel observations
X_outliers = rng.uniform(low=-1, high=1, size=(200, 2))
X_test = np.vstack((X[1000:], X_outliers))
y_test = np.array([0] * 200 + [1] * 200)
# fit the model
clf = IsolationForest(max_samples=100, random_state=rng).fit(X_train)
# predict scores (the lower, the more normal)
y_pred = -clf.decision_function(X_test)
# check that there is at most 6 errors (false positive or false negative)
assert roc_auc_score(y_test, y_pred) > 0.98
@pytest.mark.parametrize("contamination", [0.25, "auto"])
def test_iforest_works(contamination, global_random_seed):
# toy sample (the last two samples are outliers)
X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1], [7, 4], [-5, 9]]
# Test IsolationForest
clf = IsolationForest(random_state=global_random_seed, contamination=contamination)
clf.fit(X)
decision_func = -clf.decision_function(X)
pred = clf.predict(X)
# assert detect outliers:
assert np.min(decision_func[-2:]) > np.max(decision_func[:-2])
assert_array_equal(pred, 6 * [1] + 2 * [-1])
def test_max_samples_consistency():
# Make sure validated max_samples in iforest and BaseBagging are identical
X = iris.data
clf = IsolationForest().fit(X)
assert clf.max_samples_ == clf._max_samples
def test_iforest_subsampled_features():
# It tests non-regression for #5732 which failed at predict.
rng = check_random_state(0)
X_train, X_test, y_train, y_test = train_test_split(
diabetes.data[:50], diabetes.target[:50], random_state=rng
)
clf = IsolationForest(max_features=0.8)
clf.fit(X_train, y_train)
clf.predict(X_test)
def test_iforest_average_path_length():
# It tests non-regression for #8549 which used the wrong formula
# for average path length, strictly for the integer case
# Updated to check average path length when input is <= 2 (issue #11839)
result_one = 2.0 * (np.log(4.0) + np.euler_gamma) - 2.0 * 4.0 / 5.0
result_two = 2.0 * (np.log(998.0) + np.euler_gamma) - 2.0 * 998.0 / 999.0
assert_allclose(_average_path_length([0]), [0.0])
assert_allclose(_average_path_length([1]), [0.0])
assert_allclose(_average_path_length([2]), [1.0])
assert_allclose(_average_path_length([5]), [result_one])
assert_allclose(_average_path_length([999]), [result_two])
assert_allclose(
_average_path_length(np.array([1, 2, 5, 999])),
[0.0, 1.0, result_one, result_two],
)
# _average_path_length is increasing
avg_path_length = _average_path_length(np.arange(5))
assert_array_equal(avg_path_length, np.sort(avg_path_length))
def test_score_samples():
X_train = [[1, 1], [1, 2], [2, 1]]
clf1 = IsolationForest(contamination=0.1).fit(X_train)
clf2 = IsolationForest().fit(X_train)
assert_array_equal(
clf1.score_samples([[2.0, 2.0]]),
clf1.decision_function([[2.0, 2.0]]) + clf1.offset_,
)
assert_array_equal(
clf2.score_samples([[2.0, 2.0]]),
clf2.decision_function([[2.0, 2.0]]) + clf2.offset_,
)
assert_array_equal(
clf1.score_samples([[2.0, 2.0]]), clf2.score_samples([[2.0, 2.0]])
)
def test_iforest_warm_start():
"""Test iterative addition of iTrees to an iForest"""
rng = check_random_state(0)
X = rng.randn(20, 2)
# fit first 10 trees
clf = IsolationForest(
n_estimators=10, max_samples=20, random_state=rng, warm_start=True
)
clf.fit(X)
# remember the 1st tree
tree_1 = clf.estimators_[0]
# fit another 10 trees
clf.set_params(n_estimators=20)
clf.fit(X)
# expecting 20 fitted trees and no overwritten trees
assert len(clf.estimators_) == 20
assert clf.estimators_[0] is tree_1
# mock get_chunk_n_rows to actually test more than one chunk (here one
# chunk has 3 rows):
@patch(
"sklearn.ensemble._iforest.get_chunk_n_rows",
side_effect=Mock(**{"return_value": 3}),
)
@pytest.mark.parametrize("contamination, n_predict_calls", [(0.25, 3), ("auto", 2)])
def test_iforest_chunks_works1(
mocked_get_chunk, contamination, n_predict_calls, global_random_seed
):
test_iforest_works(contamination, global_random_seed)
assert mocked_get_chunk.call_count == n_predict_calls
# idem with chunk_size = 10 rows
@patch(
"sklearn.ensemble._iforest.get_chunk_n_rows",
side_effect=Mock(**{"return_value": 10}),
)
@pytest.mark.parametrize("contamination, n_predict_calls", [(0.25, 3), ("auto", 2)])
def test_iforest_chunks_works2(
mocked_get_chunk, contamination, n_predict_calls, global_random_seed
):
test_iforest_works(contamination, global_random_seed)
assert mocked_get_chunk.call_count == n_predict_calls
def test_iforest_with_uniform_data():
"""Test whether iforest predicts inliers when using uniform data"""
# 2-d array of all 1s
X = np.ones((100, 10))
iforest = IsolationForest()
iforest.fit(X)
rng = np.random.RandomState(0)
assert all(iforest.predict(X) == 1)
assert all(iforest.predict(rng.randn(100, 10)) == 1)
assert all(iforest.predict(X + 1) == 1)
assert all(iforest.predict(X - 1) == 1)
# 2-d array where columns contain the same value across rows
X = np.repeat(rng.randn(1, 10), 100, 0)
iforest = IsolationForest()
iforest.fit(X)
assert all(iforest.predict(X) == 1)
assert all(iforest.predict(rng.randn(100, 10)) == 1)
assert all(iforest.predict(np.ones((100, 10))) == 1)
# Single row
X = rng.randn(1, 10)
iforest = IsolationForest()
iforest.fit(X)
assert all(iforest.predict(X) == 1)
assert all(iforest.predict(rng.randn(100, 10)) == 1)
assert all(iforest.predict(np.ones((100, 10))) == 1)
def test_iforest_with_n_jobs_does_not_segfault():
"""Check that Isolation Forest does not segfault with n_jobs=2
Non-regression test for #23252
"""
X, _ = make_classification(n_samples=85_000, n_features=100, random_state=0)
X = csc_matrix(X)
IsolationForest(n_estimators=10, max_samples=256, n_jobs=2).fit(X)
# TODO(1.4): remove in 1.4
def test_base_estimator_property_deprecated():
X = np.array([[1, 2], [3, 4]])
y = np.array([1, 0])
model = IsolationForest()
model.fit(X, y)
warn_msg = (
"Attribute `base_estimator_` was deprecated in version 1.2 and "
"will be removed in 1.4. Use `estimator_` instead."
)
with pytest.warns(FutureWarning, match=warn_msg):
model.base_estimator_
|