1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
|
from caffe2.python import core
import caffe2.python.hypothesis_test_util as hu
import caffe2.python.serialized_test.serialized_test_util as serial
from hypothesis import given, settings
import hypothesis.strategies as st
import numpy as np
import unittest
def mux(select, left, right):
return [np.vectorize(lambda c, x, y: x if c else y)(select, left, right)]
def rowmux(select_vec, left, right):
select = [[s] * len(left) for s in select_vec]
return mux(select, left, right)
class TestWhere(serial.SerializedTestCase):
def test_reference(self):
self.assertTrue((
np.array([1, 4]) == mux([True, False],
[1, 2],
[3, 4])[0]
).all())
self.assertTrue((
np.array([[1], [4]]) == mux([[True], [False]],
[[1], [2]],
[[3], [4]])[0]
).all())
@given(N=st.integers(min_value=1, max_value=10),
engine=st.sampled_from(["", "CUDNN"]),
**hu.gcs_cpu_only)
@settings(deadline=10000)
def test_where(self, N, gc, dc, engine):
C = np.random.rand(N).astype(bool)
X = np.random.rand(N).astype(np.float32)
Y = np.random.rand(N).astype(np.float32)
op = core.CreateOperator("Where", ["C", "X", "Y"], ["Z"], engine=engine)
self.assertDeviceChecks(dc, op, [C, X, Y], [0])
self.assertReferenceChecks(gc, op, [C, X, Y], mux)
@given(N=st.integers(min_value=1, max_value=10),
engine=st.sampled_from(["", "CUDNN"]),
**hu.gcs_cpu_only)
@settings(deadline=10000)
def test_where_dim2(self, N, gc, dc, engine):
C = np.random.rand(N, N).astype(bool)
X = np.random.rand(N, N).astype(np.float32)
Y = np.random.rand(N, N).astype(np.float32)
op = core.CreateOperator("Where", ["C", "X", "Y"], ["Z"], engine=engine)
self.assertDeviceChecks(dc, op, [C, X, Y], [0])
self.assertReferenceChecks(gc, op, [C, X, Y], mux)
class TestRowWhere(hu.HypothesisTestCase):
def test_reference(self):
self.assertTrue((
np.array([1, 2]) == rowmux([True],
[1, 2],
[3, 4])[0]
).all())
self.assertTrue((
np.array([[1, 2], [7, 8]]) == rowmux([True, False],
[[1, 2], [3, 4]],
[[5, 6], [7, 8]])[0]
).all())
@given(N=st.integers(min_value=1, max_value=10),
engine=st.sampled_from(["", "CUDNN"]),
**hu.gcs_cpu_only)
def test_rowwhere(self, N, gc, dc, engine):
C = np.random.rand(N).astype(bool)
X = np.random.rand(N).astype(np.float32)
Y = np.random.rand(N).astype(np.float32)
op = core.CreateOperator(
"Where",
["C", "X", "Y"],
["Z"],
broadcast_on_rows=True,
engine=engine,
)
self.assertDeviceChecks(dc, op, [C, X, Y], [0])
self.assertReferenceChecks(gc, op, [C, X, Y], mux)
@given(N=st.integers(min_value=1, max_value=10),
engine=st.sampled_from(["", "CUDNN"]),
**hu.gcs_cpu_only)
def test_rowwhere_dim2(self, N, gc, dc, engine):
C = np.random.rand(N).astype(bool)
X = np.random.rand(N, N).astype(np.float32)
Y = np.random.rand(N, N).astype(np.float32)
op = core.CreateOperator(
"Where",
["C", "X", "Y"],
["Z"],
broadcast_on_rows=True,
engine=engine,
)
self.assertDeviceChecks(dc, op, [C, X, Y], [0])
self.assertReferenceChecks(gc, op, [C, X, Y], rowmux)
class TestIsMemberOf(serial.SerializedTestCase):
@given(N=st.integers(min_value=1, max_value=10),
engine=st.sampled_from(["", "CUDNN"]),
**hu.gcs_cpu_only)
@settings(deadline=10000)
def test_is_member_of(self, N, gc, dc, engine):
X = np.random.randint(10, size=N).astype(np.int64)
values = [0, 3, 4, 6, 8]
op = core.CreateOperator(
"IsMemberOf",
["X"],
["Y"],
value=np.array(values),
engine=engine,
)
self.assertDeviceChecks(dc, op, [X], [0])
values = set(values)
def test(x):
return [np.vectorize(lambda x: x in values)(x)]
self.assertReferenceChecks(gc, op, [X], test)
if __name__ == "__main__":
unittest.main()
|