1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
|
#include <catch.hpp>
#include <random>
#include <Eigen/Dense>
#include "sopt/imaging_padmm.h"
#include "sopt/reweighted.h"
using namespace sopt;
//! \brief Minimum set of functions and typedefs needed by reweighting
//! \details The attributes are public and static so we can access them during the tests.
struct DummyAlgorithm {
typedef t_real Scalar;
typedef Vector<Scalar> t_Vector;
typedef ConvergenceFunction<Scalar> t_IsConverged;
struct DiagnosticAndResult {
//! Expected by reweighted algorithm
static t_Vector x;
};
DiagnosticAndResult operator()(t_Vector const &x) const {
++called_with_x;
DiagnosticAndResult::x = x.array() + 0.1;
return {};
}
DiagnosticAndResult operator()(DiagnosticAndResult const &warm) const {
++called_with_warm;
DiagnosticAndResult::x = warm.x.array() + 0.1;
return {};
}
//! Applies Ψ^T * x
static t_Vector reweightee(DummyAlgorithm const &, t_Vector const &x) {
++DummyAlgorithm::called_reweightee;
return x * 2;
}
//! sets the weights
static void set_weights(DummyAlgorithm &, t_Vector const &weights) {
++DummyAlgorithm::called_weights;
DummyAlgorithm::weights = weights;
}
static t_Vector weights;
static int called_with_x;
static int called_with_warm;
static int called_reweightee;
static int called_weights;
};
int DummyAlgorithm::called_with_x = 0;
int DummyAlgorithm::called_with_warm = 0;
int DummyAlgorithm::called_reweightee = 0;
int DummyAlgorithm::called_weights = 0;
DummyAlgorithm::t_Vector DummyAlgorithm::DiagnosticAndResult::x;
DummyAlgorithm::t_Vector DummyAlgorithm::weights;
TEST_CASE("L0-Approximation") {
auto const N = 6;
DummyAlgorithm::t_Vector const input = DummyAlgorithm::t_Vector::Random(N);
auto l0algo = algorithm::reweighted(DummyAlgorithm(), DummyAlgorithm::set_weights,
DummyAlgorithm::reweightee);
DummyAlgorithm::called_with_x = 0;
DummyAlgorithm::called_with_warm = 0;
DummyAlgorithm::called_reweightee = 0;
DummyAlgorithm::called_weights = 0;
DummyAlgorithm::DiagnosticAndResult::x = DummyAlgorithm::t_Vector::Zero(0);
DummyAlgorithm::weights = DummyAlgorithm::t_Vector::Zero(0);
GIVEN("The maximum number of iteration is zero") {
l0algo.itermax(0);
WHEN("The reweighting algorithm is called") {
auto const result = l0algo(input);
THEN("The algorithm exited at the first iteration") {
CHECK(result.niters == 0);
CHECK(result.good == true);
}
THEN("The weights is set to 1") {
CHECK(result.weights.size() == 1);
CHECK(std::abs(result.weights(0) - 1) < 1e-12);
}
THEN("The inner algorithm was called once") {
CHECK(DummyAlgorithm::called_with_x == 1);
CHECK(DummyAlgorithm::called_with_warm == 0);
CHECK(result.algo.x.array().isApprox(input.array() + 0.1));
}
}
}
GIVEN("The maximum number of iterations is one") {
l0algo.itermax(1);
WHEN("The reweighting algorithm is called") {
auto const result = l0algo(input);
THEN("The algorithm exited at the second iteration") {
CHECK(result.niters == 1);
CHECK(result.good == true);
}
THEN("The weights are not one") {
CHECK(result.weights.size() == input.size());
// standard deviation of Ψ^T x, with x the output of the first call to the inner algorithm
Vector<> const PsiT_x = DummyAlgorithm::reweightee({}, input.array() + 0.1);
auto delta = standard_deviation(PsiT_x);
CHECK(result.weights.array().isApprox(delta / (delta + PsiT_x.array().abs())));
}
THEN("The inner algorithm was called twice") {
CHECK(DummyAlgorithm::called_with_x == 1);
CHECK(DummyAlgorithm::called_with_warm == 1);
CHECK(result.algo.x.array().isApprox(input.array() + 0.2));
}
}
}
}
|