File: mpi_proximals.cc

package info (click to toggle)
sopt 4.2.0%2Bdfsg-2
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 6,632 kB
  • sloc: cpp: 13,011; xml: 182; makefile: 6
file content (58 lines) | stat: -rw-r--r-- 1,833 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#include <catch2/catch_all.hpp>
#include <numeric>
#include <random>
#include <utility>

#include "sopt/proximal.h"
#include "sopt/types.h"

TEST_CASE("Parallel Euclidian norm", "[proximal]") {
  using namespace sopt;
  auto const world = mpi::Communicator::World();
  proximal::EuclidianNorm const eucl(world);

  Vector<t_real> out(5);
  Vector<t_real> x(5);
  x << 1, 2, 3, 4, 5;
  eucl(out, std::sqrt(world.all_sum_all(x.squaredNorm())) * 1.001, x);
  CHECK(out.isApprox(Vector<t_real>::Zero(x.size())));

  out = eucl(0.1, x);
  CHECK(out.isApprox(x * (1e0 - 0.1 / std::sqrt(world.all_sum_all(x.squaredNorm())))));
}

TEST_CASE("Parallel L2 ball", "[proximal]") {
  using namespace sopt;
  auto const world = mpi::Communicator::World();
  proximal::L2Ball<t_real> ball(0.5, world);
  Vector<t_real> out;
  Vector<t_real> x(5);
  x << 1, 2, 3, 4, 5;
  x *= world.rank();

  out = ball(0, x);
  CHECK(x.isApprox(out / 0.5 * std::sqrt(world.all_sum_all(x.squaredNorm()))));
  ball.epsilon(std::sqrt(world.all_sum_all(x.squaredNorm())) * 1.001);
  out = ball(0, x);
  CHECK(x.isApprox(out));
}

TEST_CASE("Parallel WeightedL2Ball", "[proximal]") {
  using namespace sopt;
  auto const world = mpi::Communicator::World();
  Vector<t_real> const weights = 0.01 * Vector<t_real>::Random(5).array() + 1e0;
  Vector<t_real> x(5);
  x << 1, 2, 3, 4, 5;
  x *= world.rank();
  proximal::WeightedL2Ball<t_real> wball(0.5, weights, world);
  proximal::L2Ball<t_real> const ball(0.5, world);

  Vector<t_real> const expected =
      ball((x.array() * weights.array()).matrix()).array() / weights.array();
  Vector<t_real> const actual = wball(x);
  CHECK(actual.isApprox(expected));

  wball.epsilon(std::sqrt(world.all_sum_all((x.array() * weights.array()).matrix().squaredNorm())) *
                1.001);
  CHECK(x.isApprox(wball(x)));
}