1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
|
/*=========================================================================
Program: Insight Segmentation & Registration Toolkit
Module: itkRBFLayer.h
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Insight Software Consortium. All rights reserved.
See ITKCopyright.txt or http://www.itk.org/HTML/Copyright.htm for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef __itkRBFLayer_h
#define __itkRBFLayer_h
#include "itkCompletelyConnectedWeightSet.h"
#include "itkLayerBase.h"
#include "itkObject.h"
#include "itkMacro.h"
#include "itkRadialBasisFunctionBase.h"
#ifdef ITK_USE_REVIEW_STATISTICS
#include "itkEuclideanDistanceMetric.h"
#else
#include "itkEuclideanDistance.h"
#endif
namespace itk
{
namespace Statistics
{
template<class TMeasurementVector, class TTargetVector>
class RBFLayer : public LayerBase<TMeasurementVector, TTargetVector>
{
public:
typedef RBFLayer Self;
typedef LayerBase<TMeasurementVector, TTargetVector> Superclass;
typedef SmartPointer<Self> Pointer;
typedef SmartPointer<const Self> ConstPointer;
/** Method for creation through the object factory. */
itkTypeMacro(RBFLayer, LayerBase);
itkNewMacro(Self);
typedef typename Superclass::ValueType ValueType;
typedef typename Superclass::ValuePointer ValuePointer;
typedef vnl_vector<ValueType> NodeVectorType;
typedef typename Superclass::InternalVectorType InternalVectorType;
typedef typename Superclass::OutputVectorType OutputVectorType;
typedef typename Superclass::LayerInterfaceType LayerInterfaceType;
typedef CompletelyConnectedWeightSet<TMeasurementVector,TTargetVector>
WeightSetType;
typedef typename Superclass::WeightSetInterfaceType WeightSetInterfaceType;
typedef typename Superclass::InputFunctionInterfaceType InputFunctionInterfaceType;
typedef typename Superclass::TransferFunctionInterfaceType TransferFunctionInterfaceType;
//Distance Metric
#ifdef ITK_USE_REVIEW_STATISTICS
typedef EuclideanDistanceMetric<InternalVectorType> DistanceMetricType;
#else
typedef EuclideanDistance<InternalVectorType> DistanceMetricType;
#endif
typedef typename DistanceMetricType::Pointer DistanceMetricPointer;
typedef RadialBasisFunctionBase<ValueType> RBFType;
//Member Functions
itkGetConstReferenceMacro(RBF_Dim, unsigned int);
void SetRBF_Dim(unsigned int size);
virtual void SetNumberOfNodes(unsigned int numNodes);
virtual ValueType GetInputValue(unsigned int i) const;
void SetInputValue(unsigned int i, ValueType value);
virtual ValueType GetOutputValue(unsigned int) const;
virtual void SetOutputValue(unsigned int, ValueType);
virtual ValueType * GetOutputVector();
void SetOutputVector(TMeasurementVector value);
virtual void ForwardPropagate();
virtual void ForwardPropagate(TMeasurementVector input);
virtual void BackwardPropagate();
virtual void BackwardPropagate(TTargetVector itkNotUsed(errors)){};
virtual void SetOutputErrorValues(TTargetVector);
virtual ValueType GetOutputErrorValue(unsigned int node_id) const;
virtual ValueType GetInputErrorValue(unsigned int node_id) const;
virtual ValueType * GetInputErrorVector();
virtual void SetInputErrorValue(ValueType, unsigned int node_id);
//TMeasurementVector GetCenter(int i);
InternalVectorType GetCenter(unsigned int i) const;
void SetCenter(TMeasurementVector c,unsigned int i);
ValueType GetRadii(unsigned int i) const;
void SetRadii(ValueType c,unsigned int i);
virtual ValueType Activation(ValueType);
virtual ValueType DActivation(ValueType);
/** Set/Get the bias */
itkSetMacro( Bias, ValueType );
itkGetConstReferenceMacro( Bias, ValueType );
void SetDistanceMetric(DistanceMetricType* f);
itkGetObjectMacro( DistanceMetric, DistanceMetricType );
itkGetConstObjectMacro( DistanceMetric, DistanceMetricType );
itkSetMacro(NumClasses,unsigned int);
itkGetConstReferenceMacro(NumClasses,unsigned int);
void SetRBF(RBFType* f);
itkGetObjectMacro(RBF, RBFType);
itkGetConstObjectMacro(RBF, RBFType);
protected:
RBFLayer();
virtual ~RBFLayer();
/** Method to print the object. */
virtual void PrintSelf( std::ostream& os, Indent indent ) const;
private:
NodeVectorType m_NodeInputValues;
NodeVectorType m_NodeOutputValues;
NodeVectorType m_InputErrorValues;
NodeVectorType m_OutputErrorValues;
typename DistanceMetricType::Pointer m_DistanceMetric;
std::vector<InternalVectorType> m_Centers; // ui....uc
InternalVectorType m_Radii;
unsigned int m_NumClasses;
ValueType m_Bias;
unsigned int m_RBF_Dim;
typename RBFType::Pointer m_RBF;
};
} // end namespace Statistics
} // end namespace itk
#ifndef ITK_MANUAL_INSTANTIATION
#include "itkRBFLayer.txx"
#endif
#endif
|