File: itkKdTreeBasedKmeansEstimator.h

package info (click to toggle)
insighttoolkit5 5.4.3-5
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 704,384 kB
  • sloc: cpp: 783,592; ansic: 628,724; xml: 44,704; fortran: 34,250; python: 22,874; sh: 4,078; pascal: 2,636; lisp: 2,158; makefile: 464; yacc: 328; asm: 205; perl: 203; lex: 146; tcl: 132; javascript: 98; csh: 81
file content (341 lines) | stat: -rw-r--r-- 12,486 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
/*=========================================================================
 *
 *  Copyright NumFOCUS
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *         https://www.apache.org/licenses/LICENSE-2.0.txt
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 *
 *=========================================================================*/
#ifndef itkKdTreeBasedKmeansEstimator_h
#define itkKdTreeBasedKmeansEstimator_h

#include <vector>
#include <unordered_map>

#include "itkObject.h"
#include "itkEuclideanDistanceMetric.h"
#include "itkDistanceToCentroidMembershipFunction.h"
#include "itkSimpleDataObjectDecorator.h"
#include "itkNumericTraitsArrayPixel.h"

namespace itk
{
namespace Statistics
{
/**
 * \class KdTreeBasedKmeansEstimator
 * \brief fast k-means algorithm implementation using k-d tree structure
 *
 * It returns k mean vectors that are centroids of k-clusters
 * using pre-generated k-d tree. k-d tree generation is done by
 * the WeightedCentroidKdTreeGenerator. The tree construction needs
 * to be done only once. The resulting k-d tree's non-terminal nodes
 * that have their children nodes have vector sums of measurement vectors
 * that belong to the nodes and the number of measurement vectors
 * in addition to the typical node boundary information and pointers to
 * children nodes. Instead of reassigning every measurement vector to
 * the nearest cluster centroid and recalculating centroid, it maintain
 * a set of cluster centroid candidates and using pruning algorithm that
 * utilizes k-d tree, it updates the means of only relevant candidates at
 * each iterations. It would be faster than traditional implementation
 * of k-means algorithm. However, the k-d tree consumes a large amount
 * of memory. The tree construction time and pruning algorithm's performance
 * are important factors to the whole process's performance. If users
 * want to use k-d tree for some purpose other than k-means estimation,
 * they can use the KdTreeGenerator instead of the
 * WeightedCentroidKdTreeGenerator. It will save the tree construction
 * time and memory usage.
 *
 * Note: There is a second implementation of k-means algorithm in ITK under the
 * While the Kd tree based implementation is more time efficient, the  GLA/LBG
 * based algorithm is more memory efficient.
 *
 * <b>Recent API changes:</b>
 * The static const macro to get the length of a measurement vector,
 * \c MeasurementVectorSize  has been removed to allow the length of a measurement
 * vector to be specified at run time. It is now obtained from the KdTree set
 * as input. You may query this length using the function GetMeasurementVectorSize().
 *
 * \sa ImageKmeansModelEstimator
 * \sa WeightedCentroidKdTreeGenerator, KdTree
 * \ingroup ITKStatistics
 */

template <typename TKdTree>
class ITK_TEMPLATE_EXPORT KdTreeBasedKmeansEstimator : public Object
{
public:
  /** Standard Self type alias. */
  using Self = KdTreeBasedKmeansEstimator;
  using Superclass = Object;
  using Pointer = SmartPointer<Self>;
  using ConstPointer = SmartPointer<const Self>;

  /** Method for creation through the object factory. */
  itkNewMacro(Self);

  /** \see LightObject::GetNameOfClass() */
  itkOverrideGetNameOfClassMacro(KdTreeBasedKmeansEstimator);

  /** Types for the KdTree data structure */
  using KdTreeNodeType = typename TKdTree::KdTreeNodeType;
  using MeasurementType = typename TKdTree::MeasurementType;
  using MeasurementVectorType = typename TKdTree::MeasurementVectorType;
  using InstanceIdentifier = typename TKdTree::InstanceIdentifier;
  using SampleType = typename TKdTree::SampleType;
  using CentroidType = typename KdTreeNodeType::CentroidType;

  /** Typedef for the length of a measurement vector */
  using MeasurementVectorSizeType = unsigned int;

