1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
|
#include <iostream>
#include <thread>
#include <chrono>
#include <random>
#include <limits>
#include <time.h>
#include <assert.h>
#include "IntervalTree.h"
#define CATCH_CONFIG_RUNNER // Mark this as file as the test-runner for catch
#include "catch.hpp" // Include the catch unit test framework
using namespace std;
typedef IntervalTree<std::size_t, bool> intervalTree;
typedef intervalTree::interval interval;
typedef intervalTree::interval_vector intervalVector;
TEST_CASE( "Empty tree" ) {
IntervalTree<std::size_t, int> t;
REQUIRE( t.findOverlapping(-1,1).size() == 0 );
}
TEST_CASE( "Singleton tree" ) {
IntervalTree<std::size_t, double> t{ {{1,3,5.5}},
1, 64, 512};
SECTION ("Point query on left") {
auto v = t.findOverlapping(1,1);
REQUIRE( v.size() == 1);
REQUIRE( v.front().start == 1 );
REQUIRE( v.front().stop == 3 );
REQUIRE( v.front().value == 5.5 );
}
SECTION ("Wild search values") {
typedef IntervalTree<double, std::size_t> IT;
IT t { {{0.0, 1.0, 0}} };
const auto inf = std::numeric_limits<double>::infinity();
const auto nan = std::numeric_limits<double>::quiet_NaN();
auto sanityResults = t.findOverlapping(inf, inf);
assert(sanityResults.size() == 0);
sanityResults = t.findOverlapping(-inf, inf);
assert(sanityResults.size() == 1);
sanityResults = t.findOverlapping(0, inf);
assert(sanityResults.size() == 1);
sanityResults = t.findOverlapping(0.5, inf);
assert(sanityResults.size() == 1);
sanityResults = t.findOverlapping(1.1, inf);
assert(sanityResults.size() == 0);
sanityResults = t.findOverlapping(-inf, 1.0);
assert(sanityResults.size() == 1);
sanityResults = t.findOverlapping(-inf, 0.5);
assert(sanityResults.size() == 1);
sanityResults = t.findOverlapping(-inf, 0.0);
assert(sanityResults.size() == 1);
sanityResults = t.findOverlapping(-inf, -0.1);
assert(sanityResults.size() == 0);
sanityResults = t.findOverlapping(nan, nan);
assert(sanityResults.size() == 0);
sanityResults = t.findOverlapping(-nan, nan);
assert(sanityResults.size() == 0);
sanityResults = t.findOverlapping(nan, 1);
assert(sanityResults.size() == 0);
sanityResults = t.findOverlapping(0, nan);
assert(sanityResults.size() == 0);
}
SECTION ("Point query in middle") {
auto v = t.findOverlapping(2,2);
REQUIRE( v.size() == 1);
REQUIRE( v.front().start == 1 );
REQUIRE( v.front().stop == 3 );
REQUIRE( v.front().value == 5.5 );
}
SECTION ("Point query on right") {
auto v = t.findOverlapping(3,3);
REQUIRE( v.size() == 1);
REQUIRE( v.front().start == 1 );
REQUIRE( v.front().stop == 3 );
REQUIRE( v.front().value == 5.5 );
}
SECTION ("Non-overlapping queries") {
REQUIRE( t.findOverlapping(4,4).size() == 0);
REQUIRE( t.findOverlapping(0,0).size() == 0);
}
}
TEST_CASE( "Two identical intervals with different contents" ) {
IntervalTree<std::size_t, double> t{{{5,10,10.5},{5,10,5.5}}};
auto v = t.findOverlapping(6,6);
REQUIRE( v.size() == 2);
REQUIRE( v.front().start == 5 );
REQUIRE( v.front().stop == 10 );
REQUIRE( v.back().start == 5 );
REQUIRE( v.back().stop == 10 );
set<double> expected{5.5, 10.5};
set<double> actual{v.front().value, v.back().value};
REQUIRE( actual == expected);
}
template<typename Scalar>
Scalar randKey(Scalar floor, Scalar ceiling) {
Scalar range = ceiling - floor;
return floor + range * ((double) rand() / (double) (RAND_MAX + 1.0));
}
template<class Scalar, typename Value>
Interval<Scalar, Value> randomInterval(Scalar maxStart, Scalar maxLength, Scalar maxStop,
const Value& value) {
Scalar start = randKey<Scalar>(0, maxStart);
Scalar stop = min<Scalar>(randKey<Scalar>(start, start + maxLength), maxStop);
return Interval<Scalar, Value>(start, stop, value);
}
int main(int argc, char**argv) {
typedef vector<std::size_t> countsVector;
// a simple sanity check
typedef IntervalTree<int, bool> ITree;
ITree::interval_vector sanityIntervals;
sanityIntervals.push_back(ITree::interval(60, 80, true));
sanityIntervals.push_back(ITree::interval(20, 40, true));
ITree sanityTree(std::move(sanityIntervals), 16, 1);
ITree::interval_vector sanityResults;
sanityResults = sanityTree.findOverlapping(30, 50);
assert(sanityResults.size() == 1);
sanityResults = sanityTree.findContained(15, 45);
assert(sanityResults.size() == 1);
srand((unsigned)time(NULL));
ITree::interval_vector intervals;
ITree::interval_vector queries;
// generate a test set of target intervals
for (int i = 0; i < 10000; ++i) {
intervals.push_back(randomInterval<int, bool>(100000, 1000, 100000 + 1, true));
}
// and queries
for (int i = 0; i < 5000; ++i) {
queries.push_back(randomInterval<int, bool>(100000, 1000, 100000 + 1, true));
}
typedef chrono::high_resolution_clock Clock;
typedef chrono::milliseconds milliseconds;
// using brute-force search
countsVector bruteforcecounts;
Clock::time_point t0 = Clock::now();
for (auto q = queries.begin(); q != queries.end(); ++q) {
ITree::interval_vector results;
for (auto i = intervals.begin(); i != intervals.end(); ++i) {
if (i->start >= q->start && i->stop <= q->stop) {
results.push_back(*i);
}
}
bruteforcecounts.push_back(results.size());
}
Clock::time_point t1 = Clock::now();
milliseconds ms = chrono::duration_cast<milliseconds>(t1 - t0);
cout << "brute force:\t" << ms.count() << "ms" << endl;
// using the interval tree
cout << intervals[0];
ITree tree = ITree(std::move(intervals), 16, 1);
countsVector treecounts;
t0 = Clock::now();
for (auto q = queries.begin(); q != queries.end(); ++q) {
auto results = tree.findContained(q->start, q->stop);
treecounts.push_back(results.size());
}
t1 = Clock::now();
ms = std::chrono::duration_cast<milliseconds>(t1 - t0);
cout << "interval tree:\t" << ms.count() << "ms" << endl;
// check that the same number of results are returned
countsVector::iterator b = bruteforcecounts.begin();
for (countsVector::iterator t = treecounts.begin(); t != treecounts.end(); ++t, ++b) {
assert(*b == *t);
}
return Catch::Session().run( argc, argv );
}
|