File: padmm.cc

package info (click to toggle)
sopt 5.0.1%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 6,704 kB
  • sloc: cpp: 13,620; xml: 182; makefile: 6
file content (86 lines) | stat: -rw-r--r-- 4,436 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#include <catch2/catch_all.hpp>
#include <random>

#include <Eigen/Dense>

#include "sopt/imaging_padmm.h"
#include "sopt/padmm.h"
#include "sopt/proximal.h"
#include "sopt/types.h"

sopt::t_int random_integer(sopt::t_int min, sopt::t_int max) {
  extern std::unique_ptr<std::mt19937_64> mersenne;
  std::uniform_int_distribution<sopt::t_int> uniform_dist(min, max);
  return uniform_dist(*mersenne);
};

using Scalar = sopt::t_real;
using t_Vector = sopt::Vector<Scalar>;
using t_Matrix = sopt::Matrix<Scalar>;

auto constexpr N = 5;

TEST_CASE("Proximal ADMM with ||x - x0||_2 functions", "[padmm][integration]") {
  using namespace sopt;
  t_Vector const target0 = t_Vector::Random(N);
  t_Vector const target1 = t_Vector::Random(N) * 4;
  auto const g0 = proximal::translate(proximal::EuclidianNorm(), -target0);
  auto const g1 = proximal::translate(proximal::EuclidianNorm(), -target1);

  t_Matrix const mId = -t_Matrix::Identity(N, N);

  t_Vector const translation = t_Vector::Ones(N) * 5;
  auto const padmm =
      algorithm::ProximalADMM<Scalar>(g0, g1, t_Vector::Zero(N)).Phi(mId).itermax(3000).regulariser_strength(0.01);
  auto const result = padmm();

  t_Vector const segment = (target1 - target0).normalized();
  t_real const alpha = (result.x - target0).transpose() * segment;

  CHECK((target1 - target0).transpose() * segment >= alpha);
  CHECK(alpha >= 0e0);
  CAPTURE(segment.transpose());
  CAPTURE((result.x - target0).transpose());
  CAPTURE((result.x - target1).transpose());
  CHECK((result.x - target0 - alpha * segment).stableNorm() < 1e-8);
}

template <typename T>
struct is_imaging_proximal_ref
    : public std::is_same<sopt::algorithm::ImagingProximalADMM<double> &, T> {};
TEST_CASE("Check type returned on setting variables") {
  // Yeah, could be static asserts
  using namespace sopt;
  using namespace sopt::algorithm;
  ImagingProximalADMM<double> admm(Vector<double>::Zero(0));
  CHECK(is_imaging_proximal_ref<decltype(admm.itermax(500))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.regulariser_strength(1e-1))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.relative_variation(5e-4))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.l2ball_proximal_epsilon(1e-4))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.tight_frame(false))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.l1_proximal_tolerance(1e-2))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.l1_proximal_nu(1))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.l1_proximal_itermax(50))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.l1_proximal_positivity_constraint(true))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.l1_proximal_real_constraint(true))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.residual_convergence(1.001))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.lagrange_update_scale(0.9))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.target(Vector<double>::Zero(0)))>::value);
  using ConvFunc = ConvergenceFunction<double>;
  CHECK(is_imaging_proximal_ref<decltype(admm.is_converged(std::declval<ConvFunc>()))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.is_converged(std::declval<ConvFunc &>()))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.is_converged(std::declval<ConvFunc &&>()))>::value);
  CHECK(is_imaging_proximal_ref<decltype(
            admm.is_converged(std::declval<ConvFunc const &>()))>::value);
  using LinTrans = LinearTransform<Vector<double>>;
  CHECK(is_imaging_proximal_ref<decltype(admm.Phi(linear_transform_identity<double>()))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.Phi(std::declval<LinTrans>()))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.Phi(std::declval<LinTrans &&>()))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.Phi(std::declval<LinTrans &>()))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.Phi(std::declval<LinTrans const &>()))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.Psi(linear_transform_identity<double>()))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.Psi(std::declval<LinTrans>()))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.Psi(std::declval<LinTrans &&>()))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.Psi(std::declval<LinTrans &>()))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.Psi(std::declval<LinTrans const &>()))>::value);
}