File: uq_main.cc

package info (click to toggle)
purify 5.0.1-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 186,836 kB
  • sloc: cpp: 17,731; python: 510; xml: 182; makefile: 7; sh: 6
file content (183 lines) | stat: -rw-r--r-- 8,195 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
#include <iostream>
#include <stdlib.h>
#include <string>
#include <tuple>
#include <vector>
#include "purify/measurement_operator_factory.h"
#include "purify/pfitsio.h"
#include "purify/setup_utils.h"
#include "purify/utilities.h"
#include "purify/yaml-parser.h"
#include "yaml-cpp/yaml.h"
#include "sopt/differentiable_func.h"
#include "sopt/non_differentiable_func.h"
#include "sopt/objective_functions.h"
#include <sopt/l1_non_diff_function.h>
#include <sopt/l2_differentiable_func.h>
#include <sopt/real_indicator.h>

using VectorC = sopt::Vector<std::complex<double>>;

int main(int argc, char **argv) {
  if (argc != 4) {
    std::cout << "purify_UQ should be run using three additional arguments." << std::endl;
    std::cout << "purify_UQ <config_path> <reference_image_path> <surrogate_image_path>"
              << std::endl;
    std::cout << "<config_path>: path to a .yaml config file specifying details of measurement "
                 "operator, wavelet operator, observations, and cost functions."
              << std::endl;
    std::cout << "<reference_image_path>: path to image file (.fits) which was output from running "
                 "purify on observed data."
              << std::endl;
    std::cout << "<surrogate_image_path>: path to modified image file (.fits) for feature analysis."
              << std::endl;
    std::cout << std::endl;
    std::cout
        << "For more information about the contents of the config file please consult the README."
        << std::endl;
    return 1;
  }

  // Load and parse the config for parameters
  const std::string config_path = argv[1];
  const YAML::Node UQ_config = YAML::LoadFile(config_path);

  // Load the Reference and Surrogate images
  const std::string ref_image_path = argv[2];
  const std::string surrogate_image_path = argv[3];
  const auto reference_image = purify::pfitsio::read2d(ref_image_path);
  const VectorC reference_vector = VectorC::Map(reference_image.data(), reference_image.size());
  const auto surrogate_image = purify::pfitsio::read2d(surrogate_image_path);
  const VectorC surrogate_vector = VectorC::Map(surrogate_image.data(), surrogate_image.size());

  const uint imsize_x = reference_image.cols();
  const uint imsize_y = reference_image.rows();

  std::unique_ptr<DifferentiableFunc<t_complex>> f;
  std::unique_ptr<NonDifferentiableFunc<t_complex>> g;

  // Prepare operators and data using purify config
  // If no purify config use basic version for now based on algo_factory test images
  purify::utilities::vis_params measurement_data;
  double regulariser_strength = 0;
  std::shared_ptr<sopt::LinearTransform<VectorC>> measurement_operator;
  std::shared_ptr<const sopt::LinearTransform<VectorC>> wavelet_operator;
  std::vector<std::tuple<std::string, t_uint>> const sara{
      std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u),
      std::make_tuple("DB3", 3u),   std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u),
      std::make_tuple("DB6", 3u),   std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)};
  if (UQ_config["purify_config_file"]) {
    YamlParser purify_config = YamlParser(UQ_config["purify_config_file"].as<std::string>());

    const auto [mop_algo, wop_algo, using_mpi] = selectOperators(purify_config);
    auto [uv_data, sigma, measurement_op_eigen_vector, image_index, w_stacks] =
        getInputData(purify_config, mop_algo, wop_algo, using_mpi);

    auto transform =
        createMeasurementOperator(purify_config, mop_algo, wop_algo, using_mpi, image_index,
                                  w_stacks, uv_data, measurement_op_eigen_vector);

    const waveletInfo wavelets = createWaveletOperator(purify_config, wop_algo);

    t_real const flux_scale = 1.;
    uv_data.vis = uv_data.vis.array() * uv_data.weights.array() / flux_scale;

    measurement_data = uv_data;
    measurement_operator = transform;
    wavelet_operator = wavelets.transform;

    // setup f and g based on config file
    setupCostFunctions(purify_config, f, g, sigma, *measurement_operator);

    regulariser_strength = purify_config.regularisation_parameter();
  } else {
    const std::string measurements_path = UQ_config["measurements_path"].as<std::string>();
    // Load the images and measurements
    measurement_data = purify::utilities::read_visibility(measurements_path, false);

    // This is the measurement operator used in the test but this should probably be selectable
    measurement_operator = purify::factory::measurement_operator_factory<sopt::Vector<t_complex>>(
        purify::factory::distributed_measurement_operator::serial, measurement_data, imsize_y,
        imsize_x, 1, 1, 2, kernels::kernel_from_string.at("kb"), 4, 4);

    wavelet_operator = purify::factory::wavelet_operator_factory<Vector<t_complex>>(
        factory::distributed_wavelet_operator::serial, sara, imsize_y, imsize_x);

    // default cost function
    f = std::make_unique<sopt::L2DifferentiableFunc<t_complex>>(
        1, *measurement_operator);  // what would a default sigma look like??
    g = std::make_unique<sopt::algorithm::L1GProximal<t_complex>>();

    try {
      regulariser_strength = UQ_config["regulariser_strength"].as<double>();
    } catch (...) {
      std::cout
          << "Regulariser strength not provided in UQ config, and no purify config was provided.\n";
      std::cout << "Regulariser strength will be 0 by default." << std::endl;
    }
  }

  // Set up confidence and objective function params
  double confidence;
  double alpha;
  if ((UQ_config["confidence_interval"]) && (UQ_config["alpha"])) {
    std::cout << "Config should only contain one of 'confidence_interval' or 'alpha'." << std::endl;
    return 1;
  }
  if (UQ_config["confidence_interval"]) {
    confidence = UQ_config["confidence_interval"].as<double>();
    alpha = 1 - confidence;
  } else if (UQ_config["alpha"]) {
    alpha = UQ_config["alpha"].as<double>();
    confidence = 1 - alpha;
  } else {
    std::cout << "Config file must contain either 'confidence_interval' or 'alpha' as a parameter."
              << std::endl;
    return 1;
  }

  if ((imsize_x != surrogate_image.cols()) || (imsize_y != surrogate_image.rows())) {
    std::cout << "Surrogate and reference images have different dimensions. Aborting." << std::endl;
    return 2;
  }

  if (((*measurement_operator) * reference_vector).size() != measurement_data.vis.size()) {
    std::cout << "Image size is not compatible with the measurement operator and data provided."
              << std::endl;
    return 3;
  }

  // Calculate the posterior function for the reference image
  // posterior = likelihood + prior
  // Likelihood = |y - Phi(x)|^2 / sigma^2  (L2 norm)
  // Prior = Sum(Psi^t * |x_i|) * regulariser_strength  (L1 norm)
  auto Posterior = [&measurement_data, measurement_operator, wavelet_operator, regulariser_strength,
                    &f, &g](const VectorC &image) {
    {
      const auto residuals = (*measurement_operator * image) - measurement_data.vis;
      auto A = f->function(image, measurement_data.vis, (*measurement_operator));
      auto B = g->function(image);
      return A + regulariser_strength * B;
    }
  };

  const double reference_posterior = Posterior(reference_vector);
  const double surrogate_posterior = Posterior(surrogate_vector);

  // Threshold for surrogate image posterior to be within confidence limit
  const double N = imsize_x * imsize_y;
  const double tau = std::sqrt(16 * std::log(3 / alpha));
  const double threshold = reference_posterior + tau * std::sqrt(N) + N;

  std::cout << "Uncertainty Quantification." << std::endl;
  std::cout << "Reference Log Posterior = " << reference_posterior << std::endl;
  std::cout << "Confidence interval = " << confidence << std::endl;
  std::cout << "Log Posterior threshold = " << threshold << std::endl;
  std::cout << "Surrogate Log Posterior = " << surrogate_posterior << std::endl;
  std::cout << "Surrogate image is "
            << ((surrogate_posterior <= threshold) ? "within the credible interval."
                                                   : "excluded by the credible interval.")
            << std::endl;

  return 0;
}