File: RFClassificationEngine.cxx

package info (click to toggle)
itksnap 3.4.0-2
  • links: PTS, VCS
  • area: main
  • in suites: stretch
  • size: 10,196 kB
  • ctags: 9,196
  • sloc: cpp: 62,895; sh: 175; makefile: 13
file content (249 lines) | stat: -rw-r--r-- 7,608 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
#include "RFClassificationEngine.h"
#include "RandomForestClassifier.h"

#include "SNAPImageData.h"
#include "ImageWrapper.h"
#include "ImageCollectionToImageFilter.h"
#include "itkImageRegionIterator.h"

// Includes from the random forest library
typedef GreyType data_t;
typedef LabelType label_t;

#include "Library/classification.h"
#include "Library/data.h"

RFClassificationEngine::RFClassificationEngine()
{
  m_DataSource = NULL;
  m_Sample = NULL;
  m_Classifier = RandomForestClassifier::New();
  m_ForestSize = 50;
  m_TreeDepth = 30;
  m_PatchRadius.Fill(0);
  m_UseCoordinateFeatures = false;
}

RFClassificationEngine::~RFClassificationEngine()
{
  if(m_Sample)
    delete m_Sample;
}

void RFClassificationEngine::SetDataSource(SNAPImageData *imageData)
{
  if(m_DataSource != imageData)
    {
    // Copy the data source
    m_DataSource = imageData;

    // Reset the classifier
    m_Classifier->Reset();
    }
}

void RFClassificationEngine::ResetClassifier()
{
  m_Classifier->Reset();
}

