1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
|
from ..externals import six
class TransformerMixin(object):
"""Mixin class for all transformers in scikit-learn."""
def fit_transform(self, X, y=None, **fit_params):
"""Fit to data, then transform it.
Fits transformer to X and y with optional parameters fit_params
and returns a transformed version of X.
Parameters
----------
X : numpy array of shape [n_samples, n_features]
Training set.
y : numpy array of shape [n_samples]
Target values.
Returns
-------
X_new : numpy array of shape [n_samples, n_features_new]
Transformed array.
"""
# non-optimized default implementation; override when a better
# method is possible for a given clustering algorithm
if y is None:
# fit method of arity 1 (unsupervised transformation)
return self.fit(X, **fit_params).transform(X)
else:
# fit method of arity 2 (supervised transformation)
return self.fit(X, y, **fit_params).transform(X)
class EstimatorMixin(object):
"""Mixin class for estimators."""
def get_params(self):
"""Get the estimator params."""
pass
def set_params(self, **params):
"""Set parameters (mimics sklearn API)."""
if not params:
return self
valid_params = self.get_params(deep=True)
for key, value in six.iteritems(params):
split = key.split('__', 1)
if len(split) > 1:
# nested objects case
name, sub_name = split
if name not in valid_params:
raise ValueError('Invalid parameter %s for estimator %s. '
'Check the list of available parameters '
'with `estimator.get_params().keys()`.' %
(name, self))
sub_object = valid_params[name]
sub_object.set_params(**{sub_name: value})
else:
# simple objects case
if key not in valid_params:
raise ValueError('Invalid parameter %s for estimator %s. '
'Check the list of available parameters '
'with `estimator.get_params().keys()`.' %
(key, self.__class__.__name__))
setattr(self, key, value)
return self
|