File: RandomForestClassifyImageFilter.txx

package info (click to toggle)
itksnap 3.4.0-2
  • links: PTS, VCS
  • area: main
  • in suites: stretch
  • size: 10,196 kB
  • ctags: 9,196
  • sloc: cpp: 62,895; sh: 175; makefile: 13
file content (223 lines) | stat: -rw-r--r-- 7,477 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
#ifndef RANDOMFORESTCLASSIFYIMAGEFILTER_TXX
#define RANDOMFORESTCLASSIFYIMAGEFILTER_TXX

#include "RandomForestClassifyImageFilter.h"
#include "itkImageRegionConstIterator.h"
#include "RandomForestClassifier.h"
#include "ImageCollectionToImageFilter.h"
#include <itkProgressReporter.h>

#include "Library/data.h"
#include "Library/forest.h"
#include "Library/statistics.h"
#include "Library/classifier.h"


template <class TInputImage, class TInputVectorImage, class TOutputImage>
RandomForestClassifyImageFilter<TInputImage, TInputVectorImage, TOutputImage>
::RandomForestClassifyImageFilter()
{
  // m_MixtureModel = NULL;
}

template <class TInputImage, class TInputVectorImage, class TOutputImage>
RandomForestClassifyImageFilter<TInputImage, TInputVectorImage, TOutputImage>
::~RandomForestClassifyImageFilter()
{
}


template <class TInputImage, class TInputVectorImage, class TOutputImage>
void
RandomForestClassifyImageFilter<TInputImage, TInputVectorImage, TOutputImage>
::AddScalarImage(InputImageType *image)
{
  this->AddInput(image);
}

template <class TInputImage, class TInputVectorImage, class TOutputImage>
void
RandomForestClassifyImageFilter<TInputImage, TInputVectorImage, TOutputImage>
::AddVectorImage(InputVectorImageType *image)
{
  this->AddInput(image);
}

template <class TInputImage, class TInputVectorImage, class TOutputImage>
void
RandomForestClassifyImageFilter<TInputImage, TInputVectorImage, TOutputImage>
::SetClassifier(RandomForestClassifier *classifier)
{
  m_Classifier = classifier;
  this->Modified();
}


template <class TInputImage, class TInputVectorImage, class TOutputImage>
void
RandomForestClassifyImageFilter<TInputImage, TInputVectorImage, TOutputImage>
::GenerateInputRequestedRegion()
{
  itk::ImageSource<TOutputImage>::GenerateInputRequestedRegion();

  for( itk::InputDataObjectIterator it( this ); !it.IsAtEnd(); it++ )
    {
    // Check whether the input is an image of the appropriate dimension
    InputImageType *input = dynamic_cast< InputImageType * >( it.GetInput() );
    InputVectorImageType *vecInput = dynamic_cast< InputVectorImageType * >( it.GetInput() );
    if (input)
      {
      InputImageRegionType inputRegion;
      this->CallCopyOutputRegionToInputRegion( inputRegion, this->GetOutput()->GetRequestedRegion() );
      inputRegion.PadByRadius(m_Classifier->GetPatchRadius());
      inputRegion.Crop(input->GetLargestPossibleRegion());
      input->SetRequestedRegion(inputRegion);
      }
    else if(vecInput)
      {
      InputImageRegionType inputRegion;
      this->CallCopyOutputRegionToInputRegion( inputRegion, this->GetOutput()->GetRequestedRegion() );
      inputRegion.PadByRadius(m_Classifier->GetPatchRadius());
      inputRegion.Crop(vecInput->GetLargestPossibleRegion());
      vecInput->SetRequestedRegion(inputRegion);
      }
    }
}


template <class TInputImage, class TInputVectorImage, class TOutputImage>
void
RandomForestClassifyImageFilter<TInputImage, TInputVectorImage, TOutputImage>
::PrintSelf(std::ostream &os, itk::Indent indent) const
{
  os << indent << "RandomForestClassifyImageFilter" << std::endl;
}

