1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269
|
import numpy as np
import pytest
from sklearn.base import ClassifierMixin, clone
from sklearn.calibration import CalibrationDisplay
from sklearn.compose import make_column_transformer
from sklearn.datasets import load_iris
from sklearn.exceptions import NotFittedError
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
ConfusionMatrixDisplay,
DetCurveDisplay,
PrecisionRecallDisplay,
PredictionErrorDisplay,
RocCurveDisplay,
)
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
@pytest.fixture(scope="module")
def data():
return load_iris(return_X_y=True)
@pytest.fixture(scope="module")
def data_binary(data):
X, y = data
return X[y < 2], y[y < 2]
@pytest.mark.parametrize(
"Display",
[CalibrationDisplay, DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay],
)
def test_display_curve_error_classifier(pyplot, data, data_binary, Display):
"""Check that a proper error is raised when only binary classification is
supported."""
X, y = data
X_binary, y_binary = data_binary
clf = DecisionTreeClassifier().fit(X, y)
# Case 1: multiclass classifier with multiclass target
msg = "Expected 'estimator' to be a binary classifier. Got 3 classes instead."
with pytest.raises(ValueError, match=msg):
Display.from_estimator(clf, X, y)
# Case 2: multiclass classifier with binary target
with pytest.raises(ValueError, match=msg):
Display.from_estimator(clf, X_binary, y_binary)
# Case 3: binary classifier with multiclass target
clf = DecisionTreeClassifier().fit(X_binary, y_binary)
msg = "The target y is not binary. Got multiclass type of target."
with pytest.raises(ValueError, match=msg):
Display.from_estimator(clf, X, y)
@pytest.mark.parametrize(
"Display",
[CalibrationDisplay, DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay],
)
def test_display_curve_error_regression(pyplot, data_binary, Display):
"""Check that we raise an error with regressor."""
# Case 1: regressor
X, y = data_binary
regressor = DecisionTreeRegressor().fit(X, y)
msg = "Expected 'estimator' to be a binary classifier. Got DecisionTreeRegressor"
with pytest.raises(ValueError, match=msg):
Display.from_estimator(regressor, X, y)
# Case 2: regression target
classifier = DecisionTreeClassifier().fit(X, y)
# Force `y_true` to be seen as a regression problem
y = y + 0.5
msg = "The target y is not binary. Got continuous type of target."
with pytest.raises(ValueError, match=msg):
Display.from_estimator(classifier, X, y)
with pytest.raises(ValueError, match=msg):
Display.from_predictions(y, regressor.fit(X, y).predict(X))
@pytest.mark.parametrize(
"response_method, msg",
[
(
"predict_proba",
"MyClassifier has none of the following attributes: predict_proba.",
),
(
"decision_function",
"MyClassifier has none of the following attributes: decision_function.",
),
(
"auto",
(
"MyClassifier has none of the following attributes: predict_proba,"
" decision_function."
),
),
(
"bad_method",
"MyClassifier has none of the following attributes: bad_method.",
),
],
)
@pytest.mark.parametrize(
"Display", [DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay]
)
def test_display_curve_error_no_response(
pyplot,
data_binary,
response_method,
msg,
Display,
):
"""Check that a proper error is raised when the response method requested
is not defined for the given trained classifier."""
X, y = data_binary
class MyClassifier(ClassifierMixin):
def fit(self, X, y):
self.classes_ = [0, 1]
return self
clf = MyClassifier().fit(X, y)
with pytest.raises(AttributeError, match=msg):
Display.from_estimator(clf, X, y, response_method=response_method)
@pytest.mark.parametrize(
"Display", [DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay]
)
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
def test_display_curve_estimator_name_multiple_calls(
pyplot,
data_binary,
Display,
constructor_name,
):
"""Check that passing `name` when calling `plot` will overwrite the original name
in the legend."""
X, y = data_binary
clf_name = "my hand-crafted name"
clf = LogisticRegression().fit(X, y)
y_pred = clf.predict_proba(X)[:, 1]
# safe guard for the binary if/else construction
assert constructor_name in ("from_estimator", "from_predictions")
if constructor_name == "from_estimator":
disp = Display.from_estimator(clf, X, y, name=clf_name)
else:
disp = Display.from_predictions(y, y_pred, name=clf_name)
assert disp.estimator_name == clf_name
pyplot.close("all")
disp.plot()
assert clf_name in disp.line_.get_label()
pyplot.close("all")
clf_name = "another_name"
disp.plot(name=clf_name)
assert clf_name in disp.line_.get_label()
@pytest.mark.parametrize(
"clf",
[
LogisticRegression(),
make_pipeline(StandardScaler(), LogisticRegression()),
make_pipeline(
make_column_transformer((StandardScaler(), [0, 1])), LogisticRegression()
),
],
)
@pytest.mark.parametrize(
"Display", [DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay]
)
def test_display_curve_not_fitted_errors(pyplot, data_binary, clf, Display):
"""Check that a proper error is raised when the classifier is not
fitted."""
X, y = data_binary
# clone since we parametrize the test and the classifier will be fitted
# when testing the second and subsequent plotting function
model = clone(clf)
with pytest.raises(NotFittedError):
Display.from_estimator(model, X, y)
model.fit(X, y)
disp = Display.from_estimator(model, X, y)
assert model.__class__.__name__ in disp.line_.get_label()
assert disp.estimator_name == model.__class__.__name__
@pytest.mark.parametrize(
"Display", [DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay]
)
def test_display_curve_n_samples_consistency(pyplot, data_binary, Display):
"""Check the error raised when `y_pred` or `sample_weight` have inconsistent
length."""
X, y = data_binary
classifier = DecisionTreeClassifier().fit(X, y)
msg = "Found input variables with inconsistent numbers of samples"
with pytest.raises(ValueError, match=msg):
Display.from_estimator(classifier, X[:-2], y)
with pytest.raises(ValueError, match=msg):
Display.from_estimator(classifier, X, y[:-2])
with pytest.raises(ValueError, match=msg):
Display.from_estimator(classifier, X, y, sample_weight=np.ones(X.shape[0] - 2))
@pytest.mark.parametrize(
"Display", [DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay]
)
def test_display_curve_error_pos_label(pyplot, data_binary, Display):
"""Check consistence of error message when `pos_label` should be specified."""
X, y = data_binary
y = y + 10
classifier = DecisionTreeClassifier().fit(X, y)
y_pred = classifier.predict_proba(X)[:, -1]
msg = r"y_true takes value in {10, 11} and pos_label is not specified"
with pytest.raises(ValueError, match=msg):
Display.from_predictions(y, y_pred)
@pytest.mark.parametrize(
"Display",
[
CalibrationDisplay,
DetCurveDisplay,
PrecisionRecallDisplay,
RocCurveDisplay,
PredictionErrorDisplay,
ConfusionMatrixDisplay,
],
)
@pytest.mark.parametrize(
"constructor",
["from_predictions", "from_estimator"],
)
def test_classifier_display_curve_named_constructor_return_type(
pyplot, data_binary, Display, constructor
):
"""Check that named constructors return the correct type when subclassed.
Non-regression test for:
https://github.com/scikit-learn/scikit-learn/pull/27675
"""
X, y = data_binary
# This can be anything - we just need to check the named constructor return
# type so the only requirement here is instantiating the class without error
y_pred = y
classifier = LogisticRegression().fit(X, y)
class SubclassOfDisplay(Display):
pass
if constructor == "from_predictions":
curve = SubclassOfDisplay.from_predictions(y, y_pred)
else: # constructor == "from_estimator"
curve = SubclassOfDisplay.from_estimator(classifier, X, y)
assert isinstance(curve, SubclassOfDisplay)
|