File: wavelets.cc

package info (click to toggle)
sopt 5.0.1%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 6,704 kB
  • sloc: cpp: 13,620; xml: 182; makefile: 6
file content (322 lines) | stat: -rw-r--r-- 12,205 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
#include <catch2/catch_all.hpp>
#include <memory>
#include <random>

#include "sopt/types.h"
#include "sopt/wavelets/direct.h"
#include "sopt/wavelets/indirect.h"
#include "sopt/wavelets/wavelet_data.h"
#include "sopt/wavelets/wavelets.h"

using t_iVector = sopt::Array<sopt::t_uint>;
t_iVector even(t_iVector const &x) {
  t_iVector result((x.size() + 1) / 2);
  for (t_iVector::Index i(0); i < x.size(); i += 2) result(i / 2) = x(i);
  return result;
};
t_iVector odd(t_iVector const &x) {
  t_iVector result(x.size() / 2);
  for (t_iVector::Index i(1); i < x.size(); i += 2) result(i / 2) = x(i);
  return result;
};
template <typename T>
Eigen::Array<typename T::Scalar, T::RowsAtCompileTime, T::ColsAtCompileTime> upsample(
    Eigen::ArrayBase<T> const &input) {
  using Matrix = Eigen::Array<typename T::Scalar, T::RowsAtCompileTime, T::ColsAtCompileTime>;
  Matrix result(input.size() * 2);
  for (t_iVector::Index i(0); i < input.size(); ++i) {
    result(2 * i) = input(i);
    result(2 * i + 1) = 0;
  }
  return result;
};

sopt::t_int random_integer(sopt::t_int min, sopt::t_int max) {
  extern std::unique_ptr<std::mt19937_64> mersenne;
  std::uniform_int_distribution<sopt::t_int> uniform_dist(min, max);
  return uniform_dist(*mersenne);
};
t_iVector random_ivector(sopt::t_int size, sopt::t_int min, sopt::t_int max) {
  extern std::unique_ptr<std::mt19937_64> mersenne;
  t_iVector result(size);
  std::uniform_int_distribution<sopt::t_int> uniform_dist(min, max);
  for (t_iVector::Index i(0); i < result.size(); ++i) result(i) = uniform_dist(*mersenne);
  return result;
};

// Checks round trip operation
template <typename T0>
void check_round_trip(Eigen::ArrayBase<T0> const &input_, sopt::t_uint db,
                      sopt::t_uint nlevels = 1) {
  auto const input = input_.eval();
  auto const &dbwave = sopt::wavelets::daubechies_data(db);
  auto const transform = sopt::wavelets::direct_transform(input, nlevels, dbwave);
  auto const actual = sopt::wavelets::indirect_transform(transform, nlevels, dbwave);
  CAPTURE(actual);
  CAPTURE(input);
  CAPTURE(transform);
  CHECK(input.isApprox(actual, 1e-14));
  CHECK(not transform.isApprox(sopt::wavelets::direct_transform(input, nlevels - 1, dbwave), 1e-4));
}

TEST_CASE("wavelet data") {
  for (sopt::t_int num = 1; num < 100; num++) {
    if (num < 39)
      REQUIRE(sopt::wavelets::daubechies_data(num).coefficients.size() == 2 * num);
    else
      REQUIRE_THROWS(sopt::wavelets::daubechies_data(num));
  }
}

