File: simpleSynRegistration.cxx

package info (click to toggle)
ants 2.5.4%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 11,672 kB
  • sloc: cpp: 85,685; sh: 15,850; perl: 863; xml: 115; python: 111; makefile: 68
file content (203 lines) | stat: -rw-r--r-- 8,217 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
// This program does SyN registration with hard-coded parameters.
// The whole registration process is included in "simpleSynReg" function.

// The Usage:
// This program does not have any flag. You should just put the arguments after the program name.
/*
~/simpleSynRegistration
fixed image
moving image
initial transform
output prefix file name
*/

#include "antsUtilities.h"
#include "itkantsRegistrationHelper.h"

namespace ants
{
using RegistrationHelperType = ants::RegistrationHelper<double, 3>;
using ImageType = RegistrationHelperType::ImageType;
using CompositeTransformType = RegistrationHelperType::CompositeTransformType;

CompositeTransformType::TransformTypePointer
simpleSynReg(ImageType::Pointer &                  fixedImage,
             ImageType::Pointer &                  movingImage,
             const CompositeTransformType::Pointer compositeInitialTransform)
{
  RegistrationHelperType::Pointer regHelper = RegistrationHelperType::New();

  const std::string                         whichMetric = "mattes";
  RegistrationHelperType::MetricEnumeration curMetric = regHelper->StringToMetricType(whichMetric);
  const float                               lowerQuantile(0.0F);
  const float                               upperQuantile(1.0F);
  const bool                                doWinsorize(false);

  regHelper->SetWinsorizeImageIntensities(doWinsorize, lowerQuantile, upperQuantile);

  const bool doHistogramMatch(true);
  regHelper->SetUseHistogramMatching(doHistogramMatch);

  const bool doEstimateLearningRateAtEachIteration = true;
  regHelper->SetDoEstimateLearningRateAtEachIteration(doEstimateLearningRateAtEachIteration);

  std::vector<std::vector<unsigned int>> iterationList;
  std::vector<double>                    convergenceThresholdList;
  std::vector<unsigned int>              convergenceWindowSizeList;
  std::vector<std::vector<unsigned int>> shrinkFactorsList;
  std::vector<std::vector<float>>        smoothingSigmasList;

  std::vector<unsigned int> iterations(3);
  iterations[0] = 100;
  iterations[1] = 70;
  iterations[2] = 20;
  std::cout << "  number of levels = 3 " << std::endl;
  iterationList.push_back(iterations);

  const double convergenceThreshold = 1e-6;
  convergenceThresholdList.push_back(convergenceThreshold);
  constexpr unsigned int convergenceWindowSize = 10;
  convergenceWindowSizeList.push_back(convergenceWindowSize);

  std::vector<unsigned int> factors(3);
  factors[0] = 3;
  factors[1] = 2;
  factors[2] = 1;
  shrinkFactorsList.push_back(factors);

  std::vector<float> sigmas(3);
  sigmas[0] = 2;
  sigmas[1] = 1;
  sigmas[2] = 0;
  smoothingSigmasList.push_back(sigmas);
  std::vector<bool> smoothingSigmasAreInPhysicalUnitsList;
  smoothingSigmasAreInPhysicalUnitsList.push_back(true); // Historical behavior before 2012-10-07
  constexpr float                          samplingPercentage = 1.0;
  RegistrationHelperType::SamplingStrategy samplingStrategy = RegistrationHelperType::none;
  constexpr unsigned int                   binOption = 200;
  bool                                     useGradientFilter = false;
  regHelper->AddMetric(curMetric, fixedImage, movingImage, 0, 1.0, samplingStrategy, binOption, 1, useGradientFilter, samplingPercentage);

  const float learningRate(0.25F);
  const float varianceForUpdateField(3.0F);
  const float varianceForTotalField(0.0F);
  regHelper->AddSyNTransform(learningRate, varianceForUpdateField, varianceForTotalField);

  regHelper->SetMovingInitialTransform(compositeInitialTransform);
  regHelper->SetIterations(iterationList);
  regHelper->SetConvergenceWindowSizes(convergenceWindowSizeList);
  regHelper->SetConvergenceThresholds(convergenceThresholdList);
  regHelper->SetSmoothingSigmas(smoothingSigmasList);
  regHelper->SetShrinkFactors(shrinkFactorsList);
  regHelper->SetSmoothingSigmasAreInPhysicalUnits(smoothingSigmasAreInPhysicalUnitsList);
  if (regHelper->DoRegistration() == EXIT_SUCCESS)
  {
    // Get the output transform
    CompositeTransformType::Pointer outputCompositeTransform = regHelper->GetModifiableCompositeTransform();
    // write out transform actually computed, so skip the initial transform
    CompositeTransformType::TransformTypePointer resultTransform = outputCompositeTransform->GetNthTransform(1);
    return resultTransform;
  }
  std::cerr << "FATAL ERROR: REGISTRATION PROCESS WAS UNSUCCESSFUL" << std::endl;
  CompositeTransformType::TransformTypePointer invalidTransform = nullptr;
  return invalidTransform; // Return an empty registration type.
}

int
simpleSynRegistration(std::vector<std::string> args, std::ostream * /*out_stream = nullptr */)
{
  // the arguments coming in as 'args' is a replacement for the standard (argc,argv) format
  // Just notice that the argv[i] equals to args[i-1]
  // and the argc equals:
  int argc = args.size() + 1;

  if (argc != 5)
  {
    std::cerr
      << "Usage: simpleSynRegistration\n"
      << "<Fixed Image> , <Moving Image> , <Initial Transform> , <Output prefix file name without any extension>"
      << std::endl;
    return EXIT_FAILURE;
  }

  // antscout->set_stream( out_stream );

  ImageType::Pointer fixedImage;
  ImageType::Pointer movingImage;
  // ========read the fixed image
  using ImageReaderType = itk::ImageFileReader<ImageType>;
  ImageReaderType::Pointer fixedImageReader = ImageReaderType::New();
  fixedImageReader->SetFileName(args[0]);
  fixedImageReader->Update();
  fixedImage = fixedImageReader->GetOutput();
  try
  {
    fixedImage->Update();
  }
  catch (const itk::ExceptionObject & excp)
  {
    std::cerr << excp << std::endl;
    return EXIT_FAILURE;
  }
  // ==========read the moving image
  ImageReaderType::Pointer movingImageReader = ImageReaderType::New();
  movingImageReader->SetFileName(args[1]);
  movingImageReader->Update();
  movingImage = movingImageReader->GetOutput();
  try
  {
    movingImage->Update();
  }
  catch (const itk::ExceptionObject & excp)
  {
    std::cerr << excp << std::endl;
    return EXIT_FAILURE;
  }

  std::cout << "  fixed image: " << args[0] << std::endl;
  std::cout << "  moving image: " << args[1] << std::endl;

  // ===========Read the initial transform and write that in a composite transform
  using TransformType = RegistrationHelperType::TransformType;
  TransformType::Pointer initialTransform = itk::ants::ReadTransform<double, 3>(args[2]);
  if (initialTransform.IsNull())
  {
    std::cerr << "Can't read initial transform " << std::endl;
    return EXIT_FAILURE;
  }
  CompositeTransformType::Pointer compositeInitialTransform = CompositeTransformType::New();
  compositeInitialTransform->AddTransform(initialTransform);

  // =========write the output transform
  // compute the output transform by calling the "simpleSynReg" function
  CompositeTransformType::TransformTypePointer outputTransform =
    simpleSynReg(fixedImage, movingImage, compositeInitialTransform);
  if (outputTransform.IsNull())
  {
    std::cerr << "ERROR: Registration FAILED TO PRODUCE VALID TRANSFORM ...\n" << std::endl;
    return EXIT_FAILURE;
  }
  std::cout << "***** Ready to write the results ...\n" << std::endl;
  std::stringstream outputFileName;
  outputFileName << args[3] << "Warp.nii.gz";
  itk::ants::WriteTransform<double, 3>(outputTransform, outputFileName.str());

  // compute and write the inverse of the output transform
  const bool writeInverse(true);
  if (writeInverse)
  {
    using DisplacementFieldTransformType = RegistrationHelperType::DisplacementFieldTransformType;
    DisplacementFieldTransformType::Pointer dispTransform =
      dynamic_cast<DisplacementFieldTransformType *>(outputTransform.GetPointer());
    using DisplacementFieldType = DisplacementFieldTransformType::DisplacementFieldType;
    std::stringstream outputInverseFileName;
    outputInverseFileName << args[3] << "InverseWarp.nii.gz";
    using InverseWriterType = itk::ImageFileWriter<DisplacementFieldType>;
    InverseWriterType::Pointer inverseWriter = InverseWriterType::New();
    inverseWriter->SetInput(dispTransform->GetInverseDisplacementField());
    inverseWriter->SetFileName(outputInverseFileName.str().c_str());
    inverseWriter->Update();
  }
  return EXIT_SUCCESS;
}
} // namespace ants