File: antsRegistrationOptimizerCommandIterationUpdate.h

package info (click to toggle)
ants 2.5.4%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 11,672 kB
  • sloc: cpp: 85,685; sh: 15,850; perl: 863; xml: 115; python: 111; makefile: 68
file content (456 lines) | stat: -rw-r--r-- 19,807 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
#ifndef antsRegistrationOptimizerCommandIterationUpdate__h_
#define antsRegistrationOptimizerCommandIterationUpdate__h_

namespace ants
{
/** \class antsRegistrationOptimizerCommandIterationUpdate
 *  \brief observe the optimizer for traditional registration methods
 */
template <typename ParametersValueType, unsigned VImageDimension, typename TOptimizer>
class antsRegistrationOptimizerCommandIterationUpdate final : public itk::Command
{
public:
  typedef antsRegistrationOptimizerCommandIterationUpdate Self;
  typedef itk::Command                                    Superclass;
  typedef itk::SmartPointer<Self>                         Pointer;
  itkNewMacro(Self);

  typedef ParametersValueType                                                                     RealType;
  typedef ParametersValueType                                                                     PixelType;
  typedef typename itk::Image<PixelType, VImageDimension>                                         ImageType;
  typedef itk::ImageToImageMetricv4<ImageType, ImageType, ImageType, RealType>                    ImageMetricType;
  typedef itk::ObjectToObjectMultiMetricv4<VImageDimension, VImageDimension, ImageType, RealType> MultiMetricType;
  typedef typename ImageMetricType::MeasureType                                                   MeasureType;
  typedef itk::CompositeTransform<RealType, VImageDimension> CompositeTransformType;
  typedef typename CompositeTransformType::TransformType     TransformBaseType;

protected:
  antsRegistrationOptimizerCommandIterationUpdate()
  {
    m_clock.Start();
    m_clock.Stop();
    const itk::RealTimeClock::TimeStampType now = m_clock.GetTotal();
    this->m_lastTotalTime = now;
    m_clock.Start();
    this->m_LogStream = &std::cout;
    this->m_origFixedImage = ImageType::New();
    this->m_origMovingImage = ImageType::New();
    this->m_ComputeFullScaleCCInterval = 0;
    this->m_WriteIterationsOutputsInIntervals = 0;
    this->m_CurrentStageNumber = 0;
    this->m_CurrentLevel = itk::NumericTraits<unsigned int>::ZeroValue();
  }

public:
  void
  Execute(itk::Object * caller, const itk::EventObject & event) final
  {
    Execute((const itk::Object *)caller, event);
  }

