File: otbSVMClassifier.h

package info (click to toggle)
otb 5.8.0%2Bdfsg-3
  • links: PTS, VCS
  • area: main
  • in suites: stretch
  • size: 38,496 kB
  • ctags: 40,282
  • sloc: cpp: 306,573; ansic: 3,575; python: 450; sh: 214; perl: 74; java: 72; makefile: 70
file content (136 lines) | stat: -rw-r--r-- 4,273 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
/*=========================================================================

  Program:   ORFEO Toolbox
  Language:  C++
  Date:      $Date$
  Version:   $Revision$


  Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
  See OTBCopyright.txt for details.


     This software is distributed WITHOUT ANY WARRANTY; without even
     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
     PURPOSE.  See the above copyright notices for more information.

=========================================================================*/
#ifndef otbSVMClassifier_h
#define otbSVMClassifier_h

#include "vcl_deprecated_header.h"

#include "itkSampleClassifierFilter.h"
#include "otbSVMModel.h"
#include "itkVectorImage.h"
#include "itkListSample.h"

namespace otb
{

/** \class SVMClassifier
 *  \brief SVM-based classifier
 *
 * The first template argument is the type of the target sample data
 * that this classifier will assign a class label for each measurement
 * vector. The second one is the pixel type of the labels to be
 * produced by the classifier.
 *
 * Before you call the GenerateData method to start the classification
 * process, you should plug in all necessary parts ( a SVM model and a
 * target sample data).
 *
 * The classification result is stored in a vector of Subsample object.
 * Each class has its own class sample (Subsample object) that has
 * InstanceIdentifiers for all measurement vectors belong to the class.
 * The InstanceIdentifiers come from the target sample data. Therefore,
 * the Subsample objects act as separate class masks.
 *
 * \deprecated
 *
 * \sa MachineLearningModel
 * \sa LibSVMMachineLearningModel
 * \sa ImageClassificationFilter
 *
 *
 * \ingroup OTBSVMLearning
 */

template<class TSample, class TLabel>
class ITK_EXPORT SVMClassifier :
  public itk::Statistics::SampleClassifierFilter<TSample>
{
public:
  /** Standard class typedef*/
  typedef SVMClassifier                              Self;
  typedef itk::Statistics::SampleClassifierFilter<TSample> Superclass;
  typedef itk::SmartPointer<Self>                    Pointer;
  typedef itk::SmartPointer<const Self>              ConstPointer;

  /** Standard macros */
  itkTypeMacro(SVMClassifier, itk::Statistics::SampleClassifier);
  itkNewMacro(Self);

  /** Output type for GetClassSample method */
  typedef itk::Statistics::MembershipSample<TSample>            OutputType;
  typedef itk::VariableLengthVector<float>                      HyperplanesDistancesType;
  typedef itk::Statistics::ListSample<HyperplanesDistancesType> HyperplanesDistancesListSampleType;

  /** typedefs from TSample object */
  typedef typename TSample::MeasurementType       MeasurementType;
  typedef typename TSample::MeasurementVectorType MeasurementVectorType;

  /** typedefs from Superclass */
  typedef typename Superclass::MembershipFunctionVectorObjectPointer
    MembershipFunctionPointerVector; //FIXME adopt new naming convention

  /** typedef for label type */
  typedef TLabel ClassLabelType;

  /** Returns the classification result */
  OutputType* GetOutput();
  void SetOutput(OutputType* output);
  using Superclass::SetOutput;

  /** Returns the hyperplanes distances */
  HyperplanesDistancesListSampleType * GetHyperplanesDistancesOutput();

  /** Type definitions for the SVM Model. */
  typedef SVMModel<MeasurementType, ClassLabelType> SVMModelType;
  typedef typename SVMModelType::Pointer            SVMModelPointer;

  /** Set the model */
  itkSetObjectMacro(Model, SVMModelType);

  /** Get the number of classes. */
  itkGetObjectMacro(Model, SVMModelType);

  void Update() ITK_OVERRIDE;

protected:
  SVMClassifier();
  ~SVMClassifier() ITK_OVERRIDE {}
  void PrintSelf(std::ostream& os, itk::Indent indent) const ITK_OVERRIDE;

  /** Starts the classification process */
  void GenerateData() ITK_OVERRIDE;
  virtual void DoClassification();

private:

  /** Output pointer (MembershipSample) */
  typename OutputType::Pointer m_Output;

  /** Hyperplanes distances output */
  typename HyperplanesDistancesListSampleType::Pointer m_HyperplanesDistancesOutput;

  SVMModelPointer m_Model;
}; // end of class

} // end of namespace otb

#ifndef OTB_MANUAL_INSTANTIATION
#include "otbSVMClassifier.txx"
#endif

#endif