TEST_CASE("Wavelet transform innards with integer data", "[wavelet]") {
  using namespace sopt::wavelets;

  t_iVector small(3);
  small << 1, 2, 3;
  t_iVector large(6);
  large << 4, 5, 6, 7, 8, 9;

  SECTION("Periodic scalar product") {
    // no wrapping
    CHECK(periodic_scalar_product(large, small, 0) == 1 * 4 + 2 * 5 + 3 * 6);
    CHECK(periodic_scalar_product(large, small, 1) == 1 * 5 + 2 * 6 + 3 * 7);
    CHECK(periodic_scalar_product(large, small, 3) == 1 * 7 + 2 * 8 + 3 * 9);

    // with wrapping
    CHECK(periodic_scalar_product(large, small, 4) == 1 * 8 + 2 * 9 + 3 * 4);
    // with wrapping and expression
    CHECK(periodic_scalar_product(large, small.reverse(), 4) == 3 * 8 + 2 * 9 + 1 * 4);
    // wrapping works with offset as well
    CHECK(periodic_scalar_product(large, small, 4 + large.size()) == 1 * 8 + 2 * 9 + 3 * 4);
    CHECK(periodic_scalar_product(large, small, 4 - 3 * large.size()) == 1 * 8 + 2 * 9 + 3 * 4);

    // signal smaller than filter
    CHECK(periodic_scalar_product(small, large.head(4), 1) == 4 * 2 + 5 * 3 + 6 * 1 + 7 * 2);
  }

  SECTION("Convolve") {
    t_iVector result(large.size());

    convolve(result, large, small);

    CHECK(result(0) == 1 * 4 + 2 * 5 + 3 * 6);
    CHECK(result(1) == 1 * 5 + 2 * 6 + 3 * 7);
    CHECK(result(3) == 1 * 7 + 2 * 8 + 3 * 9);
    CHECK(result(4) == 1 * 8 + 2 * 9 + 3 * 4);
  }

  SECTION("Convolve and sum") {
    t_iVector result(large.size());
    t_iVector noOffset(large.size());

    // Check that if high pass is zero, then this is an offseted convolution
    convolve_sum(result, large, small, large, 0 * small);
    convolve(noOffset, large, small);
    CHECK(result(small.size() - 1) == noOffset(0));
    CHECK(result(0) == noOffset(result.size() - small.size() + 1));

    // Check same for low pass
    convolve_sum(result, large, 0 * small, large, small);
    CHECK(result(small.size() - 1) == noOffset(0));
    CHECK(result(0) == noOffset(result.size() - small.size() + 1));

    // Check symmetry relationships
    auto const trial = [&small, &large](int a, int b, int c, int d) {
      t_iVector result(large.size());
      convolve_sum(result, a * large, b * small, c * large, d * small);
      return result;
    };

    // should all be ok as long as arguments sum: (a * b) + (c * d) == (a' * b') + (c' * d')
    CHECK((trial(0, 1, 3, 1) == trial(0, 1, 1, 3)).all());
    CHECK((trial(5, 1, 3, 1) == trial(3, 1, 5, 1)).all());
    CHECK((trial(1, 5, 3, 1) == trial(3, 1, 5, 1)).all());
    CHECK((trial(1, 3, 5, 1) == trial(3, 1, 5, 1)).all());
    CHECK((trial(1, 3, 1, 5) == trial(3, 1, 5, 1)).all());
    CHECK((trial(1, 0, 4, 2) == trial(3, 1, 5, 1)).all());
    CHECK((trial(1, -1, 1, 1) == trial(0, 1, 0, 1)).all());
    CHECK((trial(4, -3, 2, 6) == trial(0, 1, 0, 1)).all());
  }

  SECTION("Convolve and Down-sample simultaneously") {
    t_iVector expected(large.size());
    convolve(expected, large, small);
    t_iVector actual(large.size() / 2);
    down_convolve(actual, large, small);
    for (size_t i(0); i < static_cast<size_t>(actual.size()); ++i)
      CHECK(expected(i * 2) == actual(i));
  }

  SECTION("Convolve output to expression") {
    t_iVector actual(large.size() * 2);
    t_iVector expected(large.size());
    convolve(actual.head(large.size()), large, small);
    convolve(expected, large, small);
    CHECK((actual.head(large.size()) == expected).all());
  }

  SECTION("Copy does copy") {
    auto result = copy(large);
    CHECK(large.data() != result.data());

    auto actual = copy(large.head(3));
    CHECK(large.data() != actual.data());
    CHECK(large.data() == large.head(3).data());
  }

  SECTION("Convolve, Sum and Up-sample simultaneously") {
    for (sopt::t_int i(0); i < 100; ++i) {
      auto const Ncoeffs = random_integer(2, 10) * 2;
      auto const Nfilters = random_integer(2, 5);
      auto const Nhead = Ncoeffs / 2;
      auto const Ntail = Ncoeffs - Nhead;

      auto const coeffs = random_ivector(Ncoeffs, -10, 10);
      auto const low = random_ivector(Nfilters, -10, 10);
      auto const high = random_ivector(Nfilters, -10, 10);

      t_iVector actual(Ncoeffs);
      t_iVector expected(Ncoeffs);
      // does all in go, more complicated but compuationally less intensive
      up_convolve_sum(actual, coeffs, even(low), odd(low), even(high), odd(high));
      // first up-samples, then does convolve: conceptually simpler but does unnecessary operations
      convolve_sum(expected, upsample(coeffs.head(Nhead)), low, upsample(coeffs.tail(Ntail)), high);
      CHECK((actual == expected).all());
    }
  }
}

