1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375
|
from itertools import chain, product
import warnings
import pytest
import numpy as np
import scipy.sparse as sp
from scipy.linalg import pinv2
from scipy.sparse.csgraph import laplacian
from sklearn.utils.testing import (assert_equal, assert_raises,
assert_almost_equal, assert_array_equal,
SkipTest, assert_raises_regex,
assert_greater_equal, ignore_warnings,
assert_warns_message, assert_no_warnings)
from sklearn.utils import check_random_state
from sklearn.utils import deprecated
from sklearn.utils import resample
from sklearn.utils import safe_mask
from sklearn.utils import column_or_1d
from sklearn.utils import safe_indexing
from sklearn.utils import shuffle
from sklearn.utils import gen_even_slices
from sklearn.utils import get_chunk_n_rows
from sklearn.utils import is_scalar_nan
from sklearn.utils.extmath import pinvh
from sklearn.utils.arpack import eigsh
from sklearn.utils.mocking import MockDataFrame
from sklearn import config_context
def test_make_rng():
# Check the check_random_state utility function behavior
assert check_random_state(None) is np.random.mtrand._rand
assert check_random_state(np.random) is np.random.mtrand._rand
rng_42 = np.random.RandomState(42)
assert check_random_state(42).randint(100) == rng_42.randint(100)
rng_42 = np.random.RandomState(42)
assert check_random_state(rng_42) is rng_42
rng_42 = np.random.RandomState(42)
assert check_random_state(43).randint(100) != rng_42.randint(100)
assert_raises(ValueError, check_random_state, "some invalid seed")
def test_deprecated():
# Test whether the deprecated decorator issues appropriate warnings
# Copied almost verbatim from https://docs.python.org/library/warnings.html
# First a function...
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
@deprecated()
def ham():
return "spam"
spam = ham()
assert_equal(spam, "spam") # function must remain usable
assert_equal(len(w), 1)
assert issubclass(w[0].category, DeprecationWarning)
assert "deprecated" in str(w[0].message).lower()
# ... then a class.
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
@deprecated("don't use this")
class Ham(object):
SPAM = 1
ham = Ham()
assert hasattr(ham, "SPAM")
assert_equal(len(w), 1)
assert issubclass(w[0].category, DeprecationWarning)
assert "deprecated" in str(w[0].message).lower()
def test_resample():
# Border case not worth mentioning in doctests
assert resample() is None
# Check that invalid arguments yield ValueError
assert_raises(ValueError, resample, [0], [0, 1])
assert_raises(ValueError, resample, [0, 1], [0, 1],
replace=False, n_samples=3)
assert_raises(ValueError, resample, [0, 1], [0, 1], meaning_of_life=42)
# Issue:6581, n_samples can be more when replace is True (default).
assert_equal(len(resample([1, 2], n_samples=5)), 5)
def test_safe_mask():
random_state = check_random_state(0)
X = random_state.rand(5, 4)
X_csr = sp.csr_matrix(X)
mask = [False, False, True, True, True]
mask = safe_mask(X, mask)
assert_equal(X[mask].shape[0], 3)
mask = safe_mask(X_csr, mask)
assert_equal(X_csr[mask].shape[0], 3)
@ignore_warnings # Test deprecated backport to be removed in 0.21
def test_pinvh_simple_real():
a = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 10]], dtype=np.float64)
a = np.dot(a, a.T)
a_pinv = pinvh(a)
assert_almost_equal(np.dot(a, a_pinv), np.eye(3))
@ignore_warnings # Test deprecated backport to be removed in 0.21
def test_pinvh_nonpositive():
a = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float64)
a = np.dot(a, a.T)
u, s, vt = np.linalg.svd(a)
s[0] *= -1
a = np.dot(u * s, vt) # a is now symmetric non-positive and singular
a_pinv = pinv2(a)
a_pinvh = pinvh(a)
assert_almost_equal(a_pinv, a_pinvh)
@ignore_warnings # Test deprecated backport to be removed in 0.21
def test_pinvh_simple_complex():
a = (np.array([[1, 2, 3], [4, 5, 6], [7, 8, 10]])
+ 1j * np.array([[10, 8, 7], [6, 5, 4], [3, 2, 1]]))
a = np.dot(a, a.conj().T)
a_pinv = pinvh(a)
assert_almost_equal(np.dot(a, a_pinv), np.eye(3))
@ignore_warnings # Test deprecated backport to be removed in 0.21
def test_arpack_eigsh_initialization():
# Non-regression test that shows null-space computation is better with
# initialization of eigsh from [-1,1] instead of [0,1]
random_state = check_random_state(42)
A = random_state.rand(50, 50)
A = np.dot(A.T, A) # create s.p.d. matrix
A = laplacian(A) + 1e-7 * np.identity(A.shape[0])
k = 5
# Test if eigsh is working correctly
# New initialization [-1,1] (as in original ARPACK)
# Was [0,1] before, with which this test could fail
v0 = random_state.uniform(-1, 1, A.shape[0])
w, _ = eigsh(A, k=k, sigma=0.0, v0=v0)
# Eigenvalues of s.p.d. matrix should be nonnegative, w[0] is smallest
assert_greater_equal(w[0], 0)
def test_column_or_1d():
EXAMPLES = [
("binary", ["spam", "egg", "spam"]),
("binary", [0, 1, 0, 1]),
("continuous", np.arange(10) / 20.),
("multiclass", [1, 2, 3]),
("multiclass", [0, 1, 2, 2, 0]),
("multiclass", [[1], [2], [3]]),
("multilabel-indicator", [[0, 1, 0], [0, 0, 1]]),
("multiclass-multioutput", [[1, 2, 3]]),
("multiclass-multioutput", [[1, 1], [2, 2], [3, 1]]),
("multiclass-multioutput", [[5, 1], [4, 2], [3, 1]]),
("multiclass-multioutput", [[1, 2, 3]]),
("continuous-multioutput", np.arange(30).reshape((-1, 3))),
]
for y_type, y in EXAMPLES:
if y_type in ["binary", 'multiclass', "continuous"]:
assert_array_equal(column_or_1d(y), np.ravel(y))
else:
assert_raises(ValueError, column_or_1d, y)
def test_safe_indexing():
X = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
inds = np.array([1, 2])
X_inds = safe_indexing(X, inds)
X_arrays = safe_indexing(np.array(X), inds)
assert_array_equal(np.array(X_inds), X_arrays)
assert_array_equal(np.array(X_inds), np.array(X)[inds])
def test_safe_indexing_pandas():
try:
import pandas as pd
except ImportError:
raise SkipTest("Pandas not found")
X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
X_df = pd.DataFrame(X)
inds = np.array([1, 2])
X_df_indexed = safe_indexing(X_df, inds)
X_indexed = safe_indexing(X_df, inds)
assert_array_equal(np.array(X_df_indexed), X_indexed)
# fun with read-only data in dataframes
# this happens in joblib memmapping
X.setflags(write=False)
X_df_readonly = pd.DataFrame(X)
inds_readonly = inds.copy()
inds_readonly.setflags(write=False)
for this_df, this_inds in product([X_df, X_df_readonly],
[inds, inds_readonly]):
with warnings.catch_warnings(record=True):
X_df_indexed = safe_indexing(this_df, this_inds)
assert_array_equal(np.array(X_df_indexed), X_indexed)
def test_safe_indexing_mock_pandas():
X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
X_df = MockDataFrame(X)
inds = np.array([1, 2])
X_df_indexed = safe_indexing(X_df, inds)
X_indexed = safe_indexing(X_df, inds)
assert_array_equal(np.array(X_df_indexed), X_indexed)
def test_shuffle_on_ndim_equals_three():
def to_tuple(A): # to make the inner arrays hashable
return tuple(tuple(tuple(C) for C in B) for B in A)
A = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) # A.shape = (2,2,2)
S = set(to_tuple(A))
shuffle(A) # shouldn't raise a ValueError for dim = 3
assert_equal(set(to_tuple(A)), S)
def test_shuffle_dont_convert_to_array():
# Check that shuffle does not try to convert to numpy arrays with float
# dtypes can let any indexable datastructure pass-through.
a = ['a', 'b', 'c']
b = np.array(['a', 'b', 'c'], dtype=object)
c = [1, 2, 3]
d = MockDataFrame(np.array([['a', 0],
['b', 1],
['c', 2]],
dtype=object))
e = sp.csc_matrix(np.arange(6).reshape(3, 2))
a_s, b_s, c_s, d_s, e_s = shuffle(a, b, c, d, e, random_state=0)
assert_equal(a_s, ['c', 'b', 'a'])
assert_equal(type(a_s), list)
assert_array_equal(b_s, ['c', 'b', 'a'])
assert_equal(b_s.dtype, object)
assert_equal(c_s, [3, 2, 1])
assert_equal(type(c_s), list)
assert_array_equal(d_s, np.array([['c', 2],
['b', 1],
['a', 0]],
dtype=object))
assert_equal(type(d_s), MockDataFrame)
assert_array_equal(e_s.toarray(), np.array([[4, 5],
[2, 3],
[0, 1]]))
def test_gen_even_slices():
# check that gen_even_slices contains all samples
some_range = range(10)
joined_range = list(chain(*[some_range[slice] for slice in
gen_even_slices(10, 3)]))
assert_array_equal(some_range, joined_range)
# check that passing negative n_chunks raises an error
slices = gen_even_slices(10, -1)
assert_raises_regex(ValueError, "gen_even_slices got n_packs=-1, must be"
" >=1", next, slices)
@pytest.mark.parametrize(
('row_bytes', 'max_n_rows', 'working_memory', 'expected', 'warning'),
[(1024, None, 1, 1024, None),
(1024, None, 0.99999999, 1023, None),
(1023, None, 1, 1025, None),
(1025, None, 1, 1023, None),
(1024, None, 2, 2048, None),
(1024, 7, 1, 7, None),
(1024 * 1024, None, 1, 1, None),
(1024 * 1024 + 1, None, 1, 1,
'Could not adhere to working_memory config. '
'Currently 1MiB, 2MiB required.'),
])
def test_get_chunk_n_rows(row_bytes, max_n_rows, working_memory,
expected, warning):
if warning is not None:
def check_warning(*args, **kw):
return assert_warns_message(UserWarning, warning, *args, **kw)
else:
check_warning = assert_no_warnings
actual = check_warning(get_chunk_n_rows,
row_bytes=row_bytes,
max_n_rows=max_n_rows,
working_memory=working_memory)
assert actual == expected
assert type(actual) is type(expected)
with config_context(working_memory=working_memory):
actual = check_warning(get_chunk_n_rows,
row_bytes=row_bytes,
max_n_rows=max_n_rows)
assert actual == expected
assert type(actual) is type(expected)
@pytest.mark.parametrize("value, result", [(float("nan"), True),
(np.nan, True),
(np.float("nan"), True),
(np.float32("nan"), True),
(np.float64("nan"), True),
(0, False),
(0., False),
(None, False),
("", False),
("nan", False),
([np.nan], False)])
def test_is_scalar_nan(value, result):
assert is_scalar_nan(value) is result
def dummy_func():
pass
def test_deprecation_joblib_api(tmpdir):
def check_warning(*args, **kw):
return assert_warns_message(
DeprecationWarning, "deprecated in version 0.20.1", *args, **kw)
# Ensure that the joblib API is deprecated in sklearn.util
from sklearn.utils import Parallel, Memory, delayed
from sklearn.utils import cpu_count, hash, effective_n_jobs
check_warning(Memory, str(tmpdir))
check_warning(hash, 1)
check_warning(Parallel)
check_warning(cpu_count)
check_warning(effective_n_jobs, 1)
check_warning(delayed, dummy_func)
# Only parallel_backend and register_parallel_backend are not deprecated in
# sklearn.utils
from sklearn.utils import parallel_backend, register_parallel_backend
assert_no_warnings(parallel_backend, 'loky', None)
assert_no_warnings(register_parallel_backend, 'failing', None)
# Ensure that the deprecation have no side effect in sklearn.utils._joblib
from sklearn.utils._joblib import Parallel, Memory, delayed
from sklearn.utils._joblib import cpu_count, hash, effective_n_jobs
from sklearn.utils._joblib import parallel_backend
from sklearn.utils._joblib import register_parallel_backend
assert_no_warnings(Memory, str(tmpdir))
assert_no_warnings(hash, 1)
assert_no_warnings(Parallel)
assert_no_warnings(cpu_count)
assert_no_warnings(effective_n_jobs, 1)
assert_no_warnings(delayed, dummy_func)
assert_no_warnings(parallel_backend, 'loky', None)
assert_no_warnings(register_parallel_backend, 'failing', None)
from sklearn.utils._joblib import joblib
del joblib.parallel.BACKENDS['failing']
|