File: robust_loss.h

package info (click to toggle)
poselib 2.0.5-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,592 kB
  • sloc: cpp: 15,023; python: 182; sh: 85; makefile: 16
file content (128 lines) | stat: -rw-r--r-- 4,397 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
// Copyright (c) 2021, Viktor Larsson
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
//     * Redistributions of source code must retain the above copyright
//       notice, this list of conditions and the following disclaimer.
//
//     * Redistributions in binary form must reproduce the above copyright
//       notice, this list of conditions and the following disclaimer in the
//       documentation and/or other materials provided with the distribution.
//
//     * Neither the name of the copyright holder nor the
//       names of its contributors may be used to endorse or promote products
//       derived from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#ifndef POSELIB_ROBUST_LOSS_H_
#define POSELIB_ROBUST_LOSS_H_

#include <algorithm>
#include <cmath>
#include <limits>

namespace poselib {

// Robust loss functions
class TrivialLoss {
  public:
    TrivialLoss(double) {} // dummy to ensure we have consistent calling interface
    TrivialLoss() {}
    double loss(double r2) const { return r2; }
    double weight(double r2) const { return 1.0; }
};

class TruncatedLoss {
  public:
    TruncatedLoss(double threshold) : squared_thr(threshold * threshold) {}
    double loss(double r2) const { return std::min(r2, squared_thr); }
    double weight(double r2) const { return (r2 < squared_thr) ? 1.0 : 0.0; }

  private:
    const double squared_thr;
};

// The method from
//  Le and Zach, Robust Fitting with Truncated Least Squares: A Bilevel Optimization Approach, 3DV 2021
// for truncated least squares optimization with IRLS.
class TruncatedLossLeZach {
  public:
    TruncatedLossLeZach(double threshold) : squared_thr(threshold * threshold), mu(0.5) {}
    double loss(double r2) const { return std::min(r2, squared_thr); }
    double weight(double r2) const {
        double r2_hat = r2 / squared_thr;
        double zstar = std::min(r2_hat, 1.0);

        if (r2_hat < 1.0) {
            return 0.5;
        } else {
            // assumes mu > 0.5
            double r2m1 = r2_hat - 1.0;
            double rho = (2.0 * r2m1 + std::sqrt(4.0 * r2m1 * r2m1 * mu * mu + 2 * mu * r2m1)) / mu;
            double a = (r2_hat + mu * rho * zstar - 0.5 * rho) / (1 + mu * rho);
            double zbar = std::max(0.0, std::min(a, 1.0));
            return (zstar - zbar) / rho;
        }
    }

  private:
    const double squared_thr;

  public:
    // hyper-parameter for penalty strength
    double mu;
    // schedule for increasing mu in each iteration
    static constexpr double alpha = 1.5;
};

class HuberLoss {
  public:
    HuberLoss(double threshold) : thr(threshold) {}
    double loss(double r2) const {
        const double r = std::sqrt(r2);
        if (r <= thr) {
            return r2;
        } else {
            return thr * (2.0 * r - thr);
        }
    }
    double weight(double r2) const {
        const double r = std::sqrt(r2);
        if (r <= thr) {
            return 1.0;
        } else {
            return thr / r;
        }
    }

  private:
    const double thr;
};
class CauchyLoss {
  public:
    CauchyLoss(double threshold) : sq_thr(threshold * threshold), inv_sq_thr(1.0 / sq_thr) {}
    double loss(double r2) const { return sq_thr * std::log1p(r2 * inv_sq_thr); }
    double weight(double r2) const {
        return std::max(std::numeric_limits<double>::min(), 1.0 / (1.0 + r2 * inv_sq_thr));
    }

  private:
    const double sq_thr;
    const double inv_sq_thr;
};

} // namespace poselib

#endif