File: test_seq_dataset.py

package info (click to toggle)
scikit-learn 0.20.2%2Bdfsg-6
  • links: PTS, VCS
  • area: main
  • in suites: buster
  • size: 51,036 kB
  • sloc: python: 108,171; ansic: 8,722; cpp: 5,651; makefile: 192; sh: 40
file content (83 lines) | stat: -rw-r--r-- 2,468 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# Author: Tom Dupre la Tour
#
# License: BSD 3 clause

import numpy as np
from numpy.testing import assert_array_equal
import scipy.sparse as sp

from sklearn.utils.seq_dataset import ArrayDataset, CSRDataset
from sklearn.datasets import load_iris

from sklearn.utils.testing import assert_equal

iris = load_iris()
X = iris.data.astype(np.float64)
y = iris.target.astype(np.float64)
X_csr = sp.csr_matrix(X)
sample_weight = np.arange(y.size, dtype=np.float64)


def assert_csr_equal(X, Y):
    X.eliminate_zeros()
    Y.eliminate_zeros()
    assert_equal(X.shape[0], Y.shape[0])
    assert_equal(X.shape[1], Y.shape[1])
    assert_array_equal(X.data, Y.data)
    assert_array_equal(X.indices, Y.indices)
    assert_array_equal(X.indptr, Y.indptr)


def test_seq_dataset():
    dataset1 = ArrayDataset(X, y, sample_weight, seed=42)
    dataset2 = CSRDataset(X_csr.data, X_csr.indptr, X_csr.indices,
                          y, sample_weight, seed=42)

    for dataset in (dataset1, dataset2):
        for i in range(5):
            # next sample
            xi_, yi, swi, idx = dataset._next_py()
            xi = sp.csr_matrix((xi_), shape=(1, X.shape[1]))

            assert_csr_equal(xi, X_csr[idx])
            assert_equal(yi, y[idx])
            assert_equal(swi, sample_weight[idx])

            # random sample
            xi_, yi, swi, idx = dataset._random_py()
            xi = sp.csr_matrix((xi_), shape=(1, X.shape[1]))

            assert_csr_equal(xi, X_csr[idx])
            assert_equal(yi, y[idx])
            assert_equal(swi, sample_weight[idx])


def test_seq_dataset_shuffle():
    dataset1 = ArrayDataset(X, y, sample_weight, seed=42)
    dataset2 = CSRDataset(X_csr.data, X_csr.indptr, X_csr.indices,
                          y, sample_weight, seed=42)

    # not shuffled
    for i in range(5):
        _, _, _, idx1 = dataset1._next_py()
        _, _, _, idx2 = dataset2._next_py()
        assert_equal(idx1, i)
        assert_equal(idx2, i)

    for i in range(5):
        _, _, _, idx1 = dataset1._random_py()
        _, _, _, idx2 = dataset2._random_py()
        assert_equal(idx1, idx2)

    seed = 77
    dataset1._shuffle_py(seed)
    dataset2._shuffle_py(seed)

    for i in range(5):
        _, _, _, idx1 = dataset1._next_py()
        _, _, _, idx2 = dataset2._next_py()
        assert_equal(idx1, idx2)

        _, _, _, idx1 = dataset1._random_py()
        _, _, _, idx2 = dataset2._random_py()
        assert_equal(idx1, idx2)