1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
|
#include "caffe2/core/common.h"
#include "caffe2/core/logging.h"
#include "caffe2/onnx/onnx_exporter.h"
#include <gtest/gtest.h>
#include <string>
#include <tuple>
#include <unordered_map>
TEST(SsaTest, ConvReluInplace) {
caffe2::NetDef net;
auto* op = net.add_op();
op->set_type("Conv");
op->add_input("X");
op->add_input("W");
op->add_input("b");
op->add_output("Y");
op = net.add_op();
op->set_type("Relu");
op->add_input("Y");
op->add_output("Y");
net.add_external_input("X");
net.add_external_output("Y");
std::unordered_map<std::string, std::string> input_mapping =
caffe2::onnx::SsaRewrite(nullptr, &net);
for (const auto& net_op : net.op()) {
std::unordered_set<std::string> inputs;
for (const auto& i : net_op.input()) {
inputs.emplace(i);
}
for (const auto& o : net_op.output()) {
EXPECT_TRUE(inputs.count(o) == 0);
}
}
EXPECT_EQ(net.op(0).output(0), net.op(1).input(0));
EXPECT_EQ("X", input_mapping.at(net.external_input(0)));
EXPECT_EQ("Y", net.external_output(0));
}
TEST(SsaTest, FC_Relu_FC_InPlace_Output) {
caffe2::NetDef net;
auto* op = net.add_op();
op->set_type("FC");
op->add_input("X");
op->add_input("W0");
op->add_input("b0");
op->add_output("Y");
op = net.add_op();
op->set_type("Relu");
op->add_input("Y");
op->add_output("Y");
op = net.add_op();
op->set_type("FC");
op->add_input("Y");
op->add_input("W2");
op->add_input("b2");
op->add_output("Z");
net.add_external_input("X");
net.add_external_output("Y");
net.add_external_output("Z");
std::unordered_map<std::string, std::string> input_mapping =
caffe2::onnx::SsaRewrite(nullptr, &net);
for (const auto& net_op : net.op()) {
std::unordered_set<std::string> inputs;
for (const auto& i : net_op.input()) {
inputs.emplace(i);
}
for (const auto& o : net_op.output()) {
EXPECT_TRUE(inputs.count(o) == 0);
}
}
EXPECT_EQ(net.op(0).output(0), net.op(1).input(0));
EXPECT_EQ("Y", net.op(2).input(0));
EXPECT_EQ("Y_0", net.op(1).input(0));
EXPECT_EQ("X", input_mapping.at(net.external_input(0)));
EXPECT_EQ("Y", net.external_output(0));
EXPECT_EQ("Z", net.external_output(1));
}
|