File: forward_backward.cc

package info (click to toggle)
sopt 5.0.1%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 6,704 kB
  • sloc: cpp: 13,620; xml: 182; makefile: 6
file content (101 lines) | stat: -rw-r--r-- 4,764 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#include <catch2/catch_all.hpp>
#include <random>
#include <vector>

#include <Eigen/Dense>

#include "sopt/imaging_forward_backward.h"
#include "sopt/l1_non_diff_function.h"
#include "sopt/logging.h"
#include "sopt/maths.h"
#include "sopt/proximal.h"
#include "sopt/types.h"

// This header is not part of the installed sopt interface
// It is only present in tests
#include "tools_for_tests/directories.h"
#include "tools_for_tests/tiffwrappers.h"

sopt::t_int random_integer(sopt::t_int min, sopt::t_int max) {
  extern std::unique_ptr<std::mt19937_64> mersenne;
  std::uniform_int_distribution<sopt::t_int> uniform_dist(min, max);
  return uniform_dist(*mersenne);
};

using Scalar = sopt::t_real;
using t_Vector = sopt::Vector<Scalar>;
using t_LinearTransform = sopt::LinearTransform<t_Vector>;
using t_real = sopt::t_real;
auto constexpr N = 5;

TEST_CASE("Forward Backward with ||x - x0||_2^2 function", "[fb]") {
  using namespace sopt;
  t_Vector const target0 = t_Vector::Random(N);
  t_real constexpr beta = 0.2;
  t_real constexpr regulariser_strength = 0.1;
  int constexpr itermax = 300;
  auto const g0 = [](t_Vector &out, const t_real regulariser_strength, const t_Vector &x) {
    proximal::id(out, regulariser_strength, x);
  };
  auto const grad = [](t_Vector &out, const t_Vector image, const t_Vector &res,
                       const t_LinearTransform &Phi) { out = Phi.adjoint() * res; };
  const t_Vector x_guess = t_Vector::Random(target0.size());
  const t_Vector res = x_guess - target0;
  auto const convergence = [&target0](const t_Vector &x, const t_Vector &res) -> bool {
    return x.isApprox(target0, 1e-9);
  };
  CAPTURE(target0);
  CAPTURE(x_guess);
  CAPTURE(res);
  auto fb = algorithm::ForwardBackward<Scalar>(grad, g0, target0)
                      .itermax(itermax)
                      .regulariser_strength(regulariser_strength)
                      .step_size(beta)
                      .is_converged(convergence);
  auto const result = fb(std::make_tuple(x_guess, res));
  CAPTURE(result.niters);
  CAPTURE(result.x);
  CAPTURE(result.residual);
  CHECK(result.x.isApprox(target0, 1e-9));
  CHECK(result.good);
  CHECK(result.niters < itermax);
}

template <typename T> struct is_imaging_proximal_ref
    : public std::is_same<sopt::algorithm::ImagingForwardBackward<double> &, T> {};
template <typename T> struct is_l1_g_proximal_ref
    : public std::is_same<sopt::algorithm::L1GProximal<double> &, T> {};

TEST_CASE("Check type returned on setting variables") {
  // Yeah, could be static asserts
  using namespace sopt;
  using namespace sopt::algorithm;
  ImagingForwardBackward<double> fb(Vector<double>::Zero(0).eval());
  CHECK(is_imaging_proximal_ref<decltype(fb.itermax(500))>::value);
  CHECK(is_imaging_proximal_ref<decltype(fb.step_size(1e-1))>::value);
  CHECK(is_imaging_proximal_ref<decltype(fb.regulariser_strength(1e-1))>::value);
  CHECK(is_imaging_proximal_ref<decltype(fb.sigma(1e-1))>::value);
  CHECK(is_imaging_proximal_ref<decltype(fb.residual_convergence(1.001))>::value);
  CHECK(is_imaging_proximal_ref<decltype(fb.target(Vector<double>::Zero(0)))>::value);
  using ConvFunc = ConvergenceFunction<double>;
  CHECK(is_imaging_proximal_ref<decltype(fb.is_converged(std::declval<ConvFunc>()))>::value);
  CHECK(is_imaging_proximal_ref<decltype(fb.is_converged(std::declval<ConvFunc &>()))>::value);
  CHECK(is_imaging_proximal_ref<decltype(fb.is_converged(std::declval<ConvFunc &&>()))>::value);
  CHECK(is_imaging_proximal_ref<decltype(fb.is_converged(std::declval<ConvFunc const &>()))>::value);
  CHECK(is_imaging_proximal_ref<decltype(fb.relative_variation(5e-4))>::value);
  CHECK(is_imaging_proximal_ref<decltype(fb.tight_frame(false))>::value);

  // Test the types of the l1 g_proximal object separately
  auto gp = std::make_shared<sopt::algorithm::L1GProximal<Scalar>>(false);
  CHECK(is_l1_g_proximal_ref<decltype(gp->l1_proximal_tolerance(1e-2))>::value);
  CHECK(is_l1_g_proximal_ref<decltype(gp->l1_proximal_nu(1))>::value);
  CHECK(is_l1_g_proximal_ref<decltype(gp->l1_proximal_itermax(50))>::value);
  CHECK(is_l1_g_proximal_ref<decltype(gp->l1_proximal_positivity_constraint(true))>::value);
  CHECK(is_l1_g_proximal_ref<decltype(gp->l1_proximal_real_constraint(true))>::value);
  using LinTrans = LinearTransform<Vector<double>>;
  CHECK(is_l1_g_proximal_ref<decltype(gp->Psi(linear_transform_identity<double>()))>::value);
  CHECK(is_l1_g_proximal_ref<decltype(gp->Psi(std::declval<LinTrans>()))>::value);
  CHECK(is_l1_g_proximal_ref<decltype(gp->Psi(std::declval<LinTrans &&>()))>::value);
  CHECK(is_l1_g_proximal_ref<decltype(gp->Psi(std::declval<LinTrans &>()))>::value);
  CHECK(is_l1_g_proximal_ref<decltype(gp->Psi(std::declval<LinTrans const &>()))>::value);
}