  void
  Execute(const itk::Object *, const itk::EventObject & event) final
  {
#if 0
    if( typeid( event ) == typeid( itk::InitializeEvent ) )
      {
      const unsigned int currentLevel = this->m_Optimizer->GetCurrentLevel();

      typename TOptimizer::ShrinkFactorsPerDimensionContainerType shrinkFactors = this->m_Optimizer->GetShrinkFactorsPerDimension( currentLevel );
      typename TOptimizer::SmoothingSigmasArrayType smoothingSigmas = this->m_Optimizer->GetSmoothingSigmasPerLevel();
      typename TOptimizer::TransformParametersAdaptorsContainerType adaptors =
        this->m_Optimizer->GetTransformParametersAdaptorsPerLevel();
      bool smoothingSigmasAreInPhysicalUnits = this->m_Optimizer->GetSmoothingSigmasAreSpecifiedInPhysicalUnits();

      m_clock.Stop();
      const itk::RealTimeClock::TimeStampType now = m_clock.GetTotal();
      this->Logger() << "  Current level = " << currentLevel + 1 << " of " << this->m_NumberOfIterations.size()
                     << std::endl;
      this->Logger() << "    number of iterations = " << this->m_NumberOfIterations[currentLevel] << std::endl;
      this->Logger() << "    shrink factors = " << shrinkFactors << std::endl;
      this->Logger() << "    smoothing sigmas = " << smoothingSigmas[currentLevel];
      if( smoothingSigmasAreInPhysicalUnits )
        {
        this->Logger() << " mm" << std::endl;
        }
      else
        {
        this->Logger() << " vox" << std::endl;
        }
      this->Logger() << "    required fixed parameters = " << adaptors[currentLevel]->GetRequiredFixedParameters()
                     << std::flush << std::endl;
      // this->Logger() << "\n  LEVEL_TIME_INDEX: " << now << " SINCE_LAST: " << (now-this->m_lastTotalTime) <<
      // std::endl;
      this->m_lastTotalTime = now;
      m_clock.Start();

      typedef itk::GradientDescentOptimizerv4<ParametersValueType> GradientDescentOptimizerType;
      GradientDescentOptimizerType * optimizer = reinterpret_cast<GradientDescentOptimizerType *>( this->m_Optimizer->GetModifiableOptimizer() );

      optimizer->SetNumberOfIterations( this->m_NumberOfIterations[currentLevel] );
      }
    else
#endif
    if (typeid(event) == typeid(itk::IterationEvent))
    {
      // currentIteration indexed from 1 for printing to the screen and naming output
      const unsigned int currentIteration = this->m_Optimizer->GetCurrentIteration() + 1;
      if (currentIteration == 1)
      {
        this->m_Optimizer->SetNumberOfIterations(this->m_NumberOfIterations[this->m_CurrentLevel]);
        this->m_CurrentLevel++;

        if (this->m_ComputeFullScaleCCInterval != 0)
        {
          // Print header line one time
          this->Logger()
            << "DIAGNOSTIC,Iteration,metricValue,convergenceValue,ITERATION_TIME_INDEX,SINCE_LAST,FullScaleCCInterval="
            << this->m_ComputeFullScaleCCInterval << std::flush << std::endl;
        }
        else
        {
          this->Logger() << "DIAGNOSTIC,Iteration,metricValue,convergenceValue,ITERATION_TIME_INDEX,SINCE_LAST"
                         << std::flush << std::endl;
        }
      }
      m_clock.Stop();
      const itk::RealTimeClock::TimeStampType now = m_clock.GetTotal();

      MeasureType        metricValue = 0.0;
      const unsigned int lastIteration = this->m_Optimizer->GetNumberOfIterations();
      if ((this->m_ComputeFullScaleCCInterval != 0) &&
          (currentIteration == 1 || (currentIteration % this->m_ComputeFullScaleCCInterval == 0) ||
           currentIteration == lastIteration))
      {
        // This function finds the similarity value between the original fixed image and the original moving images
        // using a CC metric type with radius 4.
        // The feature can be used to observe the progress of the registration process at each iteration.
        this->UpdateFullScaleMetricValue(this->m_Optimizer, metricValue);
      }

      if ((this->m_WriteIterationsOutputsInIntervals != 0) &&
          (currentIteration == 1 || (currentIteration % this->m_WriteIterationsOutputsInIntervals == 0) ||
           currentIteration == lastIteration))
      {
        // This function writes the output volume of each iteration to the disk.
        // The feature can be used to observe the progress of the registration process at each iteration,
        // and make a short movie from the the registration process.
        this->WriteIntervalVolumes(this->m_Optimizer);
      }
      else
      {
        this->Logger() << " "; // if the output of current iteration is written to disk, and star
      }                        // will appear before line, else a free space will be printed to keep visual alignment.

      this->Logger() << "2DIAGNOSTIC, " << std::setw(5) << currentIteration << ", " << std::scientific
                     << std::setprecision(12) << this->m_Optimizer->GetValue() << ", " << std::scientific
                     << std::setprecision(12) << this->m_Optimizer->GetConvergenceValue() << ", "
                     << std::setprecision(4) << now << ", " << std::setprecision(4) << (now - this->m_lastTotalTime)
                     << ", ";
      if ((this->m_ComputeFullScaleCCInterval != 0) && std::fabs(metricValue) > static_cast<MeasureType>(1e-7))
      {
        this->Logger() << std::scientific << std::setprecision(12) << metricValue << std::flush << std::endl;
      }
      else
      {
        this->Logger() << std::flush << std::endl;
      }

      this->m_lastTotalTime = now;
      m_clock.Start();
    }
    else
    {
      // Unknown event type
      return;
    }
  }

