1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
|
#include <catch.hpp>
#include <random>
#include <Eigen/Dense>
#include "sopt/proximal.h"
#include "sopt/sdmm.h"
#include "sopt/types.h"
typedef sopt::t_real Scalar;
typedef sopt::Vector<Scalar> t_Vector;
typedef sopt::Matrix<Scalar> t_Matrix;
auto const N = 30;
SCENARIO("SDMM with warm start", "[sdmm][integration]") {
using namespace sopt;
GIVEN("An SDMM instance with its input") {
t_Matrix const Id = t_Matrix::Identity(N, N).eval();
t_Vector const target0 = t_Vector::Random(N);
t_Vector target1 = t_Vector::Random(N) * 4;
auto convergence = [&target1, &target0](t_Vector const &x) -> bool {
t_Vector const segment = (target1 - target0).normalized();
t_real const alpha = (x - target0).transpose() * segment;
return alpha >= 0e0 and (target1 - target0).transpose() * segment >= alpha
and (x - target0 - alpha * segment).stableNorm() < 1e-8;
};
auto sdmm = algorithm::SDMM<Scalar>()
.is_converged(convergence)
.itermax(5000)
.gamma(1)
.conjugate_gradient(std::numeric_limits<t_uint>::max(), 1e-12)
.append(proximal::translate(proximal::EuclidianNorm(), -target0), Id)
.append(proximal::translate(proximal::EuclidianNorm(), -target1), Id);
t_Vector input = t_Vector::Random(N);
WHEN("the algorithms runs") {
auto const full = sdmm(input);
THEN("it converges") {
CHECK(full.niters > 20);
CHECK(full.good);
}
WHEN("It is set to stop before convergence") {
auto const first_half = sdmm.itermax(full.niters - 5)(input);
THEN("It is not converged") { CHECK(not first_half.good); }
WHEN("A warm restart is attempted") {
auto const second_half = sdmm.itermax(5000)(first_half);
THEN("The warm restart is validated by the fast convergence") {
CHECK(second_half.niters < 10);
}
}
}
}
}
}
|