File: RFClassificationEngine.h

package info (click to toggle)
itksnap 3.4.0-2
  • links: PTS, VCS
  • area: main
  • in suites: stretch
  • size: 10,196 kB
  • ctags: 9,196
  • sloc: cpp: 62,895; sh: 175; makefile: 13
file content (96 lines) | stat: -rw-r--r-- 2,244 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#ifndef RFCLASSIFICATIONENGINE_H
#define RFCLASSIFICATIONENGINE_H

#include <itkObject.h>
#include <itkObjectFactory.h>
#include <itkSize.h>
#include "SNAPCommon.h"
#include "PropertyModel.h"

class SNAPImageData;
class RandomForestClassifier;

template <class TData, class TLabel> class MLData;

/**
 * This class serves as the high-level interface between ITK-SNAP and the
 * random forest code.
 */
class RFClassificationEngine : public itk::Object
{
public:

  // Standard ITK class stuff
  irisITKObjectMacro(RFClassificationEngine, itk::Object)

  // Patch radius type
  typedef itk::Size<3> RadiusType;

  /** Set the data source for the classification */
  void SetDataSource(SNAPImageData *imageData);

  /** Reset the classifier */
  void ResetClassifier();

  /** Train the classifier */
  void TrainClassifier();

  /** Set the classifier */
  void SetClassifier(RandomForestClassifier *rf);

  /** Access the trained classifier */
  itkGetMacro(Classifier, RandomForestClassifier *)

  /** Size of the random forest (main parameter) */
  itkGetMacro(ForestSize, int)
  itkSetMacro(ForestSize, int)

  /** Size of the random forest (main parameter) */
  itkGetMacro(TreeDepth, int)
  itkSetMacro(TreeDepth, int)

  /** Patch radius for sampling features */
  itkGetMacro(PatchRadius, const RadiusType &)
  itkSetMacro(PatchRadius, RadiusType)

  /** Whether coordinates of the voxels are used as features */
  itkGetMacro(UseCoordinateFeatures, bool)
  itkSetMacro(UseCoordinateFeatures, bool)

  /** Get the number of components passed to the classifier */
  int GetNumberOfComponents() const;


protected:

  RFClassificationEngine();
  virtual ~RFClassificationEngine();

  // The trained classifier
  SmartPtr<RandomForestClassifier> m_Classifier;

  // The data source
  SNAPImageData *m_DataSource;

  // The foreground label
  LabelType m_ForegroundLabel;

  // Number of trees
  int m_ForestSize;

  // Number of trees
  int m_TreeDepth;

  // Patch radius
  RadiusType m_PatchRadius;

  // Are coordinates included as features
  bool m_UseCoordinateFeatures;

  // Cached samples used to train the classifier
  typedef MLData<GreyType, LabelType> SampleType;
  SampleType *m_Sample;

};

#endif // RFCLASSIFICATIONENGINE_H