File: conftest.py

package info (click to toggle)
scikit-learn 1.4.2%2Bdfsg-8
  • links: PTS, VCS
  • area: main
  • in suites: sid, trixie
  • size: 25,036 kB
  • sloc: python: 201,105; cpp: 5,790; ansic: 854; makefile: 304; sh: 56; javascript: 20
file content (201 lines) | stat: -rw-r--r-- 6,453 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import os
import warnings
from os import environ
from os.path import exists, join

import pytest
from _pytest.doctest import DoctestItem

from sklearn.datasets import get_data_home
from sklearn.datasets._base import _pkl_filepath
from sklearn.datasets._twenty_newsgroups import CACHE_NAME
from sklearn.utils import IS_PYPY
from sklearn.utils._testing import SkipTest, check_skip_network
from sklearn.utils.fixes import np_base_version, parse_version


def setup_labeled_faces():
    data_home = get_data_home()
    if not exists(join(data_home, "lfw_home")):
        raise SkipTest("Skipping dataset loading doctests")


def setup_rcv1():
    check_skip_network()
    # skip the test in rcv1.rst if the dataset is not already loaded
    rcv1_dir = join(get_data_home(), "RCV1")
    if not exists(rcv1_dir):
        raise SkipTest("Download RCV1 dataset to run this test.")


def setup_twenty_newsgroups():
    cache_path = _pkl_filepath(get_data_home(), CACHE_NAME)
    if not exists(cache_path):
        raise SkipTest("Skipping dataset loading doctests")


def setup_working_with_text_data():
    if IS_PYPY and os.environ.get("CI", None):
        raise SkipTest("Skipping too slow test with PyPy on CI")
    check_skip_network()
    cache_path = _pkl_filepath(get_data_home(), CACHE_NAME)
    if not exists(cache_path):
        raise SkipTest("Skipping dataset loading doctests")


def setup_loading_other_datasets():
    try:
        import pandas  # noqa
    except ImportError:
        raise SkipTest("Skipping loading_other_datasets.rst, pandas not installed")

    # checks SKLEARN_SKIP_NETWORK_TESTS to see if test should run
    run_network_tests = environ.get("SKLEARN_SKIP_NETWORK_TESTS", "1") == "0"
    if not run_network_tests:
        raise SkipTest(
            "Skipping loading_other_datasets.rst, tests can be "
            "enabled by setting SKLEARN_SKIP_NETWORK_TESTS=0"
        )


def setup_compose():
    try:
        import pandas  # noqa
    except ImportError:
        raise SkipTest("Skipping compose.rst, pandas not installed")


def setup_impute():
    try:
        import pandas  # noqa
    except ImportError:
        raise SkipTest("Skipping impute.rst, pandas not installed")


def setup_grid_search():
    try:
        import pandas  # noqa
    except ImportError:
        raise SkipTest("Skipping grid_search.rst, pandas not installed")


def setup_preprocessing():
    try:
        import pandas  # noqa

        if parse_version(pandas.__version__) < parse_version("1.1.0"):
            raise SkipTest("Skipping preprocessing.rst, pandas version < 1.1.0")
    except ImportError:
        raise SkipTest("Skipping preprocessing.rst, pandas not installed")


def setup_unsupervised_learning():
    try:
        import skimage  # noqa
    except ImportError:
        raise SkipTest("Skipping unsupervised_learning.rst, scikit-image not installed")
    # ignore deprecation warnings from scipy.misc.face
    warnings.filterwarnings(
        "ignore", "The binary mode of fromstring", DeprecationWarning
    )


def skip_if_matplotlib_not_installed(fname):
    try:
        import matplotlib  # noqa
    except ImportError:
        basename = os.path.basename(fname)
        raise SkipTest(f"Skipping doctests for {basename}, matplotlib not installed")


def skip_if_cupy_not_installed(fname):
    try:
        import cupy  # noqa
    except ImportError:
        basename = os.path.basename(fname)
        raise SkipTest(f"Skipping doctests for {basename}, cupy not installed")


def pytest_runtest_setup(item):
    fname = item.fspath.strpath
    # normalize filename to use forward slashes on Windows for easier handling
    # later
    fname = fname.replace(os.sep, "/")

    is_index = fname.endswith("datasets/index.rst")
    if fname.endswith("datasets/labeled_faces.rst") or is_index:
        setup_labeled_faces()
    elif fname.endswith("datasets/rcv1.rst") or is_index:
        setup_rcv1()
    elif fname.endswith("datasets/twenty_newsgroups.rst") or is_index:
        setup_twenty_newsgroups()
    elif (
        fname.endswith("tutorial/text_analytics/working_with_text_data.rst") or is_index
    ):
        setup_working_with_text_data()
    elif fname.endswith("modules/compose.rst") or is_index:
        setup_compose()
    elif fname.endswith("datasets/loading_other_datasets.rst"):
        setup_loading_other_datasets()
    elif fname.endswith("modules/impute.rst"):
        setup_impute()
    elif fname.endswith("modules/grid_search.rst"):
        setup_grid_search()
    elif fname.endswith("modules/preprocessing.rst"):
        setup_preprocessing()
    elif fname.endswith("statistical_inference/unsupervised_learning.rst"):
        setup_unsupervised_learning()

    rst_files_requiring_matplotlib = [
        "modules/partial_dependence.rst",
        "modules/tree.rst",
        "tutorial/statistical_inference/settings.rst",
        "tutorial/statistical_inference/supervised_learning.rst",
    ]
    for each in rst_files_requiring_matplotlib:
        if fname.endswith(each):
            skip_if_matplotlib_not_installed(fname)

    if fname.endswith("array_api.rst"):
        skip_if_cupy_not_installed(fname)


def pytest_configure(config):
    # Use matplotlib agg backend during the tests including doctests
    try:
        import matplotlib

        matplotlib.use("agg")
    except ImportError:
        pass


def pytest_collection_modifyitems(config, items):
    """Called after collect is completed.

    Parameters
    ----------
    config : pytest config
    items : list of collected items
    """
    skip_doctests = False
    if np_base_version >= parse_version("2"):
        # Skip doctests when using numpy 2 for now. See the following discussion
        # to decide what to do in the longer term:
        # https://github.com/scikit-learn/scikit-learn/issues/27339
        reason = "Due to NEP 51 numpy scalar repr has changed in numpy 2"
        skip_doctests = True

    # Normally doctest has the entire module's scope. Here we set globs to an empty dict
    # to remove the module's scope:
    # https://docs.python.org/3/library/doctest.html#what-s-the-execution-context
    for item in items:
        if isinstance(item, DoctestItem):
            item.dtest.globs = {}

    if skip_doctests:
        skip_marker = pytest.mark.skip(reason=reason)

        for item in items:
            if isinstance(item, DoctestItem):
                item.add_marker(skip_marker)