File: test_multi_transform.cpp

package info (click to toggle)
spfft 1.1.1-5
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 1,688 kB
  • sloc: cpp: 11,562; f90: 665; ansic: 437; python: 41; makefile: 24
file content (95 lines) | stat: -rw-r--r-- 3,486 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#include <fftw3.h>

#include <algorithm>
#include <memory>
#include <random>
#include <tuple>
#include <utility>
#include <vector>

#include "gtest/gtest.h"
#include "gtest_mpi.hpp"
#include "memory/array_view_utility.hpp"
#include "memory/host_array.hpp"
#include "memory/host_array_view.hpp"
#include "mpi_util/mpi_communicator_handle.hpp"
#include "parameters/parameters.hpp"
#include "spfft/spfft.hpp"
#include "test_util/generate_indices.hpp"
#include "test_util/test_transform.hpp"
#include "util/common_types.hpp"

TEST(MPIMultiTransformTest, BackwardsForwards) {
  GTEST_MPI_GUARD
  try {
    MPICommunicatorHandle comm(MPI_COMM_WORLD);
    const std::vector<double> zStickDistribution(comm.size(), 1.0);
    const std::vector<double> xyPlaneDistribution(comm.size(), 1.0);

    const int dimX = comm.size() * 10;
    const int dimY = comm.size() * 11;
    const int dimZ = comm.size() * 12;

    const int numTransforms = 3;

    std::mt19937 randGen(42);
    const auto valueIndicesPerRank =
        create_value_indices(randGen, zStickDistribution, 0.7, 0.7, dimX, dimY, dimZ, false);
    const int numLocalXYPlanes =
        calculate_num_local_xy_planes(comm.rank(), dimZ, xyPlaneDistribution);

    const auto& localIndices = valueIndicesPerRank[comm.rank()];
    const int numValues = localIndices.size() / 3;
    std::vector<std::vector<std::complex<double>>> freqValuesPerTrans(
        numTransforms, std::vector<std::complex<double>>(numValues));

    std::vector<double*> freqValuePtr;
    for (auto& values : freqValuesPerTrans) {
      freqValuePtr.push_back(reinterpret_cast<double*>(values.data()));
    }

    // set frequency values to constant for each transform
    for (std::size_t i = 0; i < freqValuesPerTrans.size(); ++i) {
      for (auto& val : freqValuesPerTrans[i]) {
        val = std::complex<double>(i, i);
      }
    }

    std::vector<Transform> transforms;

    // create first transforms
    transforms.push_back(Grid(dimX, dimY, dimZ, dimX * dimY, numLocalXYPlanes, SPFFT_PU_HOST, -1,
                              comm.get(), SPFFT_EXCH_DEFAULT)
                             .create_transform(SPFFT_PU_HOST, SPFFT_TRANS_C2C, dimX, dimY, dimZ,
                                               numLocalXYPlanes, numValues, SPFFT_INDEX_TRIPLETS,
                                               localIndices.data()));
    // clone first transform
    for (int i = 1; i < numTransforms; ++i) {
      transforms.push_back(transforms.front().clone());
    }

    std::vector<SpfftProcessingUnitType> processingUnits(numTransforms, SPFFT_PU_HOST);
    std::vector<SpfftScalingType> scalingTypes(numTransforms, SPFFT_NO_SCALING);

    // backward
    multi_transform_backward(numTransforms, transforms.data(), freqValuePtr.data(),
                             processingUnits.data());

    // forward
    multi_transform_forward(numTransforms, transforms.data(), processingUnits.data(),
                            freqValuePtr.data(), scalingTypes.data());

    // check all values
    for (std::size_t i = 0; i < freqValuesPerTrans.size(); ++i) {
      const auto targetValue = std::complex<double>(i * dimX * dimY * dimZ, i * dimX * dimY * dimZ);
      for (auto& val : freqValuesPerTrans[i]) {
        ASSERT_NEAR(targetValue.real(), val.real(), 1e-8);
        ASSERT_NEAR(targetValue.imag(), val.imag(), 1e-8);
      }
    }

  } catch (const std::exception& e) {
    std::cout << "ERROR: " << e.what() << std::endl;
    ASSERT_TRUE(false);
  }
}