File: otbSVMModel.h

package info (click to toggle)
otb 5.8.0%2Bdfsg-3
  • links: PTS, VCS
  • area: main
  • in suites: stretch
  • size: 38,496 kB
  • ctags: 40,282
  • sloc: cpp: 306,573; ansic: 3,575; python: 450; sh: 214; perl: 74; java: 72; makefile: 70
file content (457 lines) | stat: -rw-r--r-- 12,324 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
/*=========================================================================

  Program:   ORFEO Toolbox
  Language:  C++
  Date:      $Date$
  Version:   $Revision$


  Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
  See OTBCopyright.txt for details.


     This software is distributed WITHOUT ANY WARRANTY; without even
     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
     PURPOSE.  See the above copyright notices for more information.

=========================================================================*/
#ifndef otbSVMModel_h
#define otbSVMModel_h

#include "itkObjectFactory.h"
#include "itkDataObject.h"
#include "itkVariableLengthVector.h"
#include "itkTimeProbe.h"
#include "svm.h"

namespace otb
{

/** \class SVMModel
 * \brief Class for SVM models.
 *
 * \TODO update documentation
 *
 * The basic functionality of the SVMModel framework base class is to
 * generate the models used in SVM classification. It requires input
 * images and a training image to be provided by the user.
 * This object supports data handling of multiband images. The object
 * accepts the input image in vector format only, where each pixel is a
 * vector and each element of the vector corresponds to an entry from
 * 1 particular band of a multiband dataset. A single band image is treated
 * as a vector image with a single element for every vector. The classified
 * image is treated as a single band scalar image.
 *
 * A membership function represents a specific knowledge about
 * a class. In other words, it should tell us how "likely" is that a
 * measurement vector (pattern) belong to the class.
 *
 * As the method name indicates, you can have more than one membership
 * function. One for each classes. The order you put the membership
 * calculator becomes the class label for the class that is represented
 * by the membership calculator.
 *

 *
 * \ingroup ClassificationFilters
 *
 * \ingroup OTBSVMLearning
 */
template <class TValue, class TLabel>
class ITK_EXPORT SVMModel : public itk::DataObject
{
public:
  /** Standard class typedefs. */
  typedef SVMModel                      Self;
  typedef itk::DataObject               Superclass;
  typedef itk::SmartPointer<Self>       Pointer;
  typedef itk::SmartPointer<const Self> ConstPointer;

  /** Value type */
  typedef TValue ValueType;
  /** Label Type */
  typedef TLabel                                LabelType;
  typedef std::vector<ValueType>                MeasurementType;
  typedef std::pair<MeasurementType, LabelType> SampleType;
  typedef std::vector<SampleType>               SamplesVectorType;
  /** Cache vector type */
  typedef std::vector<struct svm_node *> CacheVectorType;

  /** Distances vector */
  typedef itk::VariableLengthVector<double> ProbabilitiesVectorType;
  typedef itk::VariableLengthVector<double> DistancesVectorType;

  typedef struct svm_node * NodeCacheType;

  /** Run-time type information (and related methods). */
  itkNewMacro(Self);
  itkTypeMacro(SVMModel, itk::DataObject);

  /** Get the number of classes. */
  unsigned int GetNumberOfClasses(void) const
  {
    if (m_Model) return (unsigned int) (m_Model->nr_class);
    return 0;
  }

  /** Get the number of hyperplane. */
  unsigned int GetNumberOfHyperplane(void) const
  {
    if (m_Model) return (unsigned int) (m_Model->nr_class * (m_Model->nr_class - 1) / 2);
    return 0;
  }

  /** Gets the model */
  const struct svm_model* GetModel()
  {
    return m_Model;
  }

  /** Gets the parameters */
  struct svm_parameter& GetParameters()
  {
    return m_Parameters;
  }
  /** Gets the parameters */
  const struct svm_parameter& GetParameters() const
  {
    return m_Parameters;
  }

  /** Saves the model to a file */
  void SaveModel(const char* model_file_name) const;
  void SaveModel(const std::string& model_file_name) const
  {
    //implemented in term of const char * version
    this->SaveModel(model_file_name.c_str());
  }

  /** Loads the model from a file */
  void LoadModel(const char* model_file_name);
  void LoadModel(const std::string& model_file_name)
  {
    //implemented in term of const char * version
    this->LoadModel(model_file_name.c_str());
  }

