File: reweighted.cc

package info (click to toggle)
sopt 4.2.0%2Bdfsg-2
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 6,632 kB
  • sloc: cpp: 13,011; xml: 182; makefile: 6
file content (149 lines) | stat: -rw-r--r-- 6,292 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
#include <algorithm>
#include <exception>
#include <functional>
#include <iostream>
#include <random>
#include <vector>
#include <ctime>

#include "sopt/logging.h"
#include "sopt/maths.h"
#include "sopt/positive_quadrant.h"
#include "sopt/relative_variation.h"
#include "sopt/reweighted.h"
#include "sopt/sampling.h"
#include "sopt/sdmm.h"
#include "sopt/types.h"
#include "sopt/utilities.h"
#include "sopt/wavelets.h"
// This header is not part of the installed sopt interface
// It is only present in tests
#include "tools_for_tests/directories.h"
#include "tools_for_tests/tiffwrappers.h"

// \min_{x} ||W_j\Psi^Tx||_1 \quad \mbox{s.t.} \quad ||y - Ax||_2 < \epsilon and x \geq 0
// with W_j = ||\Psi^Tx_{j-1}||_1
// By iterating this algorithm, we can approximate L0 from L1.
int main(int argc, char const **argv) {
  // Some type aliases for simplicity
  using Scalar = double;
  // Column vector - linear algebra - A * x is a matrix-vector multiplication
  // type expected by SDMM
  using Vector = sopt::Vector<Scalar>;
  // Matrix - linear algebra - A * x is a matrix-vector multiplication
  // type expected by SDMM
  using Matrix = sopt::Matrix<Scalar>;
  // Image - 2D array - A * x is a coefficient-wise multiplication
  // Type expected by wavelets and image write/read functions
  using Image = sopt::Image<Scalar>;

  std::string const input = argc >= 2 ? argv[1] : "cameraman256";
  std::string const output = argc == 3 ? argv[2] : "none";
  if (argc > 3) {
    std::cout << "Usage:\n"
                 "$ "
              << argv[0]
              << " [input [output]]\n\n"
                 "- input: path to the image to clean (or name of standard SOPT image)\n"
                 "- output: filename pattern for output image\n";
    exit(0);
  }
  // Set up random numbers for C and C++
  auto const seed = std::time(nullptr);
  std::srand(static_cast<unsigned int>(seed));
  std::mt19937 mersenne(std::time(nullptr));

  SOPT_HIGH_LOG("Read input file {}", input);
  Image const image = sopt::tools::read_standard_tiff(input);

  SOPT_HIGH_LOG("Initializing sensing operator");
  sopt::t_uint const nmeasure = 0.33 * image.size();
  auto const sampling =
      sopt::linear_transform<Scalar>(sopt::Sampling(image.size(), nmeasure, mersenne));

  SOPT_HIGH_LOG("Initializing wavelets");
  auto const wavelet = sopt::wavelets::factory("DB4", 4);
  auto const psi = sopt::linear_transform<Scalar>(wavelet, image.rows(), image.cols());

  SOPT_HIGH_LOG("Computing sdmm parameters");
  Vector const y0 = sampling * Vector::Map(image.data(), image.size());
  auto constexpr snr = 30.0;
  auto const sigma = y0.stableNorm() / std::sqrt(y0.size()) * std::pow(10.0, -(snr / 20.0));
  auto const epsilon = std::sqrt(nmeasure + 2 * std::sqrt(y0.size())) * sigma;

  SOPT_HIGH_LOG("Create dirty vector");
  std::normal_distribution<> gaussian_dist(0, sigma);
  Vector y(y0.size());
  for (sopt::t_int i = 0; i < y0.size(); i++) y(i) = y0(i) + gaussian_dist(mersenne);
  // Write dirty imagte to file
  if (output != "none") {
    Vector const dirty = sampling.adjoint() * y;
    sopt::utilities::write_tiff(Matrix::Map(dirty.data(), image.rows(), image.cols()),
                                "dirty_" + output + ".tiff");
  }

  SOPT_HIGH_LOG("Initializing convergence function");
  auto relvar = sopt::RelativeVariation<Scalar>(5e-2);
  auto convergence = [&y, &sampling, &psi, &relvar](sopt::Vector<Scalar> const &x) -> bool {
    SOPT_MEDIUM_LOG("||x - y||_2: {}", (y - sampling * x).stableNorm());
    SOPT_MEDIUM_LOG("||Psi^Tx||_1: {}", sopt::l1_norm(psi.adjoint() * x));
    SOPT_MEDIUM_LOG("||abs(x) - x||_2: {}", (x.array().abs().matrix() - x).stableNorm());
    return relvar(x);
  };

  SOPT_HIGH_LOG("Creating SDMM Functor");
  auto const sdmm =
      sopt::algorithm::SDMM<Scalar>()
          .itermax(3000)
          .gamma(0.1)
          .conjugate_gradient(200, 1e-8)
          .is_converged(convergence)
          // Any number of (proximal g_i, L_i) pairs can be added
          // ||Psi^dagger x||_1
          .append(sopt::proximal::l1_norm<Scalar>, psi.adjoint(), psi)
          // ||y - A x|| < epsilon
          .append(sopt::proximal::translate(sopt::proximal::L2Ball<Scalar>(epsilon), -y), sampling)
          // x in positive quadrant
          .append(sopt::proximal::positive_quadrant<Scalar>);

  SOPT_HIGH_LOG("Creating the reweighted algorithm");
  // positive_quadrant projects the result of SDMM on the positive quadrant.
  // This follows the reweighted algorithm in the original C implementation.
  auto const posq = positive_quadrant(sdmm);
  using t_PosQuadSDMM = std::remove_const<decltype(posq)>::type;
  auto const min_delta = sigma * std::sqrt(y.size()) / std::sqrt(8 * image.size());
  // Sets weight after each sdmm iteration.
  // In practice, this means replacing the proximal of the l1 objective function.
  auto set_weights = [](t_PosQuadSDMM &sdmm, Vector const &weights) {
    sdmm.algorithm().proximals(0) = [weights](Vector &out, Scalar gamma, Vector const &x) {
      out = sopt::soft_threshhold(x, gamma * weights);
    };
  };
  auto call_PsiT = [&psi](t_PosQuadSDMM const &, Vector const &x) -> Vector {
    return psi.adjoint() * x;
  };
  auto const reweighted = sopt::algorithm::reweighted(posq, set_weights, call_PsiT)
                              .itermax(5)
                              .min_delta(min_delta)
                              .is_converged(sopt::RelativeVariation<Scalar>(1e-3));

  SOPT_HIGH_LOG("Computing warm-start SDMM");
  auto warm_start = sdmm(Vector::Zero(image.size()));
  warm_start.x = sopt::positive_quadrant(warm_start.x);
  SOPT_HIGH_LOG("SDMM returned {}", warm_start.good);

  SOPT_HIGH_LOG("Computing warm-start SDMM");
  auto const result = reweighted(warm_start);

  // result should tell us the function converged
  // it also contains result.niters - the number of iterations, and cg_diagnostic - the
  // result from the last call to the conjugate gradient.
  if (not result.good) throw std::runtime_error("Did not converge!");

  SOPT_HIGH_LOG("SOPT-SDMM converged in {} iterations", result.niters);
  if (output != "none")
    sopt::utilities::write_tiff(Matrix::Map(result.algo.x.data(), image.rows(), image.cols()),
                                output + ".tiff");

  return 0;
}