1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
|
import numpy as np
import os.path
from io import BytesIO
from numpy.testing import assert_equal, assert_array_equal
from nose.tools import raises
from sklearn.datasets import (load_svmlight_file, load_svmlight_files,
dump_svmlight_file)
currdir = os.path.dirname(os.path.abspath(__file__))
datafile = os.path.join(currdir, "data", "svmlight_classification.txt")
multifile = os.path.join(currdir, "data", "svmlight_multilabel.txt")
invalidfile = os.path.join(currdir, "data", "svmlight_invalid.txt")
def test_load_svmlight_file():
X, y = load_svmlight_file(datafile)
# test X's shape
assert_equal(X.indptr.shape[0], 4)
assert_equal(X.shape[0], 3)
assert_equal(X.shape[1], 21)
assert_equal(y.shape[0], 3)
# test X's non-zero values
for i, j, val in ((0, 2, 2.5), (0, 10, -5.2), (0, 15, 1.5),
(1, 5, 1.0), (1, 12, -3),
(2, 20, 27)):
assert_equal(X[i, j], val)
# tests X's zero values
assert_equal(X[0, 3], 0)
assert_equal(X[0, 5], 0)
assert_equal(X[1, 8], 0)
assert_equal(X[1, 16], 0)
assert_equal(X[2, 18], 0)
# test can change X's values
X[0, 2] *= 2
assert_equal(X[0, 2], 5)
# test y
assert_array_equal(y, [1, 2, 3])
def test_load_svmlight_file_multilabel():
X, y = load_svmlight_file(multifile, multilabel=True)
assert_equal(y, [(0, 1), (2,), (1, 2)])
def test_load_svmlight_files():
X_train, y_train, X_test, y_test = load_svmlight_files([datafile] * 2,
dtype=np.float32)
assert_array_equal(X_train.toarray(), X_test.toarray())
assert_array_equal(y_train, y_test)
assert_equal(X_train.dtype, np.float32)
assert_equal(X_test.dtype, np.float32)
X1, y1, X2, y2, X3, y3 = load_svmlight_files([datafile] * 3,
dtype=np.float64)
assert_equal(X1.dtype, X2.dtype)
assert_equal(X2.dtype, X3.dtype)
assert_equal(X3.dtype, np.float64)
def test_load_svmlight_file_n_features():
X, y = load_svmlight_file(datafile, n_features=20)
# test X'shape
assert_equal(X.indptr.shape[0], 4)
assert_equal(X.shape[0], 3)
assert_equal(X.shape[1], 20)
# test X's non-zero values
for i, j, val in ((0, 2, 2.5), (0, 10, -5.2),
(1, 5, 1.0), (1, 12, -3)):
assert_equal(X[i, j], val)
@raises(ValueError)
def test_load_invalid_file():
load_svmlight_file(invalidfile)
@raises(ValueError)
def test_load_zero_based():
f = BytesIO("-1 4:1.\n1 0:1\n")
load_svmlight_file(f, zero_based=False)
def test_load_zero_based_auto():
data1 = "-1 1:1 2:2 3:3\n"
data2 = "-1 0:0 1:1\n"
f1 = BytesIO(data1)
X, y = load_svmlight_file(f1, zero_based="auto")
assert_equal(X.shape, (1, 3))
f1 = BytesIO(data1)
f2 = BytesIO(data2)
X1, y1, X2, y2 = load_svmlight_files([f1, f2], zero_based="auto")
assert_equal(X1.shape, (1, 4))
assert_equal(X2.shape, (1, 4))
@raises(ValueError)
def test_load_invalid_file2():
load_svmlight_files([datafile, invalidfile, datafile])
@raises(TypeError)
def test_not_a_filename():
# in python 3 integers are valid file opening arguments (taken as unix
# file descriptors)
load_svmlight_file(.42)
@raises(IOError)
def test_invalid_filename():
load_svmlight_file("trou pic nic douille")
def test_dump():
Xs, y = load_svmlight_file(datafile)
Xd = Xs.toarray()
for X in (Xs, Xd):
for zero_based in (True, False):
f = BytesIO()
dump_svmlight_file(X, y, f, zero_based=zero_based)
f.seek(0)
X2, y2 = load_svmlight_file(f, zero_based=zero_based)
assert_array_equal(Xd, X2.toarray())
assert_array_equal(y, y2)
|