File: otbRandomForestsMachineLearningModel.h

package info (click to toggle)
otb 7.2.0%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 1,005,476 kB
  • sloc: cpp: 270,143; xml: 128,722; ansic: 4,367; sh: 1,768; python: 1,084; perl: 92; makefile: 72
file content (226 lines) | stat: -rw-r--r-- 9,023 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
/*
 * Copyright (C) 2005-2020 Centre National d'Etudes Spatiales (CNES)
 *
 * This file is part of Orfeo Toolbox
 *
 *     https://www.orfeo-toolbox.org/
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef otbRandomForestsMachineLearningModel_h
#define otbRandomForestsMachineLearningModel_h

#include "otbRequiresOpenCVCheck.h"

#include "itkLightObject.h"
#include "itkFixedArray.h"
#include "otbMachineLearningModel.h"
#include "itkVariableSizeMatrix.h"
#include "otbCvRTreesWrapper.h"

namespace otb
{

template <class TInputValue, class TTargetValue>
class ITK_EXPORT RandomForestsMachineLearningModel : public MachineLearningModel<TInputValue, TTargetValue>
{
public:
  /** Standard class typedefs. */
  typedef RandomForestsMachineLearningModel Self;
  typedef MachineLearningModel<TInputValue, TTargetValue> Superclass;
  typedef itk::SmartPointer<Self>       Pointer;
  typedef itk::SmartPointer<const Self> ConstPointer;

  typedef typename Superclass::InputValueType       InputValueType;
  typedef typename Superclass::InputSampleType      InputSampleType;
  typedef typename Superclass::InputListSampleType  InputListSampleType;
  typedef typename Superclass::TargetValueType      TargetValueType;
  typedef typename Superclass::TargetSampleType     TargetSampleType;
  typedef typename Superclass::TargetListSampleType TargetListSampleType;
  typedef typename Superclass::ConfidenceValueType  ConfidenceValueType;
  typedef typename Superclass::ProbaSampleType      ProbaSampleType;
  // Other
  typedef itk::VariableSizeMatrix<float> VariableImportanceMatrixType;


  // opencv typedef
  typedef CvRTreesWrapper RFType;

  /** Run-time type information (and related methods). */
  itkNewMacro(Self);
  itkTypeMacro(RandomForestsMachineLearningModel, MachineLearningModel);

  /** Train the machine learning model */
  void Train() override;

  /** Save the model to file */
  void Save(const std::string& filename, const std::string& name = "") override;

  /** Load the model from file */
  void Load(const std::string& filename, const std::string& name = "") override;

  /**\name Classification model file compatibility tests */
  //@{
  /** Is the input model file readable and compatible with the corresponding classifier ? */
  bool CanReadFile(const std::string&) override;

  /** Is the input model file writable and compatible with the corresponding classifier ? */
  bool CanWriteFile(const std::string&) override;
  //@}

  // Setters of RT parameters (documentation get from opencv doxygen 2.4)
  itkGetMacro(MaxDepth, int);
  itkSetMacro(MaxDepth, int);

  itkGetMacro(MinSampleCount, int);
  itkSetMacro(MinSampleCount, int);

  itkGetMacro(RegressionAccuracy, double);
  itkSetMacro(RegressionAccuracy, double);

  itkGetMacro(ComputeSurrogateSplit, bool);
  itkSetMacro(ComputeSurrogateSplit, bool);

  itkGetMacro(MaxNumberOfCategories, int);
  itkSetMacro(MaxNumberOfCategories, int);

  std::vector<float> GetPriors() const
  {
    return m_Priors;
  }

  void SetPriors(const std::vector<float>& priors)
  {
    m_Priors = priors;
  }

  itkGetMacro(CalculateVariableImportance, bool);
  itkSetMacro(CalculateVariableImportance, bool);

  itkGetMacro(MaxNumberOfVariables, int);
  itkSetMacro(MaxNumberOfVariables, int);

  itkGetMacro(MaxNumberOfTrees, int);
  itkSetMacro(MaxNumberOfTrees, int);

  itkGetMacro(ForestAccuracy, float);
  itkSetMacro(ForestAccuracy, float);

  itkGetMacro(TerminationCriteria, int);
  itkSetMacro(TerminationCriteria, int);

  itkGetMacro(ComputeMargin, bool);
  itkSetMacro(ComputeMargin, bool);

  /** Returns a matrix containing variable importance */
  VariableImportanceMatrixType GetVariableImportance();

  float GetTrainError();

