1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
|
/* Copyright (c) 2008-2022 the MRtrix3 contributors.
*
* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/.
*
* Covered Software is provided under this License on an "as is"
* basis, without warranty of any kind, either expressed, implied, or
* statutory, including, without limitation, warranties that the
* Covered Software is free of defects, merchantable, fit for a
* particular purpose or non-infringing.
* See the Mozilla Public License v. 2.0 for more details.
*
* For more details, see http://www.mrtrix.org/.
*/
#ifndef __math_check_gradient_h__
#define __math_check_gradient_h__
#include "debug.h"
#include "datatype.h"
namespace MR {
namespace Math {
template <class Function>
Eigen::Matrix<typename Function::value_type, Eigen::Dynamic, Eigen::Dynamic> check_function_gradient (
Function& function,
Eigen::Matrix<typename Function::value_type, Eigen::Dynamic, 1> x,
typename Function::value_type increment,
bool show_hessian = false,
Eigen::Matrix<typename Function::value_type, Eigen::Dynamic, 1> conditioner = Eigen::Matrix<typename Function::value_type, Eigen::Dynamic, 1>())
{
using value_type = typename Function::value_type;
const size_t N = function.size();
Eigen::Matrix<value_type, Eigen::Dynamic, 1> g (N);
CONSOLE ("checking gradient for cost function over " + str(N) +
" parameters of type " + DataType::from<value_type>().specifier());
value_type step_size = function.init (g);
CONSOLE ("cost function suggests initial step size = " + str(step_size));
CONSOLE ("cost function suggests initial position at [ " + str(g.transpose()) + "]");
CONSOLE ("checking gradient at position [ " + str(x.transpose()) + "]:");
Eigen::Matrix<value_type, Eigen::Dynamic, 1> g0 (N);
value_type f0 = function (x, g0);
CONSOLE (" cost function = " + str(f0));
CONSOLE (" gradient from cost function = [ " + str(g0.transpose()) + "]");
Eigen::Matrix<value_type, Eigen::Dynamic, 1> g_fd (N);
Eigen::Matrix<value_type, Eigen::Dynamic, Eigen::Dynamic> hessian;
if (show_hessian) {
hessian.resize(N, N);
if (conditioner.size()){
assert (conditioner.size() == (ssize_t) N && "conditioner size must equal number of parameters");
for (size_t n = 0; n < N; ++n)
conditioner[n] = std::sqrt(conditioner[n]);
}
}
for (size_t n = 0; n < N; ++n) {
value_type old_x = x[n];
value_type inc = increment;
if (conditioner.size()){
assert (conditioner.size() == (ssize_t) N && "conditioner size must equal number of parameters");
inc *= conditioner[n];
}
x[n] += inc;
value_type f1 = function (x, g);
if (show_hessian) {
if (conditioner.size())
g.cwiseProduct(conditioner);
hessian.col(n) = g;
}
x[n] = old_x - inc;
value_type f2 = function (x, g);
g_fd[n] = (f1-f2) / (2.0*inc);
x[n] = old_x;
if (show_hessian) {
if (conditioner.size())
g.cwiseProduct(conditioner);
hessian.col(n) -= g;
}
}
CONSOLE ("gradient by central finite difference = [ " + str(g_fd.transpose()) + "]");
CONSOLE ("normalised dot product = " + str(g_fd.dot(g0) / g_fd.squaredNorm()));
if (show_hessian) {
hessian /= 4.0*increment;
for (size_t j = 0; j < N; ++j) {
size_t i;
for (i = 0; i < j; ++i)
hessian(i,j) = hessian(j,i);
for (; i < N; ++i)
hessian(i,j) += hessian(j,i);
}
// CONSOLE ("hessian = [ " + str(hessian) + "]");
MAT(hessian);
CONSOLE("\033[00;34mcondition number: " + str(condition_number (hessian))+"\033[0m");
}
return hessian;
}
}
}
#endif
|