1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
|
import os
from os.path import exists
from os.path import join
import warnings
import numpy as np
from sklearn.utils import IS_PYPY
from sklearn.utils.testing import SkipTest
from sklearn.utils.testing import check_skip_network
from sklearn.datasets import get_data_home
from sklearn.datasets.base import _pkl_filepath
from sklearn.datasets.twenty_newsgroups import CACHE_NAME
from sklearn.utils.testing import install_mldata_mock
from sklearn.utils.testing import uninstall_mldata_mock
def setup_labeled_faces():
data_home = get_data_home()
if not exists(join(data_home, 'lfw_home')):
raise SkipTest("Skipping dataset loading doctests")
def setup_mldata():
# setup mock urllib2 module to avoid downloading from mldata.org
install_mldata_mock({
'mnist-original': {
'data': np.empty((70000, 784)),
'label': np.repeat(np.arange(10, dtype='d'), 7000),
},
'iris': {
'data': np.empty((150, 4)),
},
'datasets-uci-iris': {
'double0': np.empty((150, 4)),
'class': np.empty((150,)),
},
})
def teardown_mldata():
uninstall_mldata_mock()
def setup_rcv1():
check_skip_network()
# skip the test in rcv1.rst if the dataset is not already loaded
rcv1_dir = join(get_data_home(), "RCV1")
if not exists(rcv1_dir):
raise SkipTest("Download RCV1 dataset to run this test.")
def setup_twenty_newsgroups():
data_home = get_data_home()
cache_path = _pkl_filepath(get_data_home(), CACHE_NAME)
if not exists(cache_path):
raise SkipTest("Skipping dataset loading doctests")
def setup_working_with_text_data():
if IS_PYPY and os.environ.get('CI', None):
raise SkipTest('Skipping too slow test with PyPy on CI')
check_skip_network()
cache_path = _pkl_filepath(get_data_home(), CACHE_NAME)
if not exists(cache_path):
raise SkipTest("Skipping dataset loading doctests")
def setup_compose():
try:
import pandas # noqa
except ImportError:
raise SkipTest("Skipping compose.rst, pandas not installed")
def setup_impute():
try:
import pandas # noqa
except ImportError:
raise SkipTest("Skipping impute.rst, pandas not installed")
def setup_unsupervised_learning():
# ignore deprecation warnings from scipy.misc.face
warnings.filterwarnings('ignore', 'The binary mode of fromstring',
DeprecationWarning)
def pytest_runtest_setup(item):
fname = item.fspath.strpath
is_index = fname.endswith('datasets/index.rst')
if fname.endswith('datasets/labeled_faces.rst') or is_index:
setup_labeled_faces()
elif fname.endswith('datasets/mldata.rst') or is_index:
setup_mldata()
elif fname.endswith('datasets/rcv1.rst') or is_index:
setup_rcv1()
elif fname.endswith('datasets/twenty_newsgroups.rst') or is_index:
setup_twenty_newsgroups()
elif fname.endswith('tutorial/text_analytics/working_with_text_data.rst')\
or is_index:
setup_working_with_text_data()
elif fname.endswith('modules/compose.rst') or is_index:
setup_compose()
elif IS_PYPY and fname.endswith('modules/feature_extraction.rst'):
raise SkipTest('FeatureHasher is not compatible with PyPy')
elif fname.endswith('modules/impute.rst'):
setup_impute()
elif fname.endswith('statistical_inference/unsupervised_learning.rst'):
setup_unsupervised_learning()
def pytest_runtest_teardown(item):
fname = item.fspath.strpath
if fname.endswith('datasets/mldata.rst'):
teardown_mldata()
|