File: sdmm_warm_start.cc

package info (click to toggle)
sopt 4.2.0%2Bdfsg-2
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 6,632 kB
  • sloc: cpp: 13,011; xml: 182; makefile: 6
file content (59 lines) | stat: -rw-r--r-- 2,026 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#include <catch2/catch_all.hpp>
#include <random>

#include <Eigen/Dense>

#include "sopt/proximal.h"
#include "sopt/sdmm.h"
#include "sopt/types.h"

using Scalar = sopt::t_real;
using t_Vector = sopt::Vector<Scalar>;
using t_Matrix = sopt::Matrix<Scalar>;

auto constexpr N = 30;
SCENARIO("SDMM with warm start", "[sdmm][integration]") {
  using namespace sopt;

  GIVEN("An SDMM instance with its input") {
    t_Matrix const Id = t_Matrix::Identity(N, N).eval();
    t_Vector const target0 = t_Vector::Random(N);
    t_Vector target1 = t_Vector::Random(N) * 4;

    auto convergence = [&target1, &target0](t_Vector const &x) -> bool {
      t_Vector const segment = (target1 - target0).normalized();
      t_real const alpha = (x - target0).transpose() * segment;
      return alpha >= 0e0 and (target1 - target0).transpose() * segment >= alpha and
             (x - target0 - alpha * segment).stableNorm() < 1e-8;
    };

    auto sdmm = algorithm::SDMM<Scalar>()
                    .is_converged(convergence)
                    .itermax(5000)
                    .gamma(1)
                    .conjugate_gradient(std::numeric_limits<t_uint>::max(), 1e-12)
                    .append(proximal::translate(proximal::EuclidianNorm(), -target0), Id)
                    .append(proximal::translate(proximal::EuclidianNorm(), -target1), Id);
    t_Vector const input = t_Vector::Random(N);

    WHEN("the algorithms runs") {
      auto const full = sdmm(input);
      THEN("it converges") {
        CHECK(full.niters > 20);
        CHECK(full.good);
      }

      WHEN("It is set to stop before convergence") {
        auto const first_half = sdmm.itermax(full.niters - 5)(input);
        THEN("It is not converged") { CHECK(not first_half.good); }

        WHEN("A warm restart is attempted") {
          auto const second_half = sdmm.itermax(5000)(first_half);
          THEN("The warm restart is validated by the fast convergence") {
            CHECK(second_half.niters < 10);
          }
        }
      }
    }
  }
}