1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
|
import sklearn.neural_network as skl_nn
from Orange.base import NNBase
from Orange.classification import SklLearner
__all__ = ["NNClassificationLearner"]
class NIterCallbackMixin:
orange_callback = None
@property
def n_iter_(self):
return self.__orange_n_iter
@n_iter_.setter
def n_iter_(self, v):
self.__orange_n_iter = v
if self.orange_callback:
self.orange_callback(v)
class MLPClassifierWCallback(skl_nn.MLPClassifier, NIterCallbackMixin):
pass
class NNClassificationLearner(NNBase, SklLearner):
__wraps__ = MLPClassifierWCallback
supports_weights = False
def _initialize_wrapped(self):
clf = SklLearner._initialize_wrapped(self)
clf.orange_callback = getattr(self, "callback", None)
return clf
|