File: ObjectiveMetricUtil.cpp

package info (click to toggle)
bornagain 23.0-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 103,936 kB
  • sloc: cpp: 423,131; python: 40,997; javascript: 11,167; awk: 630; sh: 318; ruby: 173; xml: 130; makefile: 51; ansic: 24
file content (119 lines) | stat: -rw-r--r-- 4,035 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
//  ************************************************************************************************
//
//  BornAgain: simulate and fit reflection and scattering
//
//! @file      Sim/Fitting/ObjectiveMetricUtil.cpp
//! @brief     Implements ObjectiveMetric utilities.
//!
//! @homepage  http://www.bornagainproject.org
//! @license   GNU General Public License v3 or higher (see COPYING)
//! @copyright Forschungszentrum Jülich GmbH 2018
//! @authors   Scientific Computing Group at MLZ (see CITATION, AUTHORS)
//
//  ************************************************************************************************

#include "Sim/Fitting/ObjectiveMetricUtil.h"
#include "Sim/Fitting/ObjectiveMetric.h"
#include <algorithm>
#include <cmath>
#include <map>
#include <sstream>

namespace {

const std::function<double(double)> l1_norm = [](double term) { return std::abs(term); };
const std::function<double(double)> l2_norm = [](double term) { return term * term; };

const std::map<std::string, std::function<std::unique_ptr<ObjectiveMetric>()>> metric_factory = {
    {"chi2", [] { return std::make_unique<Chi2Metric>(); }},
    {"poisson-like", [] { return std::make_unique<PoissonLikeMetric>(); }},
    {"log", [] { return std::make_unique<LogMetric>(); }},
    {"reldiff", [] { return std::make_unique<meanRelativeDifferenceMetric>(); }}};
// TODO restore rq4:
// https://jugit.fz-juelich.de/mlz/bornagain/-/issues/568
// {"rq4", [] { return std::make_unique<RQ4Metric>(); }}
const std::string default_metric_name = "log";

const std::map<std::string, std::function<double(double)>> norm_factory = {{"l1", l1_norm},
                                                                           {"l2", l2_norm}};
const std::string default_norm_name = "l2";

template <typename U> std::vector<std::string> keys(const std::map<std::string, U>& map)
{
    std::vector<std::string> result;
    result.reserve(map.size());
    for (auto& item : map)
        result.push_back(item.first);
    return result;
}

} // namespace

std::function<double(double)> ObjectiveMetricUtil::l1Norm()
{
    return l1_norm;
}

std::function<double(double)> ObjectiveMetricUtil::l2Norm()
{
    return l2_norm;
}

std::unique_ptr<ObjectiveMetric> ObjectiveMetricUtil::createMetric(const std::string& metric)
{
    return createMetric(metric, defaultNormName());
}

std::unique_ptr<ObjectiveMetric> ObjectiveMetricUtil::createMetric(std::string metric,
                                                                   std::string norm)
{
    std::transform(metric.begin(), metric.end(), metric.begin(), ::tolower);
    std::transform(norm.begin(), norm.end(), norm.begin(), ::tolower);
    const auto metric_iter = metric_factory.find(metric);
    const auto norm_iter = norm_factory.find(norm);
    if (metric_iter == metric_factory.end() || norm_iter == norm_factory.end()) {
        std::stringstream ss;
        ss << "Error in ObjectiveMetricUtil::createMetric: either metric (" << metric
           << ") or norm (" << norm << ") name is unknown.\n";
        ss << availableMetricOptions();
        throw std::runtime_error(ss.str());
    }

    auto result = metric_iter->second();
    result->setNorm(norm_iter->second);
    return result;
}

std::string ObjectiveMetricUtil::availableMetricOptions()
{
    std::stringstream ss;
    ss << "Available metrics:\n";
    for (auto& item : metricNames())
        ss << "\t" << item << "\n";
    ss << "default metric: " << defaultMetricName() << "\n";
    ss << "Available norms:\n";
    for (auto& item : normNames())
        ss << "\t" << item << "\n";
    ss << "default norm: " << defaultNormName() << "\n";
    return ss.str();
}

std::vector<std::string> ObjectiveMetricUtil::normNames()
{
    return keys(norm_factory);
}

std::vector<std::string> ObjectiveMetricUtil::metricNames()
{
    return keys(metric_factory);
}

std::string ObjectiveMetricUtil::defaultNormName()
{
    return default_norm_name;
}

std::string ObjectiveMetricUtil::defaultMetricName()
{
    return default_metric_name;
}