1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
|
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <cstdio>
#include <cstdlib>
#include <random>
#include <gtest/gtest.h>
#include <faiss/IndexFlat.h>
#include <faiss/IndexIVFPQ.h>
TEST(IVFPQ, accuracy) {
// dimension of the vectors to index
int d = 64;
// size of the database we plan to index
size_t nb = 1000;
// make a set of nt training vectors in the unit cube
// (could be the database)
size_t nt = 1500;
// make the index object and train it
faiss::IndexFlatL2 coarse_quantizer(d);
// a reasonable number of cetroids to index nb vectors
int ncentroids = 25;
faiss::IndexIVFPQ index(&coarse_quantizer, d, ncentroids, 16, 8);
// index that gives the ground-truth
faiss::IndexFlatL2 index_gt(d);
std::mt19937 rng;
std::uniform_real_distribution<> distrib;
{ // training
std::vector<float> trainvecs(nt * d);
for (size_t i = 0; i < nt * d; i++) {
trainvecs[i] = distrib(rng);
}
index.verbose = true;
index.train(nt, trainvecs.data());
}
{ // populating the database
std::vector<float> database(nb * d);
for (size_t i = 0; i < nb * d; i++) {
database[i] = distrib(rng);
}
index.add(nb, database.data());
index_gt.add(nb, database.data());
}
int nq = 200;
int n_ok;
{ // searching the database
std::vector<float> queries(nq * d);
for (size_t i = 0; i < nq * d; i++) {
queries[i] = distrib(rng);
}
std::vector<faiss::idx_t> gt_nns(nq);
std::vector<float> gt_dis(nq);
index_gt.search(nq, queries.data(), 1, gt_dis.data(), gt_nns.data());
index.nprobe = 5;
int k = 5;
std::vector<faiss::idx_t> nns(k * nq);
std::vector<float> dis(k * nq);
index.search(nq, queries.data(), k, dis.data(), nns.data());
n_ok = 0;
for (int q = 0; q < nq; q++) {
for (int i = 0; i < k; i++)
if (nns[q * k + i] == gt_nns[q])
n_ok++;
}
EXPECT_GT(n_ok, nq * 0.4);
}
}
|