1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
|
"""
Testing for the base module (sklearn.ensemble.base).
"""
# Authors: Gilles Louppe
# License: BSD 3
from numpy.testing import assert_equal
from nose.tools import assert_raises, assert_true
from sklearn.ensemble import BaseEnsemble
from sklearn.tree import DecisionTreeClassifier
def test_base():
"""Check BaseEnsemble methods."""
tree = DecisionTreeClassifier()
ensemble = BaseEnsemble(base_estimator=tree, n_estimators=3)
ensemble._make_estimator()
ensemble._make_estimator()
ensemble._make_estimator()
ensemble._make_estimator(append=False)
assert_equal(3, len(ensemble))
assert_equal(3, len(ensemble.estimators_))
assert_true(isinstance(ensemble[0], DecisionTreeClassifier))
def test_error():
"""Check that proper errors are triggered."""
def instantiate(class_name, **params):
return class_name(**params)
base_estimator = object()
assert_raises(TypeError, instantiate, class_name=BaseEnsemble,
base_estimator=base_estimator, n_estimators=1)
base_estimator = DecisionTreeClassifier()
assert_raises(ValueError, instantiate, class_name=BaseEnsemble,
base_estimator=base_estimator, n_estimators=-1)
|