File: conjugate_gradient.cc

package info (click to toggle)
sopt 4.2.0%2Bdfsg-2
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 6,632 kB
  • sloc: cpp: 13,011; xml: 182; makefile: 6
file content (42 lines) | stat: -rw-r--r-- 1,861 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#include "sopt/conjugate_gradient.h"
#include <sstream>
#include <benchmark/benchmark.h>

template <typename TYPE>
void matrix_cg(benchmark::State &state) {
  auto const N = state.range_x();
  auto const epsilon = std::pow(10, -state.range_y());
  auto const A = sopt::Image<TYPE>::Random(N, N).eval();
  auto const b = sopt::Array<TYPE>::Random(N).eval();

  auto const AhA = A.matrix().transpose().conjugate() * A.matrix();
  auto const Ahb = A.matrix().transpose().conjugate() * b.matrix();
  auto output = sopt::Vector<TYPE>::Zero(N).eval();
  sopt::ConjugateGradient cg(0, epsilon);
  while (state.KeepRunning()) cg(output, AhA, Ahb);
  state.SetBytesProcessed(int64_t(state.iterations()) * int64_t(N) * sizeof(TYPE));
}

template <typename TYPE>
void function_cg(benchmark::State &state) {
  auto const N = state.range_x();
  auto const epsilon = std::pow(10, -state.range_y());
  auto const A = sopt::Image<TYPE>::Random(N, N).eval();
  auto const b = sopt::Array<TYPE>::Random(N).eval();

  auto const AhA = A.matrix().transpose().conjugate() * A.matrix();
  auto const Ahb = A.matrix().transpose().conjugate() * b.matrix();
  using t_Vector = sopt::Vector<TYPE>;
  auto func = [&AhA](t_Vector &out, t_Vector const &input) { out = AhA * input; };
  auto output = sopt::Vector<TYPE>::Zero(N).eval();
  sopt::ConjugateGradient cg(0, epsilon);
  while (state.KeepRunning()) cg(output, func, Ahb);
  state.SetBytesProcessed(int64_t(state.iterations()) * int64_t(N) * sizeof(TYPE));
}

BENCHMARK_TEMPLATE(matrix_cg, sopt::t_complex)->RangePair(1, 256, 4, 12)->UseRealTime();
BENCHMARK_TEMPLATE(matrix_cg, sopt::t_real)->RangePair(1, 256, 4, 12)->UseRealTime();
BENCHMARK_TEMPLATE(function_cg, sopt::t_complex)->RangePair(1, 256, 4, 12)->UseRealTime();
BENCHMARK_TEMPLATE(function_cg, sopt::t_real)->RangePair(1, 256, 4, 12)->UseRealTime();

BENCHMARK_MAIN()