1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
|
#include "catch2/catch_all.hpp"
#include "purify/logging.h"
#include "purify/measurement_operator_factory.h"
#include "purify/utilities.h"
#include <sopt/power_method.h>
using namespace purify;
TEST_CASE("Serial vs Distributed Operator") {
purify::logging::set_level("debug");
auto const N = 100;
auto uv_serial = utilities::random_sample_density(N, 0, constant::pi / 3);
auto const over_sample = 2;
auto const J = 4;
auto const kernel = kernels::kernel::kb;
auto const width = 128;
auto const height = 128;
const auto op_serial = purify::measurementoperator::init_degrid_operator_2d<Vector<t_complex>>(
uv_serial.u, uv_serial.v, uv_serial.w, uv_serial.weights, height, width, over_sample);
const auto op = factory::measurement_operator_factory<Vector<t_complex>>(
factory::distributed_measurement_operator::serial, uv_serial.u, uv_serial.v, uv_serial.w,
uv_serial.weights, height, width, over_sample);
SECTION("Degridding") {
Vector<t_complex> const image = Vector<t_complex>::Random(width * height);
auto uv_degrid = uv_serial;
uv_degrid.vis = *op_serial * image;
Vector<t_complex> const degridded = *op * image;
REQUIRE(degridded.size() == uv_degrid.vis.size());
REQUIRE(degridded.isApprox(uv_degrid.vis, 1e-4));
}
SECTION("Gridding") {
Vector<t_complex> const gridded = op->adjoint() * uv_serial.vis;
Vector<t_complex> const gridded_serial = op_serial->adjoint() * uv_serial.vis;
REQUIRE(gridded.size() == gridded_serial.size());
REQUIRE(gridded.isApprox(gridded_serial, 1e-4));
}
}
TEST_CASE("GPU Serial vs Distributed Operator") {
auto const N = 100;
auto uv_serial = utilities::random_sample_density(N, 0, constant::pi / 3);
auto const over_sample = 2;
auto const J = 4;
auto const kernel = kernels::kernel::kb;
auto const width = 128;
auto const height = 128;
const auto op_serial = purify::measurementoperator::init_degrid_operator_2d<Vector<t_complex>>(
uv_serial.u, uv_serial.v, uv_serial.w, uv_serial.weights, height, width, over_sample);
#ifndef PURIFY_ARRAYFIRE
REQUIRE_THROWS(factory::measurement_operator_factory<Vector<t_complex>>(
factory::distributed_measurement_operator::gpu_serial, uv_serial.u, uv_serial.v, uv_serial.w,
uv_serial.weights, height, width, over_sample));
#else
const auto op = factory::measurement_operator_factory<Vector<t_complex>>(
factory::distributed_measurement_operator::gpu_serial, uv_serial.u, uv_serial.v, uv_serial.w,
uv_serial.weights, height, width, over_sample);
SECTION("Degridding") {
Vector<t_complex> const image = Vector<t_complex>::Random(width * height);
auto uv_degrid = uv_serial;
uv_degrid.vis = *op_serial * image;
Vector<t_complex> const degridded = *op * image;
REQUIRE(degridded.size() == uv_degrid.vis.size());
REQUIRE(degridded.isApprox(uv_degrid.vis, 1e-4));
}
SECTION("Gridding") {
Vector<t_complex> const gridded = op->adjoint() * uv_serial.vis;
Vector<t_complex> const gridded_serial = op_serial->adjoint() * uv_serial.vis;
REQUIRE(gridded.size() == gridded_serial.size());
REQUIRE(gridded.isApprox(gridded_serial, 1e-4));
}
#endif
}
|