File: reweighted.cc

package info (click to toggle)
sopt 2.0.0-2
  • links: PTS, VCS
  • area: main
  • in suites: stretch
  • size: 3,932 kB
  • ctags: 1,162
  • sloc: cpp: 7,220; php: 287; python: 57; ansic: 33; makefile: 5
file content (116 lines) | stat: -rw-r--r-- 3,888 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#include <catch.hpp>
#include <random>

#include <Eigen/Dense>

#include "sopt/imaging_padmm.h"
#include "sopt/reweighted.h"

using namespace sopt;

//! \brief Minimum set of functions and typedefs needed by reweighting
//! \details The attributes are public and static so we can access them during the tests.
struct DummyAlgorithm {
  typedef t_real Scalar;
  typedef Vector<Scalar> t_Vector;
  typedef ConvergenceFunction<Scalar> t_IsConverged;

  struct DiagnosticAndResult {
    //! Expected by reweighted algorithm
    static t_Vector x;
  };

  DiagnosticAndResult operator()(t_Vector const &x) const {
    ++called_with_x;
    DiagnosticAndResult::x = x.array() + 0.1;
    return {};
  }
  DiagnosticAndResult operator()(DiagnosticAndResult const &warm) const {
    ++called_with_warm;
    DiagnosticAndResult::x = warm.x.array() + 0.1;
    return {};
  }

  //! Applies Ψ^T * x
  static t_Vector reweightee(DummyAlgorithm const &, t_Vector const &x) {
    ++DummyAlgorithm::called_reweightee;
    return x * 2;
  }
  //! sets the weights
  static void set_weights(DummyAlgorithm &, t_Vector const &weights) {
    ++DummyAlgorithm::called_weights;
    DummyAlgorithm::weights = weights;
  }

  static t_Vector weights;
  static int called_with_x;
  static int called_with_warm;
  static int called_reweightee;
  static int called_weights;
};

int DummyAlgorithm::called_with_x = 0;
int DummyAlgorithm::called_with_warm = 0;
int DummyAlgorithm::called_reweightee = 0;
int DummyAlgorithm::called_weights = 0;
DummyAlgorithm::t_Vector DummyAlgorithm::DiagnosticAndResult::x;
DummyAlgorithm::t_Vector DummyAlgorithm::weights;

TEST_CASE("L0-Approximation") {

  auto const N = 6;
  DummyAlgorithm::t_Vector const input = DummyAlgorithm::t_Vector::Random(N);

  auto l0algo = algorithm::reweighted(DummyAlgorithm(), DummyAlgorithm::set_weights,
                                      DummyAlgorithm::reweightee);

  DummyAlgorithm::called_with_x = 0;
  DummyAlgorithm::called_with_warm = 0;
  DummyAlgorithm::called_reweightee = 0;
  DummyAlgorithm::called_weights = 0;
  DummyAlgorithm::DiagnosticAndResult::x = DummyAlgorithm::t_Vector::Zero(0);
  DummyAlgorithm::weights = DummyAlgorithm::t_Vector::Zero(0);

  GIVEN("The maximum number of iteration is zero") {
    l0algo.itermax(0);
    WHEN("The reweighting algorithm is called") {
      auto const result = l0algo(input);
      THEN("The algorithm exited at the first iteration") {
        CHECK(result.niters == 0);
        CHECK(result.good == true);
      }
      THEN("The weights is set to 1") {
        CHECK(result.weights.size() == 1);
        CHECK(std::abs(result.weights(0) - 1) < 1e-12);
      }
      THEN("The inner algorithm was called once") {
        CHECK(DummyAlgorithm::called_with_x == 1);
        CHECK(DummyAlgorithm::called_with_warm == 0);
        CHECK(result.algo.x.array().isApprox(input.array() + 0.1));
      }
    }
  }

  GIVEN("The maximum number of iterations is one") {
    l0algo.itermax(1);
    WHEN("The reweighting algorithm is called") {
      auto const result = l0algo(input);
      THEN("The algorithm exited at the second iteration") {
        CHECK(result.niters == 1);
        CHECK(result.good == true);
      }
      THEN("The weights are not one") {
        CHECK(result.weights.size() == input.size());
        // standard deviation of Ψ^T x, with x the output of the first call to the inner algorithm
        Vector<> const PsiT_x = DummyAlgorithm::reweightee({}, input.array() + 0.1);
        auto delta = standard_deviation(PsiT_x);
        CHECK(result.weights.array().isApprox(delta / (delta + PsiT_x.array().abs())));
      }
      THEN("The inner algorithm was called twice") {
        CHECK(DummyAlgorithm::called_with_x == 1);
        CHECK(DummyAlgorithm::called_with_warm == 1);
        CHECK(result.algo.x.array().isApprox(input.array() + 0.2));
      }
    }
  }
}