File: tests-OSA.cpp

package info (click to toggle)
rapidfuzz-cpp 3.3.2-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,480 kB
  • sloc: cpp: 30,893; python: 63; makefile: 26; sh: 8
file content (97 lines) | stat: -rw-r--r-- 3,218 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#if CATCH2_VERSION == 2
#    include <catch2/catch.hpp>
#else
#    include <catch2/catch_test_macros.hpp>
#    include <catch2/matchers/catch_matchers_floating_point.hpp>
#endif

#include <rapidfuzz/details/types.hpp>
#include <rapidfuzz/distance/OSA.hpp>
#include <string>

#include "../common.hpp"

template <typename Sentence1, typename Sentence2>
size_t osa_distance(const Sentence1& s1, const Sentence2& s2, size_t max = std::numeric_limits<size_t>::max())
{
    size_t res1 = rapidfuzz::osa_distance(s1, s2, max);
    size_t res2 = rapidfuzz::osa_distance(s1.begin(), s1.end(), s2.begin(), s2.end(), max);
    size_t res3 = rapidfuzz::osa_distance(make_bidir(s1.begin()), make_bidir(s1.end()),
                                          make_bidir(s2.begin()), make_bidir(s2.end()), max);
    rapidfuzz::CachedOSA<rapidfuzz::char_type<Sentence1>> scorer(s1);
    size_t res4 = scorer.distance(s2, max);
    size_t res5 = scorer.distance(s2.begin(), s2.end(), max);
#ifdef RAPIDFUZZ_SIMD
    if (s1.size() <= 64) {
        std::vector<size_t> results(256 / 8);

        if (s1.size() <= 8) {
            rapidfuzz::experimental::MultiOSA<8> simd_scorer(1);
            simd_scorer.insert(s1);
            simd_scorer.distance(&results[0], results.size(), s2, max);
            REQUIRE(res1 == results[0]);
        }
        if (s1.size() <= 16) {
            rapidfuzz::experimental::MultiOSA<16> simd_scorer(1);
            simd_scorer.insert(s1);
            simd_scorer.distance(&results[0], results.size(), s2, max);
            REQUIRE(res1 == results[0]);
        }
        if (s1.size() <= 32) {
            rapidfuzz::experimental::MultiOSA<32> simd_scorer(1);
            simd_scorer.insert(s1);
            simd_scorer.distance(&results[0], results.size(), s2, max);
            REQUIRE(res1 == results[0]);
        }
        if (s1.size() <= 64) {
            rapidfuzz::experimental::MultiOSA<64> simd_scorer(1);
            simd_scorer.insert(s1);
            simd_scorer.distance(&results[0], results.size(), s2, max);
            REQUIRE(res1 == results[0]);
        }
    }
#endif
    REQUIRE(res1 == res2);
    REQUIRE(res1 == res3);
    REQUIRE(res1 == res4);
    REQUIRE(res1 == res5);
    return res1;
}

/* test some very simple cases of the osa distance */
TEST_CASE("osa[simple]")
{
    {
        std::string s1 = "";
        std::string s2 = "";
        REQUIRE(osa_distance(s1, s2) == 0);
    }

    {
        std::string s1 = "aaaa";
        std::string s2 = "";
        REQUIRE(osa_distance(s1, s2) == 4);
        REQUIRE(osa_distance(s2, s1) == 4);
        REQUIRE(osa_distance(s1, s2, 1) == 2);
        REQUIRE(osa_distance(s2, s1, 1) == 2);
    }

    {
        std::string s1 = "CA";
        std::string s2 = "ABC";
        REQUIRE(osa_distance(s1, s2) == 3);
    }

    {
        std::string s1 = "CA";
        std::string s2 = "AC";
        REQUIRE(osa_distance(s1, s2) == 1);
    }

    {
        std::string filler = str_multiply(std::string("a"), 64);
        std::string s1 = std::string("a") + filler + "CA" + filler + std::string("a");
        std::string s2 = std::string("b") + filler + "AC" + filler + std::string("b");
        REQUIRE(osa_distance(s1, s2) == 3);
    }
}