1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
|
#include <gtest/gtest.h>
#include "test/cpp/jit/test_utils.h"
#include "torch/csrc/jit/passes/create_autodiff_subgraphs.h"
namespace torch {
namespace jit {
TEST(CreateAutodiffSubgraphsTest, Basic) {
auto graph = build_lstm();
CreateAutodiffSubgraphs(graph, /*threshold=*/2);
// all of the ops are within the DifferentiableGraph
testing::FileCheck()
.check_not("aten::mm")
->check_not("aten::sigmoid")
->check_not("aten::tanh")
->check_not("aten::mul")
->check("DifferentiableGraph")
->check_next("return")
->run(*graph);
}
} // namespace jit
} // namespace torch
|