1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457
|
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef otbSVMModel_h
#define otbSVMModel_h
#include "itkObjectFactory.h"
#include "itkDataObject.h"
#include "itkVariableLengthVector.h"
#include "itkTimeProbe.h"
#include "svm.h"
namespace otb
{
/** \class SVMModel
* \brief Class for SVM models.
*
* \TODO update documentation
*
* The basic functionality of the SVMModel framework base class is to
* generate the models used in SVM classification. It requires input
* images and a training image to be provided by the user.
* This object supports data handling of multiband images. The object
* accepts the input image in vector format only, where each pixel is a
* vector and each element of the vector corresponds to an entry from
* 1 particular band of a multiband dataset. A single band image is treated
* as a vector image with a single element for every vector. The classified
* image is treated as a single band scalar image.
*
* A membership function represents a specific knowledge about
* a class. In other words, it should tell us how "likely" is that a
* measurement vector (pattern) belong to the class.
*
* As the method name indicates, you can have more than one membership
* function. One for each classes. The order you put the membership
* calculator becomes the class label for the class that is represented
* by the membership calculator.
*
*
* \ingroup ClassificationFilters
*
* \ingroup OTBSVMLearning
*/
template <class TValue, class TLabel>
class ITK_EXPORT SVMModel : public itk::DataObject
{
public:
/** Standard class typedefs. */
typedef SVMModel Self;
typedef itk::DataObject Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
/** Value type */
typedef TValue ValueType;
/** Label Type */
typedef TLabel LabelType;
typedef std::vector<ValueType> MeasurementType;
typedef std::pair<MeasurementType, LabelType> SampleType;
typedef std::vector<SampleType> SamplesVectorType;
/** Cache vector type */
typedef std::vector<struct svm_node *> CacheVectorType;
/** Distances vector */
typedef itk::VariableLengthVector<double> ProbabilitiesVectorType;
typedef itk::VariableLengthVector<double> DistancesVectorType;
typedef struct svm_node * NodeCacheType;
/** Run-time type information (and related methods). */
itkNewMacro(Self);
itkTypeMacro(SVMModel, itk::DataObject);
/** Get the number of classes. */
unsigned int GetNumberOfClasses(void) const
{
if (m_Model) return (unsigned int) (m_Model->nr_class);
return 0;
}
/** Get the number of hyperplane. */
unsigned int GetNumberOfHyperplane(void) const
{
if (m_Model) return (unsigned int) (m_Model->nr_class * (m_Model->nr_class - 1) / 2);
return 0;
}
/** Gets the model */
const struct svm_model* GetModel()
{
return m_Model;
}
/** Gets the parameters */
struct svm_parameter& GetParameters()
{
return m_Parameters;
}
/** Gets the parameters */
const struct svm_parameter& GetParameters() const
{
return m_Parameters;
}
/** Saves the model to a file */
void SaveModel(const char* model_file_name) const;
void SaveModel(const std::string& model_file_name) const
{
//implemented in term of const char * version
this->SaveModel(model_file_name.c_str());
}
/** Loads the model from a file */
void LoadModel(const char* model_file_name);
void LoadModel(const std::string& model_file_name)
{
//implemented in term of const char * version
this->LoadModel(model_file_name.c_str());
}
/** Set the SVM type to C_SVC, NU_SVC, ONE_CLASS, EPSILON_SVR, NU_SVR */
void SetSVMType(int svmtype)
{
m_Parameters.svm_type = svmtype;
m_ModelUpToDate = false;
this->Modified();
}
/** Get the SVM type (C_SVC, NU_SVC, ONE_CLASS, EPSILON_SVR, NU_SVR) */
int GetSVMType(void) const
{
return m_Parameters.svm_type;
}
/** Set the kernel type to LINEAR, POLY, RBF, SIGMOID
linear: u'*v
polynomial: (gamma*u'*v + coef0)^degree
radial basis function: exp(-gamma*|u-v|^2)
sigmoid: tanh(gamma*u'*v + coef0)*/
void SetKernelType(int kerneltype)
{
m_Parameters.kernel_type = kerneltype;
m_ModelUpToDate = false;
this->Modified();
}
/** Get the kernel type */
int GetKernelType(void) const
{
return m_Parameters.kernel_type;
}
/** Set the degree of the polynomial kernel */
void SetPolynomialKernelDegree(int degree)
{
m_Parameters.degree = degree;
m_ModelUpToDate = false;
this->Modified();
}
/** Get the degree of the polynomial kernel */
int GetPolynomialKernelDegree(void) const
{
return m_Parameters.degree;
}
/** Set the gamma parameter for poly/rbf/sigmoid kernels */
virtual void SetKernelGamma(double gamma)
{
m_Parameters.gamma = gamma;
m_ModelUpToDate = false;
this->Modified();
}
/** Get the gamma parameter for poly/rbf/sigmoid kernels */
double GetKernelGamma(void) const
{
return m_Parameters.gamma;
}
/** Set the coef0 parameter for poly/sigmoid kernels */
void SetKernelCoef0(double coef0)
{
m_Parameters.coef0 = coef0;
m_ModelUpToDate = false;
this->Modified();
}
/** Get the coef0 parameter for poly/sigmoid kernels */
double GetKernelCoef0(void) const
{
//return m_Parameters.coef0;
return m_Parameters.coef0;
}
/** Set the Nu parameter for the training */
void SetNu(double nu)
{
m_Parameters.nu = nu;
m_ModelUpToDate = false;
this->Modified();
}
/** Set the Nu parameter for the training */
double GetNu(void) const
{
//return m_Parameters.nu;
return m_Parameters.nu;
}
/** Set the cache size in MB for the training */
void SetCacheSize(int cSize)
{
m_Parameters.cache_size = static_cast<double>(cSize);
m_ModelUpToDate = false;
this->Modified();
}
/** Get the cache size in MB for the training */
int GetCacheSize(void) const
{
return static_cast<int>(m_Parameters.cache_size);
}
/** Set the C parameter for the training for C_SVC, EPSILON_SVR and NU_SVR */
void SetC(double c)
{
m_Parameters.C = c;
m_ModelUpToDate = false;
this->Modified();
}
/** Get the C parameter for the training for C_SVC, EPSILON_SVR and NU_SVR */
double GetC(void) const
{
return m_Parameters.C;
}
/** Set the tolerance for the stopping criterion for the training*/
void SetEpsilon(double eps)
{
m_Parameters.eps = eps;
m_ModelUpToDate = false;
this->Modified();
}
/** Get the tolerance for the stopping criterion for the training*/
double GetEpsilon(void) const
{
return m_Parameters.eps;
}
/* Set the value of p for EPSILON_SVR */
void SetP(double p)
{
m_Parameters.p = p;
m_ModelUpToDate = false;
this->Modified();
}
/* Get the value of p for EPSILON_SVR */
double GetP(void) const
{
return m_Parameters.p;
}
/** Use the shrinking heuristics for the training */
void DoShrinking(bool s)
{
m_Parameters.shrinking = static_cast<int>(s);
m_ModelUpToDate = false;
this->Modified();
}
/** Get Use the shrinking heuristics for the training boolea */
bool GetDoShrinking(void) const
{
return static_cast<bool>(m_Parameters.shrinking);
}
/** Do probability estimates */
void DoProbabilityEstimates(bool prob)
{
m_Parameters.probability = static_cast<int>(prob);
m_ModelUpToDate = false;
this->Modified();
}
/** Get Do probability estimates boolean */
bool GetDoProbabilityEstimates(void) const
{
return static_cast<bool>(m_Parameters.probability);
}
/** Test if the model has probabilities */
bool HasProbabilities(void) const
{
return static_cast<bool>(svm_check_probability_model(m_Model));
}
/** Return number of support vectors */
int GetNumberOfSupportVectors(void) const
{
if (m_Model) return m_Model->l;
return 0;
}
/** Return rho values */
double * GetRho(void) const
{
if (m_Model) return m_Model->rho;
return ITK_NULLPTR;
}
/** Return the support vectors */
svm_node ** GetSupportVectors(void)
{
if (m_Model) return m_Model->SV;
return ITK_NULLPTR;
}
/** Set the support vectors and changes the l number of support vectors accordind to sv.*/
void SetSupportVectors(svm_node ** sv, int nbOfSupportVector);
/** Return the alphas values (SV Coef) */
double ** GetAlpha(void)
{
if (m_Model) return m_Model->sv_coef;
return ITK_NULLPTR;
}
/** Set the alphas values (SV Coef) */
void SetAlpha(double ** alpha, int nbOfSupportVector);
/** Return the labels lists */
int * GetLabels()
{
if (m_Model) return m_Model->label;
return ITK_NULLPTR;
}
/** Get the number of SV per classes */
int * GetNumberOfSVPerClasse()
{
if (m_Model) return m_Model->nSV;
return ITK_NULLPTR;
}
struct svm_problem& GetProblem()
{
return m_Problem;
}
/** Allocate the problem */
void BuildProblem();
/** Check consistency (potentially throws exception) */
void ConsistencyCheck();
/** Estimate the model */
void Train();
/** Cross validation (returns the accuracy) */
double CrossValidation(unsigned int nbFolders);
/** Predict (Please note that due to caching this method is not
* thread safe. If you want to run multiple concurrent instances of
* this method, please consider using the GetCopy() method to clone the
* model.)*/
LabelType EvaluateLabel(const MeasurementType& measure) const;
/** Evaluate hyperplan distances (Please note that due to caching this method is not
* thread safe. If you want to run multiple concurrent instances of
* this method, please consider using the GetCopy() method to clone the
* model.)**/
DistancesVectorType EvaluateHyperplanesDistances(const MeasurementType& measure) const;
/** Evaluate probabilities of each class. Returns a probability vector ordered
* by increasing class label value
* (Please note that due to caching this method is not thread safe.
* If you want to run multiple concurrent instances of
* this method, please consider using the GetCopy() method to clone the
* model.)**/
ProbabilitiesVectorType EvaluateProbabilities(const MeasurementType& measure) const;
/** Add a new sample to the list */
void AddSample(const MeasurementType& measure, const LabelType& label);
/** Clear all samples */
void ClearSamples();
/** Set the samples vector */
void SetSamples(const SamplesVectorType& samples);
/** Reset all the model, leaving it in the same state that just
* before constructor call */
void Reset();
protected:
/** Constructor */
SVMModel();
/** Destructor */
~SVMModel() ITK_OVERRIDE;
/** Display infos */
void PrintSelf(std::ostream& os, itk::Indent indent) const ITK_OVERRIDE;
/** Delete any allocated problem */
void DeleteProblem();
/** Delete any allocated model */
void DeleteModel();
/** Initializes default parameters */
void Initialize() ITK_OVERRIDE;
private:
SVMModel(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
/** Container to hold the SVM model itself */
struct svm_model* m_Model;
/** True if model is up-to-date */
mutable bool m_ModelUpToDate;
/** Container of the SVM problem */
struct svm_problem m_Problem;
/** Container of the SVM parameters */
struct svm_parameter m_Parameters;
/** true if problem is up-to-date */
bool m_ProblemUpToDate;
/** Contains the samples */
SamplesVectorType m_Samples;
}; // class SVMModel
} // namespace otb
#ifndef OTB_MANUAL_INSTANTIATION
#include "otbSVMModel.txx"
#endif
#endif
|