1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
|
#include <algorithm>
#include <memory>
#include "test_util.h"
#include "nomnigraph/Representations/NeuralNet.h"
#include "nomnigraph/Transformations/SubgraphMatcher.h"
#include <gtest/gtest.h>
using namespace nom;
using namespace nom::repr;
using namespace nom::repr::nn;
// Test for the NNGraph subgraph matching APIs.
TEST(NeuralNetGraph, ReplaceGraph) {
NNGraph graph;
auto input1 = graph.createNode(std::make_unique<Tensor>("input1"));
auto input2 = graph.createNode(std::make_unique<Tensor>("input2"));
// Test renaming blob
nn::get<Tensor>(input2)->setName("input2_renamed");
auto sum = graph.createNode(std::make_unique<Sum>());
auto sumOutput = graph.createNode(std::make_unique<Tensor>("sumOutput"));
auto relu = graph.createNode(std::make_unique<Relu>());
auto reluOutput = graph.createNode(std::make_unique<Tensor>("reluOutput"));
graph.createEdge(input1, sum);
graph.createEdge(input2, sum);
graph.createEdge(sum, sumOutput);
graph.createEdge(sumOutput, relu);
graph.createEdge(relu, reluOutput);
/* input1 input2
\ /
\ /
sum
|
|
sumOutput
|
relu
|
reluOutput
*/
auto mg = NNMatchGraph();
auto matchSumInput =
mg.createNode(std::move(matchExternalTensorNode().count(2)));
auto matchSum = mg.createNode(nn::is<Sum>);
mg.createEdge(matchSumInput, matchSum);
auto matchSumOutput = mg.createNode(nn::is<Tensor>);
mg.createEdge(matchSum, matchSumOutput);
auto matchRelu = mg.createNode(nn::is<Relu>);
mg.createEdge(matchSumOutput, matchRelu);
auto matchRoot = matchRelu;
EXPECT_FALSE(mg.isSubgraphMatch(sum, matchRoot).isMatch());
EXPECT_FALSE(mg.isSubgraphMatch(reluOutput, matchRoot).isMatch());
EXPECT_FALSE(mg.isSubgraphMatch(input1, matchRoot).isMatch());
EXPECT_TRUE(mg.isSubgraphMatch(relu, matchRoot).isMatch());
mg.replaceSubgraph(
graph,
matchRoot,
[&matchSumOutput](
NNGraph& g,
NNGraph::NodeRef relu,
const NNMatchGraph::SubgraphMatchResultType& matchResult) {
auto fusedNode = g.createNode(std::make_unique<SumRelu>());
auto sumNode =
getProducer(matchResult.getMatchNodeMap()->at(matchSumOutput));
g.replaceOutEdges(relu, fusedNode);
g.replaceInEdges(sumNode, fusedNode);
g.deleteNodes(matchResult.getMatchedSubgraph()->getNodes());
return true;
});
/*
Fused graph:
input1 input2
\ /
\ /
sumRelu
|
|
output
*/
EXPECT_EQ(graph.getNodesCount(), 4);
auto fusedNode = getProducer(reluOutput);
EXPECT_TRUE(is<SumRelu>(fusedNode));
EXPECT_EQ(getInputs(fusedNode).size(), 2);
}
|