void RFClassificationEngine:: TrainClassifier()
{
  assert(m_DataSource && m_DataSource->IsMainLoaded());

  typedef ImageCollectionConstRegionIteratorWithIndex<
      AnatomicScalarImageWrapper::ImageType,
      AnatomicImageWrapper::ImageType> CollectionIter;

  // TODO: in the future, we should only recompute the sample when we know
  // that the data has changed. However, currently, we are just going to
  // compute a new sample every time

  // Delete the sample
  if(m_Sample)
    delete m_Sample;

  // Get the segmentation image - which determines the samples
  LabelImageWrapper *wrpSeg = m_DataSource->GetSegmentation();
  LabelImageWrapper::ImagePointer imgSeg = wrpSeg->GetImage();
  typedef itk::ImageRegionConstIteratorWithIndex<LabelImageWrapper::ImageType> LabelIter;

  // Shrink the buffered region by radius because we can't handle BCs
  itk::ImageRegion<3> reg = imgSeg->GetBufferedRegion();
  reg.ShrinkByRadius(m_PatchRadius);

  // We need to iterate throught the label image once to determine the
  // number of samples to allocate.
  unsigned long nSamples = 0;
  for(LabelIter lit(imgSeg, reg); !lit.IsAtEnd(); ++lit)
    if(lit.Value())
      nSamples++;

  // Create an iterator for going over all the anatomical image data
  CollectionIter cit(reg);
  cit.SetRadius(m_PatchRadius);

  // Add all the anatomical images to this iterator
  for(LayerIterator it = m_DataSource->GetLayers(MAIN_ROLE | OVERLAY_ROLE);
      !it.IsAtEnd(); ++it)
    {
    cit.AddImage(it.GetLayer()->GetImageBase());
    }

  // Get the number of components
  int nComp = cit.GetTotalComponents();
  int nPatch = cit.GetNeighborhoodSize();
  int nColumns = nComp * nPatch;

  // Are we using coordinate informtion
  if(m_UseCoordinateFeatures)
    nColumns += 3;

  // Create a new sample
  m_Sample = new SampleType(nSamples, nColumns);

  // Now fill out the samples
  int iSample = 0;
  for(LabelIter lit(imgSeg, reg); !lit.IsAtEnd(); ++lit, ++cit)
    {
    LabelType label = lit.Value();
    if(label)
      {
      // Fill in the data
      std::vector<GreyType> &column = m_Sample->data[iSample];
      int k = 0;
      for(int i = 0; i < nComp; i++)
        for(int j = 0; j < nPatch; j++)
          column[k++] = cit.NeighborValue(i,j);

      // Add the coordinate features if used
      if(m_UseCoordinateFeatures)
        for(int d = 0; d < 3; d++)
          column[k++] = lit.GetIndex()[d];

      // Fill in the label
      m_Sample->label[iSample] = label;

      ++iSample;
      }
    }

  // Check that the sample has at least two distinct labels
  bool isValidSample = false;
  for(int iSample = 1; iSample < m_Sample->Size(); iSample++)
    if(m_Sample->label[iSample] != m_Sample->label[iSample-1])
      { isValidSample = true; break; }

  // Now there is a valid sample. The text task is to train the classifier
  if(!isValidSample)
    throw IRISException("A classifier cannot be trained because the training "
                        "data contain fewer than two classes. Please label "
                        "examples of two or more tissue classes in the image.");

  // Set up the classifier parameters
  TrainingParameters params;
  // TODO:
  params.treeDepth = m_TreeDepth;
  params.treeNum = m_ForestSize;
  params.candidateNodeClassifierNum = 10;
  params.candidateClassifierThresholdNum = 10;
  params.subSamplePercent = 0;
  params.splitIG = 0.1;
  params.leafEntropy = 0.05;
  params.verbose = true;

  // Cap the number of training voxels at some reasonable number
  if(m_Sample->Size() > 10000)
    params.subSamplePercent = 100 * 10000.0 / m_Sample->Size();
  else
    params.subSamplePercent = 0;

  // Create the classification engine
  typedef RandomForestClassifier::RFAxisClassifierType RFAxisClassifierType;
  typedef Classification<GreyType, LabelType, RFAxisClassifierType> ClassificationType;
  ClassificationType classification;

  // Before resetting the classifier, we want to retain whatever the
  // weighting of the classes was
  std::map<LabelType, double> old_label_weights;
  if(m_Classifier->IsValidClassifier())
    {
    // Get the class weights
    const RandomForestClassifier::WeightArray &class_weights = m_Classifier->GetClassWeights();

    // Convert them to label weights (since class to label mapping may change)
    for(RandomForestClassifier::MappingType::const_iterator it =
        m_Classifier->m_ClassToLabelMapping.begin();
        it != m_Classifier->m_ClassToLabelMapping.end(); ++it)
      {
      old_label_weights[it->second] = class_weights[it->first];
      }
    }

  // Prepare the classifier
  m_Classifier->Reset();

  // Perform classifier training
  classification.Learning(
        params, *m_Sample,
        *m_Classifier->m_Forest,
        m_Classifier->m_ValidLabel,
        m_Classifier->m_ClassToLabelMapping);

  // Reset the class weights to the number of classes and assign default
  int n_classes = m_Classifier->m_ClassToLabelMapping.size(), n_fore = 0, n_back = 0;
  m_Classifier->m_ClassWeights.resize(n_classes, -1.0);

  // Apply the old weight assignments if possible. Keep track of the number of fore and back classes
  for(RandomForestClassifier::MappingType::iterator it =
      m_Classifier->m_ClassToLabelMapping.begin();
      it != m_Classifier->m_ClassToLabelMapping.end(); ++it)
    {
    if(old_label_weights.find(it->second) != old_label_weights.end())
      {
      m_Classifier->m_ClassWeights[it->first] = old_label_weights[it->second];
      }
    if(m_Classifier->m_ClassWeights[it->first] < 0.0)
      n_back++;
    else if(m_Classifier->m_ClassWeights[it->first] > 0.0)
      n_fore++;
    }

  // Make sure that we have at least one foreground class and at least one background class
  if(n_classes >= 2)
    {
    if(n_fore == 0)
      m_Classifier->m_ClassWeights.front() = 1.0;
    if(n_back == 0)
      m_Classifier->m_ClassWeights.back() = -1.0;
    }

  // Store the patch radius in the classifier - this remains fixed until
  // training is repeated
  m_Classifier->m_PatchRadius = m_PatchRadius;
  m_Classifier->m_UseCoordinateFeatures = m_UseCoordinateFeatures;
}

void RFClassificationEngine::SetClassifier(RandomForestClassifier *rf)
{
  // Set the classifier
  m_Classifier = rf;

  // Update the forest size
  m_ForestSize = m_Classifier->GetForest()->GetForestSize();
}

int RFClassificationEngine::GetNumberOfComponents() const
{
  assert(m_DataSource);

  int ncomp = 0;

  for(LayerIterator it = m_DataSource->GetLayers(MAIN_ROLE | OVERLAY_ROLE);
      !it.IsAtEnd(); ++it)
    ncomp += it.GetLayer()->GetNumberOfComponents();

  return ncomp;
}