File: multiThreadLoad_test.cpp

package info (click to toggle)
hnswlib 0.8.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 628 kB
  • sloc: cpp: 4,809; python: 1,113; makefile: 32; sh: 18
file content (140 lines) | stat: -rw-r--r-- 5,272 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#include "../../hnswlib/hnswlib.h"
#include <thread>
#include <chrono>


int main() {
    std::cout << "Running multithread load test" << std::endl;
    int d = 16;
    int max_elements = 1000;

    std::mt19937 rng;
    rng.seed(47);
    std::uniform_real_distribution<> distrib_real;

    hnswlib::L2Space space(d);
    hnswlib::HierarchicalNSW<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, 2 * max_elements);

    std::cout << "Building index" << std::endl;
    int num_threads = 40;
    int num_labels = 10;

    int num_iterations = 10;
    int start_label = 0;

    // run threads that will add elements to the index
    // about 7 threads (the number depends on num_threads and num_labels)
    // will add/update element with the same label simultaneously
    while (true) {
        // add elements by batches
        std::uniform_int_distribution<> distrib_int(start_label, start_label + num_labels - 1);
        std::vector<std::thread> threads;
        for (size_t thread_id = 0; thread_id < num_threads; thread_id++) {
            threads.push_back(
                std::thread(
                    [&] {
                        for (int iter = 0; iter < num_iterations; iter++) {
                            std::vector<float> data(d);
                            hnswlib::labeltype label = distrib_int(rng);
                            for (int i = 0; i < d; i++) {
                                data[i] = distrib_real(rng);
                            }
                            alg_hnsw->addPoint(data.data(), label);
                        }
                    }
                )
            );
        }
        for (auto &thread : threads) {
            thread.join();
        }
        if (alg_hnsw->cur_element_count > max_elements - num_labels) {
            break;
        }
        start_label += num_labels;
    }

    // insert remaining elements if needed
    for (hnswlib::labeltype label = 0; label < max_elements; label++) {
        auto search = alg_hnsw->label_lookup_.find(label);
        if (search == alg_hnsw->label_lookup_.end()) {
            std::cout << "Adding " << label << std::endl;
            std::vector<float> data(d);
            for (int i = 0; i < d; i++) {
                data[i] = distrib_real(rng);
            }
            alg_hnsw->addPoint(data.data(), label);
        }
    }

    std::cout << "Index is created" << std::endl;

    bool stop_threads = false;
    std::vector<std::thread> threads;

    // create threads that will do markDeleted and unmarkDeleted of random elements
    // each thread works with specific range of labels
    std::cout << "Starting markDeleted and unmarkDeleted threads" << std::endl;
    num_threads = 20;
    int chunk_size = max_elements / num_threads;
    for (size_t thread_id = 0; thread_id < num_threads; thread_id++) {
        threads.push_back(
            std::thread(
                [&, thread_id] {
                    std::uniform_int_distribution<> distrib_int(0, chunk_size - 1);
                    int start_id = thread_id * chunk_size;
                    std::vector<bool> marked_deleted(chunk_size);
                    while (!stop_threads) {
                        int id = distrib_int(rng);
                        hnswlib::labeltype label = start_id + id;
                        if (marked_deleted[id]) {
                            alg_hnsw->unmarkDelete(label);
                            marked_deleted[id] = false;
                        } else {
                            alg_hnsw->markDelete(label);
                            marked_deleted[id] = true;
                        }
                    }
                }
            )
        );
    }

    // create threads that will add and update random elements
    std::cout << "Starting add and update elements threads" << std::endl;
    num_threads = 20;
    std::uniform_int_distribution<> distrib_int_add(max_elements, 2 * max_elements - 1);
    for (size_t thread_id = 0; thread_id < num_threads; thread_id++) {
        threads.push_back(
            std::thread(
                [&] {
                    std::vector<float> data(d);
                    while (!stop_threads) {
                        hnswlib::labeltype label = distrib_int_add(rng);
                        for (int i = 0; i < d; i++) {
                            data[i] = distrib_real(rng);
                        }
                        alg_hnsw->addPoint(data.data(), label);
                        std::vector<float> data = alg_hnsw->getDataByLabel<float>(label);
                        float max_val = *max_element(data.begin(), data.end());
                        // never happens but prevents compiler from deleting unused code
                        if (max_val > 10) {
                            throw std::runtime_error("Unexpected value in data");
                        }
                    }
                }
            )
        );
    }

    std::cout << "Sleep and continue operations with index" << std::endl;
    int sleep_ms = 60 * 1000;
    std::this_thread::sleep_for(std::chrono::milliseconds(sleep_ms));
    stop_threads = true;
    for (auto &thread : threads) {
        thread.join();
    }
    
    std::cout << "Finish" << std::endl;
    return 0;
}