  /**  Parameters type.
   *  It defines a position in the optimization search space. */
  using ParameterType = Array<double>;
  using InternalParametersType = std::vector<ParameterType>;
  using ParametersType = Array<double>;

  /** Typedef required to generate dataobject decorated output that can
   * be plugged into SampleClassifierFilter */
  using DistanceToCentroidMembershipFunctionType = DistanceToCentroidMembershipFunction<MeasurementVectorType>;

  using DistanceToCentroidMembershipFunctionPointer = typename DistanceToCentroidMembershipFunctionType::Pointer;

  using MembershipFunctionType = MembershipFunctionBase<MeasurementVectorType>;
  using MembershipFunctionPointer = typename MembershipFunctionType::ConstPointer;
  using MembershipFunctionVectorType = std::vector<MembershipFunctionPointer>;
  using MembershipFunctionVectorObjectType = SimpleDataObjectDecorator<MembershipFunctionVectorType>;
  using MembershipFunctionVectorObjectPointer = typename MembershipFunctionVectorObjectType::Pointer;

  /** Output Membership function vector containing the membership functions with
   * the final optimized parameters */
  const MembershipFunctionVectorObjectType *
  GetOutput() const;

  /**  Set the position to initialize the optimization. */
  itkSetMacro(Parameters, ParametersType);
  itkGetConstMacro(Parameters, ParametersType);

  /** Set/Get maximum iteration limit. */
  itkSetMacro(MaximumIteration, int);
  itkGetConstMacro(MaximumIteration, int);

  /** Set/Get the termination threshold for the squared sum
   * of changes in centroid positions after one iteration */
  itkSetMacro(CentroidPositionChangesThreshold, double);
  itkGetConstMacro(CentroidPositionChangesThreshold, double);
  /** Set/Get the pointer to the KdTree */
  void
  SetKdTree(TKdTree * tree);

  const TKdTree *
  GetKdTree() const;

  /** Get the length of measurement vectors in the KdTree */
  itkGetConstMacro(MeasurementVectorSize, MeasurementVectorSizeType);

  itkGetConstMacro(CurrentIteration, int);
  itkGetConstMacro(CentroidPositionChanges, double);

  /** Start optimization
   * Optimization will stop when it meets either of two termination conditions,
   * the maximum iteration limit or epsilon (minimal changes in squared sum
   * of changes in centroid positions)  */
  void
  StartOptimization();

  using ClusterLabelsType = std::unordered_map<InstanceIdentifier, unsigned int>;

  itkSetMacro(UseClusterLabels, bool);
  itkGetConstMacro(UseClusterLabels, bool);
  itkBooleanMacro(UseClusterLabels);

protected:
  KdTreeBasedKmeansEstimator();
  ~KdTreeBasedKmeansEstimator() override = default;

  void
  PrintSelf(std::ostream & os, Indent indent) const override;

  void
  FillClusterLabels(KdTreeNodeType * node, int closestIndex);

  /**
   * \class CandidateVector
   * \brief Candidate Vector
   * \ingroup ITKStatistics
   */
  class CandidateVector
  {
  public:
    CandidateVector() = default;

    struct Candidate
    {
      CentroidType Centroid;
      CentroidType WeightedCentroid;
      int          Size;
    }; // end of struct

    virtual ~CandidateVector() = default;

    /** returns the number of candidate = k */
    int
    Size() const
    {
      return static_cast<int>(m_Candidates.size());
    }

    /** Initialize the centroids with the argument.
     * At each iteration, this should be called before filtering. */
    void
    SetCentroids(InternalParametersType & centroids)
    {
      this->m_MeasurementVectorSize = NumericTraits<ParameterType>::GetLength(centroids[0]);
      m_Candidates.resize(centroids.size());
      for (unsigned int i = 0; i < centroids.size(); ++i)
      {
        Candidate candidate;
        candidate.Centroid = centroids[i];
        NumericTraits<CentroidType>::SetLength(candidate.WeightedCentroid, m_MeasurementVectorSize);
        candidate.WeightedCentroid.Fill(0.0);
        candidate.Size = 0;
        m_Candidates[i] = candidate;
      }
    }

