1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
|
#pragma once
#include "caffe2/core/common.h"
#include "caffe2/core/tensor.h"
#include "caffe2/onnx/helper.h"
#include "caffe2/proto/caffe2_pb.h"
#include "onnx/onnx_pb.h"
#include <string>
#include <unordered_map>
#include <vector>
namespace caffe2 {
namespace onnx {
namespace {
using ::ONNX_NAMESPACE::AttributeProto;
using ::ONNX_NAMESPACE::GraphProto;
using ::ONNX_NAMESPACE::ModelProto;
using ::ONNX_NAMESPACE::NodeProto;
using ::ONNX_NAMESPACE::TensorProto;
} // namespace
using ConvertedResult =
std::pair<std::vector<NodeProto>, std::vector<TensorProto>>;
// Useful utility function
void rewriteSubnet(
Argument* arg,
std::map<std::string, std::string> oldname_to_newname);
// Rewrite Caffe2 nets into SSA forms. Notice that we will preserve the external
// output names for predict net.
TORCH_API std::unordered_map<std::string, std::string> SsaRewrite(
caffe2::NetDef* init_net,
caffe2::NetDef* pred_net,
bool PreserveInPlaceOps = true);
::ONNX_NAMESPACE::TensorProto::DataType Caffe2TypeToOnnxType(
caffe2::TensorProto::DataType t);
class TORCH_API OnnxExporter {
using SpecialOpConverter = ConvertedResult (OnnxExporter::*)(
const caffe2::OperatorDef&,
const std::unordered_map<std::string, caffe2::TensorShape>&);
public:
OnnxExporter(DummyName* dummy = nullptr) {
if (dummy) {
dummy_ = std::shared_ptr<DummyName>(dummy, [](DummyName*) {});
} else {
dummy_ = std::make_shared<DummyName>();
}
}
ConvertedResult Caffe2OpToOnnxNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
void InitOpToTensorProto(const caffe2::OperatorDef& def, TensorProto* tensor);
private:
ConvertedResult CommonCaffe2OpToOnnxNodes(const caffe2::OperatorDef& def);
ConvertedResult CreateArgMaxMinOpNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateBinaryElementwiseOpNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateCastNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateElementwiseLinearNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateConvPoolNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateGemmNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateReshapeNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateSliceNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateChannelShuffleNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateReduceMeanNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateConcatNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateMergeDimNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateLrnNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
ConvertedResult CreateUpsampleNodes(
const caffe2::OperatorDef& def,
const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
// \brief Check block listed arguments where we won't pass down when
// converting to ONNX node
bool IsBlockListed(const caffe2::Argument& arg);
// \brief Convert Caffe2 argument to Onnx attribute
void CopyCaffe2ArgToOnnxAttr(
AttributeProto* attr,
const std::string& op_type,
const caffe2::Argument& arg);
// LUT getters
const std::unordered_map<std::string, std::string>& get_renamed_operators()
const;
const std::unordered_map<std::string, std::string>& get_renamed_attrs() const;
const std::
unordered_map<std::string, std::unordered_map<std::string, std::string>>&
get_per_op_renamed_attrs() const;
const std::unordered_map<std::string, OnnxExporter::SpecialOpConverter>&
get_special_operators() const;
// Dummy name generator
std::shared_ptr<DummyName> dummy_;
};
} // namespace onnx
} // namespace caffe2
|