template <class TInputImage, class TInputVectorImage, class TOutputImage>
void
RandomForestClassifyImageFilter<TInputImage, TInputVectorImage, TOutputImage>
::ThreadedGenerateData(const OutputImageRegionType &outputRegionForThread,
                       itk::ThreadIdType threadId)
{
  assert(m_Classifier);

  OutputImagePointer outputPtr = this->GetOutput(0);

  // Fill the output region with zeros
  itk::ImageRegionIterator<OutputImageType> zit(outputPtr, outputRegionForThread);
  for(; !zit.IsAtEnd(); ++zit)
    zit.Set((OutputPixelType) 0);

  // Adjust the output region so that we don't touch image boundaries.
  OutputImageRegionType crop_region = outputPtr->GetLargestPossibleRegion();
  crop_region.ShrinkByRadius(m_Classifier->GetPatchRadius());
  OutputImageRegionType out_region = outputRegionForThread;
  bool can_crop = out_region.Crop(crop_region);

  if(!can_crop)
    return;

  // Create an iterator for the output
  typedef itk::ImageRegionIteratorWithIndex<TOutputImage> OutputIter;
  OutputIter it_out(outputPtr, out_region);

  // Create a collection iterator for the inputs
  typedef ImageCollectionConstRegionIteratorWithIndex<
      TInputImage, TInputVectorImage> CollectionIter;

  // Configure the input collection iterator
  CollectionIter cit(out_region);
  for( itk::InputDataObjectIterator it( this ); !it.IsAtEnd(); it++ )
    cit.AddImage(it.GetInput());

  // TODO: This is hard-coded
  cit.SetRadius(m_Classifier->GetPatchRadius());

  // Get the number of components
  int nComp = cit.GetTotalComponents();
  int nPatch = cit.GetNeighborhoodSize();
  int nColumns = nComp * nPatch;

  // Are coordinate features used?
  if(m_Classifier->GetUseCoordinateFeatures())
    nColumns += 3;

  // Get the number of classes
  int nClass = m_Classifier->GetClassToLabelMapping().size();

  // Get the class weights (as they are assigned to foreground/background)
  const RandomForestClassifier::WeightArray &class_weights = m_Classifier->GetClassWeights();

  // Create the MLdata representing each voxel (?)
  typedef Histogram<InputPixelType,LabelType> HistogramType;
  typedef MLData<InputPixelType,HistogramType *> TestingDataType;
  TestingDataType testData(1, nColumns);

  // Get the number of trees
  int nTrees = m_Classifier->GetForest()->trees_.size();

  // Create and allocate the test result vector
  typedef Vector<Vector<HistogramType *> > TestingResultType;
  TestingResultType testResult;
  testResult.Resize(nTrees);
  for(int i = 0; i < nTrees; i++)
    testResult[i].Resize(1);

  // Some vectors that are allocated for speed
  std::vector<size_t> vIndex(1);
  std::vector<bool> vResult(1);

  // Iterate through all the voxels
  for(; !it_out.IsAtEnd(); ++it_out, ++cit)
    {
    // Assign the data to the testData vector
    int k = 0;
    for(int i = 0; i < nComp; i++)
      for(int j = 0; j < nPatch; j++)
        testData.data[0][k++] = cit.NeighborValue(i,j);

    // Add the coordinate features
    if(m_Classifier->GetUseCoordinateFeatures())
      for(int d = 0; d < 3; d++)
        testData.data[0][k++] = it_out.GetIndex()[d];

    // Perform classification on this data
    m_Classifier->GetForest()->ApplyFast(testData, testResult, vIndex, vResult);

    // New code: compute output map with a bias parameter. The bias parameter q is such
    // that p_fore = q maps to 0 speed value. For the time being we just shift the linear
    // mapping from p_fore to speed and cap speed between -1 and 1

    // First we compute p_fore - for some reason not all trees in the forest have probabilities
    // summing up to one (some are zero), so we need to use division
    double p_fore_total = 0, p_total = 0;
    for(int i = 0; i < testResult.Size(); i++)
      {
      HistogramType *hist = testResult[i][0];
      for(int j = 0; j < nClass; j++)
        {
        double p = hist->prob_[j];
        if(class_weights[j] > 0.0)
          p_fore_total += p;
        p_total += p;
        }
      }

    // Set output only if the total probability is non-zero
    if(p_total > 0)
      {
      double q = m_Classifier->GetBiasParameter();
      double p_fore = p_fore_total / p_total;
      double speed = 2 * (p_fore - q);
      if(speed < -1.0)
        speed = -1.0;
      else if(speed > 1.0)
        speed = 1.0;

      it_out.Set((OutputPixelType)(speed * 0x7fff));
      }
    }
}


#endif