1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
|
"""
Base class for ensemble-based estimators.
"""
# Authors: Gilles Louppe
# License: BSD 3 clause
import numpy as np
import numbers
from ..base import clone
from ..base import BaseEstimator
from ..base import MetaEstimatorMixin
from ..utils import check_random_state
from ..utils._joblib import effective_n_jobs
from ..externals import six
from abc import ABCMeta, abstractmethod
MAX_RAND_SEED = np.iinfo(np.int32).max
def _set_random_states(estimator, random_state=None):
"""Sets fixed random_state parameters for an estimator
Finds all parameters ending ``random_state`` and sets them to integers
derived from ``random_state``.
Parameters
----------
estimator : estimator supporting get/set_params
Estimator with potential randomness managed by random_state
parameters.
random_state : int, RandomState instance or None, optional (default=None)
If int, random_state is the seed used by the random number generator;
If RandomState instance, random_state is the random number generator;
If None, the random number generator is the RandomState instance used
by `np.random`.
Notes
-----
This does not necessarily set *all* ``random_state`` attributes that
control an estimator's randomness, only those accessible through
``estimator.get_params()``. ``random_state``s not controlled include
those belonging to:
* cross-validation splitters
* ``scipy.stats`` rvs
"""
random_state = check_random_state(random_state)
to_set = {}
for key in sorted(estimator.get_params(deep=True)):
if key == 'random_state' or key.endswith('__random_state'):
to_set[key] = random_state.randint(MAX_RAND_SEED)
if to_set:
estimator.set_params(**to_set)
class BaseEnsemble(six.with_metaclass(ABCMeta, BaseEstimator,
MetaEstimatorMixin)):
"""Base class for all ensemble classes.
Warning: This class should not be used directly. Use derived classes
instead.
Parameters
----------
base_estimator : object, optional (default=None)
The base estimator from which the ensemble is built.
n_estimators : integer
The number of estimators in the ensemble.
estimator_params : list of strings
The list of attributes to use as parameters when instantiating a
new base estimator. If none are given, default parameters are used.
Attributes
----------
base_estimator_ : estimator
The base estimator from which the ensemble is grown.
estimators_ : list of estimators
The collection of fitted base estimators.
"""
@abstractmethod
def __init__(self, base_estimator, n_estimators=10,
estimator_params=tuple()):
# Set parameters
self.base_estimator = base_estimator
self.n_estimators = n_estimators
self.estimator_params = estimator_params
# Don't instantiate estimators now! Parameters of base_estimator might
# still change. Eg., when grid-searching with the nested object syntax.
# self.estimators_ needs to be filled by the derived classes in fit.
def _validate_estimator(self, default=None):
"""Check the estimator and the n_estimator attribute, set the
`base_estimator_` attribute."""
if not isinstance(self.n_estimators, (numbers.Integral, np.integer)):
raise ValueError("n_estimators must be an integer, "
"got {0}.".format(type(self.n_estimators)))
if self.n_estimators <= 0:
raise ValueError("n_estimators must be greater than zero, "
"got {0}.".format(self.n_estimators))
if self.base_estimator is not None:
self.base_estimator_ = self.base_estimator
else:
self.base_estimator_ = default
if self.base_estimator_ is None:
raise ValueError("base_estimator cannot be None")
def _make_estimator(self, append=True, random_state=None):
"""Make and configure a copy of the `base_estimator_` attribute.
Warning: This method should be used to properly instantiate new
sub-estimators.
"""
estimator = clone(self.base_estimator_)
estimator.set_params(**dict((p, getattr(self, p))
for p in self.estimator_params))
if random_state is not None:
_set_random_states(estimator, random_state)
if append:
self.estimators_.append(estimator)
return estimator
def __len__(self):
"""Returns the number of estimators in the ensemble."""
return len(self.estimators_)
def __getitem__(self, index):
"""Returns the index'th estimator in the ensemble."""
return self.estimators_[index]
def __iter__(self):
"""Returns iterator over estimators in the ensemble."""
return iter(self.estimators_)
def _partition_estimators(n_estimators, n_jobs):
"""Private function used to partition estimators between jobs."""
# Compute the number of jobs
n_jobs = min(effective_n_jobs(n_jobs), n_estimators)
# Partition estimators between jobs
n_estimators_per_job = np.full(n_jobs, n_estimators // n_jobs,
dtype=np.int)
n_estimators_per_job[:n_estimators % n_jobs] += 1
starts = np.cumsum(n_estimators_per_job)
return n_jobs, n_estimators_per_job.tolist(), [0] + starts.tolist()
|