1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
|
#include <algorithm>
#include <exception>
#include <functional>
#include <iostream>
#include <random>
#include <vector>
#include <ctime>
#include <catch2/catch_all.hpp>
#include "sopt/imaging_forward_backward.h"
#include "sopt/l1_non_diff_function.h"
#include "sopt/logging.h"
#include "sopt/maths.h"
#include "sopt/relative_variation.h"
#include "sopt/sampling.h"
#include "sopt/types.h"
#include "sopt/utilities.h"
#include "sopt/wavelets.h"
#include "sopt/gradient_utils.h"
// This header is not part of the installed sopt interface
// It is only present in tests
#include "tools_for_tests/directories.h"
#include "tools_for_tests/tiffwrappers.h"
// \min_{x} ||\Psi^Tx||_1 \quad \mbox{s.t.} \quad ||y - Ax||_2 < \epsilon and x \geq 0
using Scalar = double;
using Vector = sopt::Vector<Scalar>;
using Matrix = sopt::Matrix<Scalar>;
using Image = sopt::Image<Scalar>;
TEST_CASE("Inpainting"){
extern std::unique_ptr<std::mt19937_64> mersenne;
std::string const input = "cameraman256";
Image const image = sopt::tools::read_standard_tiff(input);
auto const wavelet = sopt::wavelets::factory("DB8", 4);
auto const psi = sopt::linear_transform<Scalar>(wavelet, image.rows(), image.cols());
size_t nmeasure = static_cast<size_t>(image.size() * 0.5);
double constexpr snr = 30.0;
std::shared_ptr<sopt::LinearTransform<Vector>> Phi =
std::make_shared<sopt::LinearTransform<Vector>>(
sopt::linear_transform<Scalar>(sopt::Sampling(image.size(), nmeasure, *mersenne)));
Vector y = (*Phi) * Vector::Map(image.data(), image.size());
auto sigma = y.stableNorm() / std::sqrt(y.size()) * std::pow(10.0, -(snr / 20.0));
sopt::t_real constexpr regulariser_strength = 18;
sopt::t_real const beta = sigma*sigma*0.5;
// Define a stochostic target/operator updater!
std::unique_ptr<std::mt19937_64> *m = &mersenne;
std::function<std::shared_ptr<sopt::IterationState<Vector>>()> random_updater = [&image, m, sigma, nmeasure](){
double constexpr snr = 30.0;
std::shared_ptr<sopt::LinearTransform<Vector>> Phi =
std::make_shared<sopt::LinearTransform<Vector>>(sopt::linear_transform<Scalar>(sopt::Sampling(image.size(), nmeasure, **m)));
Vector y = (*Phi) * Vector::Map(image.data(), image.size());
std::normal_distribution<> gaussian_dist(0, sigma);
for (sopt::t_int i = 0; i < y.size(); i++) y(i) = y(i) + gaussian_dist(*mersenne);
return std::make_shared<sopt::IterationState<Vector>>(y, Phi);
};
auto fb = sopt::algorithm::ImagingForwardBackward<Scalar>(random_updater);
fb.itermax(1000)
.step_size(beta) // stepsize
.sigma(sigma) // sigma
.regulariser_strength(regulariser_strength) // regularisation paramater
.relative_variation(1e-3)
.residual_tolerance(0)
.tight_frame(true);
// Create a shared pointer to an instance of the L1GProximal class
// and set its properties
auto gp = std::make_shared<sopt::algorithm::L1GProximal<Scalar>>(false);
gp->l1_proximal_tolerance(1e-4)
.l1_proximal_nu(1)
.l1_proximal_itermax(50)
.l1_proximal_positivity_constraint(true)
.l1_proximal_real_constraint(true)
.Psi(psi);
// Once the properties are set, inject it into the ImagingForwardBackward object
fb.g_function(gp);
auto const diagnostic = fb();
CHECK(diagnostic.good);
CHECK(diagnostic.niters < 500);
// compare input image to cleaned output image
// calculate mean squared error sum_i ( ( x_true(i) - x_est(i) ) **2 )
// check this is less than the number of pixels * 0.01
Eigen::Map<const Eigen::VectorXd> flat_image(image.data(), image.size());
auto mse = (flat_image - diagnostic.x).array().square().sum() / image.size();
CAPTURE(mse);
CHECK(mse < 0.01);
}
|