  itkSetMacro(ComputeFullScaleCCInterval, unsigned int);

  itkSetMacro(WriteIterationsOutputsInIntervals, unsigned int);

  itkSetMacro(CurrentStageNumber, unsigned int);

  void
  SetNumberOfIterations(const std::vector<unsigned int> & iterations)
  {
    this->m_NumberOfIterations = iterations;
  }

  void
  SetLogStream(std::ostream & logStream)
  {
    this->m_LogStream = &logStream;
  }

  /**
   * Type defining the optimizer
   */
  typedef TOptimizer OptimizerType;

  /**
   * Set Optimizer
   */
  void
  SetOptimizer(OptimizerType * optimizer)
  {
    this->m_Optimizer = optimizer;
    this->m_Optimizer->AddObserver(itk::IterationEvent(), this);
  }

  void
  SetOrigFixedImage(typename ImageType::Pointer origFixedImage)
  {
    this->m_origFixedImage = origFixedImage;
  }

  void
  SetOrigMovingImage(typename ImageType::Pointer origMovingImage)
  {
    this->m_origMovingImage = origMovingImage;
  }

  void
  UpdateFullScaleMetricValue(itk::WeakPointer<OptimizerType> myOptimizer, MeasureType & metricValue) const
  {
    // Get the registration metric from the optimizer
    typename ImageMetricType::Pointer inputMetric(dynamic_cast<ImageMetricType *>(myOptimizer->GetModifiableMetric()));

    // Define the CC metric type
    // This metric type is used to measure the general similarity metric between the original input fixed and moving
    // images.
    typedef itk::ANTSNeighborhoodCorrelationImageToImageMetricv4<ImageType, ImageType, ImageType, MeasureType>
                                                 CorrelationImageMetricType;
    typename CorrelationImageMetricType::Pointer correlationMetric = CorrelationImageMetricType::New();
    {
      typename CorrelationImageMetricType::RadiusType radius;
      radius.Fill(4); // NOTE: This is just a common reference for fine-tuning parameters, so perhaps a smaller
                      // window would be sufficient.
      correlationMetric->SetRadius(radius);
    }
    correlationMetric->SetUseMovingImageGradientFilter(false);
    correlationMetric->SetUseFixedImageGradientFilter(false);
    typename ImageMetricType::Pointer metric = correlationMetric.GetPointer();

    // We need to create an exact copy from the composite fixed and moving transforms returned from the metric
    // We should roll off the composite transform and create a new instance from each of its sub transforms

    // For the fixed transform, first we should check that wether it is an identity transform or composite transform.
    typename TransformBaseType::Pointer fixedTransform;
    if (strcmp(inputMetric->GetFixedTransform()->GetNameOfClass(), "CompositeTransform") == 0)
    {
      typename CompositeTransformType::Pointer myFixedTransform = CompositeTransformType::New();

      // We cast the metric's transform to a composite transform, so we can copy each
      // of its sub transforms to a new instance.
      // Notice that the metric transform will not be changed inside this fuction.
      typename CompositeTransformType::ConstPointer inputFixedTransform =
        dynamic_cast<CompositeTransformType *>(inputMetric->GetModifiableFixedTransform());
      const unsigned int N = inputFixedTransform->GetNumberOfTransforms();
      for (unsigned int i = 0; i < N; i++)
      {
        // Create a new instance from each sub transform.
        typename TransformBaseType::Pointer subTransform(
          dynamic_cast<TransformBaseType *>(inputFixedTransform->GetNthTransform(i)->CreateAnother().GetPointer()));
        // Copy the information to each sub transform and add this transform to the final composite transform.
        const typename TransformBaseType::ParametersType & fixedImage_paras =
          inputFixedTransform->GetNthTransform(i)->GetParameters();
        const typename TransformBaseType::FixedParametersType & fixedImage_fixed_paras =
          inputFixedTransform->GetNthTransform(i)->GetFixedParameters();
        subTransform->SetParameters(fixedImage_paras);
        subTransform->SetFixedParameters(fixedImage_fixed_paras);
        myFixedTransform->AddTransform(subTransform);
      }
      myFixedTransform->SetOnlyMostRecentTransformToOptimizeOn();
      fixedTransform = myFixedTransform;
    }
    else if (strcmp(inputMetric->GetFixedTransform()->GetNameOfClass(), "IdentityTransform") == 0)
    {
      typedef typename itk::IdentityTransform<RealType, VImageDimension> IdentityTransformType;
      typename IdentityTransformType::Pointer myFixedTransform = IdentityTransformType::New();
      fixedTransform = myFixedTransform;
    }
    else
    {
      itkExceptionMacro("Fixed Transform should be either \"Composite\" or \"Identity\" transform.");
    }

    // Same procedure for the moving transform. Moving transform is always a Composite transform.
    typename CompositeTransformType::Pointer movingTransform = CompositeTransformType::New();

    typename CompositeTransformType::ConstPointer inputMovingTransform =
      dynamic_cast<CompositeTransformType *>(inputMetric->GetModifiableMovingTransform());
    const unsigned int N = inputMovingTransform->GetNumberOfTransforms();
    for (unsigned int i = 0; i < N; i++)
    {
      typename TransformBaseType::Pointer subTransform(
        dynamic_cast<TransformBaseType *>(inputMovingTransform->GetNthTransform(i)->CreateAnother().GetPointer()));
      const typename TransformBaseType::ParametersType & moving_paras =
        inputMovingTransform->GetNthTransform(i)->GetParameters();
      const typename TransformBaseType::FixedParametersType & moving_fixed_paras =
        inputMovingTransform->GetNthTransform(i)->GetFixedParameters();
      subTransform->SetParameters(moving_paras);
      subTransform->SetFixedParameters(moving_fixed_paras);
      movingTransform->AddTransform(subTransform);
    }
    movingTransform->SetOnlyMostRecentTransformToOptimizeOn();

    metric->SetVirtualDomainFromImage(this->m_origFixedImage);
    metric->SetFixedImage(this->m_origFixedImage);
    metric->SetFixedTransform(fixedTransform);
    metric->SetMovingImage(this->m_origMovingImage);
    metric->SetMovingTransform(movingTransform);
    metric->Initialize();
    metricValue = metric->GetValue();
  }

