1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
|
from collections import Counter
import numpy as np
import pytest
from sklearn.cluster import MiniBatchKMeans
from imblearn.over_sampling import (
ADASYN,
SMOTE,
SMOTEN,
SMOTENC,
SVMSMOTE,
BorderlineSMOTE,
KMeansSMOTE,
)
from imblearn.utils.testing import _CustomNearestNeighbors
@pytest.fixture
def numerical_data():
rng = np.random.RandomState(0)
X = rng.randn(100, 2)
y = np.repeat([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0], 5)
return X, y
@pytest.fixture
def categorical_data():
rng = np.random.RandomState(0)
feature_1 = ["A"] * 10 + ["B"] * 20 + ["C"] * 30
feature_2 = ["A"] * 40 + ["B"] * 20
feature_3 = ["A"] * 20 + ["B"] * 20 + ["C"] * 10 + ["D"] * 10
X = np.array([feature_1, feature_2, feature_3], dtype=object).T
rng.shuffle(X)
y = np.array([0] * 20 + [1] * 40, dtype=np.int32)
y_labels = np.array(["not apple", "apple"], dtype=object)
y = y_labels[y]
return X, y
@pytest.fixture
def heterogeneous_data():
rng = np.random.RandomState(42)
X = np.empty((30, 4), dtype=object)
X[:, :2] = rng.randn(30, 2)
X[:, 2] = rng.choice(["a", "b", "c"], size=30).astype(object)
X[:, 3] = rng.randint(3, size=30)
y = np.array([0] * 10 + [1] * 20)
return X, y, [2, 3]
@pytest.mark.parametrize(
"smote", [BorderlineSMOTE(), SVMSMOTE()], ids=["borderline", "svm"]
)
def test_smote_m_neighbors(numerical_data, smote):
# check that m_neighbors is properly set. Regression test for:
# https://github.com/scikit-learn-contrib/imbalanced-learn/issues/568
X, y = numerical_data
_ = smote.fit_resample(X, y)
assert smote.nn_k_.n_neighbors == 6
assert smote.nn_m_.n_neighbors == 11
@pytest.mark.parametrize(
"smote, neighbor_estimator_name",
[
(ADASYN(random_state=0), "n_neighbors"),
(BorderlineSMOTE(random_state=0), "k_neighbors"),
(
KMeansSMOTE(
kmeans_estimator=MiniBatchKMeans(n_init=1, random_state=0),
random_state=1,
),
"k_neighbors",
),
(SMOTE(random_state=0), "k_neighbors"),
(SVMSMOTE(random_state=0), "k_neighbors"),
],
ids=["adasyn", "borderline", "kmeans", "smote", "svm"],
)
def test_numerical_smote_custom_nn(numerical_data, smote, neighbor_estimator_name):
X, y = numerical_data
params = {
neighbor_estimator_name: _CustomNearestNeighbors(n_neighbors=5),
}
smote.set_params(**params)
X_res, _ = smote.fit_resample(X, y)
assert X_res.shape[0] >= 120
def test_categorical_smote_k_custom_nn(categorical_data):
X, y = categorical_data
smote = SMOTEN(k_neighbors=_CustomNearestNeighbors(n_neighbors=5))
X_res, y_res = smote.fit_resample(X, y)
assert X_res.shape == (80, 3)
assert Counter(y_res) == {"apple": 40, "not apple": 40}
def test_heterogeneous_smote_k_custom_nn(heterogeneous_data):
X, y, categorical_features = heterogeneous_data
smote = SMOTENC(
categorical_features, k_neighbors=_CustomNearestNeighbors(n_neighbors=5)
)
X_res, y_res = smote.fit_resample(X, y)
assert X_res.shape == (40, 4)
assert Counter(y_res) == {0: 20, 1: 20}
@pytest.mark.parametrize(
"smote",
[BorderlineSMOTE(random_state=0), SVMSMOTE(random_state=0)],
ids=["borderline", "svm"],
)
def test_numerical_smote_extra_custom_nn(numerical_data, smote):
X, y = numerical_data
smote.set_params(m_neighbors=_CustomNearestNeighbors(n_neighbors=5))
X_res, y_res = smote.fit_resample(X, y)
assert X_res.shape == (120, 2)
assert Counter(y_res) == {0: 60, 1: 60}
# FIXME: to be removed in 0.12
@pytest.mark.parametrize(
"sampler",
[
ADASYN(random_state=0),
BorderlineSMOTE(random_state=0),
SMOTE(random_state=0),
SMOTEN(random_state=0),
SMOTENC([0], random_state=0),
SVMSMOTE(random_state=0),
],
)
def test_n_jobs_deprecation_warning(numerical_data, sampler):
X, y = numerical_data
sampler.set_params(n_jobs=2)
warning_msg = "The parameter `n_jobs` has been deprecated"
with pytest.warns(FutureWarning, match=warning_msg):
sampler.fit_resample(X, y)
|