1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
|
import pickle
import unittest
import numpy as np
import hnswlib
def get_dist(metric, pt1, pt2):
if metric == 'l2':
return np.sum((pt1-pt2)**2)
elif metric == 'ip':
return 1. - np.sum(np.multiply(pt1, pt2))
elif metric == 'cosine':
return 1. - np.sum(np.multiply(pt1, pt2)) / (np.sum(pt1**2) * np.sum(pt2**2))**.5
def brute_force_distances(metric, items, query_items, k):
dists = np.zeros((query_items.shape[0], items.shape[0]))
for ii in range(items.shape[0]):
for jj in range(query_items.shape[0]):
dists[jj,ii] = get_dist(metric, items[ii, :], query_items[jj, :])
labels = np.argsort(dists, axis=1) # equivalent, but faster: np.argpartition(dists, range(k), axis=1)
dists = np.sort(dists, axis=1) # equivalent, but faster: np.partition(dists, range(k), axis=1)
return labels[:, :k], dists[:, :k]
def check_ann_results(self, metric, items, query_items, k, ann_l, ann_d, err_thresh=0, total_thresh=0, dists_thresh=0):
brute_l, brute_d = brute_force_distances(metric, items, query_items, k)
err_total = 0
for jj in range(query_items.shape[0]):
err = np.sum(np.isin(brute_l[jj, :], ann_l[jj, :], invert=True))
if err > 0:
print(f"Warning: {err} labels are missing from ann results (k={k}, err_thresh={err_thresh})")
if err > err_thresh:
err_total += 1
self.assertLessEqual(err_total, total_thresh, f"Error: knn_query returned incorrect labels for {err_total} items (k={k})")
wrong_dists = np.sum(((brute_d - ann_d)**2.) > 1e-3)
if wrong_dists > 0:
dists_count = brute_d.shape[0]*brute_d.shape[1]
print(f"Warning: {wrong_dists} ann distance values are different from brute-force values (total # of values={dists_count}, dists_thresh={dists_thresh})")
self.assertLessEqual(wrong_dists, dists_thresh, msg=f"Error: {wrong_dists} ann distance values are different from brute-force values")
def test_space_main(self, space, dim):
# Generating sample data
data = np.float32(np.random.random((self.num_elements, dim)))
test_data = np.float32(np.random.random((self.num_test_elements, dim)))
# Declaring index
p = hnswlib.Index(space=space, dim=dim) # possible options are l2, cosine or ip
print(f"Running pickle tests for {p}")
p.num_threads = self.num_threads # by default using all available cores
p0 = pickle.loads(pickle.dumps(p)) # pickle un-initialized Index
p.init_index(max_elements=self.num_elements, ef_construction=self.ef_construction, M=self.M)
p0.init_index(max_elements=self.num_elements, ef_construction=self.ef_construction, M=self.M)
p.ef = self.ef
p0.ef = self.ef
p1 = pickle.loads(pickle.dumps(p)) # pickle Index before adding items
# add items to ann index p,p0,p1
p.add_items(data)
p1.add_items(data)
p0.add_items(data)
p2=pickle.loads(pickle.dumps(p)) # pickle Index before adding items
self.assertTrue(np.allclose(p.get_items(), p0.get_items()), "items for p and p0 must be same")
self.assertTrue(np.allclose(p0.get_items(), p1.get_items()), "items for p0 and p1 must be same")
self.assertTrue(np.allclose(p1.get_items(), p2.get_items()), "items for p1 and p2 must be same")
# Test if returned distances are same
l, d = p.knn_query(test_data, k=self.k)
l0, d0 = p0.knn_query(test_data, k=self.k)
l1, d1 = p1.knn_query(test_data, k=self.k)
l2, d2 = p2.knn_query(test_data, k=self.k)
self.assertLessEqual(np.sum(((d-d0)**2.) > 1e-3), self.dists_err_thresh, msg=f"knn distances returned by p and p0 must match")
self.assertLessEqual(np.sum(((d0-d1)**2.) > 1e-3), self.dists_err_thresh, msg=f"knn distances returned by p0 and p1 must match")
self.assertLessEqual(np.sum(((d1-d2)**2.) > 1e-3), self.dists_err_thresh, msg=f"knn distances returned by p1 and p2 must match")
# check if ann results match brute-force search
# allow for 2 labels to be missing from ann results
check_ann_results(self, space, data, test_data, self.k, l, d,
err_thresh=self.label_err_thresh,
total_thresh=self.item_err_thresh,
dists_thresh=self.dists_err_thresh)
check_ann_results(self, space, data, test_data, self.k, l2, d2,
err_thresh=self.label_err_thresh,
total_thresh=self.item_err_thresh,
dists_thresh=self.dists_err_thresh)
# Check ef parameter value
self.assertEqual(p.ef, self.ef, "incorrect value of p.ef")
self.assertEqual(p0.ef, self.ef, "incorrect value of p0.ef")
self.assertEqual(p2.ef, self.ef, "incorrect value of p2.ef")
self.assertEqual(p1.ef, self.ef, "incorrect value of p1.ef")
# Check M parameter value
self.assertEqual(p.M, self.M, "incorrect value of p.M")
self.assertEqual(p0.M, self.M, "incorrect value of p0.M")
self.assertEqual(p1.M, self.M, "incorrect value of p1.M")
self.assertEqual(p2.M, self.M, "incorrect value of p2.M")
# Check ef_construction parameter value
self.assertEqual(p.ef_construction, self.ef_construction, "incorrect value of p.ef_construction")
self.assertEqual(p0.ef_construction, self.ef_construction, "incorrect value of p0.ef_construction")
self.assertEqual(p1.ef_construction, self.ef_construction, "incorrect value of p1.ef_construction")
self.assertEqual(p2.ef_construction, self.ef_construction, "incorrect value of p2.ef_construction")
class PickleUnitTests(unittest.TestCase):
def setUp(self):
self.ef_construction = 200
self.M = 32
self.ef = 400
self.num_elements = 1000
self.num_test_elements = 100
self.num_threads = 4
self.k = 25
self.label_err_thresh = 5 # max number of missing labels allowed per test item
self.item_err_thresh = 5 # max number of items allowed with incorrect labels
self.dists_err_thresh = 50 # for two matrices, d1 and d2, dists_err_thresh controls max
# number of value pairs that are allowed to be different in d1 and d2
# i.e., number of values that are (d1-d2)**2>1e-3
def test_inner_product_space(self):
test_space_main(self, 'ip', 16)
def test_l2_space(self):
test_space_main(self, 'l2', 53)
def test_cosine_space(self):
test_space_main(self, 'cosine', 32)
|