1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
|
// Copyright (C) 2016-2019 Yixuan Qiu <yixuan.qiu@cos.name>
// Under MIT license
#ifndef LINE_SEARCH_BACKTRACKING_H
#define LINE_SEARCH_BACKTRACKING_H
#include <Eigen/Core>
#include <stdexcept> // std::runtime_error
namespace LBFGSpp {
///
/// The backtracking line search algorithm for LBFGS. Mainly for internal use.
///
template <typename Scalar>
class LineSearchBacktracking
{
private:
typedef Eigen::Matrix<Scalar, Eigen::Dynamic, 1> Vector;
public:
///
/// Line search by backtracking.
///
/// \param f A function object such that `f(x, grad)` returns the
/// objective function value at `x`, and overwrites `grad` with
/// the gradient.
/// \param fx In: The objective function value at the current point.
/// Out: The function value at the new point.
/// \param x Out: The new point moved to.
/// \param grad In: The current gradient vector. Out: The gradient at the
/// new point.
/// \param step In: The initial step length. Out: The calculated step length.
/// \param drt The current moving direction.
/// \param xp The current point.
/// \param param Parameters for the LBFGS algorithm
///
template <typename Foo>
static void LineSearch(Foo& f, Scalar& fx, Vector& x, Vector& grad,
Scalar& step,
const Vector& drt, const Vector& xp,
const LBFGSParam<Scalar>& param)
{
// Decreasing and increasing factors
const Scalar dec = 0.5;
const Scalar inc = 2.1;
// Check the value of step
if(step <= Scalar(0))
std::invalid_argument("'step' must be positive");
// Save the function value at the current x
const Scalar fx_init = fx;
// Projection of gradient on the search direction
const Scalar dg_init = grad.dot(drt);
// Make sure d points to a descent direction
if(dg_init > 0)
std::logic_error("the moving direction increases the objective function value");
const Scalar dg_test = param.ftol * dg_init;
Scalar width;
int iter;
for(iter = 0; iter < param.max_linesearch; iter++)
{
// x_{k+1} = x_k + step * d_k
x.noalias() = xp + step * drt;
// Evaluate this candidate
fx = f(x, grad);
if(fx > fx_init + step * dg_test)
{
width = dec;
} else {
// Armijo condition is met
if(param.linesearch == LBFGS_LINESEARCH_BACKTRACKING_ARMIJO)
break;
const Scalar dg = grad.dot(drt);
if(dg < param.wolfe * dg_init)
{
width = inc;
} else {
// Regular Wolfe condition is met
if(param.linesearch == LBFGS_LINESEARCH_BACKTRACKING_WOLFE)
break;
if(dg > -param.wolfe * dg_init)
{
width = dec;
} else {
// Strong Wolfe condition is met
break;
}
}
}
if(iter >= param.max_linesearch)
throw std::runtime_error("the line search routine reached the maximum number of iterations");
if(step < param.min_step)
throw std::runtime_error("the line search step became smaller than the minimum value allowed");
if(step > param.max_step)
throw std::runtime_error("the line search step became larger than the maximum value allowed");
step *= width;
}
}
};
} // namespace LBFGSpp
#endif // LINE_SEARCH_BACKTRACKING_H
|