  typename CompositeTransformType::ConstPointer
  GetMovingTransform(itk::WeakPointer<OptimizerType> myOptimizer)
  {
    typename CompositeTransformType::ConstPointer movingTransform;

    // Get the registration metric from the optimizer
    typename OptimizerType::MetricType * metric = myOptimizer->GetModifiableMetric();

    // Try casting it to a multi-metric type
    typename MultiMetricType::Pointer multiMetric = dynamic_cast<MultiMetricType *>(metric);

    // The dynamic_cast will return NULL if the real object type is not a multi metric
    if (multiMetric)
    {
      // Just get the first metric; we're more interested in the moving transform, which should be the same for
      // all metrics.
      typename ImageMetricType::Pointer firstMetric(
        dynamic_cast<ImageMetricType *>(multiMetric->GetMetricQueue()[0].GetPointer()));

      if (firstMetric.IsNotNull())
      {
        movingTransform = dynamic_cast<CompositeTransformType *>(firstMetric->GetModifiableMovingTransform());
      }
      else
      {
        itkExceptionMacro("Invalid metric conversion.");
      }
    }
    else
    {
      // Get the metric's moving transform
      typename ImageMetricType::Pointer singleMetric(dynamic_cast<ImageMetricType *>(metric));
      movingTransform = dynamic_cast<CompositeTransformType *>(singleMetric->GetModifiableMovingTransform());
    }

    return movingTransform;
  }

