1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
|
import scipy.sparse as sp
import numpy as np
import sys
from sklearn.externals.six.moves import cStringIO as StringIO
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.utils.testing import assert_raises_regex, assert_true
from sklearn.utils.estimator_checks import check_estimator
from sklearn.utils.estimator_checks import check_estimators_unfitted
from sklearn.ensemble import AdaBoostClassifier
from sklearn.linear_model import MultiTaskElasticNet
from sklearn.utils.validation import check_X_y, check_array
class CorrectNotFittedError(ValueError):
"""Exception class to raise if estimator is used before fitting.
Like NotFittedError, it inherits from ValueError, but not from
AttributeError. Used for testing only.
"""
class BaseBadClassifier(BaseEstimator, ClassifierMixin):
def fit(self, X, y):
return self
def predict(self, X):
return np.ones(X.shape[0])
class NoCheckinPredict(BaseBadClassifier):
def fit(self, X, y):
X, y = check_X_y(X, y)
return self
class NoSparseClassifier(BaseBadClassifier):
def fit(self, X, y):
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'])
if sp.issparse(X):
raise ValueError("Nonsensical Error")
return self
def predict(self, X):
X = check_array(X)
return np.ones(X.shape[0])
class CorrectNotFittedErrorClassifier(BaseBadClassifier):
def fit(self, X, y):
X, y = check_X_y(X, y)
self.coef_ = np.ones(X.shape[1])
return self
def predict(self, X):
if not hasattr(self, 'coef_'):
raise CorrectNotFittedError("estimator is not fitted yet")
X = check_array(X)
return np.ones(X.shape[0])
def test_check_estimator():
# tests that the estimator actually fails on "bad" estimators.
# not a complete test of all checks, which are very extensive.
# check that we have a set_params and can clone
msg = "it does not implement a 'get_params' methods"
assert_raises_regex(TypeError, msg, check_estimator, object)
# check that we have a fit method
msg = "object has no attribute 'fit'"
assert_raises_regex(AttributeError, msg, check_estimator, BaseEstimator)
# check that fit does input validation
msg = "TypeError not raised by fit"
assert_raises_regex(AssertionError, msg, check_estimator, BaseBadClassifier)
# check that predict does input validation (doesn't accept dicts in input)
msg = "Estimator doesn't check for NaN and inf in predict"
assert_raises_regex(AssertionError, msg, check_estimator, NoCheckinPredict)
# check for sparse matrix input handling
name = NoSparseClassifier.__name__
msg = "Estimator " + name + " doesn't seem to fail gracefully on sparse data"
# the check for sparse input handling prints to the stdout,
# instead of raising an error, so as not to remove the original traceback.
# that means we need to jump through some hoops to catch it.
old_stdout = sys.stdout
string_buffer = StringIO()
sys.stdout = string_buffer
try:
check_estimator(NoSparseClassifier)
except:
pass
finally:
sys.stdout = old_stdout
assert_true(msg in string_buffer.getvalue())
# doesn't error on actual estimator
check_estimator(AdaBoostClassifier)
check_estimator(MultiTaskElasticNet)
def test_check_estimators_unfitted():
# check that a ValueError/AttributeError is raised when calling predict
# on an unfitted estimator
msg = "AttributeError or ValueError not raised by predict"
assert_raises_regex(AssertionError, msg, check_estimators_unfitted,
"estimator", NoSparseClassifier)
# check that CorrectNotFittedError inherit from either ValueError
# or AttributeError
check_estimators_unfitted("estimator", CorrectNotFittedErrorClassifier)
|