File: test_deprecated_utils.py

package info (click to toggle)
scikit-learn 0.23.2-5
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 21,892 kB
  • sloc: python: 132,020; cpp: 5,765; javascript: 2,201; ansic: 831; makefile: 213; sh: 44
file content (128 lines) | stat: -rw-r--r-- 4,032 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import pytest
import types
import numpy as np
import warnings

from sklearn.dummy import DummyClassifier
from sklearn.utils import all_estimators
from sklearn.utils.estimator_checks import choose_check_classifiers_labels
from sklearn.utils.estimator_checks import NotAnArray
from sklearn.utils.estimator_checks import enforce_estimator_tags_y
from sklearn.utils.estimator_checks import is_public_parameter
from sklearn.utils.estimator_checks import pairwise_estimator_convert_X
from sklearn.utils.estimator_checks import set_checking_parameters
from sklearn.utils.optimize import newton_cg
from sklearn.utils.random import random_choice_csc
from sklearn.utils import safe_indexing


# This file tests the utils that are deprecated


# TODO: remove in 0.24
def test_choose_check_classifiers_labels_deprecated():
    with pytest.warns(FutureWarning,
                      match="removed in version 0.24"):
        choose_check_classifiers_labels(None, None, None)


# TODO: remove in 0.24
def test_enforce_estimator_tags_y():
    with pytest.warns(FutureWarning,
                      match="removed in version 0.24"):
        enforce_estimator_tags_y(DummyClassifier(), np.array([0, 1]))


# TODO: remove in 0.24
def test_notanarray():
    with pytest.warns(FutureWarning,
                      match="removed in version 0.24"):
        NotAnArray([1, 2])


# TODO: remove in 0.24
def test_is_public_parameter():
    with pytest.warns(FutureWarning,
                      match="removed in version 0.24"):
        is_public_parameter('hello')


# TODO: remove in 0.24
def test_pairwise_estimator_convert_X():
    with pytest.warns(FutureWarning,
                      match="removed in version 0.24"):
        pairwise_estimator_convert_X([[1, 2]], DummyClassifier())


# TODO: remove in 0.24
def test_set_checking_parameters():
    with pytest.warns(FutureWarning,
                      match="removed in version 0.24"):
        set_checking_parameters(DummyClassifier())


# TODO: remove in 0.24
def test_newton_cg():
    rng = np.random.RandomState(0)
    A = rng.normal(size=(10, 10))
    x0 = np.ones(10)

    def func(x):
        Ax = A.dot(x)
        return .5 * (Ax).dot(Ax)

    def grad(x):
        return A.T.dot(A.dot(x))

    def grad_hess(x):
        return grad(x), lambda x: A.T.dot(A.dot(x))

    with pytest.warns(FutureWarning,
                      match="removed in version 0.24"):
        newton_cg(grad_hess, func, grad, x0)


# TODO: remove in 0.24
def test_random_choice_csc():
    with pytest.warns(FutureWarning,
                      match="removed in version 0.24"):
        random_choice_csc(10, [[2]])


# TODO: remove in 0.24
def test_safe_indexing():
    with pytest.warns(FutureWarning,
                      match="removed in version 0.24"):
        safe_indexing([1, 2], 0)


# TODO: remove in 0.24
def test_partial_dependence_no_shadowing():
    # Non-regression test for:
    # https://github.com/scikit-learn/scikit-learn/issues/15842
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=FutureWarning)
        from sklearn.inspection.partial_dependence import partial_dependence as _  # noqa

        # Calling all_estimators() also triggers a recursive import of all
        # submodules, including deprecated ones.
        all_estimators()

    from sklearn.inspection import partial_dependence
    assert isinstance(partial_dependence, types.FunctionType)


# TODO: remove in 0.24
def test_dict_learning_no_shadowing():
    # Non-regression test for:
    # https://github.com/scikit-learn/scikit-learn/issues/15842
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=FutureWarning)
        from sklearn.decomposition.dict_learning import dict_learning as _  # noqa

        # Calling all_estimators() also triggers a recursive import of all
        # submodules, including deprecated ones.
        all_estimators()

    from sklearn.decomposition import dict_learning
    assert isinstance(dict_learning, types.FunctionType)