File: mpi_wavelets.cc

package info (click to toggle)
sopt 5.0.1%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 6,704 kB
  • sloc: cpp: 13,620; xml: 182; makefile: 6
file content (47 lines) | stat: -rw-r--r-- 1,858 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#include <catch2/catch_all.hpp>
#include <memory>
#include <random>

#include "sopt/types.h"
#include "sopt/wavelets.h"

TEST_CASE("Wavelet transform innards with integer data", "[wavelet]") {
  using namespace sopt::wavelets;
  using namespace sopt;

  auto const world = mpi::Communicator::World();
  sopt::wavelets::SARA const serial{std::make_tuple("DB4", 5), std::make_tuple("DB8", 2)};
  CAPTURE(serial.size());
  CAPTURE(world.size());
  auto const leftover = serial.size() % world.size();
  auto const start =
      world.rank() * (serial.size() / world.size()) + std::min(world.rank(), leftover);
  auto const end = start + (serial.size() / world.size()) +
                   ((1 <= leftover and leftover > world.rank()) ? 1 : 0);

  sopt::wavelets::SARA const parallel(serial.begin() + start, serial.begin() + end);

  auto constexpr Nx = 32;
  auto constexpr Ny = 32;
  auto const psi_serial = linear_transform<t_real>(serial, Nx, Ny);
  auto const psi_parallel = linear_transform<t_real>(parallel, Nx, Ny, world);

  SECTION("Signal to Coefficients") {
    auto const signal = world.broadcast<Vector<t_real>>(Vector<t_real>::Random(Nx * Ny));
    Vector<t_real> const serial_coeffs =
        (psi_serial.adjoint() * signal).segment(start * Nx * Ny, (end - start) * Nx * Ny);
    Vector<t_real> const para_coeffs = psi_parallel.adjoint() * signal;
    CAPTURE(start);
    CAPTURE(end);
    CHECK(serial_coeffs.isApprox(para_coeffs));
  }

  SECTION("Coefficients to Signal") {
    auto const coefficients =
        world.broadcast<Vector<t_real>>(Vector<t_real>::Random(Nx * Ny * serial.size()));
    Vector<t_real> const serial_signal = (psi_serial * coefficients);
    Vector<t_real> const para_signal =
        psi_parallel * coefficients.segment(start * Nx * Ny, (end - start) * Nx * Ny);
    CHECK(serial_signal.isApprox(para_signal));
  }
}