File: sara.cc

package info (click to toggle)
sopt 2.0.0-2
  • links: PTS, VCS
  • area: main
  • in suites: stretch
  • size: 3,932 kB
  • ctags: 1,162
  • sloc: cpp: 7,220; php: 287; python: 57; ansic: 33; makefile: 5
file content (91 lines) | stat: -rw-r--r-- 3,599 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#include <catch.hpp>
#include <random>
#include <string>
#include <tuple>

#include "sopt/wavelets.h"
#include "sopt/wavelets/sara.h"

sopt::t_int random_integer(sopt::t_int min, sopt::t_int max) {
  extern std::unique_ptr<std::mt19937_64> mersenne;
  std::uniform_int_distribution<sopt::t_int> uniform_dist(min, max);
  return uniform_dist(*mersenne);
};

TEST_CASE("Check SARA implementation mechanically", "[wavelet]") {
  using namespace sopt::wavelets;
  using namespace sopt;

  typedef std::tuple<std::string, sopt::t_uint> t_i;
  SARA const sara{t_i{std::string{"DB3"}, 1u}, t_i{std::string{"DB1"}, 2u},
                  t_i{std::string{"DB1"}, 3u}};
  SECTION("Construction and vector functionality") {
    CHECK(sara.size() == 3);
    CHECK(sara[0].levels() == 1);
    CHECK(sara[1].levels() == 2);
    CHECK(sara[2].levels() == 3);
    CHECK(sara.max_levels() == 3);
    CHECK(sara[0].coefficients.isApprox(factory("DB3", 1).coefficients));
    CHECK(sara[1].coefficients.isApprox(factory("DB1", 1).coefficients));
    CHECK(sara[2].coefficients.isApprox(factory("DB1", 1).coefficients));
  }

  Image<> input = Image<>::Random((1u << sara.max_levels()) * 3, (1u << sara.max_levels()));
  Image<> coeffs;
  sara.direct(coeffs, input);

  SECTION("Direct transform") {
    Image<> const first = sara[0].direct(input) / std::sqrt(sara.size());
    Image<> const second = sara[1].direct(input) / std::sqrt(sara.size());
    Image<> const third = sara[2].direct(input) / std::sqrt(sara.size());

    auto const N = input.cols();
    CAPTURE(coeffs.leftCols(N));
    CAPTURE(first);
    CHECK(coeffs.leftCols(N).isApprox(first));
    CHECK(coeffs.leftCols(2 * N).rightCols(N).isApprox(second));
    CHECK(coeffs.rightCols(N).isApprox(third));
  }

  SECTION("Indirect transform") {
    auto const output = sara.indirect(coeffs);
    CHECK(output.isApprox(input));
  }
}

TEST_CASE("Linear-transform wrapper", "[wavelet]") {
  using namespace sopt::wavelets;
  using namespace sopt;
  SARA const sara{std::make_tuple(std::string{"DB3"}, 1u), std::make_tuple(std::string{"DB1"}, 2u),
                  std::make_tuple(std::string{"DB1"}, 3u)};

  auto const rows = 256, cols = 256;
  auto const Psi = linear_transform<t_real>(sara, rows, cols);
  SECTION("Indirect transform") {
    Image<> const image = Image<>::Random(rows, cols);
    Image<> const expected = sara.direct(image);
    // The linear transform expects a column vector as input
    auto const as_vector = Vector<>::Map(image.data(), image.size());
    // And it returns a column vector as well
    Vector<> const actual = Psi.adjoint() * as_vector;
    CHECK(actual.size() == expected.size());
    auto const coeffs = Image<>::Map(actual.data(), image.rows(), image.cols() * sara.size());
    CHECK(expected.rows() == coeffs.rows());
    CHECK(expected.cols() == coeffs.cols());
    CHECK(coeffs.isApprox(expected, 1e-8));
  }
  SECTION("direct transform") {
    Image<> const coeffs = Image<>::Random(rows, cols * sara.size());
    Image<> const expected = sara.indirect(coeffs);
    // The linear transform expects a column vector as input
    auto const as_vector = Vector<>::Map(coeffs.data(), coeffs.size());
    // And it returns a column vector as well
    Vector<> const actual = Psi * as_vector;
    CHECK(actual.size() == expected.size());
    CHECK(coeffs.cols() % sara.size() == 0);
    auto const image = Image<>::Map(actual.data(), coeffs.rows(), coeffs.cols() / sara.size());
    CHECK(expected.rows() == image.rows());
    CHECK(expected.cols() == image.cols());
    CHECK(image.isApprox(expected, 1e-8));
  }
}