File: itkGradientDescentOptimizerBasev4.h

package info (click to toggle)
insighttoolkit5 5.4.3-5
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 704,384 kB
  • sloc: cpp: 783,592; ansic: 628,724; xml: 44,704; fortran: 34,250; python: 22,874; sh: 4,078; pascal: 2,636; lisp: 2,158; makefile: 464; yacc: 328; asm: 205; perl: 203; lex: 146; tcl: 132; javascript: 98; csh: 81
file content (225 lines) | stat: -rw-r--r-- 9,080 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
/*=========================================================================
 *
 *  Copyright NumFOCUS
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *         https://www.apache.org/licenses/LICENSE-2.0.txt
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 *
 *=========================================================================*/
#ifndef itkGradientDescentOptimizerBasev4_h
#define itkGradientDescentOptimizerBasev4_h

#include "itkObjectToObjectOptimizerBase.h"
#include "itkWindowConvergenceMonitoringFunction.h"
#include "itkThreadedIndexedContainerPartitioner.h"
#include "itkDomainThreader.h"

namespace itk
{
/**
 * \class GradientDescentOptimizerBasev4
 *  \brief Abstract base class for gradient descent-style optimizers.
 *
 * Gradient modification is threaded in \c ModifyGradient.
 *
 * Derived classes must override \c ModifyGradientByScalesOverSubRange,
 * \c ModifyGradientByLearningRateOverSubRange and \c ResumeOptimization.
 *
 * \ingroup ITKOptimizersv4
 */
template <typename TInternalComputationValueType>
class ITK_TEMPLATE_EXPORT GradientDescentOptimizerBasev4Template
  : public ObjectToObjectOptimizerBaseTemplate<TInternalComputationValueType>
{
public:
  ITK_DISALLOW_COPY_AND_MOVE(GradientDescentOptimizerBasev4Template);

  /** Standard class type aliases. */
  using Self = GradientDescentOptimizerBasev4Template;
  using Superclass = ObjectToObjectOptimizerBaseTemplate<TInternalComputationValueType>;
  using Pointer = SmartPointer<Self>;
  using ConstPointer = SmartPointer<const Self>;

  /** \see LightObject::GetNameOfClass() */
  itkOverrideGetNameOfClassMacro(GradientDescentOptimizerBasev4Template);

#if !defined(ITK_LEGACY_REMOVE)
  /**Exposes enums values for backwards compatibility*/
  static constexpr itk::StopConditionObjectToObjectOptimizerEnum MAXIMUM_NUMBER_OF_ITERATIONS =
    itk::StopConditionObjectToObjectOptimizerEnum::MAXIMUM_NUMBER_OF_ITERATIONS;
  static constexpr itk::StopConditionObjectToObjectOptimizerEnum COSTFUNCTION_ERROR =
    itk::StopConditionObjectToObjectOptimizerEnum::COSTFUNCTION_ERROR;
  static constexpr itk::StopConditionObjectToObjectOptimizerEnum UPDATE_PARAMETERS_ERROR =
    itk::StopConditionObjectToObjectOptimizerEnum::UPDATE_PARAMETERS_ERROR;
  static constexpr itk::StopConditionObjectToObjectOptimizerEnum STEP_TOO_SMALL =
    itk::StopConditionObjectToObjectOptimizerEnum::STEP_TOO_SMALL;
  static constexpr itk::StopConditionObjectToObjectOptimizerEnum CONVERGENCE_CHECKER_PASSED =
    itk::StopConditionObjectToObjectOptimizerEnum::CONVERGENCE_CHECKER_PASSED;
  static constexpr itk::StopConditionObjectToObjectOptimizerEnum GRADIENT_MAGNITUDE_TOLEARANCE =
    itk::StopConditionObjectToObjectOptimizerEnum::GRADIENT_MAGNITUDE_TOLEARANCE;
  static constexpr itk::StopConditionObjectToObjectOptimizerEnum OTHER_ERROR =
    itk::StopConditionObjectToObjectOptimizerEnum::OTHER_ERROR;
#endif

  /** Stop condition return string type */
  using typename Superclass::StopConditionReturnStringType;

  /** Stop condition internal string type */
  using typename Superclass::StopConditionDescriptionType;

  /** It should be possible to derive the internal computation type from the class object. */
  using InternalComputationValueType = TInternalComputationValueType;

  /** Metric type over which this class is templated */
  using typename Superclass::MetricType;
  using MetricTypePointer = typename MetricType::Pointer;

  /** Derivative type */
  using typename Superclass::DerivativeType;

  /** Measure type */
  using typename Superclass::MeasureType;

  using typename Superclass::ScalesType;

  using typename Superclass::ParametersType;

  /** Type for the convergence checker */
  using ConvergenceMonitoringType = itk::Function::WindowConvergenceMonitoringFunction<TInternalComputationValueType>;

  /** Get the most recent gradient values. */
  itkGetConstReferenceMacro(Gradient, DerivativeType);

  /** Get stop condition enum */
  itkGetConstReferenceMacro(StopCondition, StopConditionObjectToObjectOptimizerEnum);

  /** Start and run the optimization */
  void
  StartOptimization(bool doOnlyInitialization = false) override;

  /** Resume optimization.
   * This runs the optimization loop, and allows continuation
   * of stopped optimization */
  virtual void
  ResumeOptimization() = 0;

  /** Stop optimization. The object is left in a state so the
   * optimization can be resumed by calling ResumeOptimization. */
  virtual void
  StopOptimization();

  /** Get the reason for termination */
  const StopConditionReturnStringType
  GetStopConditionDescription() const override;

  /** Modify the gradient in place, to advance the optimization.
   * This call performs a threaded modification for transforms with
   * local support (assumed to be dense). Otherwise the modification
   * is performed w/out threading.
   * See EstimateLearningRate() to perform optionally learning rate
   * estimation.
   * At completion, m_Gradient can be used to update the transform
   * parameters. Derived classes may hold additional results in
   * other member variables.
   *
   * \sa EstimateLearningRate()
   */
  virtual void
  ModifyGradientByScales();
  virtual void
  ModifyGradientByLearningRate();

  using IndexRangeType = ThreadedIndexedContainerPartitioner::IndexRangeType;

  /** Derived classes define this worker method to modify the gradient by scales.
   * Modifications must be performed over the index range defined in
   * \c subrange.
   * Called from ModifyGradientByScales(), either directly or via threaded
   * operation. */
  virtual void
  ModifyGradientByScalesOverSubRange(const IndexRangeType & subrange) = 0;

  /** Derived classes define this worker method to modify the gradient by learning rates.
   * Modifications must be performed over the index range defined in
   * \c subrange.
   * Called from ModifyGradientByLearningRate(), either directly or via threaded
   * operation.
   * This function is used in GradientDescentOptimizerBasev4ModifyGradientByScalesThreaderTemplate
   * and GradientDescentOptimizerBasev4ModifyGradientByLearningRateThreaderTemplate classes.
   */
  virtual void
  ModifyGradientByLearningRateOverSubRange(const IndexRangeType & subrange) = 0;

protected:
  /** Default constructor */
  GradientDescentOptimizerBasev4Template();
  ~GradientDescentOptimizerBasev4Template() override = default;

  /** Flag to control use of the ScalesEstimator (if set) for
   * automatic learning step estimation at *each* iteration.
   */
  bool m_DoEstimateLearningRateAtEachIteration{};

  /** Flag to control use of the ScalesEstimator (if set) for
   * automatic learning step estimation only *once*, during first iteration.
   */
  bool m_DoEstimateLearningRateOnce{};

  /** The maximum step size in physical units, to restrict learning rates.
   * Only used with automatic learning rate estimation.
   * It may be initialized either by calling SetMaximumStepSizeInPhysicalUnits
   * manually or by using m_ScalesEstimator automatically, and the former has
   * higher priority than the latter. See main documentation.
   */
  TInternalComputationValueType m_MaximumStepSizeInPhysicalUnits{};

  /** Flag to control using the convergence monitoring for stop condition.
   *  This flag should be always set to true except for regular step gradient
   *  descent optimizer that uses minimum step length to check the convergence.
   */
  bool m_UseConvergenceMonitoring{};

  /** Window size for the convergence checker.
   *  The convergence checker calculates convergence value by fitting to
   *  a window of the energy (metric value) profile.
   */
  SizeValueType m_ConvergenceWindowSize{};

  /** The convergence checker. */
  typename ConvergenceMonitoringType::Pointer m_ConvergenceMonitoring{};

  typename DomainThreader<ThreadedIndexedContainerPartitioner, Self>::Pointer m_ModifyGradientByScalesThreader{};
  typename DomainThreader<ThreadedIndexedContainerPartitioner, Self>::Pointer m_ModifyGradientByLearningRateThreader{};

  /* Common variables for optimization control and reporting */
  bool                                     m_Stop{ false };
  StopConditionObjectToObjectOptimizerEnum m_StopCondition{};
  StopConditionDescriptionType             m_StopConditionDescription{};

  /** Current gradient */
  DerivativeType m_Gradient{};
  void
  PrintSelf(std::ostream & os, Indent indent) const override;

private:
};

/** This helps to meet backward compatibility */
using GradientDescentOptimizerBasev4 = GradientDescentOptimizerBasev4Template<double>;

} // end namespace itk

#ifndef ITK_MANUAL_INSTANTIATION
#  include "itkGradientDescentOptimizerBasev4.hxx"
#endif

#endif