File: search_scheme_algorithm_test.cpp

package info (click to toggle)
seqan3 3.0.2%2Bds-9
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 16,052 kB
  • sloc: cpp: 144,641; makefile: 1,288; ansic: 294; sh: 228; xml: 217; javascript: 50; python: 27; php: 25
file content (335 lines) | stat: -rw-r--r-- 15,953 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
// -----------------------------------------------------------------------------------------------------
// Copyright (c) 2006-2020, Knut Reinert & Freie Universität Berlin
// Copyright (c) 2016-2020, Knut Reinert & MPI für molekulare Genetik
// This file may be used, modified and/or redistributed under the terms of the 3-clause BSD-License
// shipped with this file and also available at: https://github.com/seqan/seqan3/blob/master/LICENSE.md
// -----------------------------------------------------------------------------------------------------

#include <algorithm>
#include <type_traits>

#include "helper.hpp"
#include "helper_search_scheme.hpp"
#include <seqan3/test/performance/sequence_generator.hpp>

#include <seqan3/core/debug_stream.hpp>
#include <seqan3/search/configuration/default_configuration.hpp>
#include <seqan3/search/detail/search_configurator.hpp>
#include <seqan3/search/detail/search_scheme_algorithm.hpp>
#include <seqan3/search/detail/unidirectional_search_algorithm.hpp>
#include <seqan3/search/detail/policy_max_error.hpp>
#include <seqan3/search/detail/policy_search_result_builder.hpp>
#include <seqan3/search/fm_index/all.hpp>
#include <seqan3/range/views/slice.hpp>
#include <seqan3/range/views/to.hpp>

#include <gtest/gtest.h>

// Uses the trivial search of the unidirectional search algorithm.
// The algorithm is configured with the corrsponding configuration types.
// To modify the trivial search use the configuration settings of the search algorithm.
template <typename index_t, typename query_t, typename delegate_t>
static void search_trivial(index_t const & index,
                            query_t const & query,
                            seqan3::detail::search_param error_left,
                            delegate_t && delegate)
{
    using namespace seqan3::detail;

    // Configure the algorithm according to the given specifications.
    auto cfg = seqan3::search_cfg::max_error_total{seqan3::search_cfg::error_count{error_left.total}} |
               seqan3::search_cfg::max_error_substitution{seqan3::search_cfg::error_count{error_left.substitution}} |
               seqan3::search_cfg::max_error_insertion{seqan3::search_cfg::error_count{error_left.insertion}} |
               seqan3::search_cfg::max_error_deletion{seqan3::search_cfg::error_count{error_left.deletion}} |
               seqan3::search_cfg::hit_all{} |
               seqan3::search_cfg::output_index_cursor{};

    auto indexed_query = std::pair{size_t{0}, query};
    auto algo = std::get<0>(seqan3::detail::search_configurator::configure_algorithm<decltype(indexed_query)>(cfg,
                                                                                                              index));

    // Call the algorithm and call the delegate with the returned index cursor.
    algo(indexed_query, [&] (auto result) { delegate(result.index_cursor()); });
}

