File: check_gradient.h

package info (click to toggle)
mrtrix3 3.0.4-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 13,712 kB
  • sloc: cpp: 129,776; python: 9,494; sh: 593; makefile: 234; xml: 47
file content (110 lines) | stat: -rw-r--r-- 4,203 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
/* Copyright (c) 2008-2022 the MRtrix3 contributors.
 *
 * This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, You can obtain one at http://mozilla.org/MPL/2.0/.
 *
 * Covered Software is provided under this License on an "as is"
 * basis, without warranty of any kind, either expressed, implied, or
 * statutory, including, without limitation, warranties that the
 * Covered Software is free of defects, merchantable, fit for a
 * particular purpose or non-infringing.
 * See the Mozilla Public License v. 2.0 for more details.
 *
 * For more details, see http://www.mrtrix.org/.
 */

#ifndef __math_check_gradient_h__
#define __math_check_gradient_h__

#include "debug.h"
#include "datatype.h"

namespace MR {
  namespace Math {

    template <class Function>
      Eigen::Matrix<typename Function::value_type, Eigen::Dynamic, Eigen::Dynamic> check_function_gradient (
          Function& function,
          Eigen::Matrix<typename Function::value_type, Eigen::Dynamic, 1> x,
          typename Function::value_type increment,
          bool show_hessian = false,
          Eigen::Matrix<typename Function::value_type, Eigen::Dynamic, 1> conditioner = Eigen::Matrix<typename Function::value_type, Eigen::Dynamic, 1>())
      {
        using value_type = typename Function::value_type;
        const size_t N = function.size();
        Eigen::Matrix<value_type, Eigen::Dynamic, 1> g (N);

        CONSOLE ("checking gradient for cost function over " + str(N) +
            " parameters of type " + DataType::from<value_type>().specifier());
        value_type step_size = function.init (g);
        CONSOLE ("cost function suggests initial step size = " + str(step_size));
        CONSOLE ("cost function suggests initial position at [ " + str(g.transpose()) + "]");

        CONSOLE ("checking gradient at position [ " + str(x.transpose()) + "]:");
        Eigen::Matrix<value_type, Eigen::Dynamic, 1> g0 (N);
        value_type f0 = function (x, g0);
        CONSOLE ("  cost function = " + str(f0));
        CONSOLE ("  gradient from cost function         = [ " + str(g0.transpose()) + "]");

        Eigen::Matrix<value_type, Eigen::Dynamic, 1> g_fd (N);
        Eigen::Matrix<value_type, Eigen::Dynamic, Eigen::Dynamic> hessian;
        if (show_hessian) {
          hessian.resize(N, N);
          if (conditioner.size()){
            assert (conditioner.size() == (ssize_t) N && "conditioner size must equal number of parameters");
            for (size_t n = 0; n < N; ++n)
              conditioner[n] = std::sqrt(conditioner[n]);
          }
        }

        for (size_t n = 0; n < N; ++n) {
          value_type old_x = x[n];
          value_type inc = increment;
          if (conditioner.size()){
            assert (conditioner.size() == (ssize_t) N && "conditioner size must equal number of parameters");
            inc *= conditioner[n];
          }

          x[n] += inc;
          value_type f1 = function (x, g);
          if (show_hessian) {
            if (conditioner.size())
              g.cwiseProduct(conditioner);
            hessian.col(n) = g;
          }

          x[n] = old_x - inc;
          value_type f2 = function (x, g);
          g_fd[n] = (f1-f2) / (2.0*inc);
          x[n] = old_x;
          if (show_hessian) {
            if (conditioner.size())
              g.cwiseProduct(conditioner);
            hessian.col(n) -= g;
          }

        }

        CONSOLE ("gradient by central finite difference = [ " + str(g_fd.transpose()) + "]");
        CONSOLE ("normalised dot product = " + str(g_fd.dot(g0) / g_fd.squaredNorm()));

        if (show_hessian) {
          hessian /= 4.0*increment;
          for (size_t j = 0; j < N; ++j) {
            size_t i;
            for (i = 0; i < j; ++i)
              hessian(i,j) = hessian(j,i);
            for (; i < N; ++i)
              hessian(i,j) += hessian(j,i);
          }
          // CONSOLE ("hessian = [ " + str(hessian) + "]");
          MAT(hessian);
          CONSOLE("\033[00;34mcondition number: " + str(condition_number (hessian))+"\033[0m");
        }
      return hessian;
      }
  }
}

#endif