1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
|
#include <catch2/catch_all.hpp>
#include <random>
#include <vector>
#include <Eigen/Dense>
#include "sopt/imaging_forward_backward.h"
#include "sopt/l1_non_diff_function.h"
#include "sopt/logging.h"
#include "sopt/maths.h"
#include "sopt/proximal.h"
#include "sopt/types.h"
// This header is not part of the installed sopt interface
// It is only present in tests
#include "tools_for_tests/directories.h"
#include "tools_for_tests/tiffwrappers.h"
sopt::t_int random_integer(sopt::t_int min, sopt::t_int max) {
extern std::unique_ptr<std::mt19937_64> mersenne;
std::uniform_int_distribution<sopt::t_int> uniform_dist(min, max);
return uniform_dist(*mersenne);
};
using Scalar = sopt::t_real;
using t_Vector = sopt::Vector<Scalar>;
using t_LinearTransform = sopt::LinearTransform<t_Vector>;
using t_real = sopt::t_real;
auto constexpr N = 5;
TEST_CASE("Forward Backward with ||x - x0||_2^2 function", "[fb]") {
using namespace sopt;
t_Vector const target0 = t_Vector::Random(N);
t_real constexpr beta = 0.2;
t_real constexpr regulariser_strength = 0.1;
int constexpr itermax = 300;
auto const g0 = [](t_Vector &out, const t_real regulariser_strength, const t_Vector &x) {
proximal::id(out, regulariser_strength, x);
};
auto const grad = [](t_Vector &out, const t_Vector image, const t_Vector &res,
const t_LinearTransform &Phi) { out = Phi.adjoint() * res; };
const t_Vector x_guess = t_Vector::Random(target0.size());
const t_Vector res = x_guess - target0;
auto const convergence = [&target0](const t_Vector &x, const t_Vector &res) -> bool {
return x.isApprox(target0, 1e-9);
};
CAPTURE(target0);
CAPTURE(x_guess);
CAPTURE(res);
auto fb = algorithm::ForwardBackward<Scalar>(grad, g0, target0)
.itermax(itermax)
.regulariser_strength(regulariser_strength)
.step_size(beta)
.is_converged(convergence);
auto const result = fb(std::make_tuple(x_guess, res));
CAPTURE(result.niters);
CAPTURE(result.x);
CAPTURE(result.residual);
CHECK(result.x.isApprox(target0, 1e-9));
CHECK(result.good);
CHECK(result.niters < itermax);
}
template <typename T> struct is_imaging_proximal_ref
: public std::is_same<sopt::algorithm::ImagingForwardBackward<double> &, T> {};
template <typename T> struct is_l1_g_proximal_ref
: public std::is_same<sopt::algorithm::L1GProximal<double> &, T> {};
TEST_CASE("Check type returned on setting variables") {
// Yeah, could be static asserts
using namespace sopt;
using namespace sopt::algorithm;
ImagingForwardBackward<double> fb(Vector<double>::Zero(0).eval());
CHECK(is_imaging_proximal_ref<decltype(fb.itermax(500))>::value);
CHECK(is_imaging_proximal_ref<decltype(fb.step_size(1e-1))>::value);
CHECK(is_imaging_proximal_ref<decltype(fb.regulariser_strength(1e-1))>::value);
CHECK(is_imaging_proximal_ref<decltype(fb.sigma(1e-1))>::value);
CHECK(is_imaging_proximal_ref<decltype(fb.residual_convergence(1.001))>::value);
CHECK(is_imaging_proximal_ref<decltype(fb.target(Vector<double>::Zero(0)))>::value);
using ConvFunc = ConvergenceFunction<double>;
CHECK(is_imaging_proximal_ref<decltype(fb.is_converged(std::declval<ConvFunc>()))>::value);
CHECK(is_imaging_proximal_ref<decltype(fb.is_converged(std::declval<ConvFunc &>()))>::value);
CHECK(is_imaging_proximal_ref<decltype(fb.is_converged(std::declval<ConvFunc &&>()))>::value);
CHECK(is_imaging_proximal_ref<decltype(fb.is_converged(std::declval<ConvFunc const &>()))>::value);
CHECK(is_imaging_proximal_ref<decltype(fb.relative_variation(5e-4))>::value);
CHECK(is_imaging_proximal_ref<decltype(fb.tight_frame(false))>::value);
// Test the types of the l1 g_proximal object separately
auto gp = std::make_shared<sopt::algorithm::L1GProximal<Scalar>>(false);
CHECK(is_l1_g_proximal_ref<decltype(gp->l1_proximal_tolerance(1e-2))>::value);
CHECK(is_l1_g_proximal_ref<decltype(gp->l1_proximal_nu(1))>::value);
CHECK(is_l1_g_proximal_ref<decltype(gp->l1_proximal_itermax(50))>::value);
CHECK(is_l1_g_proximal_ref<decltype(gp->l1_proximal_positivity_constraint(true))>::value);
CHECK(is_l1_g_proximal_ref<decltype(gp->l1_proximal_real_constraint(true))>::value);
using LinTrans = LinearTransform<Vector<double>>;
CHECK(is_l1_g_proximal_ref<decltype(gp->Psi(linear_transform_identity<double>()))>::value);
CHECK(is_l1_g_proximal_ref<decltype(gp->Psi(std::declval<LinTrans>()))>::value);
CHECK(is_l1_g_proximal_ref<decltype(gp->Psi(std::declval<LinTrans &&>()))>::value);
CHECK(is_l1_g_proximal_ref<decltype(gp->Psi(std::declval<LinTrans &>()))>::value);
CHECK(is_l1_g_proximal_ref<decltype(gp->Psi(std::declval<LinTrans const &>()))>::value);
}
|