1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
|
#include <random>
#include "catch2/catch_all.hpp"
#include "sopt/conjugate_gradient.h"
TEST_CASE("Conjugate gradient", "[cg]") {
using namespace sopt;
ConjugateGradient const cg(std::numeric_limits<t_uint>::max(), 1e-12);
SECTION("Real valued") {
auto const A = Image<>::Random(10, 10).eval();
auto const AtA = (A.transpose().matrix() * A.matrix()).eval();
auto const expected = Array<>::Random(A.rows()).eval();
auto const actual = cg(AtA, (A.transpose().matrix() * expected.matrix()).eval());
CHECK(actual.niters > 0);
CHECK(std::abs(actual.residual) < 1e-6);
CAPTURE(actual.residual);
CAPTURE((A.matrix() * actual.result).transpose());
CAPTURE(expected.transpose());
CHECK((A.matrix() * actual.result).isApprox(expected.matrix(), 1e-6));
}
SECTION("Complex valued") {
auto const A = Image<t_complex>::Random(10, 10).eval();
auto const AhA = (A.conjugate().transpose().matrix() * A.matrix()).eval();
auto const expected = Array<t_complex>::Random(A.rows()).eval();
auto const actual = cg(AhA, (A.conjugate().transpose().matrix() * expected.matrix()).eval());
CHECK(actual.niters > 0);
CHECK(std::abs(actual.residual) < 1e-6);
CAPTURE(actual.residual);
CAPTURE((A.matrix() * actual.result).transpose());
CAPTURE(expected.transpose());
CHECK((A.matrix() * actual.result).isApprox(expected.matrix(), 1e-6));
}
}
|