1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
|
# Test methods with long descriptive names can omit docstrings
# pylint: disable=missing-docstring
import unittest
from unittest.mock import Mock
import numpy as np
import sklearn.tree as skl_tree
from sklearn.tree._tree import TREE_LEAF
from Orange.data import Table
from Orange.classification import SklTreeLearner, TreeLearner
from Orange.regression import SklTreeRegressionLearner
class TestSklTreeLearner(unittest.TestCase):
def test_classification(self):
table = Table('iris')
learn = SklTreeLearner()
clf = learn(table)
Z = clf(table)
self.assertTrue(np.all(table.Y.flatten() == Z))
def test_regression(self):
table = Table('housing')
learn = SklTreeRegressionLearner()
model = learn(table)
pred = model(table)
self.assertTrue(np.all(table.Y.flatten() == pred))
def test_supports_weights(self):
self.assertTrue(SklTreeRegressionLearner().supports_weights)
self.assertTrue(SklTreeLearner().supports_weights)
class TestTreeLearner(unittest.TestCase):
def test_uses_preprocessors(self):
iris = Table('iris')
mock_preprocessor = Mock(return_value=iris)
tree = TreeLearner(preprocessors=[mock_preprocessor])
tree(iris)
mock_preprocessor.assert_called_with(iris)
def test_supports_weights(self):
self.assertFalse(TreeLearner().supports_weights)
class TestDecisionTreeClassifier(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.iris = Table('iris')
def test_full_tree(self):
table = self.iris
clf = skl_tree.DecisionTreeClassifier()
clf = clf.fit(table.X, table.Y)
Z = clf.predict(table.X)
self.assertTrue(np.all(table.Y.flatten() == Z))
def test_min_samples_split(self):
table = self.iris
lim = 5
clf = skl_tree.DecisionTreeClassifier(min_samples_split=lim)
clf = clf.fit(table.X, table.Y)
t = clf.tree_
for i in range(t.node_count):
if t.children_left[i] != TREE_LEAF:
self.assertGreaterEqual(t.n_node_samples[i], lim)
def test_min_samples_leaf(self):
table = self.iris
lim = 5
clf = skl_tree.DecisionTreeClassifier(min_samples_leaf=lim)
clf = clf.fit(table.X, table.Y)
t = clf.tree_
for i in range(t.node_count):
if t.children_left[i] == TREE_LEAF:
self.assertGreaterEqual(t.n_node_samples[i], lim)
def test_max_leaf_nodes(self):
table = self.iris
lim = 5
clf = skl_tree.DecisionTreeClassifier(max_leaf_nodes=lim)
clf = clf.fit(table.X, table.Y)
t = clf.tree_
self.assertLessEqual(t.node_count, lim * 2 - 1)
def test_criterion(self):
table = self.iris
clf = skl_tree.DecisionTreeClassifier(criterion="entropy")
clf = clf.fit(table.X, table.Y)
def test_splitter(self):
table = self.iris
clf = skl_tree.DecisionTreeClassifier(splitter="random")
clf = clf.fit(table.X, table.Y)
def test_weights(self):
table = self.iris
clf = skl_tree.DecisionTreeClassifier(max_depth=2)
clf = clf.fit(table.X, table.Y)
clfw = skl_tree.DecisionTreeClassifier(max_depth=2)
clfw = clfw.fit(table.X, table.Y, sample_weight=np.arange(len(table)))
self.assertFalse(len(clf.tree_.feature) == len(clfw.tree_.feature) and
np.all(clf.tree_.feature == clfw.tree_.feature))
def test_impurity(self):
table = self.iris
clf = skl_tree.DecisionTreeClassifier()
clf = clf.fit(table.X, table.Y)
t = clf.tree_
for i in range(t.node_count):
if t.children_left[i] == TREE_LEAF:
self.assertEqual(t.impurity[i], 0)
else:
l, r = t.children_left[i], t.children_right[i]
child_impurity = min(t.impurity[l], t.impurity[r])
self.assertLessEqual(child_impurity, t.impurity[i])
def test_navigate_tree(self):
table = self.iris
clf = skl_tree.DecisionTreeClassifier(max_depth=1)
clf = clf.fit(table.X, table.Y.reshape(-1, 1))
t = clf.tree_
x = table.X[0]
if x[t.feature[0]] <= t.threshold[0]:
v = t.value[t.children_left[0]][0]
else:
v = t.value[t.children_right[0]][0]
self.assertEqual(np.argmax(v), clf.predict(table.X[:1]))
|