File: wavelet_operator_mpi.cc

package info (click to toggle)
purify 4.2.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 182,264 kB
  • sloc: cpp: 16,485; python: 449; xml: 182; makefile: 7; sh: 6
file content (129 lines) | stat: -rw-r--r-- 4,702 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#include <chrono>
#include <benchmark/benchmark.h>
#include "benchmarks/utilities.h"
#include "purify/wavelet_operator_factory.h"
#include <sopt/mpi/communicator.h>
#include <sopt/mpi/session.h>

using namespace purify;

// -------------- Constructor benchmark -------------------------//

// void wavelet_operator_constructor_mpi(benchmark::State &state) {

//   // Image size
//   t_uint m_imsizex = state.range(0);
//   t_uint m_imsizey = state.range(0);
//   // MPI communicator
//   sopt::mpi::Communicator m_world = sopt::mpi::Communicator::World();

//   // benchmark the creation of measurement operator
//   while(state.KeepRunning()) {

//     auto start = std::chrono::high_resolution_clock::now();

//     const  sopt::wavelets::SARA m_sara{std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u),
//     std::make_tuple("DB2", 3u),
//       std::make_tuple("DB3", 3u),   std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
//       std::make_tuple("DB6", 3u),   std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};

//     sopt::wavelets::SARA saraDistr = sopt::wavelets::distribute_sara(m_sara, m_world);

//     sopt::LinearTransform<Vector<t_complex>> Psi = sopt::linear_transform<t_complex>(saraDistr,
//     m_imsizey, m_imsizex, m_world);

//     auto end   = std::chrono::high_resolution_clock::now();

//     state.SetIterationTime(b_utilities::duration(start,end));
//   }
// }

// BENCHMARK(wavelet_operator_constructor_mpi)
// //->Apply(b_utilities::Arguments)
// //->Args({1024})
// ->RangeMultiplier(2)->Range(1024, 1024<<4)
// ->UseManualTime()
// ->Repetitions(1)->ReportAggregatesOnly(true)
// ->Unit(benchmark::kMillisecond);

// ----------------- Application benchmarks -----------------------//

class WaveletOperatorMPIFixture : public ::benchmark::Fixture {
 public:
  WaveletOperatorMPIFixture(){};
  void SetUp(const ::benchmark::State& state) {
    m_imsizex = state.range(0);
    m_imsizey = state.range(0);
    b_utilities::update_comm(m_world);
    sopt::wavelets::SARA m_sara{
        std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
        std::make_tuple("DB3", 3u),   std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
        std::make_tuple("DB6", 3u),   std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};

    sopt::wavelets::SARA const saraDistr = sopt::wavelets::distribute_sara(m_sara, m_world);

    // Get the number of wavelet coefs
    n_wave_coeff = saraDistr.size() * m_imsizey * m_imsizex;
    m_Psi = sopt::linear_transform<t_complex>(saraDistr, m_imsizey, m_imsizex, m_world);
  }

  void TearDown(const ::benchmark::State& state) {}

  // A bunch of useful variables
  t_uint m_counter;
  // MPI communicator
  sopt::mpi::Communicator m_world;
  sopt::LinearTransform<Vector<t_complex>> m_Psi = sopt::linear_transform_identity<t_complex>();
  t_uint m_imsizex;
  t_uint m_imsizey;
  // Get the number of wavelet coefs
  t_uint n_wave_coeff;
};

BENCHMARK_DEFINE_F(WaveletOperatorMPIFixture, Forward)(benchmark::State& state) {
  // Benchmark the application of the operator

  // Apply Psi to a temporary vector
  Vector<t_complex> image = Vector<t_complex>::Random(m_imsizey * m_imsizex);
  Vector<t_complex> const wavelet_coeff = Vector<t_complex>::Ones(n_wave_coeff);

  while (m_world.broadcast<int>(state.KeepRunning())) {
    auto start = std::chrono::high_resolution_clock::now();
    image = m_Psi * wavelet_coeff;
    auto end = std::chrono::high_resolution_clock::now();
    state.SetIterationTime(b_utilities::duration(start, end, m_world));
  }
}

BENCHMARK_DEFINE_F(WaveletOperatorMPIFixture, Adjoint)(benchmark::State& state) {
  // Apply Psi to a temporary vector
  Vector<t_complex> const image = Vector<t_complex>::Ones(m_imsizey * m_imsizex);
  Vector<t_complex> wavelet_coeff = Vector<t_complex>::Zero(n_wave_coeff);

  while (m_world.broadcast<int>(state.KeepRunning())) {
    auto start = std::chrono::high_resolution_clock::now();
    wavelet_coeff = m_Psi.adjoint() * image;
    auto end = std::chrono::high_resolution_clock::now();
    state.SetIterationTime(b_utilities::duration(start, end, m_world));
  }
}

BENCHMARK_REGISTER_F(WaveletOperatorMPIFixture, Forward)
    // //->Apply(b_utilities::Arguments)
    ->RangeMultiplier(2)
    ->Range(128, 128 << 3)
    ->UseManualTime()
    ->Repetitions(5)
    ->ReportAggregatesOnly(true)
    ->Unit(benchmark::kMillisecond);

BENCHMARK_REGISTER_F(WaveletOperatorMPIFixture, Adjoint)
    // //->Apply(b_utilities::Arguments)
    ->RangeMultiplier(2)
    ->Range(128, 128 << 3)
    ->UseManualTime()
    ->Repetitions(5)
    ->ReportAggregatesOnly(true)
    ->Unit(benchmark::kMillisecond);

// BENCHMARK_MAIN();