1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
|
# Test methods with long descriptive names can omit docstrings
# pylint: disable=missing-docstring
import pickle
import unittest
from Orange.base import SklLearner, Learner, Model
from Orange.data import Domain, Table
from Orange.preprocess import Discretize, Randomize, Continuize
from Orange.regression import LinearRegressionLearner
class DummyLearner(Learner):
def fit(self, *_, **__):
return unittest.mock.Mock()
class DummySklLearner(SklLearner):
def fit(self, *_, **__):
return unittest.mock.Mock()
class DummyLearnerPP(Learner):
preprocessors = (Randomize(),)
class TestLearner(unittest.TestCase):
def test_uses_default_preprocessors_unless_custom_pps_specified(self):
"""Learners should use their default preprocessors unless custom
preprocessors were passed in to the constructor"""
learner = DummyLearner()
self.assertEqual(
type(learner).preprocessors, tuple(learner.active_preprocessors),
'Learner should use default preprocessors, unless preprocessors '
'were specified in init')
def test_overrides_custom_preprocessors(self):
"""Passing preprocessors to the learner constructor should override the
default preprocessors defined on the learner"""
pp = Discretize()
learner = DummyLearnerPP(preprocessors=(pp,))
self.assertEqual(
tuple(learner.active_preprocessors), (pp,),
'Learner should override default preprocessors when specified in '
'constructor')
def test_use_default_preprocessors_property(self):
"""We can specify that we want to use default preprocessors despite
passing our own ones in the constructor"""
learner = DummyLearnerPP(preprocessors=(Discretize(),))
learner.use_default_preprocessors = True
preprocessors = list(learner.active_preprocessors)
self.assertEqual(
len(preprocessors), 2,
'Learner did not properly insert custom preprocessor into '
'preprocessor list')
self.assertIsInstance(
preprocessors[0], Discretize,
'Custom preprocessor was inserted in incorrect order')
self.assertIsInstance(preprocessors[1], Randomize)
def test_preprocessors_can_be_passed_in_as_non_iterable(self):
"""For convenience, we can pass a single preprocessor instance"""
pp = Discretize()
learner = DummyLearnerPP(preprocessors=pp)
self.assertEqual(
tuple(learner.active_preprocessors), (pp,),
'Preprocessors should be able to be passed in as single object '
'as well as an iterable object')
def test_preprocessors_can_be_passed_in_as_generator(self):
"""Since we support iterables, we should support generators as well"""
pp = (Discretize(),)
learner = DummyLearnerPP(p for p in pp)
self.assertEqual(
tuple(learner.active_preprocessors), pp,
'Preprocessors should be able to be passed in as single object '
'as well as an iterable object')
def test_callback(self):
callback = unittest.mock.Mock()
learner = DummyLearner(preprocessors=[Discretize(), Randomize()])
learner(Table("iris"), callback)
args = [x[0][0] for x in callback.call_args_list]
self.assertEqual(min(args), 0)
self.assertEqual(max(args), 1)
self.assertListEqual(args, sorted(args))
class TestSklLearner(unittest.TestCase):
def test_linreg(self):
self.assertTrue(
LinearRegressionLearner().supports_weights,
"Either LinearRegression no longer supports weighted tables or "
"SklLearner.supports_weights is out-of-date.")
def test_callback(self):
callback = unittest.mock.Mock()
learner = DummySklLearner(preprocessors=[Continuize(), Randomize()])
learner(Table("iris"), callback)
args = [x[0][0] for x in callback.call_args_list]
self.assertEqual(min(args), 0)
self.assertEqual(max(args), 1)
self.assertListEqual(args, sorted(args))
class TestModel(unittest.TestCase):
def test_pickle(self):
"""Make sure data is not saved when pickling a model."""
model = Model(Domain([]))
model.original_data = [1, 2, 3]
model2 = pickle.loads(pickle.dumps(model))
self.assertEqual(model.domain, model2.domain)
self.assertEqual(model.original_data, [1, 2, 3])
self.assertEqual(model2.original_data, None)
if __name__ == "__main__":
unittest.main()
|