File: bindings_test_replace.py

package info (click to toggle)
hnswlib 0.8.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 628 kB
  • sloc: cpp: 4,809; python: 1,113; makefile: 32; sh: 18
file content (245 lines) | stat: -rw-r--r-- 9,873 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
import os
import pickle
import unittest

import numpy as np

import hnswlib


class RandomSelfTestCase(unittest.TestCase):
    def testRandomSelf(self):
        """
            Tests if replace of deleted elements works correctly
            Tests serialization of the index with replaced elements
        """
        dim = 16
        num_elements = 5000
        max_num_elements = 2 * num_elements

        recall_threshold = 0.98

        # Generating sample data
        print("Generating data")
        # batch 1
        first_id = 0
        last_id = num_elements
        labels1 = np.arange(first_id, last_id)
        data1 = np.float32(np.random.random((num_elements, dim)))
        # batch 2
        first_id += num_elements
        last_id += num_elements
        labels2 = np.arange(first_id, last_id)
        data2 = np.float32(np.random.random((num_elements, dim)))
        # batch 3
        first_id += num_elements
        last_id += num_elements
        labels3 = np.arange(first_id, last_id)
        data3 = np.float32(np.random.random((num_elements, dim)))
        # batch 4
        first_id += num_elements
        last_id += num_elements
        labels4 = np.arange(first_id, last_id)
        data4 = np.float32(np.random.random((num_elements, dim)))

        # Declaring index
        hnsw_index = hnswlib.Index(space='l2', dim=dim)
        hnsw_index.init_index(max_elements=max_num_elements, ef_construction=200, M=16, allow_replace_deleted=True)

        hnsw_index.set_ef(100)
        hnsw_index.set_num_threads(4)

        # Add batch 1 and 2
        print("Adding batch 1")
        hnsw_index.add_items(data1, labels1)
        print("Adding batch 2")
        hnsw_index.add_items(data2, labels2)  # maximum number of elements is reached

        # Delete nearest neighbors of batch 2
        print("Deleting neighbors of batch 2")
        labels2_deleted, _ = hnsw_index.knn_query(data2, k=1)
        # delete probable duplicates from nearest neighbors
        labels2_deleted_no_dup = set(labels2_deleted.flatten())
        num_duplicates = len(labels2_deleted) - len(labels2_deleted_no_dup)
        for l in labels2_deleted_no_dup:
            hnsw_index.mark_deleted(l)
        labels1_found, _ = hnsw_index.knn_query(data1, k=1)
        items = hnsw_index.get_items(labels1_found)
        diff_with_gt_labels = np.mean(np.abs(data1 - items))
        self.assertAlmostEqual(diff_with_gt_labels, 0, delta=1e-3)

        labels2_after, _ = hnsw_index.knn_query(data2, k=1)
        for la in labels2_after:
            if la[0] in labels2_deleted_no_dup:
                print(f"Found deleted label {la[0]} during knn search")
                self.assertTrue(False)
        print("All the neighbors of data2 are removed")

        # Replace deleted elements
        print("Inserting batch 3 by replacing deleted elements")
        # Maximum number of elements is reached therefore we cannot add new items
        # but we can replace the deleted ones
        # Note: there may be less than num_elements elements.
        #       As we could delete less than num_elements because of duplicates
        labels3_tr = labels3[0:labels3.shape[0] - num_duplicates]
        data3_tr = data3[0:data3.shape[0] - num_duplicates]
        hnsw_index.add_items(data3_tr, labels3_tr, replace_deleted=True)

        # After replacing, all labels should be retrievable
        print("Checking that remaining labels are in index")
        # Get remaining data from batch 1 and batch 2 after deletion of elements
        remaining_labels = (set(labels1) | set(labels2)) - labels2_deleted_no_dup
        remaining_labels_list = list(remaining_labels)
        comb_data = np.concatenate((data1, data2), axis=0)
        remaining_data = comb_data[remaining_labels_list]

        returned_items = hnsw_index.get_items(remaining_labels_list)
        self.assertTrue((remaining_data == returned_items).all())

        returned_items = hnsw_index.get_items(labels3_tr)
        self.assertTrue((data3_tr == returned_items).all())

        # Check index serialization
        # Delete batch 3
        print("Deleting batch 3")
        for l in labels3_tr:
            hnsw_index.mark_deleted(l)

        # Save index
        index_path = "index.bin"
        print(f"Saving index to {index_path}")
        hnsw_index.save_index(index_path)
        del hnsw_index

        # Reinit and load the index
        hnsw_index = hnswlib.Index(space='l2', dim=dim)  # the space can be changed - keeps the data, alters the distance function.
        hnsw_index.set_num_threads(4)
        print(f"Loading index from {index_path}")
        hnsw_index.load_index(index_path, max_elements=max_num_elements, allow_replace_deleted=True)

        # Insert batch 4
        print("Inserting batch 4 by replacing deleted elements")
        labels4_tr = labels4[0:labels4.shape[0] - num_duplicates]
        data4_tr = data4[0:data4.shape[0] - num_duplicates]
        hnsw_index.add_items(data4_tr, labels4_tr, replace_deleted=True)

        # Check recall
        print("Checking recall")
        labels_found, _ = hnsw_index.knn_query(data4_tr, k=1)
        recall = np.mean(labels_found.reshape(-1) == labels4_tr)
        print(f"Recall for the 4 batch: {recall}")
        self.assertGreater(recall, recall_threshold)

        # Delete batch 4
        print("Deleting batch 4")
        for l in labels4_tr:
            hnsw_index.mark_deleted(l)

        print("Testing pickle serialization")
        hnsw_index_pckl = pickle.loads(pickle.dumps(hnsw_index))
        del hnsw_index
        # Insert batch 3
        print("Inserting batch 3 by replacing deleted elements")
        hnsw_index_pckl.add_items(data3_tr, labels3_tr, replace_deleted=True)

        # Check recall
        print("Checking recall")
        labels_found, _ = hnsw_index_pckl.knn_query(data3_tr, k=1)
        recall = np.mean(labels_found.reshape(-1) == labels3_tr)
        print(f"Recall for the 3 batch: {recall}")
        self.assertGreater(recall, recall_threshold)

        os.remove(index_path)


    def test_recall_degradation(self):
        """
            Compares recall of the index with replaced elements and without
            Measures recall degradation
        """
        dim = 16
        num_elements = 10_000
        max_num_elements = 2 * num_elements
        query_size = 1_000
        k = 100

        recall_threshold = 0.98
        max_recall_diff = 0.02

        # Generating sample data
        print("Generating data")
        # batch 1
        first_id = 0
        last_id = num_elements
        labels1 = np.arange(first_id, last_id)
        data1 = np.float32(np.random.random((num_elements, dim)))
        # batch 2
        first_id += num_elements
        last_id += num_elements
        labels2 = np.arange(first_id, last_id)
        data2 = np.float32(np.random.random((num_elements, dim)))
        # batch 3
        first_id += num_elements
        last_id += num_elements
        labels3 = np.arange(first_id, last_id)
        data3 = np.float32(np.random.random((num_elements, dim)))
        # query to test recall
        query_data = np.float32(np.random.random((query_size, dim)))

        # Declaring index
        hnsw_index_no_replace = hnswlib.Index(space='l2', dim=dim)
        hnsw_index_no_replace.init_index(max_elements=max_num_elements, ef_construction=200, M=16, allow_replace_deleted=False)
        hnsw_index_with_replace = hnswlib.Index(space='l2', dim=dim)
        hnsw_index_with_replace.init_index(max_elements=max_num_elements, ef_construction=200, M=16, allow_replace_deleted=True)

        bf_index = hnswlib.BFIndex(space='l2', dim=dim)
        bf_index.init_index(max_elements=max_num_elements)

        hnsw_index_no_replace.set_ef(100)
        hnsw_index_no_replace.set_num_threads(50)
        hnsw_index_with_replace.set_ef(100)
        hnsw_index_with_replace.set_num_threads(50)

        # Add data
        print("Adding data")
        hnsw_index_with_replace.add_items(data1, labels1)
        hnsw_index_with_replace.add_items(data2, labels2)  # maximum number of elements is reached
        bf_index.add_items(data1, labels1)
        bf_index.add_items(data3, labels3)  # maximum number of elements is reached

        for l in labels2:
            hnsw_index_with_replace.mark_deleted(l)
        hnsw_index_with_replace.add_items(data3, labels3, replace_deleted=True)

        hnsw_index_no_replace.add_items(data1, labels1)
        hnsw_index_no_replace.add_items(data3, labels3)  # maximum number of elements is reached

        # Query the elements and measure recall:
        labels_hnsw_with_replace, _ = hnsw_index_with_replace.knn_query(query_data, k)
        labels_hnsw_no_replace, _ = hnsw_index_no_replace.knn_query(query_data, k)
        labels_bf, distances_bf = bf_index.knn_query(query_data, k)

        # Measure recall
        correct_with_replace = 0
        correct_no_replace = 0
        for i in range(query_size):
            for label in labels_hnsw_with_replace[i]:
                for correct_label in labels_bf[i]:
                    if label == correct_label:
                        correct_with_replace += 1
                        break
            for label in labels_hnsw_no_replace[i]:
                for correct_label in labels_bf[i]:
                    if label == correct_label:
                        correct_no_replace += 1
                        break

        recall_with_replace = float(correct_with_replace) / (k*query_size)
        recall_no_replace = float(correct_no_replace) / (k*query_size)
        print("recall with replace:", recall_with_replace)
        print("recall without replace:", recall_no_replace)

        recall_diff = abs(recall_with_replace - recall_with_replace)

        self.assertGreater(recall_no_replace, recall_threshold)
        self.assertLess(recall_diff, max_recall_diff)