  void
  WriteIntervalVolumes(itk::WeakPointer<OptimizerType> myOptimizer)
  {
    // First, compute the moving transform
    typename CompositeTransformType::Pointer movingTransform = CompositeTransformType::New();

    // Get the moving transform of the current metric
    typename CompositeTransformType::ConstPointer inputMovingTransform = this->GetMovingTransform(myOptimizer);

    const unsigned int N = inputMovingTransform->GetNumberOfTransforms();
    for (unsigned int i = 0; i < N; i++)
    {
      typename TransformBaseType::Pointer subTransform(
        dynamic_cast<TransformBaseType *>(inputMovingTransform->GetNthTransform(i)->CreateAnother().GetPointer()));
      const typename TransformBaseType::ParametersType & moving_paras =
        inputMovingTransform->GetNthTransform(i)->GetParameters();
      const typename TransformBaseType::FixedParametersType & moving_fixed_paras =
        inputMovingTransform->GetNthTransform(i)->GetFixedParameters();
      subTransform->SetParameters(moving_paras);
      subTransform->SetFixedParameters(moving_fixed_paras);
      movingTransform->AddTransform(subTransform);
    }
    movingTransform->SetOnlyMostRecentTransformToOptimizeOn();

    // Now we apply this output transform to get warped image
    typedef itk::LinearInterpolateImageFunction<ImageType, RealType> LinearInterpolatorType;
    typename LinearInterpolatorType::Pointer                         linearInterpolator = LinearInterpolatorType::New();

    typedef itk::ResampleImageFilter<ImageType, ImageType, RealType> ResampleFilterType;
    typename ResampleFilterType::Pointer                             resampler = ResampleFilterType::New();
    resampler->SetTransform(movingTransform);
    resampler->SetInput(this->m_origMovingImage);
    resampler->SetOutputParametersFromImage(this->m_origFixedImage);
    resampler->SetInterpolator(linearInterpolator);
    resampler->SetDefaultPixelValue(0);
    resampler->Update();

    // write the results to the disk
    std::stringstream currentFileName;
    currentFileName << "Stage" << this->m_CurrentStageNumber + 1 << "_level" << this->m_CurrentLevel;
    /*
    The name arrangement of written files are important to us.
    To prevent: "Iter1 Iter10 Iter2 Iter20" we use the following style.
    Then the order is: "Iter1 Iter2 ... Iters10 ... Itert20"
    */

    const unsigned int currentIteration = this->m_Optimizer->GetCurrentIteration() + 1;

    if (currentIteration < 10)
    {
      currentFileName << "_Iter000" << currentIteration << ".nii.gz";
    }
    else if (currentIteration < 100)
    {
      currentFileName << "_Iter00" << currentIteration << ".nii.gz";
    }
    else if (currentIteration < 1000)
    {
      currentFileName << "_Iter0" << currentIteration << ".nii.gz";
    }
    else
    {
      currentFileName << "_Iter" << currentIteration << ".nii.gz";
    }
    std::cout << "*"; // The star befor each DIAGNOSTIC shows that its output is writtent out.
    std::cout << currentFileName.str()
              << std::endl; // The star befor each DIAGNOSTIC shows that its output is writtent out.

    typedef itk::ImageFileWriter<ImageType> WarpedImageWriterType;
    typename WarpedImageWriterType::Pointer writer = WarpedImageWriterType::New();
    writer->SetFileName(currentFileName.str().c_str());
    writer->SetInput(resampler->GetOutput());
    try
    {
      writer->Update();
    }
    catch (const itk::ExceptionObject & err)
    {
      std::cout << "Can't write warped image " << currentFileName.str().c_str() << std::endl;
      std::cout << "Exception Object caught: " << std::endl;
      std::cout << err << std::endl;
    }
  }

private:
  std::ostream &
  Logger() const
  {
    return *m_LogStream;
  }

  /**
   *  WeakPointer to the Optimizer
   */
  itk::WeakPointer<OptimizerType> m_Optimizer;

  std::vector<unsigned int>         m_NumberOfIterations;
  std::ostream *                    m_LogStream;
  itk::TimeProbe                    m_clock;
  itk::RealTimeClock::TimeStampType m_lastTotalTime;

  unsigned int m_ComputeFullScaleCCInterval;
  unsigned int m_WriteIterationsOutputsInIntervals;
  unsigned int m_CurrentStageNumber;
  unsigned int m_CurrentLevel;

  typename ImageType::Pointer m_origFixedImage;
  typename ImageType::Pointer m_origMovingImage;
};
} // end namespace ants
#endif // antsRegistrationOptimizerCommandIterationUpdate__h_