File: primal_dual.cc

package info (click to toggle)
sopt 4.2.0%2Bdfsg-2
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 6,632 kB
  • sloc: cpp: 13,011; xml: 182; makefile: 6
file content (101 lines) | stat: -rw-r--r-- 3,592 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#include <catch2/catch_all.hpp>
#include <random>

#include <cassert>

#include <Eigen/Dense>

#include "sopt/imaging_primal_dual.h"
#include "sopt/primal_dual.h"
#include "sopt/proximal.h"
#include "sopt/types.h"

using Scalar = sopt::t_real;
using t_Vector = sopt::Vector<Scalar>;
using t_Matrix = sopt::Matrix<Scalar>;
using Catch::Approx;

auto constexpr N = 5;

TEST_CASE("Primal Dual Imaging", "[primaldual]") {
  using namespace sopt;

  t_Matrix const mId = t_Matrix::Identity(N, N);

  t_Vector target = t_Vector::Random(N);

  target = sopt::positive_quadrant(target);

  auto const epsilon = target.stableNorm() / 2;

  auto primaldual = algorithm::ImagingPrimalDual<Scalar>(target)
                        .l1_proximal_weights(t_Vector::Ones(target.size()))
                        .Phi(mId)
                        .Psi(mId)
                        .itermax(5000)
                        .tau(0.1)
                        .gamma(0.4)
                        .l2ball_proximal_epsilon(epsilon)
                        .relative_variation(1e-4)
                        .residual_convergence(epsilon);

  auto const result = primaldual();
  CHECK((result.x - target).stableNorm() <= Approx(epsilon).margin(1e-10));
  CHECK(result.good);
  primaldual
      .l1_proximal([](t_Vector &output, const t_real &gamma, const t_Vector &input) {
        output = gamma * input;
      })
      .l1_proximal_weighted(
          [](t_Vector &output, const Vector<t_real> &gamma, const t_Vector &input) {
            output = 10 * gamma.array() * input.array();
          });
  CHECK_THROWS(primaldual());
}
TEST_CASE("Primal Dual with 0.5 * ||x - x0||_2^2 function", "[primaldual]") {
  using namespace sopt;
  t_Vector const target0 = t_Vector::Random(N);
  auto const f = [](t_Vector &out, const t_real gamma, const t_Vector &x) {
    proximal::id(out, gamma, x);
  };
  auto const g = proximal::L2Norm<Scalar>();
  const t_Vector x_guess = t_Vector::Random(target0.size());
  const t_Vector res = x_guess - target0;
  auto const convergence = [&target0](const t_Vector &x, const t_Vector &res) -> bool {
    return x.isApprox(target0, 1e-9);
  };
  CAPTURE(target0);
  CAPTURE(x_guess);
  CAPTURE(res);
  auto const pd = algorithm::PrimalDual<Scalar>(f, g, target0)
                      .itermax(3000)
                      .gamma(0.9)
                      .rho(0.5)
                      .update_scale(0.5)
                      .is_converged(convergence);
  auto const result = pd(std::make_tuple(x_guess, res));
  CAPTURE(result.niters);
  CAPTURE(result.x);
  CAPTURE(result.residual);
  CHECK(result.x.isApprox(target0, 1e-9));
  CHECK(result.good);
  CHECK(result.niters < 200);
}

template <typename T>
struct is_primal_dual_ref : public std::is_same<sopt::algorithm::ImagingPrimalDual<double> &, T> {};
TEST_CASE("Check type returned on setting variables") {
  // Yeah, could be static asserts
  using namespace sopt;
  using namespace sopt::algorithm;
  ImagingPrimalDual<double> pd(Vector<double>::Zero(0));
  CHECK(is_primal_dual_ref<decltype(pd.itermax(500))>::value);
  CHECK(is_primal_dual_ref<decltype(pd.sigma(1))>::value);
  CHECK(is_primal_dual_ref<decltype(pd.tau(1))>::value);
  CHECK(is_primal_dual_ref<decltype(pd.rho(1))>::value);
  CHECK(is_primal_dual_ref<decltype(pd.xi(1))>::value);
  CHECK(is_primal_dual_ref<decltype(pd.gamma(1e0))>::value);
  CHECK(is_primal_dual_ref<decltype(pd.update_scale(1e0))>::value);
  CHECK(is_primal_dual_ref<decltype(pd.positivity_constraint(true))>::value);
  CHECK(is_primal_dual_ref<decltype(pd.real_constraint(true))>::value);
}