File: test_parameters_mpi

package info (click to toggle)
purify 5.0.1-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 186,836 kB
  • sloc: cpp: 17,731; python: 510; xml: 182; makefile: 7; sh: 6
file content (91 lines) | stat: -rw-r--r-- 4,018 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
//how the data is distributed and read in
utilities::vis_params dirty_visibilities(const std::vector<std::string> &names) {
  return utilities::read_visibility(names, false);
}

utilities::vis_params
dirty_visibilities(const std::vector<std::string> &names, sopt::mpi::Communicator const &comm) {
  if(comm.size() == 1)
    return dirty_visibilities(names);
  if(comm.is_root()) {
    auto result = dirty_visibilities(names);
    auto const order = distribute::distribute_measurements(result, comm, distribute::plan::w_term);
    return utilities::regroup_and_scatter(result, order, comm);
  }
  auto result = utilities::scatter_visibilities(comm);
  return result;
}

//mpi communicator
  auto const world = sopt::mpi::Communicator::World();

//how the data is read in
  auto uv_data = dirty_visibilities({input_data_path}, world);
  uv_data.units = utilities::vis_units::radians;

  REQUIRE(world.all_sum_all(uv_data.size()) == 13107);
// image size
  t_uint const imsizey = 256;
  t_uint const imsizex = 256;

//input measurement operator parameters
  auto const measurements_transform
      = factory::measurement_operator_factory<Vector<t_complex>>(
          factory::distributed_measurement_operator::mpi_distribute_image,
          uv_data, imsizey, imsizex, 1, 1, 2, 100,
          0.0001, kernels::kernel_from_string.at("kb"), 4, 4);

//wavelets used
  std::vector<std::tuple<std::string, t_uint>> const sara{
      std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
      std::make_tuple("DB3", 3u),   std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
      std::make_tuple("DB6", 3u),   std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
  auto const wavelets = factory::wavelet_operator_factory<Vector<t_complex>>(
      factory::distributed_wavelet_operator::mpi_sara, sara, imsizey, imsizex);
//value of sigma used to calculate epsilon in algorithm factory
  t_real const sigma = world.broadcast(0.02378738741225); //see test_parameters file
  SECTION("global"){
    //input padmm factory parameters
  auto const padmm
      = factory::algorithm_factory<sopt::algorithm::ImagingProximalADMM<t_complex>>(
          factory::algorithm::padmm, factory::algo_distribution::mpi_serial,
          measurements_transform, wavelets, uv_data, sigma, imsizey, imsizex, sara.size(), 500);

  auto const diagnostic = (*padmm)();
  CHECK(diagnostic.niters == 139);

  const std::string &expected_solution_path = data_filename(test_dir + "solution.fits");
  const std::string &expected_residual_path = data_filename(test_dir + "residual.fits");

  const auto solution = pfitsio::read2d(expected_solution_path);
  const auto residual = pfitsio::read2d(expected_residual_path);

  }
  SECTION("local"){
  auto const padmm
      = factory::algorithm_factory<sopt::algorithm::ImagingProximalADMM<t_complex>>(
          factory::algorithm::padmm, factory::algo_distribution::mpi_distributed,
          measurements_transform, wavelets, uv_data, sigma, imsizey, imsizex, sara.size(), 500);

  auto const diagnostic = (*padmm)();
  t_real const epsilon = utilities::calculate_l2_radius(world.all_sum_all(uv_data.vis.size()), world.broadcast(sigma));
  CHECK(sopt::mpi::l2_norm(diagnostic.residual,padmm->l2ball_proximal_weights(), world) < epsilon);
  //the algorithm depends on nodes, so other than a basic bounds check, 
  //it is hard to know exact precision (might depend on probability theory...)
  if(world.size() > 2 or world.size() == 0)
    return;

  //testing the case where there are two nodes exactly.
  const std::string &expected_solution_path = (world.size() == 2) ?
    data_filename(test_dir + "mpi_solution.fits"):
    data_filename(test_dir + "solution.fits");
  const std::string &expected_residual_path = (world.size() == 2) ? 
    data_filename(test_dir + "mpi_residual.fits"):
    data_filename(test_dir + "residual.fits");
  if(world.size() == 1)
    CHECK(diagnostic.niters == 139);
  if(world.size() == 2)
    CHECK(diagnostic.niters == 138);

  }
}