File: RandomForestClassifier.h

package info (click to toggle)
itksnap 3.4.0-2
  • links: PTS, VCS
  • area: main
  • in suites: stretch
  • size: 10,196 kB
  • ctags: 9,196
  • sloc: cpp: 62,895; sh: 175; makefile: 13
file content (93 lines) | stat: -rw-r--r-- 2,509 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#ifndef RANDOMFORESTCLASSIFIER_H
#define RANDOMFORESTCLASSIFIER_H

#include <itkDataObject.h>
#include <itkObjectFactory.h>
#include <itkSize.h>
#include <SNAPCommon.h>
#include <map>

template <class dataT, class labelT> class Histogram;
template <class dataT, class labelT> class AxisAlignedClassifier;
template <class HistT, class ClassT, class dataT> class DecisionForest;

/**
 * This class encapsulates a Random Forest classifier
 */
class RandomForestClassifier : public itk::DataObject
{
public:

  // Standard ITK stuff
  irisITKObjectMacro(RandomForestClassifier, itk::DataObject)

  // typedefs
  typedef Histogram<GreyType, LabelType> RFHistogramType;
  typedef AxisAlignedClassifier<GreyType, LabelType> RFAxisClassifierType;
  typedef DecisionForest<RFHistogramType, RFAxisClassifierType, GreyType> RandomForestType;
  typedef std::map<size_t, LabelType> MappingType;
  typedef itk::Size<3> SizeType;

  // A list of weights for each class - used to construct speed image
  typedef std::vector<double> WeightArray;

  // Reset the classifier
  void Reset();

  // Get the mapping from the class indices to labels
  irisGetMacro(ClassToLabelMapping, const MappingType &)

  // Get the random forest
  irisGetMacro(Forest, RandomForestType *)

  // Get the patch radius
  irisGetMacro(PatchRadius, const SizeType &)

  /** Whether coordinates of the voxels are used as features */
  itkGetMacro(UseCoordinateFeatures, bool)
  itkSetMacro(UseCoordinateFeatures, bool)

  // Set the bias parameter (adjusts the mapping of FG probability to speed)
  itkGetMacro(BiasParameter, double)
  itkSetMacro(BiasParameter, double)

  // Get a reference to the weight array
  irisGetMacro(ClassWeights, const WeightArray &)

  // Set the weight for a class
  void SetClassWeight(size_t class_id, double weight);

  // Test if the classifier is valid (has 2+ classes)
  bool IsValidClassifier() const;

protected:

  RandomForestClassifier();
  ~RandomForestClassifier();

  // The actual decision forest
  RandomForestType *m_Forest;

  // Whether the labels are valid (?)
  bool m_ValidLabel;

  // Mapping of index to label (?)
  MappingType m_ClassToLabelMapping;

  // Weight of each class
  WeightArray m_ClassWeights;

  // The patch radius
  SizeType m_PatchRadius;

  // Whether coordinate features are used
  bool m_UseCoordinateFeatures;

  // Bias parameter
  double m_BiasParameter;

  // Let the engine handle our data
  friend class RFClassificationEngine;
};

#endif // RANDOMFORESTCLASSIFIER_H