File: otbMeanShiftSmoothingImageFilter.hxx

package info (click to toggle)
otb 7.2.0%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 1,005,476 kB
  • sloc: cpp: 270,143; xml: 128,722; ansic: 4,367; sh: 1,768; python: 1,084; perl: 92; makefile: 72
file content (826 lines) | stat: -rw-r--r-- 32,807 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
/*
 * Copyright (C) 2005-2020 Centre National d'Etudes Spatiales (CNES)
 *
 * This file is part of Orfeo Toolbox
 *
 *     https://www.orfeo-toolbox.org/
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */


#ifndef otbMeanShiftSmoothingImageFilter_hxx
#define otbMeanShiftSmoothingImageFilter_hxx

#include "otbMeanShiftSmoothingImageFilter.h"
#include "itkImageRegionIterator.h"
#include "otbUnaryFunctorWithIndexWithOutputSizeImageFilter.h"
#include "otbMacro.h"

#include "itkProgressReporter.h"


namespace otb
{
template <class TInputImage, class TOutputImage, class TKernel, class TOutputIterationImage>
MeanShiftSmoothingImageFilter<TInputImage, TOutputImage, TKernel, TOutputIterationImage>::MeanShiftSmoothingImageFilter()
  : m_RangeBandwidth(16.),
    m_RangeBandwidthRamp(0),
    m_SpatialBandwidth(3)
    // , m_SpatialRadius(???)
    ,
    m_Threshold(1e-3),
    m_MaxIterationNumber(10)
    // , m_Kernel(...)
    ,
    m_NumberOfComponentsPerPixel(0)
    // , m_JointImage(0)
    // , m_ModeTable(0)
    ,
    m_ModeSearch(false),
    m_ThreadIdNumberOfBits(0)
#if 0
      , m_BucketOptimization(false)
#endif
{
  this->SetNumberOfRequiredOutputs(4);
  this->SetNthOutput(0, OutputImageType::New());
  this->SetNthOutput(1, OutputSpatialImageType::New());
  this->SetNthOutput(2, OutputIterationImageType::New());
  this->SetNthOutput(3, OutputLabelImageType::New());
  m_GlobalShift.Fill(0);
}

template <class TInputImage, class TOutputImage, class TKernel, class TOutputIterationImage>
MeanShiftSmoothingImageFilter<TInputImage, TOutputImage, TKernel, TOutputIterationImage>::~MeanShiftSmoothingImageFilter()
{
}

template <class TInputImage, class TOutputImage, class TKernel, class TOutputIterationImage>
const typename MeanShiftSmoothingImageFilter<TInputImage, TOutputImage, TKernel, TOutputIterationImage>::OutputSpatialImageType*
MeanShiftSmoothingImageFilter<TInputImage, TOutputImage, TKernel, TOutputIterationImage>::GetSpatialOutput() const
{
  return static_cast<const OutputSpatialImageType*>(this->itk::ProcessObject::GetOutput(1));
}

template <class TInputImage, class TOutputImage, class TKernel, class TOutputIterationImage>
typename MeanShiftSmoothingImageFilter<TInputImage, TOutputImage, TKernel, TOutputIterationImage>::OutputSpatialImageType*
MeanShiftSmoothingImageFilter<TInputImage, TOutputImage, TKernel, TOutputIterationImage>::GetSpatialOutput()
{
  return static_cast<OutputSpatialImageType*>(this->itk::ProcessObject::GetOutput(1));
}

template <class TInputImage, class TOutputImage, class TKernel, class TOutputIterationImage>
const typename MeanShiftSmoothingImageFilter<TInputImage, TOutputImage, TKernel, TOutputIterationImage>::OutputImageType*
MeanShiftSmoothingImageFilter<TInputImage, TOutputImage, TKernel, TOutputIterationImage>::GetRangeOutput() const
{
  return static_cast<const OutputImageType*>(this->itk::ProcessObject::GetOutput(0));
}

template <class TInputImage, class TOutputImage, class TKernel, class TOutputIterationImage>
typename MeanShiftSmoothingImageFilter<TInputImage, TOutputImage, TKernel, TOutputIterationImage>::OutputImageType*
MeanShiftSmoothingImageFilter<TInputImage, TOutputImage, TKernel, TOutputIterationImage>::GetRangeOutput()
{
  return static_cast<OutputImageType*>(this->itk::ProcessObject::GetOutput(0));
}

template <class TInputImage, class TOutputImage, class TKernel, class TOutputIterationImage>
typename MeanShiftSmoothingImageFilter<TInputImage, TOutputImage, TKernel, TOutputIterationImage>::OutputIterationImageType*
MeanShiftSmoothingImageFilter<TInputImage, TOutputImage, TKernel, TOutputIterationImage>::GetIterationOutput()
{
  return static_cast<OutputIterationImageType*>(this->itk::ProcessObject::GetOutput(2));
}

template <class TInputImage, class TOutputImage, class TKernel, class TOutputIterationImage>
const typename MeanShiftSmoothingImageFilter<TInputImage, TOutputImage, TKernel, TOutputIterationImage>::OutputIterationImageType*
MeanShiftSmoothingImageFilter<TInputImage, TOutputImage, TKernel, TOutputIterationImage>::GetIterationOutput() const
{
  return static_cast<OutputIterationImageType*>(this->itk::ProcessObject::GetOutput(2));
}

template <class TInputImage, class TOutputImage, class TKernel, class TOutputIterationImage>
typename MeanShiftSmoothingImageFilter<TInputImage, TOutputImage, TKernel, TOutputIterationImage>::OutputLabelImageType*
MeanShiftSmoothingImageFilter<TInputImage, TOutputImage, TKernel, TOutputIterationImage>::GetLabelOutput()
{
  return static_cast<OutputLabelImageType*>(this->itk::ProcessObject::GetOutput(3));
}

template <class TInputImage, class TOutputImage, class TKernel, class TOutputIterationImage>
const typename MeanShiftSmoothingImageFilter<TInputImage, TOutputImage, TKernel, TOutputIterationImage>::OutputLabelImageType*
MeanShiftSmoothingImageFilter<TInputImage, TOutputImage, TKernel, TOutputIterationImage>::GetLabelOutput() const
{
  return static_cast<OutputLabelImageType*>(this->itk::ProcessObject::GetOutput(3));
}

template <class TInputImage, class TOutputImage, class TKernel, class TOutputIterationImage>
void MeanShiftSmoothingImageFilter<TInputImage, TOutputImage, TKernel, TOutputIterationImage>::AllocateOutputs()
{
  typename OutputSpatialImageType::Pointer   spatialOutputPtr   = this->GetSpatialOutput();
  typename OutputImageType::Pointer          rangeOutputPtr     = this->GetRangeOutput();
  typename OutputIterationImageType::Pointer iterationOutputPtr = this->GetIterationOutput();
  typename OutputLabelImageType::Pointer     labelOutputPtr     = this->GetLabelOutput();

  spatialOutputPtr->SetBufferedRegion(spatialOutputPtr->GetRequestedRegion());
  spatialOutputPtr->Allocate();

  rangeOutputPtr->SetBufferedRegion(rangeOutputPtr->GetRequestedRegion());
  rangeOutputPtr->Allocate();

  iterationOutputPtr->SetBufferedRegion(iterationOutputPtr->GetRequestedRegion());
  iterationOutputPtr->Allocate();

  labelOutputPtr->SetBufferedRegion(labelOutputPtr->GetRequestedRegion());
  labelOutputPtr->Allocate();
}

template <class TInputImage, class TOutputImage, class TKernel, class TOutputIterationImage>
void MeanShiftSmoothingImageFilter<TInputImage, TOutputImage, TKernel, TOutputIterationImage>::GenerateOutputInformation()
{
  Superclass::GenerateOutputInformation();

  m_NumberOfComponentsPerPixel = this->GetInput()->GetNumberOfComponentsPerPixel();

  if (this->GetSpatialOutput())
  {
    this->GetSpatialOutput()->SetNumberOfComponentsPerPixel(ImageDimension); // image lattice
  }
  if (this->GetRangeOutput())
  {
    this->GetRangeOutput()->SetNumberOfComponentsPerPixel(m_NumberOfComponentsPerPixel);
  }
}

template <class TInputImage, class TOutputImage, class TKernel, class TOutputIterationImage>
void MeanShiftSmoothingImageFilter<TInputImage, TOutputImage, TKernel, TOutputIterationImage>::GenerateInputRequestedRegion()
{
  // Call superclass implementation
  Superclass::GenerateInputRequestedRegion();

  // Retrieve input pointers
  InputImagePointerType  inPtr       = const_cast<TInputImage*>(this->GetInput());
  OutputImagePointerType outRangePtr = this->GetRangeOutput();

  // Check pointers before using them
  if (!inPtr || !outRangePtr)
  {
    return;
  }

  // Retrieve requested region (TODO: check if we need to handle
  // region for outHDispPtr)
  RegionType outputRequestedRegion = outRangePtr->GetRequestedRegion();

  // Pad by the appropriate radius
  RegionType inputRequestedRegion = outputRequestedRegion;

  // Initializes the spatial radius from kernel bandwidth
  m_SpatialRadius.Fill(m_Kernel.GetRadius(m_SpatialBandwidth));

  InputSizeType margin;

  for (unsigned int comp = 0; comp < ImageDimension; ++comp)
  {
    margin[comp] = (m_MaxIterationNumber * m_SpatialRadius[comp]) + 1;
  }

  inputRequestedRegion.PadByRadius(margin);

  // Crop the input requested region at the input's largest possible region
  if (inputRequestedRegion.Crop(inPtr->GetLargestPossibleRegion()))
  {
    inPtr->SetRequestedRegion(inputRequestedRegion);
    return;
  }
  else
  {
    // Couldn't crop the region (requested region is outside the largest
    // possible region).  Throw an exception.

    // store what we tried to request (prior to trying to crop)
    inPtr->SetRequestedRegion(inputRequestedRegion);

    // build an exception
    itk::InvalidRequestedRegionError e(__FILE__, __LINE__);
    e.SetLocation(ITK_LOCATION);
    e.SetDescription("Requested region is (at least partially) outside the largest possible region.");
    e.SetDataObject(inPtr);
    throw e;
  }
}

template <class TInputImage, class TOutputImage, class TKernel, class TOutputIterationImage>
void MeanShiftSmoothingImageFilter<TInputImage, TOutputImage, TKernel, TOutputIterationImage>::BeforeThreadedGenerateData()
{
  // typedef itk::ImageRegionConstIteratorWithIndex<InputImageType> InputIteratorWithIndexType;
  // typedef itk::ImageRegionIterator<RealVectorImageType> JointImageIteratorType;

  OutputSpatialImagePointerType              outSpatialPtr   = this->GetSpatialOutput();
  OutputImagePointerType                     outRangePtr     = this->GetRangeOutput();
  typename InputImageType::ConstPointer      inputPtr        = this->GetInput();
  typename OutputIterationImageType::Pointer iterationOutput = this->GetIterationOutput();
  typename OutputSpatialImageType::Pointer   spatialOutput   = this->GetSpatialOutput();

  // InputIndexType index;


  m_SpatialRadius.Fill(m_Kernel.GetRadius(m_SpatialBandwidth));

  m_NumberOfComponentsPerPixel = this->GetInput()->GetNumberOfComponentsPerPixel();

  // Allocate output images
  this->AllocateOutputs();

  // Initialize output images to zero
  iterationOutput->FillBuffer(0);
  OutputSpatialPixelType zero(spatialOutput->GetNumberOfComponentsPerPixel());
  zero.Fill(0);
  spatialOutput->FillBuffer(zero);

  // m_JointImage is the input data expressed in the joint spatial-range
  // domain, i.e. spatial coordinates are concatenated to the range values.
  // Moreover, pixel components in this image are normalized by their respective
  // (spatial or range) bandwidth.
  typedef Meanshift::SpatialRangeJointDomainTransform<InputImageType, RealVectorImageType> FunctionType;
  typedef otb::UnaryFunctorWithIndexWithOutputSizeImageFilter<InputImageType, RealVectorImageType, FunctionType> JointImageFunctorType;

  typename JointImageFunctorType::Pointer jointImageFunctor = JointImageFunctorType::New();

  jointImageFunctor->SetInput(inputPtr);
  jointImageFunctor->GetFunctor().Initialize(ImageDimension, m_NumberOfComponentsPerPixel, m_GlobalShift);
  jointImageFunctor->GetOutput()->SetRequestedRegion(this->GetInput()->GetBufferedRegion());
  jointImageFunctor->Update();
  m_JointImage = jointImageFunctor->GetOutput();

#if 0
  if (m_BucketOptimization)
    {
    // Create bucket image
    // Note: because values in the input m_JointImage are normalized, the
    // rangeRadius argument is just 1
    m_BucketImage = BucketImageType(static_cast<typename RealVectorImageType::ConstPointer> (m_JointImage),
                                    m_JointImage->GetRequestedRegion(), m_Kernel.GetRadius(m_SpatialBandwidth), 1,
                                    ImageDimension);
    }
#endif
  /*
   // Allocate the joint domain image
   m_JointImage = RealVectorImageType::New();
   m_JointImage->SetNumberOfComponentsPerPixel(ImageDimension + m_NumberOfComponentsPerPixel);
   m_JointImage->SetRegions(inputPtr->GetRequestedRegion());
   m_JointImage->Allocate();

   InputIteratorWithIndexType inputIt(inputPtr, inputPtr->GetRequestedRegion());
   JointImageIteratorType jointIt(m_JointImage, inputPtr->GetRequestedRegion());

   // Initialize the joint image with scaled values
   inputIt.GoToBegin();
   jointIt.GoToBegin();

   while (!inputIt.IsAtEnd())
   {
   typename InputImageType::PixelType const& inputPixel = inputIt.Get();
   index = inputIt.GetIndex();

   RealVector & jointPixel = jointIt.Get();
   for(unsigned int comp = 0; comp < ImageDimension; comp++)
   {
   jointPixel[comp] = index[comp] / m_SpatialBandwidth;
   }
   for(unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++)
   {
   jointPixel[ImageDimension + comp] = inputPixel[comp] / m_RangeBandwidth;
   }
   // jointIt.Set(jointPixel);

   ++inputIt;
   ++jointIt;
   }
   */

  // TODO don't create mode table iterator when ModeSearch is set to false
  m_ModeTable = ModeTableImageType::New();
  m_ModeTable->SetRegions(inputPtr->GetRequestedRegion());
  m_ModeTable->Allocate();
  m_ModeTable->FillBuffer(0);

  if (m_ModeSearch)
  {
    // Image to store the status at each pixel:
    // 0 : no mode has been found yet
    // 1 : a mode has been assigned to this pixel
    // 2 : a mode will be assigned to this pixel


    // Initialize counters for mode (also used for mode labeling)
    // Most significant bits of label counters are used to identify the thread
    // Id.
    unsigned int numThreads;

    numThreads             = this->GetNumberOfThreads();
    m_ThreadIdNumberOfBits = -1;
    unsigned int n         = numThreads;
    while (n != 0)
    {
      n >>= 1;
      m_ThreadIdNumberOfBits++;
    }
    if (m_ThreadIdNumberOfBits == 0)
      m_ThreadIdNumberOfBits = 1; // minimum 1 bit
    m_NumLabels.SetSize(numThreads);
    for (unsigned int i = 0; i < numThreads; i++)
    {
      m_NumLabels[i] = static_cast<LabelType>(i) << (sizeof(LabelType) * 8 - m_ThreadIdNumberOfBits);
    }
  }
}

// Calculates the mean shift vector at the position given by jointPixel
template <class TInputImage, class TOutputImage, class TKernel, class TOutputIterationImage>
void MeanShiftSmoothingImageFilter<TInputImage, TOutputImage, TKernel, TOutputIterationImage>::CalculateMeanShiftVector(
    const typename RealVectorImageType::Pointer jointImage, const RealVector& jointPixel, const OutputRegionType& outputRegion, const RealVector& bandwidth,
    RealVector& meanShiftVector)
{
  const unsigned int jointDimension = ImageDimension + m_NumberOfComponentsPerPixel;

  InputIndexType inputIndex;
  InputIndexType regionIndex;
  InputSizeType  regionSize;

  assert(meanShiftVector.GetSize() == jointDimension);
  meanShiftVector.Fill(0);

  // Calculates current pixel neighborhood region, restricted to the output image region
  for (unsigned int comp = 0; comp < ImageDimension; ++comp)
  {
    inputIndex[comp] = std::floor(jointPixel[comp] + 0.5) - m_GlobalShift[comp];

    regionIndex[comp] =
        std::max(static_cast<long int>(outputRegion.GetIndex().GetElement(comp)), static_cast<long int>(inputIndex[comp] - m_SpatialRadius[comp] - 1));
    const long int indexRight = std::min(static_cast<long int>(outputRegion.GetIndex().GetElement(comp) + outputRegion.GetSize().GetElement(comp) - 1),
                                         static_cast<long int>(inputIndex[comp] + m_SpatialRadius[comp] + 1));

    regionSize[comp] = std::max(0l, indexRight - static_cast<long int>(regionIndex[comp]) + 1);
  }

  RegionType neighborhoodRegion;
  neighborhoodRegion.SetIndex(regionIndex);
  neighborhoodRegion.SetSize(regionSize);

  RealType   weightSum = 0;
  RealVector shifts(jointDimension);

  // An iterator on the neighborhood of the current pixel (in joint
  // spatial-range domain)
  otb::Meanshift::FastImageRegionConstIterator<RealVectorImageType> it(jointImage, neighborhoodRegion);
  // itk::ImageRegionConstIterator<RealVectorImageType> it(jointImage, neighborhoodRegion);

  it.GoToBegin();
  while (!it.IsAtEnd())
  {
    const RealType* jointNeighbor = it.GetPixelPointer();

    // Compute the squared norm of the difference
    // This is the L2 norm, TODO: replace by the templated norm
    RealType norm2 = 0;
    for (unsigned int comp = 0; comp < jointDimension; comp++)
    {
      shifts[comp] = jointNeighbor[comp] - jointPixel[comp];
      double d     = shifts[comp] / bandwidth[comp];
      norm2 += d * d;
    }

    // Compute pixel weight from kernel
    const RealType weight = m_Kernel(norm2);
    /*
     // The following code is an alternative way to compute norm2 and weight
     // It separates the norms of spatial and range elements
     RealType spatialNorm2;
     RealType rangeNorm2;
     spatialNorm2 = 0;
     for (unsigned int comp = 0; comp < ImageDimension; comp++)
     {
     RealType d;
     d = jointNeighbor[comp] - jointPixel[comp];
     spatialNorm2 += d*d;
     }

     if(spatialNorm2 >= 1.0)
     {
     weight = 0;
     }
     else
     {
     rangeNorm2 = 0;
     for (unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++)
     {
     RealType d;
     d = jointNeighbor[ImageDimension + comp] - jointPixel[ImageDimension + comp];
     rangeNorm2 += d*d;
     }

     weight = (rangeNorm2 <= 1.0)? 1.0 : 0.0;
     }
     */

    // Update sum of weights
    weightSum += weight;

    // Update mean shift vector
    for (unsigned int comp = 0; comp < jointDimension; comp++)
    {
      meanShiftVector[comp] += weight * shifts[comp];
    }

    ++it;
  }

  if (weightSum > 0)
  {
    for (unsigned int comp = 0; comp < jointDimension; comp++)
    {
      meanShiftVector[comp] = meanShiftVector[comp] / weightSum;
    }
  }
}

#if 0
// Calculates the mean shift vector at the position given by jointPixel
template<class TInputImage, class TOutputImage, class TKernel, class TOutputIterationImage>
void MeanShiftSmoothingImageFilter<TInputImage, TOutputImage, TKernel, TOutputIterationImage>::CalculateMeanShiftVectorBucket(
                                                                                                                              const RealVector& jointPixel,
                                                                                                                              RealVector& meanShiftVector)
{
  const unsigned int jointDimension = ImageDimension + m_NumberOfComponentsPerPixel;

  RealType weightSum = 0;

  for (unsigned int comp = 0; comp < jointDimension; comp++)
    {
    meanShiftVector[comp] = 0;
    }

  RealVector jointNeighbor(ImageDimension + m_NumberOfComponentsPerPixel);

  InputIndexType index;
  for (unsigned int dim = 0; dim < ImageDimension; ++dim)
    {
    index[dim] = jointPixel[dim] * m_SpatialBandwidth + 0.5;
    }

  const std::vector<unsigned int>
      neighborBuckets = m_BucketImage.GetNeighborhoodBucketListIndices(
                                                                       m_BucketImage.BucketIndexToBucketListIndex(
                                                                                                                  m_BucketImage.GetBucketIndex(
                                                                                                                                               jointPixel,
                                                                                                                                               index)));

  unsigned int numNeighbors = m_BucketImage.GetNumberOfNeighborBuckets();
  for (unsigned int neighborIndex = 0; neighborIndex < numNeighbors; ++neighborIndex)
    {
    const typename BucketImageType::BucketType & bucket = m_BucketImage.GetBucket(neighborBuckets[neighborIndex]);
    if (bucket.empty()) continue;
    typename BucketImageType::BucketType::const_iterator it = bucket.begin();
    while (it != bucket.end())
      {
      jointNeighbor.SetData(const_cast<RealType*> (*it));

      // Compute the squared norm of the difference
      // This is the L2 norm, TODO: replace by the templated norm
      RealType norm2 = 0;
      for (unsigned int comp = 0; comp < jointDimension; comp++)
        {
        const RealType d = jointNeighbor[comp] - jointPixel[comp];
        norm2 += d * d;
        }

      // Compute pixel weight from kernel
      const RealType weight = m_Kernel(norm2);

      // Update sum of weights
      weightSum += weight;

      // Update mean shift vector
      for (unsigned int comp = 0; comp < jointDimension; comp++)
        {
        meanShiftVector[comp] += weight * jointNeighbor[comp];
        }

      ++it;
      }
    }

  if (weightSum > 0)
    {
    for (unsigned int comp = 0; comp < jointDimension; comp++)
      {
      meanShiftVector[comp] = meanShiftVector[comp] / weightSum - jointPixel[comp];
      }
    }
}
#endif

template <class TInputImage, class TOutputImage, class TKernel, class TOutputIterationImage>
void MeanShiftSmoothingImageFilter<TInputImage, TOutputImage, TKernel, TOutputIterationImage>::ThreadedGenerateData(
    const OutputRegionType& outputRegionForThread, itk::ThreadIdType threadId)
{
  // at the first iteration


  // Retrieve output images pointers
  typename OutputSpatialImageType::Pointer   spatialOutput   = this->GetSpatialOutput();
  typename OutputImageType::Pointer          rangeOutput     = this->GetRangeOutput();
  typename OutputIterationImageType::Pointer iterationOutput = this->GetIterationOutput();
  typename OutputLabelImageType::Pointer     labelOutput     = this->GetLabelOutput();

  // Get input image pointer
  typename InputImageType::ConstPointer input = this->GetInput();

  // defines input and output iterators
  typedef itk::ImageRegionIterator<OutputImageType>          OutputIteratorType;
  typedef itk::ImageRegionIterator<OutputSpatialImageType>   OutputSpatialIteratorType;
  typedef itk::ImageRegionIterator<OutputIterationImageType> OutputIterationIteratorType;
  typedef itk::ImageRegionIterator<OutputLabelImageType>     OutputLabelIteratorType;

  const unsigned int jointDimension = ImageDimension + m_NumberOfComponentsPerPixel;

  typename OutputImageType::PixelType        rangePixel(m_NumberOfComponentsPerPixel);
  typename OutputSpatialImageType::PixelType spatialPixel(ImageDimension);

  RealVector jointPixel(jointDimension);

  RealVector bandwidth(jointDimension);
  for (unsigned int comp = 0; comp < ImageDimension; comp++)
    bandwidth[comp]      = m_SpatialBandwidth;

  itk::ProgressReporter progress(this, threadId, outputRegionForThread.GetNumberOfPixels());

  RegionType const& requestedRegion = input->GetRequestedRegion();

  typedef itk::ImageRegionConstIteratorWithIndex<RealVectorImageType> JointImageIteratorType;
  JointImageIteratorType                                              jointIt(m_JointImage, outputRegionForThread);

  OutputIteratorType          rangeIt(rangeOutput, outputRegionForThread);
  OutputSpatialIteratorType   spatialIt(spatialOutput, outputRegionForThread);
  OutputIterationIteratorType iterationIt(iterationOutput, outputRegionForThread);
  OutputLabelIteratorType     labelIt(labelOutput, outputRegionForThread);

  typedef itk::ImageRegionIterator<ModeTableImageType> ModeTableImageIteratorType;
  ModeTableImageIteratorType                           modeTableIt(m_ModeTable, outputRegionForThread);

  jointIt.GoToBegin();
  rangeIt.GoToBegin();
  spatialIt.GoToBegin();
  iterationIt.GoToBegin();
  modeTableIt.GoToBegin();
  labelIt.GoToBegin();

  unsigned int iteration = 0;

  // Mean shift vector, updating the joint pixel at each iteration
  RealVector meanShiftVector(jointDimension);

  // Variables used by mode search optimization
  // List of indices where the current pixel passes through
  std::vector<InputIndexType> pointList;
  if (m_ModeSearch)
    pointList.resize(m_MaxIterationNumber);
  // Number of times an already processed candidate pixel is encountered, resulting in no
  // further computation (Used for statistics only)
  unsigned int numBreaks = 0;
  // index of the current pixel updated during the mean shift loop
  InputIndexType modeCandidate;

  for (; !jointIt.IsAtEnd(); ++jointIt, ++rangeIt, ++spatialIt, ++iterationIt, ++modeTableIt, ++labelIt, progress.CompletedPixel())
  {

    // if pixel has been already processed (by mode search optimization), skip
    typename ModeTableImageType::InternalPixelType const& currentPixelMode = modeTableIt.Get();
    if (m_ModeSearch && currentPixelMode == 1)
    {
      numBreaks++;
      continue;
    }

    bool hasConverged = false;

    // get input pixel in the joint spatial-range domain (with components
    // normalized by bandwidth)
    const RealVector& jointPixelVal = jointIt.Get(); // Pixel in the joint spatial-range domain
    for (unsigned int comp = 0; comp < jointDimension; comp++)
      jointPixel[comp]     = jointPixelVal[comp];

    for (unsigned int comp = ImageDimension; comp < jointDimension; comp++)
      bandwidth[comp]      = m_RangeBandwidthRamp * jointPixel[comp] + m_RangeBandwidth;

    // index of the currently processed output pixel
    InputIndexType currentIndex = jointIt.GetIndex();

    // Number of points currently in the pointList
    unsigned int pointCount = 0; // Note: used only in mode search optimization
    iteration               = 0;
    while ((iteration < m_MaxIterationNumber) && (!hasConverged))
    {

      if (m_ModeSearch)
      {
        // Find index of the pixel closest to the current jointPixel (not normalized by bandwidth)
        for (unsigned int comp = 0; comp < ImageDimension; comp++)
        {
          modeCandidate[comp] = std::floor(jointPixel[comp] - m_GlobalShift[comp] + 0.5);
        }
        // Check status of candidate mode

        // If pixel candidate has status 0 (no mode assigned) or 1 (mode assigned)
        // but not 2 (pixel in current search path), and pixel has actually moved
        // from its initial position, and pixel candidate is inside the output
        // region, then perform optimization tasks
        if (modeCandidate != currentIndex && m_ModeTable->GetPixel(modeCandidate) != 2 && outputRegionForThread.IsInside(modeCandidate))
        {
          // Obtain the data point to see if it close to jointPixel
          RealType          diff           = 0;
          RealVector const& candidatePixel = m_JointImage->GetPixel(modeCandidate);
          for (unsigned int comp = ImageDimension; comp < jointDimension; comp++)
          {
            const RealType d = (candidatePixel[comp] - jointPixel[comp]) / bandwidth[comp];
            diff += d * d;
          }

          if (diff < 0.5) // Spectral value is close enough
          {
            // If no mode has been associated to the candidate pixel then
            // associate it to the upcoming mode
            if (m_ModeTable->GetPixel(modeCandidate) == 0)
            {
              // Add the candidate to the list of pixels that will be assigned the
              // finally calculated mode value
              pointList[pointCount++] = modeCandidate;
              m_ModeTable->SetPixel(modeCandidate, 2);
            }
            else // == 1
            {
              // The candidate pixel has already been assigned to a mode
              // Assign the same value
              rangePixel = rangeOutput->GetPixel(modeCandidate);
              for (unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++)
              {
                jointPixel[ImageDimension + comp] = rangePixel[comp];
              }
              // Update the mode table because pixel will be assigned just now
              modeTableIt.Set(2); // m_ModeTable->SetPixel(currentIndex, 2);
              // bypass further calculation
              numBreaks++;
              break;
            }
          }
        }
      } // end if (m_ModeSearch)

// Calculate meanShiftVector
#if 0
      if (m_BucketOptimization)
        {
        this->CalculateMeanShiftVectorBucket(jointPixel, meanShiftVector);
        }
      else
        {
#endif
      this->CalculateMeanShiftVector(m_JointImage, jointPixel, requestedRegion, bandwidth, meanShiftVector);

#if 0
        }
#endif

      // Compute mean shift vector squared norm (not normalized by bandwidth)
      // and add mean shift vector to current joint pixel
      double meanShiftVectorSqNorm = 0;
      for (unsigned int comp = 0; comp < jointDimension; comp++)
      {
        const double v = meanShiftVector[comp];
        meanShiftVectorSqNorm += v * v;
        jointPixel[comp] += meanShiftVector[comp];
      }

      // TODO replace SSD Test with templated metric
      hasConverged = meanShiftVectorSqNorm < m_Threshold;
      iteration++;
    }

    for (unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++)
    {
      rangePixel[comp] = jointPixel[ImageDimension + comp];
    }

    for (unsigned int comp = 0; comp < ImageDimension; comp++)
    {
      spatialPixel[comp] = jointPixel[comp] - currentIndex[comp] - m_GlobalShift[comp];
    }

    rangeIt.Set(rangePixel);
    spatialIt.Set(spatialPixel);

    const typename OutputIterationImageType::PixelType iterationPixel = iteration;
    iterationIt.Set(iterationPixel);

    if (m_ModeSearch)
    {
      // Update the mode table now that the current pixel has been assigned
      modeTableIt.Set(1); // m_ModeTable->SetPixel(currentIndex, 1);

      // If the loop exited with hasConverged or too many iterations, then we have a new mode
      LabelType label;
      if (hasConverged || iteration == m_MaxIterationNumber)
      {
        m_NumLabels[threadId]++;
        label = m_NumLabels[threadId];
      }
      else // the loop exited through a break. Use the already assigned mode label
      {
        label = labelOutput->GetPixel(modeCandidate);
      }
      labelIt.Set(label);

      // Also assign all points in the list to the same mode
      for (unsigned int i = 0; i < pointCount; i++)
      {
        rangeOutput->SetPixel(pointList[i], rangePixel);
        m_ModeTable->SetPixel(pointList[i], 1);
        labelOutput->SetPixel(pointList[i], label);
      }
    }
    else // if ModeSearch is not set LabelOutput can't be generated
    {
      LabelType labelZero = 0;
      labelIt.Set(labelZero);
    }
  }
  // std::cout << "numBreaks: " << numBreaks << " Break ratio: " << numBreaks / (RealType)outputRegionForThread.GetNumberOfPixels() << std::endl;
}

/* after threaded convergence test */
template <class TInputImage, class TOutputImage, class TKernel, class TOutputIterationImage>
void MeanShiftSmoothingImageFilter<TInputImage, TOutputImage, TKernel, TOutputIterationImage>::AfterThreadedGenerateData()
{
  typename OutputLabelImageType::Pointer                 labelOutput = this->GetLabelOutput();
  typedef itk::ImageRegionIterator<OutputLabelImageType> OutputLabelIteratorType;
  OutputLabelIteratorType                                labelIt(labelOutput, labelOutput->GetRequestedRegion());

  // Reassign mode labels
  // Note: Labels are only computed when mode search optimization is enabled
  if (m_ModeSearch)
  {
    // New labels will be consecutive. The following vector contains the new
    // start label for each thread.
    itk::VariableLengthVector<LabelType> newLabelOffset;
    newLabelOffset.SetSize(this->GetNumberOfThreads());
    newLabelOffset[0] = 0;
    for (itk::ThreadIdType i = 1; i < this->GetNumberOfThreads(); i++)
    {
      // Retrieve the number of labels in the thread by removing the threadId
      // from the most significant bits
      LabelType localNumLabel =
          m_NumLabels[i - 1] & ((static_cast<LabelType>(1) << (sizeof(LabelType) * 8 - m_ThreadIdNumberOfBits)) - static_cast<LabelType>(1));
      newLabelOffset[i] = localNumLabel + newLabelOffset[i - 1];
    }

    labelIt.GoToBegin();

    while (!labelIt.IsAtEnd())
    {
      LabelType const label = labelIt.Get();

      // Get threadId from most significant bits
      const itk::ThreadIdType threadId = label >> (sizeof(LabelType) * 8 - m_ThreadIdNumberOfBits);

      // Relabeling
      // First get the label number by removing the threadId bits
      // Then add the label offset specific to the threadId
      LabelType newLabel = label & ((static_cast<LabelType>(1) << (sizeof(LabelType) * 8 - m_ThreadIdNumberOfBits)) - static_cast<LabelType>(1));
      newLabel += newLabelOffset[threadId];

      labelIt.Set(newLabel);
      ++labelIt;
    }
  }
}

template <class TInputImage, class TOutputImage, class TKernel, class TOutputIterationImage>
void MeanShiftSmoothingImageFilter<TInputImage, TOutputImage, TKernel, TOutputIterationImage>::PrintSelf(std::ostream& os, itk::Indent indent) const
{
  Superclass::PrintSelf(os, indent);
  os << indent << "Spatial bandwidth: " << m_SpatialBandwidth << std::endl;
  os << indent << "Range bandwidth: " << m_RangeBandwidth << std::endl;
}

} // end namespace otb

#endif