File: test_tree.py

package info (click to toggle)
orange3 3.40.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 15,908 kB
  • sloc: python: 162,745; ansic: 622; makefile: 322; sh: 93; cpp: 77
file content (132 lines) | stat: -rw-r--r-- 4,504 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# Test methods with long descriptive names can omit docstrings
# pylint: disable=missing-docstring

import unittest
from unittest.mock import Mock

import numpy as np
import sklearn.tree as skl_tree
from sklearn.tree._tree import TREE_LEAF

from Orange.data import Table
from Orange.classification import SklTreeLearner, TreeLearner
from Orange.regression import SklTreeRegressionLearner


class TestSklTreeLearner(unittest.TestCase):
    def test_classification(self):
        table = Table('iris')
        learn = SklTreeLearner()
        clf = learn(table)
        Z = clf(table)
        self.assertTrue(np.all(table.Y.flatten() == Z))

    def test_regression(self):
        table = Table('housing')
        learn = SklTreeRegressionLearner()
        model = learn(table)
        pred = model(table)
        self.assertTrue(np.all(table.Y.flatten() == pred))

    def test_supports_weights(self):
        self.assertTrue(SklTreeRegressionLearner().supports_weights)
        self.assertTrue(SklTreeLearner().supports_weights)


class TestTreeLearner(unittest.TestCase):
    def test_uses_preprocessors(self):
        iris = Table('iris')
        mock_preprocessor = Mock(return_value=iris)

        tree = TreeLearner(preprocessors=[mock_preprocessor])
        tree(iris)
        mock_preprocessor.assert_called_with(iris)

    def test_supports_weights(self):
        self.assertFalse(TreeLearner().supports_weights)


class TestDecisionTreeClassifier(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.iris = Table('iris')

    def test_full_tree(self):
        table = self.iris
        clf = skl_tree.DecisionTreeClassifier()
        clf = clf.fit(table.X, table.Y)
        Z = clf.predict(table.X)
        self.assertTrue(np.all(table.Y.flatten() == Z))

    def test_min_samples_split(self):
        table = self.iris
        lim = 5
        clf = skl_tree.DecisionTreeClassifier(min_samples_split=lim)
        clf = clf.fit(table.X, table.Y)
        t = clf.tree_
        for i in range(t.node_count):
            if t.children_left[i] != TREE_LEAF:
                self.assertGreaterEqual(t.n_node_samples[i], lim)

    def test_min_samples_leaf(self):
        table = self.iris
        lim = 5
        clf = skl_tree.DecisionTreeClassifier(min_samples_leaf=lim)
        clf = clf.fit(table.X, table.Y)
        t = clf.tree_
        for i in range(t.node_count):
            if t.children_left[i] == TREE_LEAF:
                self.assertGreaterEqual(t.n_node_samples[i], lim)

    def test_max_leaf_nodes(self):
        table = self.iris
        lim = 5
        clf = skl_tree.DecisionTreeClassifier(max_leaf_nodes=lim)
        clf = clf.fit(table.X, table.Y)
        t = clf.tree_
        self.assertLessEqual(t.node_count, lim * 2 - 1)

    def test_criterion(self):
        table = self.iris
        clf = skl_tree.DecisionTreeClassifier(criterion="entropy")
        clf = clf.fit(table.X, table.Y)

    def test_splitter(self):
        table = self.iris
        clf = skl_tree.DecisionTreeClassifier(splitter="random")
        clf = clf.fit(table.X, table.Y)

    def test_weights(self):
        table = self.iris
        clf = skl_tree.DecisionTreeClassifier(max_depth=2)
        clf = clf.fit(table.X, table.Y)
        clfw = skl_tree.DecisionTreeClassifier(max_depth=2)
        clfw = clfw.fit(table.X, table.Y, sample_weight=np.arange(len(table)))
        self.assertFalse(len(clf.tree_.feature) == len(clfw.tree_.feature) and
                         np.all(clf.tree_.feature == clfw.tree_.feature))

    def test_impurity(self):
        table = self.iris
        clf = skl_tree.DecisionTreeClassifier()
        clf = clf.fit(table.X, table.Y)
        t = clf.tree_
        for i in range(t.node_count):
            if t.children_left[i] == TREE_LEAF:
                self.assertEqual(t.impurity[i], 0)
            else:
                l, r = t.children_left[i], t.children_right[i]
                child_impurity = min(t.impurity[l], t.impurity[r])
                self.assertLessEqual(child_impurity, t.impurity[i])

    def test_navigate_tree(self):
        table = self.iris
        clf = skl_tree.DecisionTreeClassifier(max_depth=1)
        clf = clf.fit(table.X, table.Y.reshape(-1, 1))
        t = clf.tree_

        x = table.X[0]
        if x[t.feature[0]] <= t.threshold[0]:
            v = t.value[t.children_left[0]][0]
        else:
            v = t.value[t.children_right[0]][0]
        self.assertEqual(np.argmax(v), clf.predict(table.X[:1]))