File: owstack.py

package info (click to toggle)
orange3 3.40.0-2
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 15,912 kB
  • sloc: python: 162,745; ansic: 622; makefile: 322; sh: 93; cpp: 77
file content (81 lines) | stat: -rw-r--r-- 2,272 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from typing import List

from Orange.base import Learner
from Orange.data import Table
from Orange.ensembles.stack import StackedFitter
from Orange.widgets.settings import Setting
from Orange.widgets.utils.owlearnerwidget import OWBaseLearner
from Orange.widgets.widget import Input, MultiInput


class OWStackedLearner(OWBaseLearner):
    name = "Stacking"
    description = "Stack multiple models."
    icon = "icons/Stacking.svg"
    priority = 100

    LEARNER = StackedFitter

    learner_name = Setting("Stack")

    class Inputs(OWBaseLearner.Inputs):
        learners = MultiInput("Learners", Learner, filter_none=True)
        aggregate = Input("Aggregate", Learner)

    def __init__(self):
        self.learners: List[Learner] = []
        self.aggregate = None
        super().__init__()

    def add_main_layout(self):
        pass

    @Inputs.learners
    def set_learner(self, index: int, learner: Learner):
        self.learners[index] = learner
        self._invalidate()

    @Inputs.learners.insert
    def insert_learner(self, index, learner):
        self.learners.insert(index, learner)
        self._invalidate()

    @Inputs.learners.remove
    def remove_learner(self, index):
        self.learners.pop(index)
        self._invalidate()

    @Inputs.aggregate
    def set_aggregate(self, aggregate):
        self.aggregate = aggregate
        self._invalidate()

    def _invalidate(self):
        self.learner = self.model = None
        # ... and handleNewSignals will do the rest

    def create_learner(self):
        if not self.learners:
            return None
        params = {"preprocessors": self.preprocessors}
        if self.aggregate:
            params["aggregate"] = self.aggregate
        return self.LEARNER(tuple(self.learners), **params)

    def get_learner_parameters(self):
        return (("Base learners", [l.name for l in self.learners]),
                ("Aggregator",
                 self.aggregate.name if self.aggregate else 'default'))


if __name__ == "__main__":
    import sys
    from AnyQt.QtWidgets import QApplication

    a = QApplication(sys.argv)
    ow = OWStackedLearner()
    d = Table(sys.argv[1] if len(sys.argv) > 1 else 'iris')
    ow.set_data(d)
    ow.show()
    a.exec()
    ow.saveSettings()