File: itkBSplineInterpolateVectorImageFunction.hxx

package info (click to toggle)
elastix 5.3.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 45,644 kB
  • sloc: cpp: 85,720; lisp: 4,118; python: 1,045; sh: 200; xml: 182; makefile: 33
file content (92 lines) | stat: -rw-r--r-- 3,714 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
/*=========================================================================
 *
 *  Copyright UMC Utrecht and contributors
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *        http://www.apache.org/licenses/LICENSE-2.0.txt
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 *
 *=========================================================================*/

#ifndef _itkBSplineInterpolateVectorImageFunction_hxx
#define _itkBSplineInterpolateVectorImageFunction_hxx

#include "itkBSplineInterpolateVectorImageFunction.h"
#include <itkVectorIndexSelectionCastImageFilter.h>

/**
 * ******************* SetInputImage ***********************
 */
template <typename TImage, typename TInterpolator>
void
BSplineInterpolateVectorImageFunction<TImage, TInterpolator>::SetInputImage(typename TImage::Pointer vectorImage)
{
  // Loop over each feature (channel) in the vector image
  // Create a separate scalar image and corresponding interpolator for it
  for (unsigned int i = 0; i < vectorImage->GetVectorLength(); ++i)
  {
    auto selector = itk::VectorIndexSelectionCastImageFilter<TImage, itk::Image<float, TImage::ImageDimension>>::New();
    selector->SetInput(vectorImage);
    selector->SetIndex(i);
    selector->Update();

    auto interpolator = TInterpolator::New();
    interpolator->SetInputImage(selector->GetOutput());
    interpolator->SetSplineOrder(3);
    m_Interpolators.push_back(interpolator);
  }
} // end SetInputImage

/**
 * ******************* Evaluate ***********************
 */
template <typename TImage, typename TInterpolator>
typename torch::Tensor
BSplineInterpolateVectorImageFunction<TImage, TInterpolator>::Evaluate(typename TImage::PointType point,
                                                                       std::vector<unsigned int> subsetOfFeatures) const
{
  std::vector<float> result;
  for (const unsigned int feature : subsetOfFeatures)
  {
    result.push_back(m_Interpolators[feature]->Evaluate(point));
  }
  return torch::from_blob(result.data(), { static_cast<int64_t>(result.size()) }, torch::kFloat32).clone();
} // end Evaluate

/**
 * ******************* EvaluateDerivative ***********************
 */
template <typename TImage, typename TInterpolator>
typename torch::Tensor
BSplineInterpolateVectorImageFunction<TImage, TInterpolator>::EvaluateDerivative(
  typename ImageType::PointType point,
  std::vector<unsigned int>     subsetOfFeatures) const
{
  using CovariantVectorType = itk::CovariantVector<float, TImage::ImageDimension>;

  std::vector<float>  derivative(subsetOfFeatures.size() * TImage::ImageDimension, 0.0f);
  CovariantVectorType dev;
  // Fill the derivative tensor with directional gradients for each selected feature
  for (int i = 0; i < subsetOfFeatures.size(); ++i)
  {
    dev = m_Interpolators[subsetOfFeatures[i]]->EvaluateDerivative(point);
    for (unsigned int it = 0; it < TImage::ImageDimension; ++it)
    {
      derivative[i * TImage::ImageDimension + it] = static_cast<float>(dev[it]);
    }
  }
  return torch::from_blob(derivative.data(),
                          { static_cast<int64_t>(subsetOfFeatures.size()), TImage::ImageDimension },
                          torch::kFloat32)
    .clone();
} // end EvaluateDerivative

#endif // end #ifndef _itkBSplineInterpolateVectorImageFunction_hxx