1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
|
#include <gtest/gtest.h>
#include <test/cpp/jit/test_custom_class_registrations.h>
#include <torch/csrc/jit/passes/freeze_module.h>
#include <torch/custom_class.h>
#include <torch/script.h>
#include <iostream>
#include <string>
#include <vector>
namespace torch {
namespace jit {
TEST(CustomClassTest, TorchbindIValueAPI) {
script::Module m("m");
// test make_custom_class API
auto custom_class_obj = make_custom_class<MyStackClass<std::string>>(
std::vector<std::string>{"foo", "bar"});
m.define(R"(
def forward(self, s : __torch__.torch.classes._TorchScriptTesting._StackString):
return s.pop(), s
)");
auto test_with_obj = [&m](IValue obj, std::string expected) {
auto res = m.run_method("forward", obj);
auto tup = res.toTuple();
AT_ASSERT(tup->elements().size() == 2);
auto str = tup->elements()[0].toStringRef();
auto other_obj =
tup->elements()[1].toCustomClass<MyStackClass<std::string>>();
AT_ASSERT(str == expected);
auto ref_obj = obj.toCustomClass<MyStackClass<std::string>>();
AT_ASSERT(other_obj.get() == ref_obj.get());
};
test_with_obj(custom_class_obj, "bar");
// test IValue() API
auto my_new_stack = c10::make_intrusive<MyStackClass<std::string>>(
std::vector<std::string>{"baz", "boo"});
auto new_stack_ivalue = c10::IValue(my_new_stack);
test_with_obj(new_stack_ivalue, "boo");
}
TEST(CustomClassTest, ScalarTypeClass) {
script::Module m("m");
// test make_custom_class API
auto cc = make_custom_class<ScalarTypeClass>(at::kFloat);
m.register_attribute("s", cc.type(), cc, false);
std::ostringstream oss;
m.save(oss);
std::istringstream iss(oss.str());
caffe2::serialize::IStreamAdapter adapter{&iss};
auto loaded_module = torch::jit::load(iss, torch::kCPU);
}
class TorchBindTestClass : public torch::jit::CustomClassHolder {
public:
std::string get() {
return "Hello, I am your test custom class";
}
};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
constexpr char class_doc_string[] = R"(
I am docstring for TorchBindTestClass
Args:
What is an argument? Oh never mind, I don't take any.
Return:
How would I know? I am just a holder of some meaningless test methods.
)";
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
constexpr char method_doc_string[] =
"I am docstring for TorchBindTestClass get_with_docstring method";
namespace {
static auto reg =
torch::class_<TorchBindTestClass>(
"_TorchBindTest",
"_TorchBindTestClass",
class_doc_string)
.def("get", &TorchBindTestClass::get)
.def("get_with_docstring", &TorchBindTestClass::get, method_doc_string);
} // namespace
// Tests DocString is properly propagated when defining CustomClasses.
TEST(CustomClassTest, TestDocString) {
auto class_type = getCustomClass(
"__torch__.torch.classes._TorchBindTest._TorchBindTestClass");
AT_ASSERT(class_type);
AT_ASSERT(class_type->doc_string() == class_doc_string);
AT_ASSERT(class_type->getMethod("get").doc_string().empty());
AT_ASSERT(
class_type->getMethod("get_with_docstring").doc_string() ==
method_doc_string);
}
TEST(CustomClassTest, Serialization) {
script::Module m("m");
// test make_custom_class API
auto custom_class_obj = make_custom_class<MyStackClass<std::string>>(
std::vector<std::string>{"foo", "bar"});
m.register_attribute(
"s",
custom_class_obj.type(),
custom_class_obj,
// NOLINTNEXTLINE(bugprone-argument-comment)
/*is_parameter=*/false);
m.define(R"(
def forward(self):
return self.s.return_a_tuple()
)");
auto test_with_obj = [](script::Module& mod) {
auto res = mod.run_method("forward");
auto tup = res.toTuple();
AT_ASSERT(tup->elements().size() == 2);
auto i = tup->elements()[1].toInt();
AT_ASSERT(i == 123);
};
auto frozen_m = torch::jit::freeze_module(m.clone());
test_with_obj(m);
test_with_obj(frozen_m);
std::ostringstream oss;
m.save(oss);
std::istringstream iss(oss.str());
caffe2::serialize::IStreamAdapter adapter{&iss};
auto loaded_module = torch::jit::load(iss, torch::kCPU);
std::ostringstream oss_frozen;
frozen_m.save(oss_frozen);
std::istringstream iss_frozen(oss_frozen.str());
caffe2::serialize::IStreamAdapter adapter_frozen{&iss_frozen};
auto loaded_frozen_module = torch::jit::load(iss_frozen, torch::kCPU);
}
} // namespace jit
} // namespace torch
|