File: test_svmlight_format.py

package info (click to toggle)
scikit-learn 0.11.0-2%2Bdeb7u1
  • links: PTS, VCS
  • area: main
  • in suites: wheezy
  • size: 13,900 kB
  • sloc: python: 34,740; ansic: 8,860; cpp: 8,849; pascal: 230; makefile: 211; sh: 14
file content (137 lines) | stat: -rw-r--r-- 3,843 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import numpy as np
import os.path
from io import BytesIO

from numpy.testing import assert_equal, assert_array_equal
from nose.tools import raises

from sklearn.datasets import (load_svmlight_file, load_svmlight_files,
                              dump_svmlight_file)

currdir = os.path.dirname(os.path.abspath(__file__))
datafile = os.path.join(currdir, "data", "svmlight_classification.txt")
multifile = os.path.join(currdir, "data", "svmlight_multilabel.txt")
invalidfile = os.path.join(currdir, "data", "svmlight_invalid.txt")


def test_load_svmlight_file():
    X, y = load_svmlight_file(datafile)

    # test X's shape
    assert_equal(X.indptr.shape[0], 4)
    assert_equal(X.shape[0], 3)
    assert_equal(X.shape[1], 21)
    assert_equal(y.shape[0], 3)

    # test X's non-zero values
    for i, j, val in ((0, 2, 2.5), (0, 10, -5.2), (0, 15, 1.5),
                     (1, 5, 1.0), (1, 12, -3),
                     (2, 20, 27)):

        assert_equal(X[i, j], val)

    # tests X's zero values
    assert_equal(X[0, 3], 0)
    assert_equal(X[0, 5], 0)
    assert_equal(X[1, 8], 0)
    assert_equal(X[1, 16], 0)
    assert_equal(X[2, 18], 0)

    # test can change X's values
    X[0, 2] *= 2
    assert_equal(X[0, 2], 5)

    # test y
    assert_array_equal(y, [1, 2, 3])


def test_load_svmlight_file_multilabel():
    X, y = load_svmlight_file(multifile, multilabel=True)
    assert_equal(y, [(0, 1), (2,), (1, 2)])


def test_load_svmlight_files():
    X_train, y_train, X_test, y_test = load_svmlight_files([datafile] * 2,
                                                           dtype=np.float32)
    assert_array_equal(X_train.toarray(), X_test.toarray())
    assert_array_equal(y_train, y_test)
    assert_equal(X_train.dtype, np.float32)
    assert_equal(X_test.dtype, np.float32)

    X1, y1, X2, y2, X3, y3 = load_svmlight_files([datafile] * 3,
                                                 dtype=np.float64)
    assert_equal(X1.dtype, X2.dtype)
    assert_equal(X2.dtype, X3.dtype)
    assert_equal(X3.dtype, np.float64)


def test_load_svmlight_file_n_features():
    X, y = load_svmlight_file(datafile, n_features=20)

    # test X'shape
    assert_equal(X.indptr.shape[0], 4)
    assert_equal(X.shape[0], 3)
    assert_equal(X.shape[1], 20)

    # test X's non-zero values
    for i, j, val in ((0, 2, 2.5), (0, 10, -5.2),
                     (1, 5, 1.0), (1, 12, -3)):

        assert_equal(X[i, j], val)


@raises(ValueError)
def test_load_invalid_file():
    load_svmlight_file(invalidfile)


@raises(ValueError)
def test_load_zero_based():
    f = BytesIO("-1 4:1.\n1 0:1\n")
    load_svmlight_file(f, zero_based=False)


def test_load_zero_based_auto():
    data1 = "-1 1:1 2:2 3:3\n"
    data2 = "-1 0:0 1:1\n"

    f1 = BytesIO(data1)
    X, y = load_svmlight_file(f1, zero_based="auto")
    assert_equal(X.shape, (1, 3))

    f1 = BytesIO(data1)
    f2 = BytesIO(data2)
    X1, y1, X2, y2 = load_svmlight_files([f1, f2], zero_based="auto")
    assert_equal(X1.shape, (1, 4))
    assert_equal(X2.shape, (1, 4))


@raises(ValueError)
def test_load_invalid_file2():
    load_svmlight_files([datafile, invalidfile, datafile])


@raises(TypeError)
def test_not_a_filename():
    # in python 3 integers are valid file opening arguments (taken as unix
    # file descriptors)
    load_svmlight_file(.42)


@raises(IOError)
def test_invalid_filename():
    load_svmlight_file("trou pic nic douille")


def test_dump():
    Xs, y = load_svmlight_file(datafile)
    Xd = Xs.toarray()

    for X in (Xs, Xd):
        for zero_based in (True, False):
            f = BytesIO()
            dump_svmlight_file(X, y, f, zero_based=zero_based)
            f.seek(0)
            X2, y2 = load_svmlight_file(f, zero_based=zero_based)
            assert_array_equal(Xd, X2.toarray())
            assert_array_equal(y, y2)