protected:
  /** Constructor */
  RandomForestsMachineLearningModel();

  /** Destructor */
  ~RandomForestsMachineLearningModel() override = default;

  /** Predict values using the model */
  TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType* quality = nullptr, ProbaSampleType* proba = nullptr) const override;

  /** PrintSelf method */
  void PrintSelf(std::ostream& os, itk::Indent indent) const override;

  /* /\** Input list sample *\/ */
  /* typename InputListSampleType::Pointer m_InputListSample; */

  /* /\** Target list sample *\/ */
  /* typename TargetListSampleType::Pointer m_TargetListSample; */

private:
  RandomForestsMachineLearningModel(const Self&) = delete;
  void operator=(const Self&) = delete;

  cv::Ptr<CvRTreesWrapper> m_RFModel;

  /** The depth of the tree. A low value will likely underfit and conversely a
   * high value will likely overfit. The optimal value can be obtained using cross
   * validation or other suitable methods. */
  int m_MaxDepth;
  /** minimum samples required at a leaf node for it to be split. A reasonable
   * value is a small percentage of the total data e.g. 1%. */
  int m_MinSampleCount;
  /** Termination criteria for regression trees. If all absolute differences
   * between an estimated value in a node and values of train samples in this node
   * are less than this parameter then the node will not be split */
  float m_RegressionAccuracy;
  bool  m_ComputeSurrogateSplit;
  /** Cluster possible values of a categorical variable into
   * \f$ K \leq MaxCategories \f$
   * clusters to find a suboptimal split. If a discrete variable,
   * on which the training procedure tries to make a split, takes more than
   * max_categories values, the precise best subset estimation may take a very
   * long time because the algorithm is exponential. Instead, many decision
   * trees engines (including ML) try to find sub-optimal split in this case by
   * clustering all the samples into max categories clusters that is some
   * categories are merged together. The clustering is applied only in n>2-class
   * classification problems for categorical variables with N > max_categories
   * possible values. In case of regression and 2-class classification the
   * optimal split can be found efficiently without employing clustering, thus
   * the parameter is not used in these cases.
   */
  int m_MaxNumberOfCategories;
  /** The array of a priori class probabilities, sorted by the class label
   * value. The parameter can be used to tune the decision tree preferences toward
   * a certain class. For example, if you want to detect some rare anomaly
   * occurrence, the training base will likely contain much more normal cases than
   * anomalies, so a very good classification performance will be achieved just by
   * considering every case as normal. To avoid this, the priors can be specified,
   * where the anomaly probability is artificially increased (up to 0.5 or even
   * greater), so the weight of the misclassified anomalies becomes much bigger,
   * and the tree is adjusted properly. You can also think about this parameter as
   * weights of prediction categories which determine relative weights that you
   * give to misclassification. That is, if the weight of the first category is 1
   * and the weight of the second category is 10, then each mistake in predicting
   * the second category is equivalent to making 10 mistakes in predicting the
   * first category. */
  std::vector<float> m_Priors;
  /** If true then variable importance will be calculated and then it can be
   * retrieved by CvRTreesWrapper::get_var_importance(). */
  bool m_CalculateVariableImportance;
  /** The size of the randomly selected subset of features at each tree node and
   * that are used to find the best split(s). If you set it to 0 then the size will
   * be set to the square root of the total number of features. */
  int m_MaxNumberOfVariables;
  /** The maximum number of trees in the forest (surprise, surprise). Typically
   * the more trees you have the better the accuracy. However, the improvement in
   * accuracy generally diminishes and asymptotes pass a certain number of
   * trees. Also to keep in mind, the number of tree increases the prediction time
   *linearly. */
  int m_MaxNumberOfTrees;
  /** Sufficient accuracy (OOB error) */
  float m_ForestAccuracy;
  /** The type of the termination criteria */
  int m_TerminationCriteria;
  /** Whether to compute margin (difference in probability between the
   * 2 most voted classes) instead of confidence (probability of the most
   * voted class) in prediction*/
  bool m_ComputeMargin;
};
} // end namespace otb

#ifndef OTB_MANUAL_INSTANTIATION
#include "otbRandomForestsMachineLearningModel.hxx"
#endif

#endif