1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
|
#ifndef CAFFE2_OPT_CONVERTER_H
#define CAFFE2_OPT_CONVERTER_H
#include "caffe2/core/common.h"
#include "caffe2/core/logging.h"
#include "caffe2/opt/annotations.h"
#include "caffe2/proto/caffe2_pb.h"
#include "nomnigraph/Graph/Graph.h"
#include "nomnigraph/Representations/ControlFlow.h"
#include "nomnigraph/Representations/NeuralNet.h"
#include <unordered_map>
namespace caffe2 {
TORCH_API void injectDataEdgeIndicators(caffe2::NetDef* net);
TORCH_API void removeDataEdgeIndicators(caffe2::NetDef* net);
// Default conversion to a NNModule
// Optionally strict -- which checks for various input and output conditions.
// Optionally this function will update a vector that maps operators in the
// netdef positionally to NodeRefs in the resultant NNModule.
TORCH_API nom::repr::NNModule convertToNNModule(
const caffe2::NetDef& net,
bool strict = false,
std::vector<nom::repr::NNGraph::NodeRef>* = nullptr);
TORCH_API caffe2::NetDef convertToCaffe2Proto(nom::repr::NNModule&);
// Pass in an oldNet to copy all the attributes of that network.
// Be warned that transformations that modify the graph's inputs or outputs
// are not reflected in changes to external_input or external_output.
TORCH_API caffe2::NetDef convertToCaffe2Proto(
nom::repr::NNModule&,
const caffe2::NetDef& oldNet);
// Use these functions instead of the registry directly.
TORCH_API std::unique_ptr<nom::repr::NeuralNetOperator>
convertToNeuralNetOperator(const caffe2::OperatorDef& op);
TORCH_API caffe2::OperatorDef convertToOperatorDef(
const nom::repr::NNGraph::NodeRef& instrNode);
// If the annotation doesn't exist, attempt to add it
TORCH_API Caffe2Annotation* getOrAddCaffe2Annotation(
nom::repr::NNGraph::NodeRef& instrNode);
class TORCH_API Converter {
public:
explicit Converter() = default;
virtual std::unique_ptr<nom::repr::NeuralNetOperator>
convertToNeuralNetOperator(const OperatorDef&) = 0;
virtual OperatorDef convertToOperatorDef(const nom::repr::NeuralNetOperator*);
static std::map<std::string, caffe2::Argument> getArgumentsFromOperator(
caffe2::OperatorDef op);
virtual ~Converter() {}
protected:
caffe2::DeviceOption getDeviceOption(
const nom::repr::NeuralNetOperator* nnOp) const;
};
C10_DECLARE_REGISTRY(ConverterRegistry, Converter);
#define REGISTER_CONVERTER(name, cls) \
C10_REGISTER_CLASS(ConverterRegistry, name, cls)
#define TRIVIAL_CONVERTER(opName) \
class opName##Converter : public Converter { \
std::unique_ptr<nom::repr::NeuralNetOperator> convertToNeuralNetOperator( \
const OperatorDef& op) override { \
return std::make_unique<nom::repr::opName>(); \
} \
virtual ~opName##Converter() {} \
};
} // namespace caffe2
#endif // CAFFE2_OPT_CONVERTER_H
|