    /** gets the centroids (k-means) */
    void
    GetCentroids(InternalParametersType & centroids)
    {
      unsigned int i;

      centroids.resize(this->Size());
      for (i = 0; i < static_cast<unsigned int>(this->Size()); ++i)
      {
        centroids[i] = m_Candidates[i].Centroid;
      }
    }

    /** updates the centroids using the vector sum of measurement vectors
     * that belongs to each centroid and the number of measurement vectors */
    void
    UpdateCentroids()
    {
      unsigned int i, j;

      for (i = 0; i < static_cast<unsigned int>(this->Size()); ++i)
      {
        if (m_Candidates[i].Size > 0)
        {
          for (j = 0; j < m_MeasurementVectorSize; ++j)
          {
            m_Candidates[i].Centroid[j] =
              m_Candidates[i].WeightedCentroid[j] / static_cast<double>(m_Candidates[i].Size);
          }
        }
      }
    }

    /** gets the index-th candidates */
    Candidate & operator[](int index) { return m_Candidates[index]; }

  private:
    /** internal storage for the candidates */
    std::vector<Candidate> m_Candidates;

    /** Length of each measurement vector */
    MeasurementVectorSizeType m_MeasurementVectorSize{ 0 };
  }; // end of class

  /** gets the sum of squared difference between the previous position
   * and current position of all centroid. This is the primary termination
   * condition for this algorithm. If the return value is less than
   * the value that was set by the SetCentroidPositionChangesThreshold
   * method. */
  double
  GetSumOfSquaredPositionChanges(InternalParametersType & previous, InternalParametersType & current);

  /** get the index of the closest candidate to the measurements
   * measurement vector */
  int
  GetClosestCandidate(ParameterType & measurements, std::vector<int> & validIndexes);

  /** returns true if the pointA is farther than pointB to the boundary */
  bool
  IsFarther(ParameterType &         pointA,
            ParameterType &         pointB,
            MeasurementVectorType & lowerBound,
            MeasurementVectorType & upperBound);

  /** recursive pruning algorithm. the validIndexes vector contains
   * only the indexes of the surviving candidates for the node */
  void
  Filter(KdTreeNodeType *        node,
         std::vector<int>        validIndexes,
         MeasurementVectorType & lowerBound,
         MeasurementVectorType & upperBound);

  /** copies the source parameters (k-means) to the target */
  void
  CopyParameters(InternalParametersType & source, InternalParametersType & target);

  /** copies the source parameters (k-means) to the target */
  void
  CopyParameters(ParametersType & source, InternalParametersType & target);

  /** copies the source parameters (k-means) to the target */
  void
  CopyParameters(InternalParametersType & source, ParametersType & target);

  /** imports the measurements measurement vector data to the point */
  void
  GetPoint(ParameterType & point, MeasurementVectorType measurements);

  void
  PrintPoint(ParameterType & point);

private:
  /** current number of iteration */
  int m_CurrentIteration{ 0 };
  /** maximum number of iteration. termination criterion */
  int m_MaximumIteration{ 100 };
  /** sum of squared centroid position changes at the current iteration */
  double m_CentroidPositionChanges{ 0.0 };
  /** threshold for the sum of squared centroid position changes.
   * termination criterion */
  double m_CentroidPositionChangesThreshold{ 0.0 };
  /** pointer to the k-d tree */
  typename TKdTree::Pointer m_KdTree{};
  /** pointer to the euclidean distance function */
  typename EuclideanDistanceMetric<ParameterType>::Pointer m_DistanceMetric{};

  /** k-means */
  ParametersType m_Parameters{};

  CandidateVector m_CandidateVector{};

  ParameterType m_TempVertex{};

  bool                                  m_UseClusterLabels{ false };
  bool                                  m_GenerateClusterLabels{ false };
  ClusterLabelsType                     m_ClusterLabels{};
  MeasurementVectorSizeType             m_MeasurementVectorSize{ 0 };
  MembershipFunctionVectorObjectPointer m_MembershipFunctionsObject{};
}; // end of class
} // end of namespace Statistics
} // end of namespace itk

#ifndef ITK_MANUAL_INSTANTIATION
#  include "itkKdTreeBasedKmeansEstimator.hxx"
#endif

#endif