1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
|
#include <algorithm>
#include <exception>
#include <functional>
#include <iostream>
#include <random>
#include <vector>
#include <ctime>
#include "sopt/logging.h"
#include "sopt/maths.h"
#include "sopt/positive_quadrant.h"
#include "sopt/relative_variation.h"
#include "sopt/reweighted.h"
#include "sopt/sampling.h"
#include "sopt/sdmm.h"
#include "sopt/types.h"
#include "sopt/utilities.h"
#include "sopt/wavelets.h"
// This header is not part of the installed sopt interface
// It is only present in tests
#include "tools_for_tests/directories.h"
#include "tools_for_tests/tiffwrappers.h"
// \min_{x} ||W_j\Psi^Tx||_1 \quad \mbox{s.t.} \quad ||y - Ax||_2 < \epsilon and x \geq 0
// with W_j = ||\Psi^Tx_{j-1}||_1
// By iterating this algorithm, we can approximate L0 from L1.
int main(int argc, char const **argv) {
// Some type aliases for simplicity
using Scalar = double;
// Column vector - linear algebra - A * x is a matrix-vector multiplication
// type expected by SDMM
using Vector = sopt::Vector<Scalar>;
// Matrix - linear algebra - A * x is a matrix-vector multiplication
// type expected by SDMM
using Matrix = sopt::Matrix<Scalar>;
// Image - 2D array - A * x is a coefficient-wise multiplication
// Type expected by wavelets and image write/read functions
using Image = sopt::Image<Scalar>;
std::string const input = argc >= 2 ? argv[1] : "cameraman256";
std::string const output = argc == 3 ? argv[2] : "none";
if (argc > 3) {
std::cout << "Usage:\n"
"$ "
<< argv[0]
<< " [input [output]]\n\n"
"- input: path to the image to clean (or name of standard SOPT image)\n"
"- output: filename pattern for output image\n";
exit(0);
}
// Set up random numbers for C and C++
auto const seed = std::time(nullptr);
std::srand(static_cast<unsigned int>(seed));
std::mt19937 mersenne(std::time(nullptr));
SOPT_HIGH_LOG("Read input file {}", input);
Image const image = sopt::tools::read_standard_tiff(input);
SOPT_HIGH_LOG("Initializing sensing operator");
sopt::t_uint const nmeasure = 0.33 * image.size();
auto const sampling =
sopt::linear_transform<Scalar>(sopt::Sampling(image.size(), nmeasure, mersenne));
SOPT_HIGH_LOG("Initializing wavelets");
auto const wavelet = sopt::wavelets::factory("DB4", 4);
auto const psi = sopt::linear_transform<Scalar>(wavelet, image.rows(), image.cols());
SOPT_HIGH_LOG("Computing sdmm parameters");
Vector const y0 = sampling * Vector::Map(image.data(), image.size());
auto constexpr snr = 30.0;
auto const sigma = y0.stableNorm() / std::sqrt(y0.size()) * std::pow(10.0, -(snr / 20.0));
auto const epsilon = std::sqrt(nmeasure + 2 * std::sqrt(y0.size())) * sigma;
SOPT_HIGH_LOG("Create dirty vector");
std::normal_distribution<> gaussian_dist(0, sigma);
Vector y(y0.size());
for (sopt::t_int i = 0; i < y0.size(); i++) y(i) = y0(i) + gaussian_dist(mersenne);
// Write dirty imagte to file
if (output != "none") {
Vector const dirty = sampling.adjoint() * y;
sopt::utilities::write_tiff(Matrix::Map(dirty.data(), image.rows(), image.cols()),
"dirty_" + output + ".tiff");
}
SOPT_HIGH_LOG("Initializing convergence function");
auto relvar = sopt::RelativeVariation<Scalar>(5e-2);
auto convergence = [&y, &sampling, &psi, &relvar](sopt::Vector<Scalar> const &x) -> bool {
SOPT_MEDIUM_LOG("||x - y||_2: {}", (y - sampling * x).stableNorm());
SOPT_MEDIUM_LOG("||Psi^Tx||_1: {}", sopt::l1_norm(psi.adjoint() * x));
SOPT_MEDIUM_LOG("||abs(x) - x||_2: {}", (x.array().abs().matrix() - x).stableNorm());
return relvar(x);
};
SOPT_HIGH_LOG("Creating SDMM Functor");
auto const sdmm =
sopt::algorithm::SDMM<Scalar>()
.itermax(3000)
.gamma(0.1)
.conjugate_gradient(200, 1e-8)
.is_converged(convergence)
// Any number of (proximal g_i, L_i) pairs can be added
// ||Psi^dagger x||_1
.append(sopt::proximal::l1_norm<Scalar>, psi.adjoint(), psi)
// ||y - A x|| < epsilon
.append(sopt::proximal::translate(sopt::proximal::L2Ball<Scalar>(epsilon), -y), sampling)
// x in positive quadrant
.append(sopt::proximal::positive_quadrant<Scalar>);
SOPT_HIGH_LOG("Creating the reweighted algorithm");
// positive_quadrant projects the result of SDMM on the positive quadrant.
// This follows the reweighted algorithm in the original C implementation.
auto const posq = positive_quadrant(sdmm);
using t_PosQuadSDMM = std::remove_const<decltype(posq)>::type;
auto const min_delta = sigma * std::sqrt(y.size()) / std::sqrt(8 * image.size());
// Sets weight after each sdmm iteration.
// In practice, this means replacing the proximal of the l1 objective function.
auto set_weights = [](t_PosQuadSDMM &sdmm, Vector const &weights) {
sdmm.algorithm().proximals(0) = [weights](Vector &out, Scalar gamma, Vector const &x) {
out = sopt::soft_threshhold(x, gamma * weights);
};
};
auto call_PsiT = [&psi](t_PosQuadSDMM const &, Vector const &x) -> Vector {
return psi.adjoint() * x;
};
auto const reweighted = sopt::algorithm::reweighted(posq, set_weights, call_PsiT)
.itermax(5)
.min_delta(min_delta)
.is_converged(sopt::RelativeVariation<Scalar>(1e-3));
SOPT_HIGH_LOG("Computing warm-start SDMM");
auto warm_start = sdmm(Vector::Zero(image.size()));
warm_start.x = sopt::positive_quadrant(warm_start.x);
SOPT_HIGH_LOG("SDMM returned {}", warm_start.good);
SOPT_HIGH_LOG("Computing warm-start SDMM");
auto const result = reweighted(warm_start);
// result should tell us the function converged
// it also contains result.niters - the number of iterations, and cg_diagnostic - the
// result from the last call to the conjugate gradient.
if (not result.good) throw std::runtime_error("Did not converge!");
SOPT_HIGH_LOG("SOPT-SDMM converged in {} iterations", result.niters);
if (output != "none")
sopt::utilities::write_tiff(Matrix::Map(result.algo.x.data(), image.rows(), image.cols()),
output + ".tiff");
return 0;
}
|