File: bindings_test_pickle.py

package info (click to toggle)
hnswlib 0.8.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 628 kB
  • sloc: cpp: 4,809; python: 1,113; makefile: 32; sh: 18
file content (151 lines) | stat: -rw-r--r-- 6,494 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import pickle
import unittest

import numpy as np

import hnswlib


def get_dist(metric, pt1, pt2):
    if metric == 'l2':
        return np.sum((pt1-pt2)**2)
    elif metric == 'ip':
        return 1. - np.sum(np.multiply(pt1, pt2))
    elif metric == 'cosine':
        return 1. - np.sum(np.multiply(pt1, pt2)) / (np.sum(pt1**2) * np.sum(pt2**2))**.5


def brute_force_distances(metric, items, query_items, k):
    dists = np.zeros((query_items.shape[0], items.shape[0]))
    for ii in range(items.shape[0]):
        for jj in range(query_items.shape[0]):
            dists[jj,ii] = get_dist(metric, items[ii, :], query_items[jj, :])

    labels = np.argsort(dists, axis=1) # equivalent, but faster: np.argpartition(dists, range(k), axis=1)
    dists = np.sort(dists, axis=1)     # equivalent, but faster: np.partition(dists, range(k), axis=1)

    return labels[:, :k], dists[:, :k]


def check_ann_results(self, metric, items, query_items, k, ann_l, ann_d, err_thresh=0, total_thresh=0, dists_thresh=0):
    brute_l, brute_d = brute_force_distances(metric, items, query_items, k)
    err_total = 0
    for jj in range(query_items.shape[0]):
        err = np.sum(np.isin(brute_l[jj, :], ann_l[jj, :], invert=True))
        if err > 0:
            print(f"Warning: {err} labels are missing from ann results (k={k}, err_thresh={err_thresh})")

        if err > err_thresh:
            err_total += 1

    self.assertLessEqual(err_total, total_thresh, f"Error: knn_query returned incorrect labels for {err_total} items (k={k})")

    wrong_dists = np.sum(((brute_d - ann_d)**2.) > 1e-3)
    if wrong_dists > 0:
        dists_count = brute_d.shape[0]*brute_d.shape[1]
        print(f"Warning: {wrong_dists} ann distance values are different from brute-force values (total # of values={dists_count}, dists_thresh={dists_thresh})")

    self.assertLessEqual(wrong_dists, dists_thresh, msg=f"Error: {wrong_dists} ann distance values are different from brute-force values")


def test_space_main(self, space, dim):

    # Generating sample data
    data = np.float32(np.random.random((self.num_elements, dim)))
    test_data = np.float32(np.random.random((self.num_test_elements, dim)))

    # Declaring index
    p = hnswlib.Index(space=space, dim=dim)  # possible options are l2, cosine or ip
    print(f"Running pickle tests for {p}")

    p.num_threads = self.num_threads  # by default using all available cores

    p0 = pickle.loads(pickle.dumps(p)) # pickle un-initialized Index
    p.init_index(max_elements=self.num_elements, ef_construction=self.ef_construction, M=self.M)
    p0.init_index(max_elements=self.num_elements, ef_construction=self.ef_construction, M=self.M)

    p.ef = self.ef
    p0.ef = self.ef

    p1 = pickle.loads(pickle.dumps(p)) # pickle Index before adding items

    # add items to ann index p,p0,p1
    p.add_items(data)
    p1.add_items(data)
    p0.add_items(data)

    p2=pickle.loads(pickle.dumps(p)) # pickle Index before adding items

    self.assertTrue(np.allclose(p.get_items(), p0.get_items()), "items for p and p0 must be same")
    self.assertTrue(np.allclose(p0.get_items(), p1.get_items()), "items for p0 and p1 must be same")
    self.assertTrue(np.allclose(p1.get_items(), p2.get_items()), "items for p1 and p2 must be same")

    # Test if returned distances are same
    l, d = p.knn_query(test_data, k=self.k)
    l0, d0 = p0.knn_query(test_data, k=self.k)
    l1, d1 = p1.knn_query(test_data, k=self.k)
    l2, d2 = p2.knn_query(test_data, k=self.k)

    self.assertLessEqual(np.sum(((d-d0)**2.) > 1e-3), self.dists_err_thresh, msg=f"knn distances returned by p and p0 must match")
    self.assertLessEqual(np.sum(((d0-d1)**2.) > 1e-3), self.dists_err_thresh, msg=f"knn distances returned by p0 and p1 must match")
    self.assertLessEqual(np.sum(((d1-d2)**2.) > 1e-3), self.dists_err_thresh, msg=f"knn distances returned by p1 and p2 must match")

    # check if ann results match brute-force search
    #   allow for 2 labels to be missing from ann results
    check_ann_results(self, space, data, test_data, self.k, l, d,
                           err_thresh=self.label_err_thresh,
                           total_thresh=self.item_err_thresh,
                           dists_thresh=self.dists_err_thresh)

    check_ann_results(self, space, data, test_data, self.k, l2, d2,
                           err_thresh=self.label_err_thresh,
                           total_thresh=self.item_err_thresh,
                           dists_thresh=self.dists_err_thresh)

    # Check ef parameter value
    self.assertEqual(p.ef, self.ef, "incorrect value of p.ef")
    self.assertEqual(p0.ef, self.ef, "incorrect value of p0.ef")
    self.assertEqual(p2.ef, self.ef, "incorrect value of p2.ef")
    self.assertEqual(p1.ef, self.ef, "incorrect value of p1.ef")

    # Check M parameter value
    self.assertEqual(p.M, self.M, "incorrect value of p.M")
    self.assertEqual(p0.M, self.M, "incorrect value of p0.M")
    self.assertEqual(p1.M, self.M, "incorrect value of p1.M")
    self.assertEqual(p2.M, self.M, "incorrect value of p2.M")

    # Check ef_construction parameter value
    self.assertEqual(p.ef_construction, self.ef_construction, "incorrect value of p.ef_construction")
    self.assertEqual(p0.ef_construction, self.ef_construction, "incorrect value of p0.ef_construction")
    self.assertEqual(p1.ef_construction, self.ef_construction, "incorrect value of p1.ef_construction")
    self.assertEqual(p2.ef_construction, self.ef_construction, "incorrect value of p2.ef_construction")


class PickleUnitTests(unittest.TestCase):

    def setUp(self):
        self.ef_construction = 200
        self.M = 32
        self.ef = 400

        self.num_elements = 1000
        self.num_test_elements = 100

        self.num_threads = 4
        self.k = 25

        self.label_err_thresh = 5  # max number of missing labels allowed per test item
        self.item_err_thresh = 5   # max number of items allowed with incorrect labels

        self.dists_err_thresh = 50 # for two matrices, d1 and d2, dists_err_thresh controls max
                                 # number of value pairs that are allowed to be different in d1 and d2
                                 # i.e., number of values that are (d1-d2)**2>1e-3

    def test_inner_product_space(self):
        test_space_main(self, 'ip', 16)

    def test_l2_space(self):
        test_space_main(self, 'l2', 53)

    def test_cosine_space(self):
        test_space_main(self, 'cosine', 32)