File: otbSVMClassifierPointSet.cxx

package info (click to toggle)
otb 5.8.0%2Bdfsg-3
  • links: PTS, VCS
  • area: main
  • in suites: stretch
  • size: 38,496 kB
  • ctags: 40,282
  • sloc: cpp: 306,573; ansic: 3,575; python: 450; sh: 214; perl: 74; java: 72; makefile: 70
file content (142 lines) | stat: -rw-r--r-- 3,727 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
/*=========================================================================

  Program:   ORFEO Toolbox
  Language:  C++
  Date:      $Date$
  Version:   $Revision$


  Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
  See OTBCopyright.txt for details.


     This software is distributed WITHOUT ANY WARRANTY; without even
     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
     PURPOSE.  See the above copyright notices for more information.

=========================================================================*/



#include <fstream>

#include "itkPoint.h"

#include "itkPointSetToListSampleAdaptor.h"
#include "itkSubsample.h"
#include "otbSVMClassifier.h"

int otbSVMClassifierPointSet(int argc, char* argv[])
{
  if (argc != 2)
    {
    std::cout << "Usage : " << argv[0] << " modelFile"
              << std::endl;
    return EXIT_FAILURE;
    }

  const char * modelFilename  = argv[1];

  std::cout << "Building the pointset" << std::endl;

  typedef double                      InputPixelType;
  typedef int                         LabelPixelType;
  typedef std::vector<InputPixelType> InputVectorType;
  const unsigned int Dimension = 2;

  typedef itk::PointSet<InputVectorType,  Dimension>
  MeasurePointSetType;

  MeasurePointSetType::Pointer mPSet = MeasurePointSetType::New();

  typedef MeasurePointSetType::PointType MeasurePointType;

  typedef MeasurePointSetType::PointsContainer MeasurePointsContainer;

  MeasurePointsContainer::Pointer mCont = MeasurePointsContainer::New();

  unsigned int pointId;

  for (pointId = 0; pointId < 20; pointId++)
    {

    MeasurePointType mP;

    mP[0] = pointId;
    mP[1] = pointId;

    InputVectorType measure;
    //measure.push_back(vcl_pow(pointId, 2.0));
    measure.push_back(double(2.0 * pointId));
    measure.push_back(double(-10));

    mCont->InsertElement(pointId, mP);
    mPSet->SetPointData(pointId, measure);

    }

  mPSet->SetPoints(mCont);

  std::cout << "PointSet built" << std::endl;

  typedef itk::Statistics::PointSetToListSampleAdaptor<MeasurePointSetType>
  SampleType;
  SampleType::Pointer sample = SampleType::New();
  sample->SetPointSet(mPSet);

  std::cout << "Sample set to Adaptor" << std::endl;

  /** preparing classifier and decision rule object */
  typedef otb::SVMModel<SampleType::MeasurementVectorType::ValueType, LabelPixelType> ModelType;

  ModelType::Pointer model = ModelType::New();

  model->LoadModel(modelFilename);

  std::cout << "Model loaded" << std::endl;

  int numberOfClasses = model->GetNumberOfClasses();

  typedef otb::SVMClassifier<SampleType, LabelPixelType> ClassifierType;

  ClassifierType::Pointer classifier = ClassifierType::New();

  classifier->SetNumberOfClasses(numberOfClasses);
  classifier->SetModel(model);
  classifier->SetInput(sample.GetPointer());
  classifier->Update();

  /* Build the class map */

  std::cout << "classifier get output" << std::endl;
  ClassifierType::OutputType* membershipSample =
    classifier->GetOutput();
  std::cout << "Sample iterators" << std::endl;
  ClassifierType::OutputType::ConstIterator m_iter =
    membershipSample->Begin();
  ClassifierType::OutputType::ConstIterator m_last =
    membershipSample->End();

  double error = 0.0;
  pointId = 0;
  while (m_iter != m_last)
    {
    ClassifierType::ClassLabelType label = m_iter.GetClassLabel();

    InputVectorType measure;

    mPSet->GetPointData(pointId, &measure);

    if (label != ((measure[0] + measure[1]) > 0)) error++;

    std::cout << label << "/" <<
    ((measure[0] + measure[1]) > 0) << std::endl;

    ++pointId;
    ++m_iter;
    }

  std::cout << "Error = " << error / pointId << std::endl;

  return EXIT_SUCCESS;
}