1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
|
import unittest
from Orange.data import Table
from Orange.ensembles.stack import StackedFitter, StackedLearner
from Orange.evaluation import CA, CrossValidation, MSE
from Orange.modelling import KNNLearner, TreeLearner
class TestStackedFitter(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.iris = Table('iris')
cls.housing = Table('housing')
def test_classification(self):
sf = StackedFitter([TreeLearner(), KNNLearner()])
cv = CrossValidation(k=3)
results = cv(self.iris, [sf])
ca = CA(results)
self.assertGreater(ca, 0.9)
def test_regression(self):
sf = StackedFitter([TreeLearner(), KNNLearner()])
cv = CrossValidation(k=3, random_state=0)
results = cv(self.housing[:50], [sf, TreeLearner(), KNNLearner()])
mse = MSE()(results)
self.assertLess(mse[0], mse[1])
self.assertLess(mse[0], mse[2])
def test_timeseries(self):
def aggregate(data):
assert type(data) is Table
class CustomTable(Table):
pass
sl = StackedLearner([TreeLearner(), KNNLearner()],
aggregate=aggregate)
data = CustomTable(self.iris)
sl(data)
|