File: conjugate_gradient.cc

package info (click to toggle)
sopt 2.0.0-2
  • links: PTS, VCS
  • area: main
  • in suites: stretch
  • size: 3,932 kB
  • ctags: 1,162
  • sloc: cpp: 7,220; php: 287; python: 57; ansic: 33; makefile: 5
file content (39 lines) | stat: -rw-r--r-- 1,392 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#include <random>
#include "catch.hpp"

#include "sopt/conjugate_gradient.h"

TEST_CASE("Conjugate gradient", "[cg]") {
  using namespace sopt;

  ConjugateGradient const cg(std::numeric_limits<t_uint>::max(), 1e-12);
  SECTION("Real valued") {
    auto const A = Image<>::Random(10, 10).eval();
    auto const AtA = (A.transpose().matrix() * A.matrix()).eval();
    auto const expected = Array<>::Random(A.rows()).eval();

    auto const actual = cg(AtA, (A.transpose().matrix() * expected.matrix()).eval());

    CHECK(actual.niters > 0);
    CHECK(actual.residual == Approx(0));
    CAPTURE(actual.residual);
    CAPTURE((A.matrix() * actual.result).transpose());
    CAPTURE(expected.transpose());
    CHECK((A.matrix() * actual.result).isApprox(expected.matrix(), 1e-6));
  }

  SECTION("Complex valued") {
    auto const A = Image<t_complex>::Random(10, 10).eval();
    auto const AhA = (A.conjugate().transpose().matrix() * A.matrix()).eval();
    auto const expected = Array<t_complex>::Random(A.rows()).eval();

    auto const actual = cg(AhA, (A.conjugate().transpose().matrix() * expected.matrix()).eval());

    CHECK(actual.niters > 0);
    CHECK(actual.residual == Approx(0));
    CAPTURE(actual.residual);
    CAPTURE((A.matrix() * actual.result).transpose());
    CAPTURE(expected.transpose());
    CHECK((A.matrix() * actual.result).isApprox(expected.matrix(), 1e-6));
  }
}