1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
|
#include <gtest/gtest.h>
#include "test/cpp/jit/test_utils.h"
#include "torch/csrc/jit/runtime/graph_executor.h"
#include "torch/jit.h"
#include "torch/script.h"
#include "torch/torch.h"
namespace torch {
namespace jit {
TEST(GraphExecutorTest, Basic_CUDA) {
constexpr int batch_size = 4;
constexpr int input_size = 256;
int hidden_size = 2 * input_size;
auto input = at::randn({batch_size, input_size}, at::kCUDA);
auto hx = at::randn({batch_size, hidden_size}, at::kCUDA);
auto cx = at::randn({batch_size, hidden_size}, at::kCUDA);
auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCUDA));
auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCUDA));
auto g = build_lstm();
GraphExecutor executor(g, "");
auto stack = createStack({input, hx, cx, w_ih, w_hh});
executor.run(stack);
ASSERT_EQ(stack.size(), 2);
at::Tensor r0, r1;
std::tie(r0, r1) = lstm(input, hx, cx, w_ih, w_hh);
ASSERT_TRUE(almostEqual(stack[0].toTensor(), r0));
ASSERT_TRUE(almostEqual(stack[1].toTensor(), r1));
}
TEST(GraphExecutorTest, runAsync_executor) {
/*
TODO: there are some problem with C++ parsing script program involving
fork. Use the test module below for now.
issue about this: github.com/pytorch/pytorch/issues/46368
The test module file is generated by following:
class DemoModule(torch.nn.Module):
def forward(self):
r1 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100))
r2 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100))
return r1.wait() + r2.wait()
demo = DemoModule()
torch.jit.save(torch.jit.script(demo), 'test_interpreter_async.pt')
*/
std::string filePath(__FILE__);
auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1);
testModelFile.append("test_interpreter_async.pt");
auto module = load(testModelFile);
auto graph = module.get_method("forward").graph();
GraphExecutor graphExecutor(graph, "");
auto asyncCounter = 0;
std::mutex mtx;
// a dummy executor which actually use at::launch, but add up a counter
auto launcher = [&](std::function<void()> f) {
mtx.lock();
++asyncCounter;
mtx.unlock();
at::launch(move(f));
};
std::vector<IValue> stack;
// NOLINTNEXTLINE(modernize-use-emplace)
stack.push_back(module._ivalue());
graphExecutor.runAsync(stack, launcher)->wait();
ASSERT_TRUE(asyncCounter > 0);
}
} // namespace jit
} // namespace torch
|