TEST_CASE("1D wavelet transform with floating point data", "[wavelet]") {
  using namespace sopt;
  using namespace sopt::wavelets;

  Image<> const data = Image<>::Random(16, 16);
  auto const &wavelet = daubechies_data(4);

  // Condition on input fixture data
  REQUIRE((data.rows() % 2 == 0 and (data.cols() == 1 or data.cols() % 2 == 0)));

  SECTION("Direct transform == two downsample + convolution") {
    auto const actual = direct_transform(data.row(0), 1, wavelet);
    Array<> high(data.cols() / 2);
    Array<> low(data.cols() / 2);
    down_convolve(high, data.row(0), wavelet.direct_filter.high);
    down_convolve(low, data.row(0), wavelet.direct_filter.low);
    CHECK(low.transpose().isApprox(actual.head(data.row(0).size() / 2)));
    CHECK(high.transpose().isApprox(actual.tail(data.row(0).size() / 2)));
  }

  SECTION("Indirect transform == two upsample + convolution") {
    auto const actual = indirect_transform(data.row(0).transpose(), 1, wavelet);
    auto const low = upsample(data.row(0).transpose().head(data.rows() / 2));
    auto const high = upsample(data.row(0).transpose().tail(data.rows() / 2));
    auto expected = copy(data.row(0).transpose());
    convolve_sum(expected, low, wavelet.direct_filter.low.reverse(), high,
                 wavelet.direct_filter.high.reverse());
    CAPTURE(expected.transpose());
    CAPTURE(actual.transpose());
    CHECK(expected.isApprox(actual));
  }

  SECTION("Round-trip test for single level") {
    for (t_int i(0); i < 20; ++i) {
      check_round_trip(Array<>::Random(random_integer(2, 100) * 2), random_integer(1, 38), 1);
    }
  }

  SECTION("Round-trip test for two levels") {
    check_round_trip(Array<>::Random(8), 1, 2);
    check_round_trip(Array<>::Random(8), 2, 2);
    check_round_trip(Array<>::Random(16), 4, 2);
    check_round_trip(Array<>::Random(52), 10, 2);
  }

  t_uint constexpr nlevels = 5;
  SECTION("Round-trip test for multiple levels") {
    for (t_int i(0); i < 10; ++i) {
      auto const n = random_integer(2, nlevels);
      check_round_trip(Array<>::Random(random_integer(2, 100) * (1u << n)), random_integer(1, 38),
                       n);
    }
  }
}

TEST_CASE("1D wavelet transform with complex data", "[wavelet]") {
  using namespace sopt;
  using namespace sopt::wavelets;
  SECTION("Round-trip test for complex data") {
    auto input = Array<t_complex>::Random(random_integer(2, 100) * 2).eval();
    auto const &dbwave = daubechies_data(random_integer(1, 38));
    auto const actual = indirect_transform(direct_transform(input, 1, dbwave), 1, dbwave);
    CHECK(input.isApprox(actual, 1e-14));
    CHECK(not input.isApprox(direct_transform(input, 1, dbwave), 1e-4));
  }
}

TEST_CASE("2D wavelet transform with real data", "[wavelet]") {
  using namespace sopt;
  using namespace sopt::wavelets;
  SECTION("Single level round-trip test for square matrix") {
    auto N = random_integer(2, 100) * 2;
    check_round_trip(Image<>::Random(N, N), random_integer(1, 38), 1);
  }
  SECTION("Single level round-trip test for non-square matrix") {
    auto Nx = random_integer(2, 5) * 2;
    auto Ny = Nx + 5 * 2;
    check_round_trip(Image<>::Random(Nx, Ny), random_integer(1, 38), 1);
  }
  SECTION("Round-trip test for multiple levels") {
    for (t_int i(0); i < 10; ++i) {
      auto const n = random_integer(2, 5);
      auto const Nx = random_integer(2, 5) * (1u << n);
      auto const Ny = random_integer(2, 5) * (1u << n);
      check_round_trip(Image<>::Random(Nx, Ny), random_integer(1, 38), n);
    }
  }
}

TEST_CASE("Functor implementation", "[wavelet]") {
  using namespace sopt;
  auto const wavelet = wavelets::factory("DB3", 4);
  auto const input = Image<t_complex>::Random(256, 128).eval();
  SECTION("Normal instances") {
    auto const transform = wavelet.direct(input);
    CHECK(transform.isApprox(wavelets::direct_transform(input, wavelet.levels(), wavelet)));
    CHECK(input.isApprox(wavelet.indirect(transform)));
  }
  SECTION("Expression instances") {
    Image<t_complex> output(2, input.cols());
    wavelet.direct(output.row(0).transpose(), input.row(0).transpose());
    wavelet.indirect(output.row(0).transpose(), output.row(1).transpose());
    CHECK(input.row(0).isApprox(output.row(1)));
  }
}

TEST_CASE("Automatic input resizing", "[wavelet]") {
  using namespace sopt;
  auto const wavelet = wavelets::factory("DB3", 4);
  auto const input = Image<t_complex>::Random(256, 128).eval();
  Image<t_complex> output(1, 1);
  wavelet.direct(output, input);
  CHECK(output.rows() == input.rows());
  CHECK(output.cols() == input.cols());

  output.resize(1, 1);
  wavelet.indirect(input, output);
  CHECK(output.rows() == input.rows());
  CHECK(output.cols() == input.cols());
}

TEST_CASE("Dirac wavelets") {
  using namespace sopt;
  auto const wavelet = wavelets::factory("Dirac");
  Image<t_complex> const input = Image<t_complex>::Random(256, 128);
  Image<t_complex> output(1, 1);

  wavelet.direct(output, input);
  CHECK(output.isApprox(input));

  output = Image<t_complex>::Zero(1, 1);
  wavelet.indirect(input, output);
  CHECK(output.isApprox(input));
}