1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
|
import unittest
import warnings
from unittest.mock import Mock
import numpy as np
from Orange.data.util import scale, one_hot, SharedComputeValue
import Orange
class TestDataUtil(unittest.TestCase):
def test_scale(self):
np.testing.assert_equal(scale([0, 1, 2], -1, 1), [-1, 0, 1])
np.testing.assert_equal(scale([3, 3, 3]), [1, 1, 1])
np.testing.assert_equal(scale([.1, .5, np.nan]), [0, 1, np.nan])
np.testing.assert_equal(scale(np.array([])), np.array([]))
def test_one_hot(self):
np.testing.assert_equal(
one_hot([0, 1, 2, 1], int), [[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
[0, 1, 0]])
np.testing.assert_equal(one_hot([], int), np.zeros((0, 0), dtype=int))
class DummyPlus(SharedComputeValue):
def compute(self, data, shared_data):
return data.X[:, 0] + shared_data
class DummyTable(Orange.data.Table):
pass
class TestSharedComputeValue(unittest.TestCase):
def test_compat_compute_value(self):
data = Orange.data.Table("iris")
obj = DummyPlus(lambda data: 1.)
res = obj(data)
obj = lambda data: data.X[:, 0] + 1.
res2 = obj(data)
np.testing.assert_equal(res, res2)
def test_with_row_indices(self):
obj = DummyPlus(lambda data: 1.)
data = Orange.data.Table("iris")
domain = Orange.data.Domain([Orange.data.ContinuousVariable("cv", compute_value=obj)])
data1 = Orange.data.Table.from_table(domain, data)[:10]
data2 = Orange.data.Table.from_table(domain, data, range(10))
np.testing.assert_equal(data1.X, data2.X)
def test_single_call(self):
obj = DummyPlus(Mock(return_value=1))
self.assertEqual(obj.compute_shared.call_count, 0)
data = Orange.data.Table("iris")[45:55] # two classes
domain = Orange.data.Domain([at.copy(compute_value=obj)
for at in data.domain.attributes],
data.domain.class_vars)
Orange.data.Table.from_table(domain, data)
self.assertEqual(obj.compute_shared.call_count, 1)
ndata = Orange.data.Table.from_table(domain, data)
self.assertEqual(obj.compute_shared.call_count, 2)
#the learner performs imputation
c = Orange.classification.LogisticRegressionLearner()(ndata)
self.assertEqual(obj.compute_shared.call_count, 2)
c(data) #the new data should be converted with one call
self.assertEqual(obj.compute_shared.call_count, 3)
#test with descendants of table
DummyTable.from_table(c.domain, data)
self.assertEqual(obj.compute_shared.call_count, 4)
def test_compute_shared_eq_warning(self):
with warnings.catch_warnings(record=True) as warns:
DummyPlus(compute_shared=lambda *_: 42)
class Valid:
def __eq__(self, other):
pass
def __hash__(self):
pass
DummyPlus(compute_shared=Valid())
self.assertEqual(warns, [])
class Invalid:
pass
DummyPlus(compute_shared=Invalid())
self.assertNotEqual(warns, [])
with warnings.catch_warnings(record=True) as warns:
class MissingHash:
def __eq__(self, other):
pass
DummyPlus(compute_shared=MissingHash())
self.assertNotEqual(warns, [])
with warnings.catch_warnings(record=True) as warns:
class MissingEq:
def __hash__(self):
pass
DummyPlus(compute_shared=MissingEq())
self.assertNotEqual(warns, [])
with warnings.catch_warnings(record=True) as warns:
class Subclass(Valid):
pass
DummyPlus(compute_shared=Subclass())
self.assertNotEqual(warns, [])
def test_eq_hash(self):
x = Orange.data.ContinuousVariable("x")
y = Orange.data.ContinuousVariable("y")
x2 = Orange.data.ContinuousVariable("x")
assert x == x2
assert hash(x) == hash(x2)
assert x != y
assert hash(x) != hash(y)
c1 = SharedComputeValue(abs, x)
c2 = SharedComputeValue(abs, x2)
d = SharedComputeValue(abs, y)
e = SharedComputeValue(len, x)
self.assertNotEqual(c1, None)
self.assertEqual(c1, c2)
self.assertEqual(hash(c1), hash(c2))
self.assertNotEqual(c1, d)
self.assertNotEqual(hash(c1), hash(d))
self.assertNotEqual(c1, e)
self.assertNotEqual(hash(c1), hash(e))
def test_eq_hash_inheritance(self):
class NoFlag:
pass
class WithFlag:
InheritEq = True
x = Orange.data.ContinuousVariable("x")
self.assertWarnsRegex(
UserWarning, ".*define __eq__ and __hash__.*",
SharedComputeValue, NoFlag(), x)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
SharedComputeValue(WithFlag(), x)
self.assertEqual(len(w), 0)
|