1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
|
/*
* We have a python unit test for exceptions in test/jit/test_exception.py .
* Add a CPP version here to verify that excepted exception types thrown from
* C++. This is hard to test in python code since C++ exceptions will be
* translated to python exceptions.
*/
#include <gtest/gtest.h>
#include <pybind11/embed.h>
#include <torch/csrc/jit/frontend/parser.h>
#include <torch/csrc/jit/frontend/resolver.h>
#include <torch/csrc/jit/runtime/jit_exception.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/jit.h>
#include <iostream>
#include <stdexcept>
namespace torch {
namespace jit {
namespace py = pybind11;
TEST(TestException, TestAssertion) {
std::string pythonCode = R"PY(
def foo():
raise AssertionError("An assertion failed")
)PY";
auto cu_ptr = torch::jit::compile(pythonCode);
torch::jit::GraphFunction* gf =
(torch::jit::GraphFunction*)&cu_ptr->get_function("foo");
std::cerr << "Graph is\n" << *gf->graph() << std::endl;
bool is_jit_exception = false;
std::string message;
c10::optional<std::string> exception_class;
try {
cu_ptr->run_method("foo");
} catch (JITException& e) {
is_jit_exception = true;
message = e.what();
exception_class = e.getPythonClassName();
}
EXPECT_TRUE(is_jit_exception);
EXPECT_FALSE(exception_class);
EXPECT_TRUE(
message.find("RuntimeError: AssertionError: An assertion failed") !=
std::string::npos);
}
struct MyPythonExceptionValue : public torch::jit::SugaredValue {
explicit MyPythonExceptionValue(const py::object& exception_class) {
qualified_name_ =
(py::str(py::getattr(exception_class, "__module__", py::str(""))) +
py::str(".") +
py::str(py::getattr(exception_class, "__name__", py::str(""))))
.cast<std::string>();
}
std::string kind() const override {
return "My Python exception";
}
// Simplified from PythonExceptionValue::call
std::shared_ptr<torch::jit::SugaredValue> call(
const torch::jit::SourceRange& loc,
torch::jit::GraphFunction& caller,
at::ArrayRef<torch::jit::NamedValue> args,
at::ArrayRef<torch::jit::NamedValue> kwargs,
size_t n_binders) override {
TORCH_CHECK(args.size() == 1);
Value* error_message = args.at(0).value(*caller.graph());
Value* qualified_class_name =
insertConstant(*caller.graph(), qualified_name_, loc);
return std::make_shared<ExceptionMessageValue>(
error_message, qualified_class_name);
}
private:
std::string qualified_name_;
};
class SimpleResolver : public torch::jit::Resolver {
public:
explicit SimpleResolver() {}
std::shared_ptr<torch::jit::SugaredValue> resolveValue(
const std::string& name,
torch::jit::GraphFunction& m,
const torch::jit::SourceRange& loc) override {
// follows toSugaredValue (toSugaredValue is defined in caffe2:_C which is
// a python extension. We can not add that as a cpp_binary's dep)
if (name == "SimpleValueError") {
py::object obj = py::globals()["SimpleValueError"];
return std::make_shared<MyPythonExceptionValue>(obj);
}
TORCH_CHECK(false, "resolveValue: can not resolve '", name, "{}'");
}
torch::jit::TypePtr resolveType(
const std::string& name,
const torch::jit::SourceRange& loc) override {
return nullptr;
}
};
/*
* - The python source code parsing for TorchScript here is learned from
* torch::jit::compile.
* - The code only parses one Def. If there are multiple in the code, those
* except the first one are skipped.
*/
TEST(TestException, TestCustomException) {
py::scoped_interpreter guard{};
py::exec(R"PY(
class SimpleValueError(ValueError):
def __init__(self, message):
super(SimpleValueError, self).__init__(message)
)PY");
std::string pythonCode = R"PY(
def foo():
raise SimpleValueError("An assertion failed")
)PY";
torch::jit::Parser p(
std::make_shared<torch::jit::Source>(pythonCode, "<string>", 1));
auto def = torch::jit::Def(p.parseFunction(/*is_method=*/false));
std::cerr << "Def is:\n" << def << std::endl;
auto cu = std::make_shared<torch::jit::CompilationUnit>();
(void)cu->define(
c10::nullopt,
{},
{},
{def},
// class PythonResolver is defined in
// torch/csrc/jit/python/script_init.cpp. It's not in a header file so I
// can not use it. Create a SimpleResolver insteand
{std::make_shared<SimpleResolver>()},
nullptr);
torch::jit::GraphFunction* gf =
(torch::jit::GraphFunction*)&cu->get_function("foo");
std::cerr << "Graph is\n" << *gf->graph() << std::endl;
bool is_jit_exception = false;
c10::optional<std::string> exception_class;
std::string message;
try {
cu->run_method("foo");
} catch (JITException& e) {
is_jit_exception = true;
exception_class = e.getPythonClassName();
message = e.what();
}
EXPECT_TRUE(is_jit_exception);
EXPECT_EQ("__main__.SimpleValueError", *exception_class);
EXPECT_TRUE(
message.find("__main__.SimpleValueError: An assertion failed") !=
std::string::npos);
}
} // namespace jit
} // namespace torch
|