template <typename text_t>
inline void test_search_hamming(auto index, text_t const & text, auto const & search, uint64_t const query_length,
                                std::vector<uint8_t> const & error_distribution, size_t const seed,
                                auto const & blocks_length, auto const & ordered_blocks_length,
                                uint64_t const start_pos)
{
    using char_t = typename text_t::value_type;

    uint64_t const pos = std::rand() % (text.size() - query_length + 1);
    text_t const orig_query = text | seqan3::views::slice(pos, pos + query_length) | seqan3::views::to<text_t>;

    // Modify query s.t. it has errors matching error_distribution.
    auto query = orig_query;
    auto it = index.cursor();
    uint64_t current_blocks_length = 0;
    for (uint8_t block = 0; block < search.blocks(); ++block)
    {
        uint64_t const single_block_length = ordered_blocks_length[block];
        EXPECT_LE(error_distribution[block], single_block_length);
        if (error_distribution[block] > single_block_length)
        {
            seqan3::debug_stream << "Error in block " << block << "(+ 1): " << error_distribution[block]
                                 << " errors cannot fit into a block of length " << single_block_length << "." << '\n'
                                 << "Error Distribution: " << error_distribution << '\n';
            exit(1);
        }

        // Choose random positions in the query sequence for substitutions. Repeat until all error positions are unique.
        std::vector<uint8_t> error_positions(error_distribution[block]);
        do
        {
            error_positions.clear();
            for (uint8_t error = 0; error < error_distribution[block]; ++error)
                error_positions.push_back(std::rand() % single_block_length);
            std::sort(error_positions.begin(), error_positions.end());
        } while (std::adjacent_find(error_positions.begin(), error_positions.end()) != error_positions.end());

        // Construct query sequence with chosen error positions.
        for (uint8_t error = 0; error < error_positions.size(); ++error)
        {
            uint64_t const pos = error_positions[error] + current_blocks_length;
            // Decrease alphabet size by one because we don't want to replace query[pos], with the same character.
            uint8_t new_rank = std::rand() % (seqan3::alphabet_size<char_t> - 1);
            // If it is a match now, it can't be the highest rank of the alphabet. Thus we set it to the highest rank.
            if (new_rank == seqan3::to_rank(query[pos]))
                new_rank = seqan3::alphabet_size<char_t> - 1;
            seqan3::assign_rank_to(new_rank, query[pos]);
        }
        current_blocks_length += single_block_length;
    }

    std::vector<uint64_t> hits_trivial, hits_ss;

    auto delegate_trivial = [&hits_trivial] (auto const & it)
    {
        for (auto && res : it.locate())
            hits_trivial.push_back(res.second);
    };

    auto delegate_ss = [&hits_ss] (auto const & it)
    {
        for (auto && res : it.locate())
            hits_ss.push_back(res.second);
    };

    auto remove_predicate_ss = [&text, &orig_query, query_length] (uint64_t const hit)
    {
        seqan3::dna4_vector matched_seq = text | seqan3::views::slice(hit, hit + query_length)
                                               | seqan3::views::to<seqan3::dna4_vector>;
        return (matched_seq != orig_query);
    };

    auto remove_predicate_trivial = [&] (uint64_t const hit)
    {
        // filter only correct error distributions
        seqan3::dna4_vector matched_seq = text | seqan3::views::slice(hit, hit + query_length)
                                               | seqan3::views::to<seqan3::dna4_vector>;
        if (orig_query != matched_seq)
            return true;

        uint64_t lb = 0, rb = 0;
        uint8_t total_errors = 0;
        for (uint8_t block = 0; block < search.blocks(); ++block)
        {
            rb += ordered_blocks_length[block];

            uint8_t errors = 0;
            for (uint64_t i = lb; i < rb; ++i)
                if (hit + i >= text.size())
                    ++errors;
                else
                    errors += query[i] != text[hit + i];
            total_errors += errors;
            if (errors != error_distribution[block])
                return true;
            lb += ordered_blocks_length[block];
        }
        return false;
    };

    uint8_t const total        = search.u.back();
    uint8_t const substitution = std::rand() % (total + 1);

    seqan3::detail::search_param error_left{total, substitution, 0, 0};

    // Find all hits using search schemes.
    seqan3::detail::search_ss<false>(it, query, start_pos, start_pos + 1, 0, 0, true, search, blocks_length, error_left,
                                     delegate_ss);

    // Find all hits using trivial backtracking.
    search_trivial(index, query, error_left, delegate_trivial);

    // Eliminate hits that we are not interested in (based on the search and chosen error distribution)
    hits_ss.erase(std::remove_if(hits_ss.begin(), hits_ss.end(), remove_predicate_ss), hits_ss.end());

    hits_trivial.erase(std::remove_if(hits_trivial.begin(), hits_trivial.end(), remove_predicate_trivial),
                       hits_trivial.end());

    // Eliminate duplicates
    hits_ss = seqan3::uniquify(hits_ss);
    hits_trivial = seqan3::uniquify(hits_trivial);

    EXPECT_EQ(hits_ss, hits_trivial);
    if (hits_ss != hits_trivial)
    {
        seqan3::debug_stream << "Seed: " << seed << '\n'
                             << "Text: " << text << '\n'
                             << "Query: " << query << '\n'
                             << "Errors: " << total << ", " << substitution << '\n'
                             << "SS hits: " << hits_ss << '\n'
                             << "Trivial hits: " << hits_trivial << '\n';
    }
}

template <typename search_scheme_t>
inline void test_search_scheme_hamming(search_scheme_t const & search_scheme, size_t const seed,
                                       uint64_t const iterations)
{
    seqan3::dna4_vector text;

    search_scheme_t ordered_search_scheme;
    std::vector<std::vector<std::vector<uint8_t> > > error_distributions(search_scheme.size());

    // Calculate all error distributions and sort each of them (from left to right).
    uint8_t max_error = 0;
    for (uint8_t search_id = 0; search_id < search_scheme.size(); ++search_id)
    {
        ordered_search_scheme[search_id] = search_scheme[search_id];
        seqan3::search_error_distribution(error_distributions[search_id], search_scheme[search_id]);
        for (std::vector<uint8_t> & resElem : error_distributions[search_id])
            seqan3::order_search_vector(resElem, search_scheme[search_id]);
        max_error = std::max(max_error, search_scheme[search_id].u.back());
    }

    for (uint64_t text_length = 10; text_length < 10000; text_length *= 10)
    {
        uint64_t const query_length_min = std::max<uint64_t>(3, search_scheme.front().blocks() * max_error);
        uint64_t const query_length_max = std::min<uint64_t>(16, text_length);

        text = seqan3::test::generate_sequence<seqan3::dna4>(text_length, 0/*variance*/, seed);
        seqan3::bi_fm_index index(text);

        for (uint64_t i = 0; i < iterations; ++i)
        {
            for (uint64_t query_length = query_length_min; query_length < query_length_max; ++query_length)
            {
                auto const block_info = search_scheme_block_info(search_scheme, query_length);
                for (uint8_t search_id = 0; search_id < search_scheme.size(); ++search_id)
                {
                    auto const & [blocks_length, start_pos] = block_info[search_id];

                    std::vector<uint64_t> ordered_blocks_length;
                    seqan3::get_ordered_search(search_scheme[search_id], blocks_length,
                                               ordered_search_scheme[search_id], ordered_blocks_length);

                    for (auto && error_distribution : error_distributions[search_id])
                    {
                        test_search_hamming(index, text, search_scheme[search_id], query_length,
                                            error_distribution, seed, blocks_length, ordered_blocks_length, start_pos);
                    }
                }
            }
        }
    }
}

