File: itkStandardGradientDescentOptimizer.h

package info (click to toggle)
elastix 5.2.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 42,480 kB
  • sloc: cpp: 68,403; lisp: 4,118; python: 1,013; xml: 182; sh: 177; makefile: 33
file content (169 lines) | stat: -rw-r--r-- 5,906 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
/*=========================================================================
 *
 *  Copyright UMC Utrecht and contributors
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *        http://www.apache.org/licenses/LICENSE-2.0.txt
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 *
 *=========================================================================*/

#ifndef itkStandardGradientDescentOptimizer_h
#define itkStandardGradientDescentOptimizer_h

#include "itkGradientDescentOptimizer2.h"

namespace itk
{

/**
 * \class StandardGradientDescentOptimizer
 * \brief This class implements a gradient descent optimizer with a decaying gain.
 *
 * If \f$C(x)\f$ is a costfunction that has to be minimised, the following iterative
 * algorithm is used to find the optimal parameters \f$x\f$:
 *
 *   \f[ x(k+1) = x(k) - a(k) dC/dx \f]
 *
 * The gain \f$a(k)\f$ at each iteration \f$k\f$ is defined by:
 *
 *   \f[ a(k) =  a / (A + k + 1)^\alpha \f].
 *
 * It is very suitable to be used in combination with a stochastic estimate
 * of the gradient \f$dC/dx\f$. For example, in image registration problems it is
 * often advantageous to compute the metric derivative (\f$dC/dx\f$) on a new set
 * of randomly selected image samples in each iteration. You may set the parameter
 * \c NewSamplesEveryIteration to \c "true" to achieve this effect.
 * For more information on this strategy, you may have a look at:
 *
 * S. Klein, M. Staring, J.P.W. Pluim,
 * "Comparison of gradient approximation techniques for optimisation of mutual information in nonrigid registration",
 * in: SPIE Medical Imaging: Image Processing,
 * Editor(s): J.M. Fitzpatrick, J.M. Reinhardt, SPIE press, 2005, vol. 5747, Proceedings of SPIE, pp. 192-203.
 *
 * Or:
 *
 * S. Klein, M. Staring, J.P.W. Pluim,
 * "Evaluation of Optimization Methods for Nonrigid Medical Image Registration using Mutual Information and B-Splines"
 * IEEE Transactions on Image Processing, 2007, nr. 16(12), December.
 *
 * This class also serves as a base class for other GradientDescent type
 * algorithms, like the AcceleratedGradientDescentOptimizer.
 *
 * \sa StandardGradientDescent, AcceleratedGradientDescentOptimizer
 * \ingroup Optimizers
 */

class StandardGradientDescentOptimizer : public GradientDescentOptimizer2
{
public:
  ITK_DISALLOW_COPY_AND_MOVE(StandardGradientDescentOptimizer);

  /** Standard ITK.*/
  using Self = StandardGradientDescentOptimizer;
  using Superclass = GradientDescentOptimizer2;

  using Pointer = SmartPointer<Self>;
  using ConstPointer = SmartPointer<const Self>;

  /** Method for creation through the object factory. */
  itkNewMacro(Self);

  /** Run-time type information (and related methods). */
  itkTypeMacro(StandardGradientDescentOptimizer, GradientDescentOptimizer2);

  /** Typedefs inherited from the superclass. */
  using Superclass::MeasureType;
  using Superclass::ParametersType;
  using Superclass::DerivativeType;
  using Superclass::CostFunctionType;
  using Superclass::ScalesType;
  using Superclass::ScaledCostFunctionType;
  using Superclass::ScaledCostFunctionPointer;
  using Superclass::StopConditionType;

  /** Set/Get a. */
  itkSetMacro(Param_a, double);
  itkGetConstMacro(Param_a, double);

  /** Set/Get A. */
  itkSetMacro(Param_A, double);
  itkGetConstMacro(Param_A, double);

  /** Set/Get alpha. */
  itkSetMacro(Param_alpha, double);
  itkGetConstMacro(Param_alpha, double);

  /** Sets a new LearningRate before calling the Superclass'
   * implementation, and updates the current time. */
  void
  AdvanceOneStep() override;

  /** Set current time to 0 and call superclass' implementation. */
  void
  StartOptimization() override;

  /** Set/Get the initial time. Should be >=0. This function is
   * superfluous, since Param_A does effectively the same.
   * However, in inheriting classes, like the AcceleratedGradientDescent
   * the initial time may have a different function than Param_A.
   * Default: 0.0 */
  itkSetMacro(InitialTime, double);
  itkGetConstMacro(InitialTime, double);

  /** Get the current time. This equals the CurrentIteration in this base class
   * but may be different in inheriting classes, such as the AccelerateGradientDescent */
  itkGetConstMacro(CurrentTime, double);

  /** Set the current time to the initial time. This can be useful
   * to 'reset' the optimisation, for example if you changed the
   * cost function while optimisation. Be careful with this function. */
  virtual void
  ResetCurrentTimeToInitialTime()
  {
    this->m_CurrentTime = this->m_InitialTime;
  }


protected:
  StandardGradientDescentOptimizer();
  ~StandardGradientDescentOptimizer() override = default;

  /** Function to compute the parameter at time/iteration k. */
  virtual double
  Compute_a(double k) const;

  /** Function to update the current time
   * This function just increments the CurrentTime by 1.
   * Inheriting functions may implement something smarter,
   * for example, dependent on the progress */
  virtual void
  UpdateCurrentTime();

  /** The current time, which serves as input for Compute_a */
  double m_CurrentTime{ 0.0 };

  /** Constant step size or others, different value of k. */
  bool m_UseConstantStep{ false };

private:
  /**Parameters, as described by Spall.*/
  double m_Param_a{ 1.0 };
  double m_Param_A{ 1.0 };
  double m_Param_alpha{ 0.602 };

  /** Settings */
  double m_InitialTime{ 0.0 };
};

} // end namespace itk

#endif // end #ifndef itkStandardGradientDescentOptimizer_h