1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
|
import pytest
import types
import numpy as np
import warnings
from sklearn.dummy import DummyClassifier
from sklearn.utils import all_estimators
from sklearn.utils.estimator_checks import choose_check_classifiers_labels
from sklearn.utils.estimator_checks import NotAnArray
from sklearn.utils.estimator_checks import enforce_estimator_tags_y
from sklearn.utils.estimator_checks import is_public_parameter
from sklearn.utils.estimator_checks import pairwise_estimator_convert_X
from sklearn.utils.estimator_checks import set_checking_parameters
from sklearn.utils.optimize import newton_cg
from sklearn.utils.random import random_choice_csc
from sklearn.utils import safe_indexing
# This file tests the utils that are deprecated
# TODO: remove in 0.24
def test_choose_check_classifiers_labels_deprecated():
with pytest.warns(FutureWarning,
match="removed in version 0.24"):
choose_check_classifiers_labels(None, None, None)
# TODO: remove in 0.24
def test_enforce_estimator_tags_y():
with pytest.warns(FutureWarning,
match="removed in version 0.24"):
enforce_estimator_tags_y(DummyClassifier(), np.array([0, 1]))
# TODO: remove in 0.24
def test_notanarray():
with pytest.warns(FutureWarning,
match="removed in version 0.24"):
NotAnArray([1, 2])
# TODO: remove in 0.24
def test_is_public_parameter():
with pytest.warns(FutureWarning,
match="removed in version 0.24"):
is_public_parameter('hello')
# TODO: remove in 0.24
def test_pairwise_estimator_convert_X():
with pytest.warns(FutureWarning,
match="removed in version 0.24"):
pairwise_estimator_convert_X([[1, 2]], DummyClassifier())
# TODO: remove in 0.24
def test_set_checking_parameters():
with pytest.warns(FutureWarning,
match="removed in version 0.24"):
set_checking_parameters(DummyClassifier())
# TODO: remove in 0.24
def test_newton_cg():
rng = np.random.RandomState(0)
A = rng.normal(size=(10, 10))
x0 = np.ones(10)
def func(x):
Ax = A.dot(x)
return .5 * (Ax).dot(Ax)
def grad(x):
return A.T.dot(A.dot(x))
def grad_hess(x):
return grad(x), lambda x: A.T.dot(A.dot(x))
with pytest.warns(FutureWarning,
match="removed in version 0.24"):
newton_cg(grad_hess, func, grad, x0)
# TODO: remove in 0.24
def test_random_choice_csc():
with pytest.warns(FutureWarning,
match="removed in version 0.24"):
random_choice_csc(10, [[2]])
# TODO: remove in 0.24
def test_safe_indexing():
with pytest.warns(FutureWarning,
match="removed in version 0.24"):
safe_indexing([1, 2], 0)
# TODO: remove in 0.24
def test_partial_dependence_no_shadowing():
# Non-regression test for:
# https://github.com/scikit-learn/scikit-learn/issues/15842
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=FutureWarning)
from sklearn.inspection.partial_dependence import partial_dependence as _ # noqa
# Calling all_estimators() also triggers a recursive import of all
# submodules, including deprecated ones.
all_estimators()
from sklearn.inspection import partial_dependence
assert isinstance(partial_dependence, types.FunctionType)
# TODO: remove in 0.24
def test_dict_learning_no_shadowing():
# Non-regression test for:
# https://github.com/scikit-learn/scikit-learn/issues/15842
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=FutureWarning)
from sklearn.decomposition.dict_learning import dict_learning as _ # noqa
# Calling all_estimators() also triggers a recursive import of all
# submodules, including deprecated ones.
all_estimators()
from sklearn.decomposition import dict_learning
assert isinstance(dict_learning, types.FunctionType)
|