File: sklearn_test.py

package info (click to toggle)
jsonpickle 3.0.0%2Bdfsg1-1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 1,184 kB
  • sloc: python: 6,088; javascript: 654; makefile: 90; sh: 17
file content (57 lines) | stat: -rw-r--r-- 1,687 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from __future__ import absolute_import, division, unicode_literals

import pytest

try:
    import numpy as np
    from sklearn.tree import DecisionTreeClassifier
except ImportError:
    pytest.skip('sklearn is not available', allow_module_level=True)

import jsonpickle
import jsonpickle.ext.numpy


@pytest.fixture(scope='module', autouse=True)
def numpy_extension():
    """Initialize the numpy extension for this test module"""
    jsonpickle.ext.numpy.register_handlers()
    yield  # control to the test function.
    jsonpickle.ext.numpy.unregister_handlers()


def test_decision_tree():
    #  Create data.
    np.random.seed(13)
    x_values = np.random.randint(low=0, high=10, size=12)
    x = x_values.reshape(4, 3)
    y_values = np.random.randint(low=0, high=2, size=4)
    y = y_values.reshape(-1, 1)

    # train model
    classifier = DecisionTreeClassifier(max_depth=1)
    classifier.fit(x, y)

    # freeze and thaw
    pickler = jsonpickle.pickler.Pickler()
    unpickler = jsonpickle.unpickler.Unpickler()
    actual = unpickler.restore(pickler.flatten(classifier))

    assert isinstance(actual, classifier.__class__)
    if hasattr(classifier, 'tree_'):
        assert isinstance(actual.tree_, classifier.tree_.__class__)

    # predict from thawed
    array_values = np.array([1, 2, 3])
    array = array_values.reshape(1, -1)
    prediction = actual.predict(array)
    assert prediction[0] == 1

    assert actual.max_depth == classifier.max_depth
    assert actual.score(x, y) == classifier.score(x, y)
    if hasattr(classifier, 'get_depth'):
        assert actual.get_depth() == classifier.get_depth()


if __name__ == '__main__':
    pytest.main([__file__])