1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
|
import os
import warnings
from os import environ
from os.path import exists, join
import pytest
from _pytest.doctest import DoctestItem
from sklearn.datasets import get_data_home
from sklearn.datasets._base import _pkl_filepath
from sklearn.datasets._twenty_newsgroups import CACHE_NAME
from sklearn.utils import IS_PYPY
from sklearn.utils._testing import SkipTest, check_skip_network
from sklearn.utils.fixes import np_base_version, parse_version
def setup_labeled_faces():
data_home = get_data_home()
if not exists(join(data_home, "lfw_home")):
raise SkipTest("Skipping dataset loading doctests")
def setup_rcv1():
check_skip_network()
# skip the test in rcv1.rst if the dataset is not already loaded
rcv1_dir = join(get_data_home(), "RCV1")
if not exists(rcv1_dir):
raise SkipTest("Download RCV1 dataset to run this test.")
def setup_twenty_newsgroups():
cache_path = _pkl_filepath(get_data_home(), CACHE_NAME)
if not exists(cache_path):
raise SkipTest("Skipping dataset loading doctests")
def setup_working_with_text_data():
if IS_PYPY and os.environ.get("CI", None):
raise SkipTest("Skipping too slow test with PyPy on CI")
check_skip_network()
cache_path = _pkl_filepath(get_data_home(), CACHE_NAME)
if not exists(cache_path):
raise SkipTest("Skipping dataset loading doctests")
def setup_loading_other_datasets():
try:
import pandas # noqa
except ImportError:
raise SkipTest("Skipping loading_other_datasets.rst, pandas not installed")
# checks SKLEARN_SKIP_NETWORK_TESTS to see if test should run
run_network_tests = environ.get("SKLEARN_SKIP_NETWORK_TESTS", "1") == "0"
if not run_network_tests:
raise SkipTest(
"Skipping loading_other_datasets.rst, tests can be "
"enabled by setting SKLEARN_SKIP_NETWORK_TESTS=0"
)
def setup_compose():
try:
import pandas # noqa
except ImportError:
raise SkipTest("Skipping compose.rst, pandas not installed")
def setup_impute():
try:
import pandas # noqa
except ImportError:
raise SkipTest("Skipping impute.rst, pandas not installed")
def setup_grid_search():
try:
import pandas # noqa
except ImportError:
raise SkipTest("Skipping grid_search.rst, pandas not installed")
def setup_preprocessing():
try:
import pandas # noqa
if parse_version(pandas.__version__) < parse_version("1.1.0"):
raise SkipTest("Skipping preprocessing.rst, pandas version < 1.1.0")
except ImportError:
raise SkipTest("Skipping preprocessing.rst, pandas not installed")
def setup_unsupervised_learning():
try:
import skimage # noqa
except ImportError:
raise SkipTest("Skipping unsupervised_learning.rst, scikit-image not installed")
# ignore deprecation warnings from scipy.misc.face
warnings.filterwarnings(
"ignore", "The binary mode of fromstring", DeprecationWarning
)
def skip_if_matplotlib_not_installed(fname):
try:
import matplotlib # noqa
except ImportError:
basename = os.path.basename(fname)
raise SkipTest(f"Skipping doctests for {basename}, matplotlib not installed")
def skip_if_cupy_not_installed(fname):
try:
import cupy # noqa
except ImportError:
basename = os.path.basename(fname)
raise SkipTest(f"Skipping doctests for {basename}, cupy not installed")
def pytest_runtest_setup(item):
fname = item.fspath.strpath
# normalize filename to use forward slashes on Windows for easier handling
# later
fname = fname.replace(os.sep, "/")
is_index = fname.endswith("datasets/index.rst")
if fname.endswith("datasets/labeled_faces.rst") or is_index:
setup_labeled_faces()
elif fname.endswith("datasets/rcv1.rst") or is_index:
setup_rcv1()
elif fname.endswith("datasets/twenty_newsgroups.rst") or is_index:
setup_twenty_newsgroups()
elif (
fname.endswith("tutorial/text_analytics/working_with_text_data.rst") or is_index
):
setup_working_with_text_data()
elif fname.endswith("modules/compose.rst") or is_index:
setup_compose()
elif fname.endswith("datasets/loading_other_datasets.rst"):
setup_loading_other_datasets()
elif fname.endswith("modules/impute.rst"):
setup_impute()
elif fname.endswith("modules/grid_search.rst"):
setup_grid_search()
elif fname.endswith("modules/preprocessing.rst"):
setup_preprocessing()
elif fname.endswith("statistical_inference/unsupervised_learning.rst"):
setup_unsupervised_learning()
rst_files_requiring_matplotlib = [
"modules/partial_dependence.rst",
"modules/tree.rst",
"tutorial/statistical_inference/settings.rst",
"tutorial/statistical_inference/supervised_learning.rst",
]
for each in rst_files_requiring_matplotlib:
if fname.endswith(each):
skip_if_matplotlib_not_installed(fname)
if fname.endswith("array_api.rst"):
skip_if_cupy_not_installed(fname)
def pytest_configure(config):
# Use matplotlib agg backend during the tests including doctests
try:
import matplotlib
matplotlib.use("agg")
except ImportError:
pass
def pytest_collection_modifyitems(config, items):
"""Called after collect is completed.
Parameters
----------
config : pytest config
items : list of collected items
"""
skip_doctests = False
if np_base_version >= parse_version("2"):
# Skip doctests when using numpy 2 for now. See the following discussion
# to decide what to do in the longer term:
# https://github.com/scikit-learn/scikit-learn/issues/27339
reason = "Due to NEP 51 numpy scalar repr has changed in numpy 2"
skip_doctests = True
# Normally doctest has the entire module's scope. Here we set globs to an empty dict
# to remove the module's scope:
# https://docs.python.org/3/library/doctest.html#what-s-the-execution-context
for item in items:
if isinstance(item, DoctestItem):
item.dtest.globs = {}
if skip_doctests:
skip_marker = pytest.mark.skip(reason=reason)
for item in items:
if isinstance(item, DoctestItem):
item.add_marker(skip_marker)
|