  /** Set the SVM type to C_SVC, NU_SVC, ONE_CLASS, EPSILON_SVR, NU_SVR */
  void SetSVMType(int svmtype)
  {
    m_Parameters.svm_type = svmtype;
    m_ModelUpToDate = false;
    this->Modified();
  }

  /** Get the SVM type (C_SVC, NU_SVC, ONE_CLASS, EPSILON_SVR, NU_SVR) */
  int GetSVMType(void) const
  {
    return m_Parameters.svm_type;
  }

  /** Set the kernel type to LINEAR, POLY, RBF, SIGMOID
  linear: u'*v
  polynomial: (gamma*u'*v + coef0)^degree
  radial basis function: exp(-gamma*|u-v|^2)
  sigmoid: tanh(gamma*u'*v + coef0)*/
  void SetKernelType(int kerneltype)
  {
    m_Parameters.kernel_type = kerneltype;
    m_ModelUpToDate = false;
    this->Modified();
  }

  /** Get the kernel type */
  int GetKernelType(void) const
  {
    return m_Parameters.kernel_type;
  }

  /** Set the degree of the polynomial kernel */
  void SetPolynomialKernelDegree(int degree)
  {
    m_Parameters.degree = degree;
    m_ModelUpToDate = false;
    this->Modified();
  }

  /** Get the degree of the polynomial kernel */
  int GetPolynomialKernelDegree(void) const
  {
    return m_Parameters.degree;
  }

  /** Set the gamma parameter for poly/rbf/sigmoid kernels */
  virtual void SetKernelGamma(double gamma)
  {
    m_Parameters.gamma = gamma;
    m_ModelUpToDate = false;
    this->Modified();
  }
  /** Get the gamma parameter for poly/rbf/sigmoid kernels */
  double GetKernelGamma(void) const
  {
    return m_Parameters.gamma;
  }

  /** Set the coef0 parameter for poly/sigmoid kernels */
  void SetKernelCoef0(double coef0)
  {
    m_Parameters.coef0 = coef0;
    m_ModelUpToDate = false;
    this->Modified();
  }

  /** Get the coef0 parameter for poly/sigmoid kernels */
  double GetKernelCoef0(void) const
  {
    //return m_Parameters.coef0;
    return m_Parameters.coef0;
  }

  /** Set the Nu parameter for the training */
  void SetNu(double nu)
  {
    m_Parameters.nu = nu;
    m_ModelUpToDate = false;
    this->Modified();
  }

  /** Set the Nu parameter for the training */
  double GetNu(void) const
  {
    //return m_Parameters.nu;
    return m_Parameters.nu;
  }

  /** Set the cache size in MB for the training */
  void SetCacheSize(int cSize)
  {
    m_Parameters.cache_size = static_cast<double>(cSize);
    m_ModelUpToDate = false;
    this->Modified();
  }

  /** Get the cache size in MB for the training */
  int GetCacheSize(void) const
  {
    return static_cast<int>(m_Parameters.cache_size);
  }

  /** Set the C parameter for the training for C_SVC, EPSILON_SVR and NU_SVR */
  void SetC(double c)
  {
    m_Parameters.C = c;
    m_ModelUpToDate = false;
    this->Modified();
  }

  /** Get the C parameter for the training for C_SVC, EPSILON_SVR and NU_SVR */
  double GetC(void) const
  {
    return m_Parameters.C;
  }

  /** Set the tolerance for the stopping criterion for the training*/
  void SetEpsilon(double eps)
  {
    m_Parameters.eps = eps;
    m_ModelUpToDate = false;
    this->Modified();
  }

  /** Get the tolerance for the stopping criterion for the training*/
  double GetEpsilon(void) const
  {
    return m_Parameters.eps;
  }

  /* Set the value of p for EPSILON_SVR */
  void SetP(double p)
  {
    m_Parameters.p = p;
    m_ModelUpToDate = false;
    this->Modified();
  }

  /* Get the value of p for EPSILON_SVR */
  double GetP(void) const
  {
    return m_Parameters.p;
  }

  /** Use the shrinking heuristics for the training */
  void DoShrinking(bool s)
  {
    m_Parameters.shrinking = static_cast<int>(s);
    m_ModelUpToDate = false;
    this->Modified();
  }

  /** Get Use the shrinking heuristics for the training boolea */
  bool GetDoShrinking(void) const
  {
    return static_cast<bool>(m_Parameters.shrinking);
  }

  /** Do probability estimates */
  void DoProbabilityEstimates(bool prob)
  {
    m_Parameters.probability = static_cast<int>(prob);
    m_ModelUpToDate = false;
    this->Modified();
  }

