File: bench_rabitq_simd.cpp

package info (click to toggle)
faiss 1.13.2-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 9,228 kB
  • sloc: cpp: 91,727; python: 31,865; sh: 874; ansic: 425; makefile: 41
file content (100 lines) | stat: -rw-r--r-- 3,077 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */

#include <benchmark/benchmark.h>
#include <faiss/utils/AlignedTable.h>
#include <faiss/utils/rabitq_simd.h>
#include <faiss/utils/random.h>

namespace faiss {

const auto& randomData() {
    static auto data = [] {
        AlignedTable<uint8_t> x(10 << 20); // 10 MiB
        byte_rand(x.data(), x.size(), 456);
        return x;
    }();
    return data;
}

void bench_rabitq_generic(benchmark::State& state, auto distFn) {
    uint8_t qb = state.range(0);
    size_t d = state.range(1);
    size_t size = (d + 7) / 8;

    auto& x = randomData();
    AlignedTable<uint8_t> q(qb * size);
    byte_rand(q.data(), q.size(), 123);

    size_t n = x.size() / size;

    uint64_t sum = 0;
    size_t r = 0;
    for (auto _ : state) {
        ++r;
        for (size_t i = 0; i < n; ++i) {
            sum += distFn(q.data(), x.data() + i * size, size, qb);
        }
        benchmark::DoNotOptimize(sum);
    }
    state.SetItemsProcessed(n * r);
    state.SetBytesProcessed(r * x.size());
}

void bench_rabitq_sum(benchmark::State& state) {
    bench_rabitq_generic(
            state,
            [](const uint8_t*, const uint8_t* x, size_t size, size_t)
                    -> int64_t { return rabitq::popcount(x, size); });
}

void bench_rabitq_and_dot_product(benchmark::State& state) {
    bench_rabitq_generic(
            state,
            [](const uint8_t* q, const uint8_t* x, size_t size, size_t qb)
                    -> int64_t {
                return rabitq::bitwise_and_dot_product(q, x, size, qb);
            });
}

void bench_rabitq_xor_dot_product(benchmark::State& state) {
    bench_rabitq_generic(
            state,
            [](const uint8_t* q, const uint8_t* x, size_t size, size_t qb)
                    -> int64_t {
                return rabitq::bitwise_xor_dot_product(q, x, size, qb);
            });
}

void bench_rabitq_and_dot_product_with_sum(benchmark::State& state) {
    bench_rabitq_generic(
            state,
            [](const uint8_t* q, const uint8_t* x, size_t size, size_t qb)
                    -> int64_t {
                auto sum_q = rabitq::popcount(x, size);
                auto dp = rabitq::bitwise_and_dot_product(q, x, size, qb);
                // Synthetic operation using both inputs for benchmarking.
                return sum_q + dp;
            });
}

const std::vector<int64_t> qbs{1, 2, 4, 8};
const std::vector<int64_t> dims{64, 100, 256, 512, 1000, 1024, 3072};

BENCHMARK(bench_rabitq_sum)->ArgsProduct({{0}, dims})->ArgNames({"qb", "d"});
BENCHMARK(bench_rabitq_and_dot_product)
        ->ArgsProduct({qbs, dims})
        ->ArgNames({"qb", "d"});
BENCHMARK(bench_rabitq_xor_dot_product)
        ->ArgsProduct({qbs, dims})
        ->ArgNames({"qb", "d"});
BENCHMARK(bench_rabitq_and_dot_product_with_sum)
        ->ArgsProduct({qbs, dims})
        ->ArgNames({"qb", "d"});
BENCHMARK_MAIN();

} // namespace faiss