1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
|
#ifndef RFCLASSIFICATIONENGINE_H
#define RFCLASSIFICATIONENGINE_H
#include <itkObject.h>
#include <itkObjectFactory.h>
#include <itkSize.h>
#include "SNAPCommon.h"
#include "PropertyModel.h"
class SNAPImageData;
class RandomForestClassifier;
template <class TData, class TLabel> class MLData;
/**
* This class serves as the high-level interface between ITK-SNAP and the
* random forest code.
*/
class RFClassificationEngine : public itk::Object
{
public:
// Standard ITK class stuff
irisITKObjectMacro(RFClassificationEngine, itk::Object)
// Patch radius type
typedef itk::Size<3> RadiusType;
/** Set the data source for the classification */
void SetDataSource(SNAPImageData *imageData);
/** Reset the classifier */
void ResetClassifier();
/** Train the classifier */
void TrainClassifier();
/** Set the classifier */
void SetClassifier(RandomForestClassifier *rf);
/** Access the trained classifier */
itkGetMacro(Classifier, RandomForestClassifier *)
/** Size of the random forest (main parameter) */
itkGetMacro(ForestSize, int)
itkSetMacro(ForestSize, int)
/** Size of the random forest (main parameter) */
itkGetMacro(TreeDepth, int)
itkSetMacro(TreeDepth, int)
/** Patch radius for sampling features */
itkGetMacro(PatchRadius, const RadiusType &)
itkSetMacro(PatchRadius, RadiusType)
/** Whether coordinates of the voxels are used as features */
itkGetMacro(UseCoordinateFeatures, bool)
itkSetMacro(UseCoordinateFeatures, bool)
/** Get the number of components passed to the classifier */
int GetNumberOfComponents() const;
protected:
RFClassificationEngine();
virtual ~RFClassificationEngine();
// The trained classifier
SmartPtr<RandomForestClassifier> m_Classifier;
// The data source
SNAPImageData *m_DataSource;
// The foreground label
LabelType m_ForegroundLabel;
// Number of trees
int m_ForestSize;
// Number of trees
int m_TreeDepth;
// Patch radius
RadiusType m_PatchRadius;
// Are coordinates included as features
bool m_UseCoordinateFeatures;
// Cached samples used to train the classifier
typedef MLData<GreyType, LabelType> SampleType;
SampleType *m_Sample;
};
#endif // RFCLASSIFICATIONENGINE_H
|