1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
|
from caffe2.python import core, workspace
from caffe2.python.test_util import TestCase
import numpy as np
import tempfile
class TestIndexOps(TestCase):
def _test_index_ops(self, entries, dtype, index_create_op):
workspace.RunOperatorOnce(core.CreateOperator(
index_create_op,
[],
['index'],
max_elements=10))
my_entries = np.array(
[entries[0], entries[1], entries[2]], dtype=dtype)
workspace.FeedBlob('entries', my_entries)
workspace.RunOperatorOnce(core.CreateOperator(
'IndexLoad',
['index', 'entries'],
['index']))
query1 = np.array(
[entries[0], entries[3], entries[0], entries[4]],
dtype=dtype)
workspace.FeedBlob('query1', query1)
workspace.RunOperatorOnce(core.CreateOperator(
'IndexGet',
['index', 'query1'],
['result1']))
result1 = workspace.FetchBlob('result1')
np.testing.assert_array_equal([1, 4, 1, 5], result1)
workspace.RunOperatorOnce(core.CreateOperator(
'IndexFreeze',
['index'],
['index']))
query2 = np.array(
[entries[5], entries[4], entries[0], entries[6], entries[7]],
dtype=dtype)
workspace.FeedBlob('query2', query2)
workspace.RunOperatorOnce(core.CreateOperator(
'IndexGet',
['index', 'query2'],
['result2']))
result2 = workspace.FetchBlob('result2')
np.testing.assert_array_equal([0, 5, 1, 0, 0], result2)
workspace.RunOperatorOnce(core.CreateOperator(
'IndexSize',
['index'],
['index_size']))
size = workspace.FetchBlob('index_size')
self.assertEquals(size, 6)
workspace.RunOperatorOnce(core.CreateOperator(
'IndexStore',
['index'],
['stored_entries']))
stored_actual = workspace.FetchBlob('stored_entries')
new_entries = np.array([entries[3], entries[4]], dtype=dtype)
expected = np.concatenate((my_entries, new_entries))
if dtype is str:
# we'll always get bytes back from Caffe2
expected = np.array([
x.item().encode('utf-8') if isinstance(x, np.str_) else x
for x in expected
], dtype=object)
np.testing.assert_array_equal(expected, stored_actual)
workspace.RunOperatorOnce(core.CreateOperator(
index_create_op,
[],
['index2']))
workspace.RunOperatorOnce(core.CreateOperator(
'IndexLoad',
['index2', 'stored_entries'],
['index2'],
skip_first_entry=1))
workspace.RunOperatorOnce(core.CreateOperator(
'IndexSize',
['index2'],
['index2_size']))
index2_size = workspace.FetchBlob('index2_size')
self.assertEquals(index2_size, 5)
# test serde
with tempfile.NamedTemporaryFile() as tmp:
workspace.RunOperatorOnce(core.CreateOperator(
'Save',
['index'],
[],
absolute_path=1,
db_type='minidb',
db=tmp.name))
# frees up the blob
workspace.FeedBlob('index', np.array([]))
# reloads the index
workspace.RunOperatorOnce(core.CreateOperator(
'Load',
[],
['index'],
absolute_path=1,
db_type='minidb',
db=tmp.name))
query3 = np.array(
[entries[0], entries[3], entries[0], entries[4], entries[4]],
dtype=dtype)
workspace.FeedBlob('query3', query3)
workspace.RunOperatorOnce(core.CreateOperator(
'IndexGet', ['index', 'query3'], ['result3']))
result3 = workspace.FetchBlob('result3')
np.testing.assert_array_equal([1, 4, 1, 5, 5], result3)
def test_string_index_ops(self):
self._test_index_ops([
'entry1', 'entry2', 'entry3', 'new_entry1',
'new_entry2', 'miss1', 'miss2', 'miss3',
], str, 'StringIndexCreate')
def test_int_index_ops(self):
self._test_index_ops(list(range(8)), np.int32, 'IntIndexCreate')
def test_long_index_ops(self):
self._test_index_ops(list(range(8)), np.int64, 'LongIndexCreate')
if __name__ == "__main__":
import unittest
unittest.main()
|