  /** Get Do probability estimates boolean */
  bool GetDoProbabilityEstimates(void) const
  {
    return static_cast<bool>(m_Parameters.probability);
  }

  /** Test if the model has probabilities */
  bool HasProbabilities(void) const
    {
    return static_cast<bool>(svm_check_probability_model(m_Model));
    }

  /** Return number of support vectors */
  int GetNumberOfSupportVectors(void) const
  {
    if (m_Model) return m_Model->l;
    return 0;
  }

  /** Return rho values */
  double * GetRho(void) const
  {
    if (m_Model) return m_Model->rho;
    return ITK_NULLPTR;
  }
  /** Return the support vectors */
  svm_node ** GetSupportVectors(void)
  {
    if (m_Model) return m_Model->SV;
    return ITK_NULLPTR;
  }
  /** Set the support vectors and changes the l number of support vectors accordind to sv.*/
  void SetSupportVectors(svm_node ** sv, int nbOfSupportVector);

  /** Return the alphas values (SV Coef) */
  double ** GetAlpha(void)
  {
    if (m_Model) return m_Model->sv_coef;
    return ITK_NULLPTR;
  }
  /** Set the alphas values (SV Coef) */
  void SetAlpha(double ** alpha, int nbOfSupportVector);

  /** Return the labels lists */
  int * GetLabels()
  {
    if (m_Model) return m_Model->label;
    return ITK_NULLPTR;
  }

  /** Get the number of SV per classes */
  int * GetNumberOfSVPerClasse()
  {
    if (m_Model) return m_Model->nSV;
    return ITK_NULLPTR;
  }

  struct svm_problem& GetProblem()
  {
    return m_Problem;
  }

  /** Allocate the problem */
  void BuildProblem();

  /** Check consistency (potentially throws exception) */
  void ConsistencyCheck();

  /** Estimate the model */
  void Train();

  /** Cross validation (returns the accuracy) */
  double CrossValidation(unsigned int nbFolders);

  /** Predict (Please note that due to caching this method is not
* thread safe. If you want to run multiple concurrent instances of
* this method, please consider using the GetCopy() method to clone the
* model.)*/
  LabelType EvaluateLabel(const MeasurementType& measure) const;

  /** Evaluate hyperplan distances (Please note that due to caching this method is not
* thread safe. If you want to run multiple concurrent instances of
* this method, please consider using the GetCopy() method to clone the
* model.)**/
  DistancesVectorType EvaluateHyperplanesDistances(const MeasurementType& measure) const;

  /** Evaluate probabilities of each class. Returns a probability vector ordered
   * by increasing class label value
   * (Please note that due to caching this method is not thread safe.
   * If you want to run multiple concurrent instances of
   * this method, please consider using the GetCopy() method to clone the
   * model.)**/
  ProbabilitiesVectorType EvaluateProbabilities(const MeasurementType& measure) const;

  /** Add a new sample to the list */
  void AddSample(const MeasurementType& measure, const LabelType& label);

  /** Clear all samples */
  void ClearSamples();

  /** Set the samples vector */
  void SetSamples(const SamplesVectorType& samples);

  /** Reset all the model, leaving it in the same state that just
   * before constructor call */
  void Reset();

protected:
  /** Constructor */
  SVMModel();
  /** Destructor */
  ~SVMModel() ITK_OVERRIDE;
  /** Display infos */
  void PrintSelf(std::ostream& os, itk::Indent indent) const ITK_OVERRIDE;

/** Delete any allocated problem */
  void DeleteProblem();

  /** Delete any allocated model */
  void DeleteModel();

  /** Initializes default parameters */
  void Initialize() ITK_OVERRIDE;

private:
  SVMModel(const Self &); //purposely not implemented
  void operator =(const Self&); //purposely not implemented

  /** Container to hold the SVM model itself */
  struct svm_model* m_Model;

  /** True if model is up-to-date */
  mutable bool m_ModelUpToDate;

  /** Container of the SVM problem */
  struct svm_problem m_Problem;

  /** Container of the SVM parameters */
  struct svm_parameter m_Parameters;

  /** true if problem is up-to-date */
  bool m_ProblemUpToDate;

  /** Contains the samples */
  SamplesVectorType m_Samples;
}; // class SVMModel

} // namespace otb

#ifndef OTB_MANUAL_INSTANTIATION
#include "otbSVMModel.txx"
#endif

#endif