1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
|
#include <torch/script.h>
#include <memory>
#include <string>
#include <sstream>
#include <vector>
#include <iostream>
void test_module_forward_invocation_no_hooks_run(
const std::string &path_to_exported_script_module) {
std::cout << "testing: "
<< "test_module_forward_invocation_no_hooks_run" << std::endl;
torch::jit::Module module =
torch::jit::load(path_to_exported_script_module + "_" +
"test_module_forward_multiple_inputs" + ".pt");
std::vector<torch::jit::IValue> inputs = {torch::List<std::string>({"a"}),
torch::jit::IValue("no_pre_hook")};
auto output = module(inputs);
auto output_forward = module.forward(inputs);
torch::jit::IValue correct_direct_output =
std::tuple<torch::List<std::string>, std::string>(
{"a", "outer_mod_name", "inner_mod_name"}, "no_pre_hook_");
std::cout << "----- module output: " << output << std::endl;
std::cout << "----- module forward output: " << output_forward << std::endl;
AT_ASSERT(correct_direct_output == output_forward);
}
void test_submodule_called_directly_with_hooks(
const std::string &path_to_exported_script_module) {
std::cout << "testing: "
<< "test_submodule_to_call_directly_with_hooks" << std::endl;
torch::jit::Module module =
torch::jit::load(path_to_exported_script_module + "_" +
"test_submodule_to_call_directly_with_hooks" + ".pt");
torch::jit::Module submodule = *module.modules().begin();
std::vector<torch::jit::IValue> inputs = {"a"};
auto output = submodule(inputs);
torch::jit::IValue correct_output = "pre_hook_override_name_inner_mod_fh";
std::cout << "----- submodule's output: " << output << std::endl;
std::cout << "----- expected output : " << correct_output << std::endl;
AT_ASSERT(correct_output == correct_output);
}
struct HooksTestCase {
std::string name;
std::vector<torch::jit::IValue> inputs;
torch::jit::IValue output;
HooksTestCase(std::string name, std::vector<torch::jit::IValue> inputs,
torch::jit::IValue output)
: name(name), inputs(std::move(inputs)), output(std::move(output)) {}
};
int main(int argc, const char *argv[]) {
if (argc != 2) {
std::cerr << "usage: test_jit_hooks <path-to-exported-script-module>\n";
return -1;
}
const std::string path_to_exported_script_module = argv[1];
std::cout << "path to exported module:" << path_to_exported_script_module
<< std::endl;
std::cout << "Tesing JIT Hooks in CPP" << std::endl;
// Note: Modules loaded in this file are produced in /test/jit_hooks/model.py
std::vector<HooksTestCase> test_cases = {
HooksTestCase("test_submodule_multiple_hooks_single_input",
{torch::jit::IValue("a")},
"pre_hook_override_name2_inner_mod_fwh1"),
HooksTestCase("test_submodule_hook_return_nothing",
{torch::jit::IValue("a")}, "a_outermod_inner_mod"),
HooksTestCase("test_submodule_same_hook_repeated",
{torch::jit::IValue("a")},
"a_outermod_ph_ph_inner_mod_fh_fh"),
HooksTestCase("test_submodule_forward_single_input",
{torch::jit::IValue("a")},
"pre_hook_override_name_inner_mod"),
HooksTestCase(
"test_submodule_multiple_hooks_multiple_inputs",
{torch::List<std::string>({"a"}), torch::jit::IValue("no_pre_hook")},
std::tuple<torch::List<std::string>, std::string>(
{"pre_hook_override_name", "inner_mod_name"},
"pre_hook_override2_fh1_fh2")),
HooksTestCase(
"test_submodule_forward_multiple_inputs",
{torch::List<std::string>({"a"}), torch::jit::IValue("no_pre_hook")},
std::tuple<torch::List<std::string>, std::string>(
{"pre_hook_override_name", "inner_mod_name"},
"pre_hook_override_fh")),
HooksTestCase("test_module_forward_single_input",
{torch::jit::IValue("a")},
"pre_hook_override_name_outermod_inner_mod_fh"),
HooksTestCase("test_module_multiple_hooks_single_input",
{torch::jit::IValue("a")},
"pre_hook_override_name2_outermod_inner_mod_fh1_fh2"),
HooksTestCase("test_module_hook_return_nothing",
{torch::jit::IValue("a")}, "a_outermod_inner_mod"),
HooksTestCase("test_module_same_hook_repeated", {torch::jit::IValue("a")},
"a_ph_ph_outermod_inner_mod_fh_fh"),
HooksTestCase(
"test_module_forward_multiple_inputs",
{torch::List<std::string>({"a"}), torch::jit::IValue("no_pre_hook")},
std::tuple<torch::List<std::string>, std::string>(
{"pre_hook_override_name", "outer_mod_name", "inner_mod_name"},
"pre_hook_override_fh")),
HooksTestCase(
"test_module_multiple_hooks_multiple_inputs",
{torch::List<std::string>({"a"}), torch::jit::IValue("no_pre_hook")},
std::tuple<torch::List<std::string>, std::string>(
{"pre_hook_override_name2", "outer_mod_name", "inner_mod_name"},
"pre_hook_override_fh1_fh2")),
HooksTestCase("test_module_no_forward_input", {}, torch::jit::IValue()),
HooksTestCase("test_forward_tuple_input", {std::tuple<int>(11)},
{std::tuple<int>(11)}),
};
for (HooksTestCase &test_case : test_cases) {
std::cout << "testing: " << test_case.name << std::endl;
torch::jit::Module module = torch::jit::load(
path_to_exported_script_module + "_" + test_case.name + ".pt");
torch::jit::IValue output = module(test_case.inputs);
std::cout << "----- module's output: " << output << std::endl;
std::cout << "----- expected output: " << test_case.output << std::endl;
AT_ASSERT(output == test_case.output);
}
// special test cases that don't call the imported module directly
test_module_forward_invocation_no_hooks_run(path_to_exported_script_module);
test_submodule_called_directly_with_hooks(path_to_exported_script_module);
std::cout << "JIT CPP Hooks okay!" << std::endl;
return 0;
}
|