1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
|
from typing import List
from Orange.base import Learner
from Orange.data import Table
from Orange.ensembles.stack import StackedFitter
from Orange.widgets.settings import Setting
from Orange.widgets.utils.owlearnerwidget import OWBaseLearner
from Orange.widgets.widget import Input, MultiInput
class OWStackedLearner(OWBaseLearner):
name = "Stacking"
description = "Stack multiple models."
icon = "icons/Stacking.svg"
priority = 100
LEARNER = StackedFitter
learner_name = Setting("Stack")
class Inputs(OWBaseLearner.Inputs):
learners = MultiInput("Learners", Learner, filter_none=True)
aggregate = Input("Aggregate", Learner)
def __init__(self):
self.learners: List[Learner] = []
self.aggregate = None
super().__init__()
def add_main_layout(self):
pass
@Inputs.learners
def set_learner(self, index: int, learner: Learner):
self.learners[index] = learner
self._invalidate()
@Inputs.learners.insert
def insert_learner(self, index, learner):
self.learners.insert(index, learner)
self._invalidate()
@Inputs.learners.remove
def remove_learner(self, index):
self.learners.pop(index)
self._invalidate()
@Inputs.aggregate
def set_aggregate(self, aggregate):
self.aggregate = aggregate
self._invalidate()
def _invalidate(self):
self.learner = self.model = None
# ... and handleNewSignals will do the rest
def create_learner(self):
if not self.learners:
return None
params = {"preprocessors": self.preprocessors}
if self.aggregate:
params["aggregate"] = self.aggregate
return self.LEARNER(tuple(self.learners), **params)
def get_learner_parameters(self):
return (("Base learners", [l.name for l in self.learners]),
("Aggregator",
self.aggregate.name if self.aggregate else 'default'))
if __name__ == "__main__":
import sys
from AnyQt.QtWidgets import QApplication
a = QApplication(sys.argv)
ow = OWStackedLearner()
d = Table(sys.argv[1] if len(sys.argv) > 1 else 'iris')
ow.set_data(d)
ow.show()
a.exec()
ow.saveSettings()
|