File: padmm.cc

package info (click to toggle)
sopt 2.0.0-2
  • links: PTS, VCS
  • area: main
  • in suites: stretch
  • size: 3,932 kB
  • ctags: 1,162
  • sloc: cpp: 7,220; php: 287; python: 57; ansic: 33; makefile: 5
file content (89 lines) | stat: -rw-r--r-- 4,525 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#include <catch.hpp>
#include <random>

#include <Eigen/Dense>

#include "sopt/imaging_padmm.h"
#include "sopt/padmm.h"
#include "sopt/proximal.h"
#include "sopt/types.h"

sopt::t_int random_integer(sopt::t_int min, sopt::t_int max) {
  extern std::unique_ptr<std::mt19937_64> mersenne;
  std::uniform_int_distribution<sopt::t_int> uniform_dist(min, max);
  return uniform_dist(*mersenne);
};

typedef sopt::t_real Scalar;
typedef sopt::Vector<Scalar> t_Vector;
typedef sopt::Matrix<Scalar> t_Matrix;

auto const N = 5;

TEST_CASE("Proximal ADMM with ||x - x0||_2 functions", "[padmm][integration]") {
  using namespace sopt;
  t_Vector const target0 = t_Vector::Random(N);
  t_Vector const target1 = t_Vector::Random(N) * 4;
  auto const g0 = proximal::translate(proximal::EuclidianNorm(), -target0);
  auto const g1 = proximal::translate(proximal::EuclidianNorm(), -target1);

  t_Matrix const mId = -t_Matrix::Identity(N, N);

  t_Vector const translation = t_Vector::Ones(N) * 5;
  auto const padmm = algorithm::ProximalADMM<Scalar>(g0, g1, t_Vector::Zero(N))
                         .Phi(mId)
                         .itermax(3000)
                         .gamma(0.01);
  auto const result = padmm();

  t_Vector const segment = (target1 - target0).normalized();
  t_real const alpha = (result.x - target0).transpose() * segment;

  CHECK((target1 - target0).transpose() * segment >= alpha);
  CHECK(alpha >= 0e0);
  CAPTURE(segment.transpose());
  CAPTURE((result.x - target0).transpose());
  CAPTURE((result.x - target1).transpose());
  CHECK((result.x - target0 - alpha * segment).stableNorm() < 1e-8);
}

template <class T>
struct is_imaging_proximal_ref
    : public std::is_same<sopt::algorithm::ImagingProximalADMM<double> &, T> {};
TEST_CASE("Check type returned on setting variables") {
  // Yeah, could be static asserts
  using namespace sopt;
  using namespace sopt::algorithm;
  ImagingProximalADMM<double> admm(Vector<double>::Zero(0));
  CHECK(is_imaging_proximal_ref<decltype(admm.itermax(500))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.gamma(1e-1))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.relative_variation(5e-4))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.l2ball_proximal_epsilon(1e-4))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.tight_frame(false))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.l1_proximal_tolerance(1e-2))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.l1_proximal_nu(1))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.l1_proximal_itermax(50))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.l1_proximal_positivity_constraint(true))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.l1_proximal_real_constraint(true))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.residual_convergence(1.001))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.lagrange_update_scale(0.9))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.nu(1e0))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.target(Vector<double>::Zero(0)))>::value);
  typedef ConvergenceFunction<double> ConvFunc;
  CHECK(is_imaging_proximal_ref<decltype(admm.is_converged(std::declval<ConvFunc>()))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.is_converged(std::declval<ConvFunc &>()))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.is_converged(std::declval<ConvFunc &&>()))>::value);
  CHECK(is_imaging_proximal_ref<decltype(
            admm.is_converged(std::declval<ConvFunc const &>()))>::value);
  typedef LinearTransform<Vector<double>> LinTrans;
  CHECK(is_imaging_proximal_ref<decltype(admm.Phi(linear_transform_identity<double>()))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.Phi(std::declval<LinTrans>()))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.Phi(std::declval<LinTrans &&>()))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.Phi(std::declval<LinTrans &>()))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.Phi(std::declval<LinTrans const &>()))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.Psi(linear_transform_identity<double>()))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.Psi(std::declval<LinTrans>()))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.Psi(std::declval<LinTrans &&>()))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.Psi(std::declval<LinTrans &>()))>::value);
  CHECK(is_imaging_proximal_ref<decltype(admm.Psi(std::declval<LinTrans const &>()))>::value);
}