1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
|
#include <gtest/gtest.h>
#include <torch/nativert/graph/Graph.h>
#include <torch/nativert/graph/passes/pass_manager/PassManager.h>
#include <torch/csrc/jit/testing/file_check.h>
using namespace ::testing;
using namespace torch::nativert;
TEST(PassManagerTest, TestEmptyPass) {
GraphPassManager manager({"EmptyPass"});
EXPECT_FALSE(manager.run(Graph::createGraph().get()));
}
TEST(PassPipelineTest, TestConcat) {
GraphPassPipeline p1({"test"});
EXPECT_EQ(p1.size(), 1);
EXPECT_EQ(p1.at(0), "test");
p1.concat({"test1", "test2"});
EXPECT_EQ(p1.at(0), "test");
EXPECT_EQ(p1.at(1), "test1");
EXPECT_EQ(p1.at(2), "test2");
}
TEST(PassPipelineTest, TestPushFront) {
GraphPassPipeline p1({"test"});
EXPECT_EQ(p1.size(), 1);
EXPECT_EQ(p1.at(0), "test");
p1.push_front("test1");
EXPECT_EQ(p1.at(0), "test1");
EXPECT_EQ(p1.at(1), "test");
}
|