1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
|
/*
* Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES)
*
* This file is part of Orfeo Toolbox
*
* https://www.orfeo-toolbox.org/
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef otbPCAModel_h
#define otbPCAModel_h
#include "otbMachineLearningModelTraits.h"
#include "otbMachineLearningModel.h"
#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wshadow"
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Woverloaded-virtual"
#pragma GCC diagnostic ignored "-Wsign-compare"
#pragma GCC diagnostic ignored "-Wunused-local-typedefs"
#if defined(__clang__)
#pragma clang diagnostic ignored "-Wheader-guard"
#pragma clang diagnostic ignored "-Wexpansion-to-defined"
#else
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#endif
#endif
#include "otb_shark.h"
#include <shark/Algorithms/Trainers/PCA.h>
#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic pop
#endif
namespace otb
{
/** \class PCAModel
*
* This class wraps a PCA model implemented by Shark, in a otb::MachineLearningModel
*
* \ingroup OTBDimensionalityReductionLearning
*/
template <class TInputValue>
class ITK_EXPORT PCAModel
: public MachineLearningModel<
itk::VariableLengthVector< TInputValue >,
itk::VariableLengthVector< TInputValue > >
{
public:
typedef PCAModel Self;
typedef MachineLearningModel<
itk::VariableLengthVector< TInputValue >,
itk::VariableLengthVector< TInputValue> > Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
typedef typename Superclass::InputValueType InputValueType;
typedef typename Superclass::InputSampleType InputSampleType;
typedef typename Superclass::InputListSampleType InputListSampleType;
typedef typename InputListSampleType::Pointer ListSamplePointerType;
typedef typename Superclass::TargetValueType TargetValueType;
typedef typename Superclass::TargetSampleType TargetSampleType;
typedef typename Superclass::TargetListSampleType TargetListSampleType;
// Confidence map related typedefs
typedef typename Superclass::ConfidenceValueType ConfidenceValueType;
typedef typename Superclass::ConfidenceSampleType ConfidenceSampleType;
typedef typename Superclass::ConfidenceListSampleType ConfidenceListSampleType;
itkNewMacro(Self);
itkTypeMacro(PCAModel, DimensionalityReductionModel);
itkSetMacro(DoResizeFlag,bool);
itkSetMacro(WriteEigenvectors, bool);
itkGetMacro(WriteEigenvectors, bool);
bool CanReadFile(const std::string & filename) override;
bool CanWriteFile(const std::string & filename) override;
void Save(const std::string & filename, const std::string & name="") override;
void Load(const std::string & filename, const std::string & name="") override;
void Train() override;
protected:
PCAModel();
~PCAModel() override;
virtual TargetSampleType DoPredict(
const InputSampleType& input,
ConfidenceValueType * quality = ITK_NULLPTR) const override;
virtual void DoPredictBatch(
const InputListSampleType *,
const unsigned int & startIndex,
const unsigned int & size,
TargetListSampleType *,
ConfidenceListSampleType * quality = ITK_NULLPTR) const override;
private:
shark::LinearModel<> m_Encoder;
shark::LinearModel<> m_Decoder;
shark::PCA m_PCA;
bool m_DoResizeFlag;
bool m_WriteEigenvectors;
};
} // end namespace otb
#ifndef OTB_MANUAL_INSTANTIATION
#include "otbPCAModel.txx"
#endif
#endif
|