1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267
|
import warnings
import numpy as np
import scipy.sparse as sp
from scipy.linalg import pinv2
from itertools import chain
from sklearn.utils.testing import (assert_equal, assert_raises, assert_true,
assert_almost_equal, assert_array_equal,
SkipTest, assert_raises_regex,
assert_greater_equal)
from sklearn.utils import check_random_state
from sklearn.utils import deprecated
from sklearn.utils import resample
from sklearn.utils import safe_mask
from sklearn.utils import column_or_1d
from sklearn.utils import safe_indexing
from sklearn.utils import shuffle
from sklearn.utils import gen_even_slices
from sklearn.utils.extmath import pinvh
from sklearn.utils.arpack import eigsh
from sklearn.utils.mocking import MockDataFrame
from sklearn.utils.graph import graph_laplacian
def test_make_rng():
# Check the check_random_state utility function behavior
assert_true(check_random_state(None) is np.random.mtrand._rand)
assert_true(check_random_state(np.random) is np.random.mtrand._rand)
rng_42 = np.random.RandomState(42)
assert_true(check_random_state(42).randint(100) == rng_42.randint(100))
rng_42 = np.random.RandomState(42)
assert_true(check_random_state(rng_42) is rng_42)
rng_42 = np.random.RandomState(42)
assert_true(check_random_state(43).randint(100) != rng_42.randint(100))
assert_raises(ValueError, check_random_state, "some invalid seed")
def test_deprecated():
# Test whether the deprecated decorator issues appropriate warnings
# Copied almost verbatim from http://docs.python.org/library/warnings.html
# First a function...
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
@deprecated()
def ham():
return "spam"
spam = ham()
assert_equal(spam, "spam") # function must remain usable
assert_equal(len(w), 1)
assert_true(issubclass(w[0].category, DeprecationWarning))
assert_true("deprecated" in str(w[0].message).lower())
# ... then a class.
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
@deprecated("don't use this")
class Ham(object):
SPAM = 1
ham = Ham()
assert_true(hasattr(ham, "SPAM"))
assert_equal(len(w), 1)
assert_true(issubclass(w[0].category, DeprecationWarning))
assert_true("deprecated" in str(w[0].message).lower())
def test_resample():
# Border case not worth mentioning in doctests
assert_true(resample() is None)
# Check that invalid arguments yield ValueError
assert_raises(ValueError, resample, [0], [0, 1])
assert_raises(ValueError, resample, [0, 1], [0, 1],
replace=False, n_samples=3)
assert_raises(ValueError, resample, [0, 1], [0, 1], meaning_of_life=42)
# Issue:6581, n_samples can be more when replace is True (default).
assert_equal(len(resample([1, 2], n_samples=5)), 5)
def test_safe_mask():
random_state = check_random_state(0)
X = random_state.rand(5, 4)
X_csr = sp.csr_matrix(X)
mask = [False, False, True, True, True]
mask = safe_mask(X, mask)
assert_equal(X[mask].shape[0], 3)
mask = safe_mask(X_csr, mask)
assert_equal(X_csr[mask].shape[0], 3)
def test_pinvh_simple_real():
a = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 10]], dtype=np.float64)
a = np.dot(a, a.T)
a_pinv = pinvh(a)
assert_almost_equal(np.dot(a, a_pinv), np.eye(3))
def test_pinvh_nonpositive():
a = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float64)
a = np.dot(a, a.T)
u, s, vt = np.linalg.svd(a)
s[0] *= -1
a = np.dot(u * s, vt) # a is now symmetric non-positive and singular
a_pinv = pinv2(a)
a_pinvh = pinvh(a)
assert_almost_equal(a_pinv, a_pinvh)
def test_pinvh_simple_complex():
a = (np.array([[1, 2, 3], [4, 5, 6], [7, 8, 10]])
+ 1j * np.array([[10, 8, 7], [6, 5, 4], [3, 2, 1]]))
a = np.dot(a, a.conj().T)
a_pinv = pinvh(a)
assert_almost_equal(np.dot(a, a_pinv), np.eye(3))
def test_arpack_eigsh_initialization():
# Non-regression test that shows null-space computation is better with
# initialization of eigsh from [-1,1] instead of [0,1]
random_state = check_random_state(42)
A = random_state.rand(50, 50)
A = np.dot(A.T, A) # create s.p.d. matrix
A = graph_laplacian(A) + 1e-7 * np.identity(A.shape[0])
k = 5
# Test if eigsh is working correctly
# New initialization [-1,1] (as in original ARPACK)
# Was [0,1] before, with which this test could fail
v0 = random_state.uniform(-1,1, A.shape[0])
w, _ = eigsh(A, k=k, sigma=0.0, v0=v0)
# Eigenvalues of s.p.d. matrix should be nonnegative, w[0] is smallest
assert_greater_equal(w[0], 0)
def test_column_or_1d():
EXAMPLES = [
("binary", ["spam", "egg", "spam"]),
("binary", [0, 1, 0, 1]),
("continuous", np.arange(10) / 20.),
("multiclass", [1, 2, 3]),
("multiclass", [0, 1, 2, 2, 0]),
("multiclass", [[1], [2], [3]]),
("multilabel-indicator", [[0, 1, 0], [0, 0, 1]]),
("multiclass-multioutput", [[1, 2, 3]]),
("multiclass-multioutput", [[1, 1], [2, 2], [3, 1]]),
("multiclass-multioutput", [[5, 1], [4, 2], [3, 1]]),
("multiclass-multioutput", [[1, 2, 3]]),
("continuous-multioutput", np.arange(30).reshape((-1, 3))),
]
for y_type, y in EXAMPLES:
if y_type in ["binary", 'multiclass', "continuous"]:
assert_array_equal(column_or_1d(y), np.ravel(y))
else:
assert_raises(ValueError, column_or_1d, y)
def test_safe_indexing():
X = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
inds = np.array([1, 2])
X_inds = safe_indexing(X, inds)
X_arrays = safe_indexing(np.array(X), inds)
assert_array_equal(np.array(X_inds), X_arrays)
assert_array_equal(np.array(X_inds), np.array(X)[inds])
def test_safe_indexing_pandas():
try:
import pandas as pd
except ImportError:
raise SkipTest("Pandas not found")
X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
X_df = pd.DataFrame(X)
inds = np.array([1, 2])
X_df_indexed = safe_indexing(X_df, inds)
X_indexed = safe_indexing(X_df, inds)
assert_array_equal(np.array(X_df_indexed), X_indexed)
# fun with read-only data in dataframes
# this happens in joblib memmapping
X.setflags(write=False)
X_df_readonly = pd.DataFrame(X)
with warnings.catch_warnings(record=True):
X_df_ro_indexed = safe_indexing(X_df_readonly, inds)
assert_array_equal(np.array(X_df_ro_indexed), X_indexed)
def test_safe_indexing_mock_pandas():
X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
X_df = MockDataFrame(X)
inds = np.array([1, 2])
X_df_indexed = safe_indexing(X_df, inds)
X_indexed = safe_indexing(X_df, inds)
assert_array_equal(np.array(X_df_indexed), X_indexed)
def test_shuffle_on_ndim_equals_three():
def to_tuple(A): # to make the inner arrays hashable
return tuple(tuple(tuple(C) for C in B) for B in A)
A = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) # A.shape = (2,2,2)
S = set(to_tuple(A))
shuffle(A) # shouldn't raise a ValueError for dim = 3
assert_equal(set(to_tuple(A)), S)
def test_shuffle_dont_convert_to_array():
# Check that shuffle does not try to convert to numpy arrays with float
# dtypes can let any indexable datastructure pass-through.
a = ['a', 'b', 'c']
b = np.array(['a', 'b', 'c'], dtype=object)
c = [1, 2, 3]
d = MockDataFrame(np.array([['a', 0],
['b', 1],
['c', 2]],
dtype=object))
e = sp.csc_matrix(np.arange(6).reshape(3, 2))
a_s, b_s, c_s, d_s, e_s = shuffle(a, b, c, d, e, random_state=0)
assert_equal(a_s, ['c', 'b', 'a'])
assert_equal(type(a_s), list)
assert_array_equal(b_s, ['c', 'b', 'a'])
assert_equal(b_s.dtype, object)
assert_equal(c_s, [3, 2, 1])
assert_equal(type(c_s), list)
assert_array_equal(d_s, np.array([['c', 2],
['b', 1],
['a', 0]],
dtype=object))
assert_equal(type(d_s), MockDataFrame)
assert_array_equal(e_s.toarray(), np.array([[4, 5],
[2, 3],
[0, 1]]))
def test_gen_even_slices():
# check that gen_even_slices contains all samples
some_range = range(10)
joined_range = list(chain(*[some_range[slice] for slice in gen_even_slices(10, 3)]))
assert_array_equal(some_range, joined_range)
# check that passing negative n_chunks raises an error
slices = gen_even_slices(10, -1)
assert_raises_regex(ValueError, "gen_even_slices got n_packs=-1, must be"
" >=1", next, slices)
|