1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290
|
#pragma once
#include <vector>
#include <limits>
#include <utility>
#include <ostream>
#include <chrono>
#include <tuple>
#include <algorithm>
#include "common_defs.h"
#include "cell_with_value.h"
#include "box.h"
#include "dual_point.h"
#include "dual_box.h"
#include "persistence_module.h"
#include "bifiltration.h"
#include "bottleneck.h"
namespace md {
#ifdef MD_PRINT_HEAT_MAP
template<class Real>
using HeatMap = std::map<DualPoint<Real>, Real>;
template<class Real>
using HeatMaps = std::map<int, HeatMap<Real>>;
#endif
enum class BoundStrategy {
bruteforce,
local_dual_bound,
local_dual_bound_refined,
local_dual_bound_for_each_point,
local_combined
};
enum class TraverseStrategy {
depth_first,
breadth_first,
breadth_first_value,
upper_bound
};
inline std::ostream& operator<<(std::ostream& os, const BoundStrategy& s)
{
switch(s) {
case BoundStrategy::bruteforce :
os << "bruteforce";
break;
case BoundStrategy::local_dual_bound :
os << "local_grob";
break;
case BoundStrategy::local_combined :
os << "local_combined";
break;
case BoundStrategy::local_dual_bound_refined :
os << "local_refined";
break;
case BoundStrategy::local_dual_bound_for_each_point :
os << "local_for_each_point";
break;
default:
os << "FORGOTTEN BOUND STRATEGY";
}
return os;
}
inline std::ostream& operator<<(std::ostream& os, const TraverseStrategy& s)
{
switch(s) {
case TraverseStrategy::depth_first :
os << "DFS";
break;
case TraverseStrategy::breadth_first :
os << "BFS";
break;
case TraverseStrategy::breadth_first_value :
os << "BFS-VAL";
break;
case TraverseStrategy::upper_bound :
os << "UB";
break;
default:
os << "FORGOTTEN TRAVERSE STRATEGY";
}
return os;
}
inline std::istream& operator>>(std::istream& is, TraverseStrategy& s)
{
std::string ss;
is >> ss;
if (ss == "DFS") {
s = TraverseStrategy::depth_first;
} else if (ss == "BFS") {
s = TraverseStrategy::breadth_first;
} else if (ss == "BFS-VAL") {
s = TraverseStrategy::breadth_first_value;
} else if (ss == "UB") {
s = TraverseStrategy::upper_bound;
} else {
throw std::runtime_error("UNKNOWN TRAVERSE STRATEGY");
}
return is;
}
inline std::istream& operator>>(std::istream& is, BoundStrategy& s)
{
std::string ss;
is >> ss;
if (ss == "bruteforce") {
s = BoundStrategy::bruteforce;
} else if (ss == "local_grob") {
s = BoundStrategy::local_dual_bound;
} else if (ss == "local_combined") {
s = BoundStrategy::local_combined;
} else if (ss == "local_refined") {
s = BoundStrategy::local_dual_bound_refined;
} else if (ss == "local_for_each_point") {
s = BoundStrategy::local_dual_bound_for_each_point;
} else {
throw std::runtime_error("UNKNOWN BOUND STRATEGY");
}
return is;
}
inline BoundStrategy bs_from_string(std::string s)
{
std::stringstream ss(s);
BoundStrategy result;
ss >> result;
return result;
}
inline TraverseStrategy ts_from_string(std::string s)
{
std::stringstream ss(s);
TraverseStrategy result;
ss >> result;
return result;
}
template<class Real>
struct CalculationParams {
static constexpr int ALL_DIMENSIONS = -1;
Real hera_epsilon {0.001}; // relative error in hera call
Real delta {0.1}; // relative error for matching distance
int max_depth {8}; // maximal number of refinenemnts
int initialization_depth {2};
int dim {0}; // in which dim to calculate the distance; use ALL_DIMENSIONS to get max over all dims
BoundStrategy bound_strategy {BoundStrategy::local_combined};
TraverseStrategy traverse_strategy {TraverseStrategy::breadth_first};
bool tolerate_max_iter_exceeded {false};
Real actual_error {std::numeric_limits<Real>::max()};
int actual_max_depth {0};
int n_hera_calls {0}; // for experiments only; is set in matching_distance function, input value is ignored
// stop looping over points immediately, if current point's displacement is too large
// to prune the cell
// if true, cells are pruned immediately, and bounds may increase
// (just return something large enough to not prune the cell)
bool stop_asap { true };
// print statistics on each quad-tree level
bool print_stats { false };
#ifdef MD_PRINT_HEAT_MAP
HeatMaps<Real> heat_maps;
#endif
};
template<class Real_, class DiagramProvider>
class DistanceCalculator {
using Real = Real_;
using CellValueVector = std::vector<CellWithValue<Real>>;
public:
DistanceCalculator(const DiagramProvider& a,
const DiagramProvider& b,
CalculationParams<Real>& params);
Real distance();
int get_hera_calls_number() const;
#ifndef MD_TEST_CODE
private:
#endif
DiagramProvider module_a_;
DiagramProvider module_b_;
CalculationParams<Real>& params_;
int n_hera_calls_;
std::map<int, int> n_hera_calls_per_level_;
Real distance_;
// if calculate_on_intermediate, then weighted distance
// will be calculated on centers of each grid in between
CellValueVector get_refined_grid(int init_depth, bool calculate_on_intermediate, bool calculate_on_last = true);
CellValueVector get_initial_dual_grid(Real& lower_bound);
#ifdef MD_PRINT_HEAT_MAP
void heatmap_in_dimension(int dim, int depth);
#endif
Real get_max_x(int module) const;
Real get_max_y(int module) const;
void set_cell_central_value(CellWithValue<Real>& dual_cell);
Real get_distance();
Real get_distance_pq();
Real get_max_possible_value(const CellWithValue<Real>* first_cell_ptr, int n_cells);
Real get_upper_bound(const CellWithValue<Real>& dual_cell, Real good_enough_upper_bound) const;
Real get_single_dgm_bound(const CellWithValue<Real>& dual_cell, ValuePoint vp, int module,
Real good_enough_value) const;
// this bound depends only on dual box
Real get_local_dual_bound(int module, const DualBox<Real>& dual_box) const;
Real get_local_dual_bound(const DualBox<Real>& dual_box) const;
// this bound depends only on dual box, is more accurate
Real get_local_refined_bound(int module, const DualBox<Real>& dual_box) const;
Real get_local_refined_bound(const DualBox<Real>& dual_box) const;
Real get_good_enough_upper_bound(Real lower_bound) const;
Real get_max_displacement_single_point(const CellWithValue<Real>& dual_cell, ValuePoint value_point,
const Point<Real>& p) const;
void check_upper_bound(const CellWithValue<Real>& dual_cell) const;
Real distance_on_line(DualPoint<Real> line);
Real distance_on_line_const(DualPoint<Real> line) const;
Real current_error(Real lower_bound, Real upper_bound);
};
template<class Real>
Real matching_distance(const Bifiltration<Real>& bif_a, const Bifiltration<Real>& bif_b,
CalculationParams<Real>& params);
template<class Real>
Real matching_distance(const ModulePresentation<Real>& mod_a, const ModulePresentation<Real>& mod_b,
CalculationParams<Real>& params);
// for upper bound experiment
struct UbExperimentRecord {
double error;
double lower_bound;
double upper_bound;
CellWithValue<double> cell;
long long int time;
long long int n_hera_calls;
};
inline std::ostream& operator<<(std::ostream& os, const UbExperimentRecord& r)
{
os << r.time << "\t" << r.n_hera_calls << "\t" << r.error << "\t" << r.lower_bound << "\t" << r.upper_bound;
return os;
}
template<class K, class V>
void print_map(const std::map<K, V>& dic)
{
for(const auto kv : dic) {
std::cout << kv.first << " -> " << kv.second << "\n";
}
}
} // namespace md
#include "matching_distance.hpp"
|