1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
|
/*
* Copyright (C) 2005-2020 Centre National d'Etudes Spatiales (CNES)
*
* This file is part of Orfeo Toolbox
*
* https://www.orfeo-toolbox.org/
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef otbRandomForestsMachineLearningModel_h
#define otbRandomForestsMachineLearningModel_h
#include "otbRequiresOpenCVCheck.h"
#include "itkLightObject.h"
#include "itkFixedArray.h"
#include "otbMachineLearningModel.h"
#include "itkVariableSizeMatrix.h"
#include "otbCvRTreesWrapper.h"
namespace otb
{
template <class TInputValue, class TTargetValue>
class ITK_EXPORT RandomForestsMachineLearningModel : public MachineLearningModel<TInputValue, TTargetValue>
{
public:
/** Standard class typedefs. */
typedef RandomForestsMachineLearningModel Self;
typedef MachineLearningModel<TInputValue, TTargetValue> Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
typedef typename Superclass::InputValueType InputValueType;
typedef typename Superclass::InputSampleType InputSampleType;
typedef typename Superclass::InputListSampleType InputListSampleType;
typedef typename Superclass::TargetValueType TargetValueType;
typedef typename Superclass::TargetSampleType TargetSampleType;
typedef typename Superclass::TargetListSampleType TargetListSampleType;
typedef typename Superclass::ConfidenceValueType ConfidenceValueType;
typedef typename Superclass::ProbaSampleType ProbaSampleType;
// Other
typedef itk::VariableSizeMatrix<float> VariableImportanceMatrixType;
// opencv typedef
typedef CvRTreesWrapper RFType;
/** Run-time type information (and related methods). */
itkNewMacro(Self);
itkTypeMacro(RandomForestsMachineLearningModel, MachineLearningModel);
/** Train the machine learning model */
void Train() override;
/** Save the model to file */
void Save(const std::string& filename, const std::string& name = "") override;
/** Load the model from file */
void Load(const std::string& filename, const std::string& name = "") override;
/**\name Classification model file compatibility tests */
//@{
/** Is the input model file readable and compatible with the corresponding classifier ? */
bool CanReadFile(const std::string&) override;
/** Is the input model file writable and compatible with the corresponding classifier ? */
bool CanWriteFile(const std::string&) override;
//@}
// Setters of RT parameters (documentation get from opencv doxygen 2.4)
itkGetMacro(MaxDepth, int);
itkSetMacro(MaxDepth, int);
itkGetMacro(MinSampleCount, int);
itkSetMacro(MinSampleCount, int);
itkGetMacro(RegressionAccuracy, double);
itkSetMacro(RegressionAccuracy, double);
itkGetMacro(ComputeSurrogateSplit, bool);
itkSetMacro(ComputeSurrogateSplit, bool);
itkGetMacro(MaxNumberOfCategories, int);
itkSetMacro(MaxNumberOfCategories, int);
std::vector<float> GetPriors() const
{
return m_Priors;
}
void SetPriors(const std::vector<float>& priors)
{
m_Priors = priors;
}
itkGetMacro(CalculateVariableImportance, bool);
itkSetMacro(CalculateVariableImportance, bool);
itkGetMacro(MaxNumberOfVariables, int);
itkSetMacro(MaxNumberOfVariables, int);
itkGetMacro(MaxNumberOfTrees, int);
itkSetMacro(MaxNumberOfTrees, int);
itkGetMacro(ForestAccuracy, float);
itkSetMacro(ForestAccuracy, float);
itkGetMacro(TerminationCriteria, int);
itkSetMacro(TerminationCriteria, int);
itkGetMacro(ComputeMargin, bool);
itkSetMacro(ComputeMargin, bool);
/** Returns a matrix containing variable importance */
VariableImportanceMatrixType GetVariableImportance();
float GetTrainError();
protected:
/** Constructor */
RandomForestsMachineLearningModel();
/** Destructor */
~RandomForestsMachineLearningModel() override = default;
/** Predict values using the model */
TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType* quality = nullptr, ProbaSampleType* proba = nullptr) const override;
/** PrintSelf method */
void PrintSelf(std::ostream& os, itk::Indent indent) const override;
/* /\** Input list sample *\/ */
/* typename InputListSampleType::Pointer m_InputListSample; */
/* /\** Target list sample *\/ */
/* typename TargetListSampleType::Pointer m_TargetListSample; */
private:
RandomForestsMachineLearningModel(const Self&) = delete;
void operator=(const Self&) = delete;
cv::Ptr<CvRTreesWrapper> m_RFModel;
/** The depth of the tree. A low value will likely underfit and conversely a
* high value will likely overfit. The optimal value can be obtained using cross
* validation or other suitable methods. */
int m_MaxDepth;
/** minimum samples required at a leaf node for it to be split. A reasonable
* value is a small percentage of the total data e.g. 1%. */
int m_MinSampleCount;
/** Termination criteria for regression trees. If all absolute differences
* between an estimated value in a node and values of train samples in this node
* are less than this parameter then the node will not be split */
float m_RegressionAccuracy;
bool m_ComputeSurrogateSplit;
/** Cluster possible values of a categorical variable into
* \f$ K \leq MaxCategories \f$
* clusters to find a suboptimal split. If a discrete variable,
* on which the training procedure tries to make a split, takes more than
* max_categories values, the precise best subset estimation may take a very
* long time because the algorithm is exponential. Instead, many decision
* trees engines (including ML) try to find sub-optimal split in this case by
* clustering all the samples into max categories clusters that is some
* categories are merged together. The clustering is applied only in n>2-class
* classification problems for categorical variables with N > max_categories
* possible values. In case of regression and 2-class classification the
* optimal split can be found efficiently without employing clustering, thus
* the parameter is not used in these cases.
*/
int m_MaxNumberOfCategories;
/** The array of a priori class probabilities, sorted by the class label
* value. The parameter can be used to tune the decision tree preferences toward
* a certain class. For example, if you want to detect some rare anomaly
* occurrence, the training base will likely contain much more normal cases than
* anomalies, so a very good classification performance will be achieved just by
* considering every case as normal. To avoid this, the priors can be specified,
* where the anomaly probability is artificially increased (up to 0.5 or even
* greater), so the weight of the misclassified anomalies becomes much bigger,
* and the tree is adjusted properly. You can also think about this parameter as
* weights of prediction categories which determine relative weights that you
* give to misclassification. That is, if the weight of the first category is 1
* and the weight of the second category is 10, then each mistake in predicting
* the second category is equivalent to making 10 mistakes in predicting the
* first category. */
std::vector<float> m_Priors;
/** If true then variable importance will be calculated and then it can be
* retrieved by CvRTreesWrapper::get_var_importance(). */
bool m_CalculateVariableImportance;
/** The size of the randomly selected subset of features at each tree node and
* that are used to find the best split(s). If you set it to 0 then the size will
* be set to the square root of the total number of features. */
int m_MaxNumberOfVariables;
/** The maximum number of trees in the forest (surprise, surprise). Typically
* the more trees you have the better the accuracy. However, the improvement in
* accuracy generally diminishes and asymptotes pass a certain number of
* trees. Also to keep in mind, the number of tree increases the prediction time
*linearly. */
int m_MaxNumberOfTrees;
/** Sufficient accuracy (OOB error) */
float m_ForestAccuracy;
/** The type of the termination criteria */
int m_TerminationCriteria;
/** Whether to compute margin (difference in probability between the
* 2 most voted classes) instead of confidence (probability of the most
* voted class) in prediction*/
bool m_ComputeMargin;
};
} // end namespace otb
#ifndef OTB_MANUAL_INSTANTIATION
#include "otbRandomForestsMachineLearningModel.hxx"
#endif
#endif
|