1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
|
// Copyright (c) 2021, Viktor Larsson
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
//
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
//
// * Neither the name of the copyright holder nor the
// names of its contributors may be used to endorse or promote products
// derived from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#ifndef POSELIB_ROBUST_LOSS_H_
#define POSELIB_ROBUST_LOSS_H_
#include <algorithm>
#include <cmath>
#include <limits>
namespace poselib {
// Robust loss functions
class TrivialLoss {
public:
TrivialLoss(double) {} // dummy to ensure we have consistent calling interface
TrivialLoss() {}
double loss(double r2) const { return r2; }
double weight(double r2) const { return 1.0; }
};
class TruncatedLoss {
public:
TruncatedLoss(double threshold) : squared_thr(threshold * threshold) {}
double loss(double r2) const { return std::min(r2, squared_thr); }
double weight(double r2) const { return (r2 < squared_thr) ? 1.0 : 0.0; }
private:
const double squared_thr;
};
// The method from
// Le and Zach, Robust Fitting with Truncated Least Squares: A Bilevel Optimization Approach, 3DV 2021
// for truncated least squares optimization with IRLS.
class TruncatedLossLeZach {
public:
TruncatedLossLeZach(double threshold) : squared_thr(threshold * threshold), mu(0.5) {}
double loss(double r2) const { return std::min(r2, squared_thr); }
double weight(double r2) const {
double r2_hat = r2 / squared_thr;
double zstar = std::min(r2_hat, 1.0);
if (r2_hat < 1.0) {
return 0.5;
} else {
// assumes mu > 0.5
double r2m1 = r2_hat - 1.0;
double rho = (2.0 * r2m1 + std::sqrt(4.0 * r2m1 * r2m1 * mu * mu + 2 * mu * r2m1)) / mu;
double a = (r2_hat + mu * rho * zstar - 0.5 * rho) / (1 + mu * rho);
double zbar = std::max(0.0, std::min(a, 1.0));
return (zstar - zbar) / rho;
}
}
private:
const double squared_thr;
public:
// hyper-parameter for penalty strength
double mu;
// schedule for increasing mu in each iteration
static constexpr double alpha = 1.5;
};
class HuberLoss {
public:
HuberLoss(double threshold) : thr(threshold) {}
double loss(double r2) const {
const double r = std::sqrt(r2);
if (r <= thr) {
return r2;
} else {
return thr * (2.0 * r - thr);
}
}
double weight(double r2) const {
const double r = std::sqrt(r2);
if (r <= thr) {
return 1.0;
} else {
return thr / r;
}
}
private:
const double thr;
};
class CauchyLoss {
public:
CauchyLoss(double threshold) : sq_thr(threshold * threshold), inv_sq_thr(1.0 / sq_thr) {}
double loss(double r2) const { return sq_thr * std::log1p(r2 * inv_sq_thr); }
double weight(double r2) const {
return std::max(std::numeric_limits<double>::min(), 1.0 / (1.0 + r2 * inv_sq_thr));
}
private:
const double sq_thr;
const double inv_sq_thr;
};
} // namespace poselib
#endif
|