template <typename search_scheme_t>
inline void test_search_scheme_edit(search_scheme_t const & search_scheme, size_t const seed, uint64_t const iterations)
{
    seqan3::dna4_vector text, query;

    // retrieve maximum number of errors from search_scheme
    uint8_t max_error = 0;
    for (auto const & search : search_scheme)
        max_error = std::max(max_error, search.u.back());

    for (uint64_t text_length = 10; text_length < 10000; text_length *= 10)
    {
        uint64_t const query_length_min = std::max<uint64_t>(3, search_scheme.front().blocks() * max_error);
        uint64_t const query_length_max = std::min<uint64_t>(16, text_length);

        text = seqan3::test::generate_sequence<seqan3::dna4>(text_length, 0/*variance*/, seed);
        seqan3::bi_fm_index index(text);

        uint8_t const substitution = std::rand() % (max_error + 1);
        uint8_t const insertion    = std::rand() % (max_error + 1);
        uint8_t const deletion     = std::rand() % (max_error + 1);
        seqan3::detail::search_param error_left{max_error, substitution, insertion, deletion};

        for (uint64_t i = 0; i < iterations; ++i)
        {
            for (uint64_t query_length = query_length_min; query_length < query_length_max; ++query_length)
            {
                query = seqan3::test::generate_sequence<seqan3::dna4>(query_length, 0/*variance*/, seed);

                std::vector<uint64_t> hits_trivial, hits_ss;

                auto delegate_trivial = [&hits_trivial] (auto const & it)
                {
                    for (auto && res : it.locate())
                        hits_trivial.push_back(res.second);
                };

                auto delegate_ss = [&hits_ss] (auto const & it)
                {
                    for (auto && res : it.locate())
                        hits_ss.push_back(res.second);
                };

                // Find all hits using search schemes.
                seqan3::detail::search_ss<false>(index, query, error_left, search_scheme, delegate_ss);
                // Find all hits using trivial backtracking.
                search_trivial(index, query, error_left, delegate_trivial);

                // Eliminate duplicates
                hits_ss = seqan3::uniquify(hits_ss);
                hits_trivial = seqan3::uniquify(hits_trivial);

                EXPECT_EQ(hits_ss, hits_trivial);
                if (hits_ss != hits_trivial)
                {
                    seqan3::debug_stream << "Seed: " << seed << '\n'
                                         << "Text: " << text << '\n'
                                         << "Query: " << query << '\n'
                                         << "Errors: " << max_error << ", " << substitution << ", "
                                                       << insertion << ", " << deletion << '\n'
                                         << "SS hits: " << hits_ss << '\n'
                                         << "Trivial hits: " << hits_trivial << '\n';
                }
            }
        }
    }
}

TEST(search_scheme_test, search_scheme_hamming)
{
    size_t seed = 42;

    test_search_scheme_hamming(seqan3::detail::optimum_search_scheme<0, 0>, seed, 10);
    test_search_scheme_hamming(seqan3::detail::optimum_search_scheme<0, 1>, seed, 10);
    test_search_scheme_hamming(seqan3::detail::optimum_search_scheme<1, 1>, seed, 10);
    test_search_scheme_hamming(seqan3::detail::optimum_search_scheme<0, 2>, seed, 10);
    test_search_scheme_hamming(seqan3::detail::optimum_search_scheme<1, 2>, seed, 10);
    test_search_scheme_hamming(seqan3::detail::optimum_search_scheme<2, 2>, seed, 10);
    test_search_scheme_hamming(seqan3::detail::optimum_search_scheme<0, 3>, seed, 10);
    test_search_scheme_hamming(seqan3::detail::optimum_search_scheme<1, 3>, seed, 10);
    test_search_scheme_hamming(seqan3::detail::optimum_search_scheme<2, 3>, seed, 10);
    // test_search_scheme_hamming(seqan3::detail::optimum_search_scheme<3, 3>, seed, 10);
}

TEST(search_scheme_test, search_scheme_edit)
{
    size_t seed = 42;

    // TODO: test with lower bounds != 0.
    // For that we need alignment statistics to know the number of errors spent in search_trivial
    test_search_scheme_edit(seqan3::detail::optimum_search_scheme<0, 0>, seed, 10);
    test_search_scheme_edit(seqan3::detail::optimum_search_scheme<0, 1>, seed, 10);
    test_search_scheme_edit(seqan3::detail::optimum_search_scheme<0, 2>, seed, 10);
    test_search_scheme_edit(seqan3::detail::optimum_search_scheme<0, 3>, seed, 10);
}