File: otbSampleAugmentationFilter.h

package info (click to toggle)
otb 8.1.1%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 1,030,436 kB
  • sloc: xml: 231,007; cpp: 224,490; ansic: 4,592; sh: 1,790; python: 1,131; perl: 92; makefile: 72
file content (173 lines) | stat: -rw-r--r-- 5,405 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
/*
 * Copyright (C) 2005-2022 Centre National d'Etudes Spatiales (CNES)
 *
 * This file is part of Orfeo Toolbox
 *
 *     https://www.orfeo-toolbox.org/
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef otbSampleAugmentationFilter_h
#define otbSampleAugmentationFilter_h

#include "itkProcessObject.h"
#include "otbOGRDataSourceWrapper.h"
#include "otbSampleAugmentation.h"
#include "OTBSamplingExport.h"
#include <string>

namespace otb
{


/**
 * \class SampleAugmentationFilter
 *
 * \brief Filter to generate synthetic samples from existing ones
 *
 * This class generates synthetic samples from existing ones either by
 * replication, jitter (adding gaussian noise to the features of
 * existing samples) or SMOTE (linear combination of pairs
 * neighbouring samples of the same class.
 *
 * \ingroup OTBSampling
 */

class OTBSampling_EXPORT SampleAugmentationFilter : public itk::ProcessObject
{
public:
  /** typedef for the classes standards. */
  typedef SampleAugmentationFilter      Self;
  typedef itk::ProcessObject            Superclass;
  typedef itk::SmartPointer<Self>       Pointer;
  typedef itk::SmartPointer<const Self> ConstPointer;

  /** Method for management of the object factory. */
  itkNewMacro(Self);

  /** Return the name of the class. */
  itkTypeMacro(SampleAugmentationFilter, ProcessObject);

  typedef ogr::DataSource                     OGRDataSourceType;
  typedef typename OGRDataSourceType::Pointer OGRDataSourcePointerType;
  typedef ogr::Layer                          OGRLayerType;

  typedef itk::ProcessObject::DataObjectPointerArraySizeType DataObjectPointerArraySizeType;

  using SampleType       = sampleAugmentation::SampleType;
  using SampleVectorType = sampleAugmentation::SampleVectorType;

  enum class Strategy
  {
    Replicate,
    Jitter,
    Smote
  };

  /** Set/Get the input OGRDataSource of this process object.  */
  using Superclass::SetInput;
  virtual void SetInput(const OGRDataSourceType* ds);
  const OGRDataSourceType* GetInput(unsigned int idx);

  virtual void SetOutputSamples(ogr::DataSource* data);

  /** Set the Field Name in which labels will be written. (default is "class")
   * A field "ClassFieldName" of type integer is created in the output memory layer.
   */
  itkSetMacro(ClassFieldName, std::string);
  /**
   * Return the Field name in which labels have been written.
   */
  itkGetMacro(ClassFieldName, std::string);


  itkSetMacro(Layer, size_t);
  itkGetMacro(Layer, size_t);
  itkSetMacro(Label, int);
  itkGetMacro(Label, int);
  void SetStrategy(Strategy s)
  {
    m_Strategy = s;
  }
  Strategy GetStrategy() const
  {
    return m_Strategy;
  }
  itkSetMacro(NumberOfSamples, int);
  itkGetMacro(NumberOfSamples, int);
  void SetExcludedFields(const std::vector<std::string>& ef)
  {
    m_ExcludedFields = ef;
  }
  std::vector<std::string> GetExcludedFields() const
  {
    return m_ExcludedFields;
  }
  itkSetMacro(StdFactor, double);
  itkGetMacro(StdFactor, double);
  itkSetMacro(SmoteNeighbors, size_t);
  itkGetMacro(SmoteNeighbors, size_t);
  itkSetMacro(Seed, int);
  itkGetMacro(Seed, int);
  /**
     * Get the output \c ogr::DataSource which is a "memory" datasource.
     */
  const OGRDataSourceType* GetOutput();

protected:
  SampleAugmentationFilter();
  ~SampleAugmentationFilter() override
  {
  }

  /** Generate Data method*/
  void GenerateData() override;

  /** DataObject pointer */
  typedef itk::DataObject::Pointer DataObjectPointer;

  DataObjectPointer MakeOutput(DataObjectPointerArraySizeType idx) override;
  using Superclass::MakeOutput;


  SampleVectorType ExtractSamples(const ogr::DataSource::Pointer vectors, size_t layerName, const std::string& classField, const int label,
                                  const std::vector<std::string>& excludedFields = {});

  void SampleToOGRFeatures(const ogr::DataSource::Pointer& vectors, ogr::DataSource* output, const SampleVectorType& samples, const size_t layerName,
                           const std::string& classField, int label, const std::vector<std::string>& excludedFields = {});

  std::set<size_t> GetExcludedFieldsIds(const std::vector<std::string>& excludedFields, const ogr::Layer& inputLayer);
  bool IsNumericField(const ogr::Feature& feature, const int idx);

  ogr::Feature SelectTemplateFeature(const ogr::Layer& inputLayer, const std::string& classField, int label);

private:
  SampleAugmentationFilter(const Self&) = delete;
  void operator=(const Self&) = delete;

  std::string              m_ClassFieldName;
  size_t                   m_Layer;
  int                      m_Label;
  std::vector<std::string> m_ExcludedFields;
  Strategy                 m_Strategy;
  int                      m_NumberOfSamples;
  double                   m_StdFactor;
  size_t                   m_SmoteNeighbors;
  int                      m_Seed;
};


} // end namespace otb

#endif