File: conjugate_gradient.cc

package info (click to toggle)
sopt 5.0.1%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 6,704 kB
  • sloc: cpp: 13,620; xml: 182; makefile: 6
file content (39 lines) | stat: -rw-r--r-- 1,411 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#include <random>
#include "catch2/catch_all.hpp"

#include "sopt/conjugate_gradient.h"

TEST_CASE("Conjugate gradient", "[cg]") {
  using namespace sopt;

  ConjugateGradient const cg(std::numeric_limits<t_uint>::max(), 1e-12);
  SECTION("Real valued") {
    auto const A = Image<>::Random(10, 10).eval();
    auto const AtA = (A.transpose().matrix() * A.matrix()).eval();
    auto const expected = Array<>::Random(A.rows()).eval();

    auto const actual = cg(AtA, (A.transpose().matrix() * expected.matrix()).eval());

    CHECK(actual.niters > 0);
    CHECK(std::abs(actual.residual) < 1e-6);
    CAPTURE(actual.residual);
    CAPTURE((A.matrix() * actual.result).transpose());
    CAPTURE(expected.transpose());
    CHECK((A.matrix() * actual.result).isApprox(expected.matrix(), 1e-6));
  }

  SECTION("Complex valued") {
    auto const A = Image<t_complex>::Random(10, 10).eval();
    auto const AhA = (A.conjugate().transpose().matrix() * A.matrix()).eval();
    auto const expected = Array<t_complex>::Random(A.rows()).eval();

    auto const actual = cg(AhA, (A.conjugate().transpose().matrix() * expected.matrix()).eval());

    CHECK(actual.niters > 0);
    CHECK(std::abs(actual.residual) < 1e-6);
    CAPTURE(actual.residual);
    CAPTURE((A.matrix() * actual.result).transpose());
    CAPTURE(expected.transpose());
    CHECK((A.matrix() * actual.result).isApprox(expected.matrix(), 1e-6));
  }
}