File: test_stack.py

package info (click to toggle)
orange3 3.40.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 15,908 kB
  • sloc: python: 162,745; ansic: 622; makefile: 322; sh: 93; cpp: 77
file content (41 lines) | stat: -rw-r--r-- 1,249 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import unittest

from Orange.data import Table
from Orange.ensembles.stack import StackedFitter, StackedLearner
from Orange.evaluation import CA, CrossValidation, MSE
from Orange.modelling import KNNLearner, TreeLearner


class TestStackedFitter(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.iris = Table('iris')
        cls.housing = Table('housing')

    def test_classification(self):
        sf = StackedFitter([TreeLearner(), KNNLearner()])
        cv = CrossValidation(k=3)
        results = cv(self.iris, [sf])
        ca = CA(results)
        self.assertGreater(ca, 0.9)

    def test_regression(self):
        sf = StackedFitter([TreeLearner(), KNNLearner()])
        cv = CrossValidation(k=3, random_state=0)
        results = cv(self.housing[:50], [sf, TreeLearner(), KNNLearner()])
        mse = MSE()(results)
        self.assertLess(mse[0], mse[1])
        self.assertLess(mse[0], mse[2])

    def test_timeseries(self):
        def aggregate(data):
            assert type(data) is Table

        class CustomTable(Table):
            pass

        sl = StackedLearner([TreeLearner(), KNNLearner()],
                             aggregate=aggregate)

        data = CustomTable(self.iris)
        sl(data)