1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
|
# pylint: disable=protected-access
import numpy as np
from AnyQt.QtCore import Qt
from Orange.base import Model
from Orange.data import Table
from Orange.widgets.model.owtree import OWTreeLearner
from Orange.widgets.tests.base import (
DefaultParameterMapping,
ParameterMapping,
WidgetLearnerTestMixin,
WidgetTest,
)
class TestOWClassificationTree(WidgetTest, WidgetLearnerTestMixin):
def setUp(self):
self.widget = self.create_widget(
OWTreeLearner, stored_settings={"auto_apply": False})
self.init()
self.model_class = Model
self.parameters = [
ParameterMapping.from_attribute(self.widget, 'max_depth'),
ParameterMapping.from_attribute(
self.widget, 'min_internal', 'min_samples_split'),
ParameterMapping.from_attribute(
self.widget, 'min_leaf', 'min_samples_leaf')]
# NB. sufficient_majority is divided by 100, so it cannot be tested
# like this
self.checks = [sb.gui_element.cbox for sb in self.parameters]
def test_parameters_unchecked(self):
"""Check learner and model for various values of all parameters
when pruning parameters are not checked
"""
for cb in self.checks:
cb.setCheckState(Qt.Unchecked)
self.parameters = [DefaultParameterMapping(par.name, val)
for par, val in zip(self.parameters, (None, 2, 1))]
self.test_parameters()
def test_sparse_data_classification(self):
"""
Classification Tree can handle sparse data.
GH-2430
"""
table1 = Table("iris")
self.send_signal(self.widget.Inputs.data, table1)
model_dense = self.get_output(self.widget.Outputs.model)
table2 = Table("iris").to_sparse()
self.send_signal(self.widget.Inputs.data, table2)
model_sparse = self.get_output(self.widget.Outputs.model)
self.assertTrue(np.array_equal(model_dense._code, model_sparse._code))
self.assertTrue(np.array_equal(model_dense._values, model_sparse._values))
def test_sparse_data_regression(self):
"""
Regression Tree can handle sparse data.
GH-2497
"""
table1 = Table("housing")
self.send_signal(self.widget.Inputs.data, table1)
model_dense = self.get_output(self.widget.Outputs.model)
table2 = Table("housing").to_sparse()
self.send_signal(self.widget.Inputs.data, table2)
model_sparse = self.get_output(self.widget.Outputs.model)
self.assertTrue(np.array_equal(model_dense._code, model_sparse._code))
self.assertTrue(np.array_equal(model_dense._values, model_sparse._values))
|