1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417
|
#pragma once
#include <exception>
#include <memory>
#include <mutex>
#include <queue>
#include <string>
#include <system_error>
#include <ATen/detail/FunctionTraits.h>
#include <c10/util/Exception.h>
#include <c10/util/StringUtil.h>
#include <pybind11/pybind11.h>
#include <torch/csrc/Export.h>
#include <torch/csrc/jit/runtime/jit_exception.h>
#include <torch/csrc/utils/auto_gil.h>
#include <torch/csrc/utils/cpp_stacktraces.h>
#include <torch/csrc/utils/pybind.h>
#if defined(USE_DISTRIBUTED) && defined(USE_C10D)
#include <torch/csrc/distributed/c10d/exception.h>
#endif
static inline void PyErr_SetString(PyObject* type, const std::string& message) {
PyErr_SetString(type, message.c_str());
}
/// NOTE [ Conversion Cpp Python Warning ]
/// The warning handler cannot set python warnings immediately
/// as it requires acquiring the GIL (potential deadlock)
/// and would need to cleanly exit if the warning raised a
/// python error. To solve this, we buffer the warnings and
/// process them when we go back to python.
/// This requires the two try/catch blocks below to handle the
/// following cases:
/// - If there is no Error raised in the inner try/catch, the
/// buffered warnings are processed as python warnings.
/// - If they don't raise an error, the function process with the
/// original return code.
/// - If any of them raise an error, the error is set (PyErr_*) and
/// the destructor will raise a cpp exception python_error() that
/// will be caught by the outer try/catch that will be able to change
/// the return value of the function to reflect the error.
/// - If an Error was raised in the inner try/catch, the inner try/catch
/// must set the python error. The buffered warnings are then
/// processed as cpp warnings as we cannot predict before hand
/// whether a python warning will raise an error or not and we
/// cannot handle two errors at the same time.
/// This advanced handler will only be used in the current thread.
/// If any other thread is used, warnings will be processed as
/// cpp warnings.
#define HANDLE_TH_ERRORS \
try { \
torch::PyWarningHandler __enforce_warning_buffer; \
try {
#define _CATCH_GENERIC_ERROR(ErrorType, PythonErrorType, retstmnt) \
catch (const c10::ErrorType& e) { \
auto msg = torch::get_cpp_stacktraces_enabled() \
? e.what() \
: e.what_without_backtrace(); \
PyErr_SetString(PythonErrorType, torch::processErrorMsg(msg)); \
retstmnt; \
}
// Only catch torch-specific exceptions
#define CATCH_CORE_ERRORS(retstmnt) \
catch (python_error & e) { \
e.restore(); \
retstmnt; \
} \
_CATCH_GENERIC_ERROR(IndexError, PyExc_IndexError, retstmnt) \
_CATCH_GENERIC_ERROR(ValueError, PyExc_ValueError, retstmnt) \
_CATCH_GENERIC_ERROR(TypeError, PyExc_TypeError, retstmnt) \
_CATCH_GENERIC_ERROR( \
NotImplementedError, PyExc_NotImplementedError, retstmnt) \
_CATCH_GENERIC_ERROR(LinAlgError, THPException_LinAlgError, retstmnt) \
_CATCH_GENERIC_ERROR( \
OutOfMemoryError, THPException_OutOfMemoryError, retstmnt) \
_CATCH_GENERIC_ERROR(Error, PyExc_RuntimeError, retstmnt) \
catch (torch::PyTorchError & e) { \
auto msg = torch::processErrorMsg(e.what()); \
PyErr_SetString(e.python_type(), msg); \
retstmnt; \
}
#if defined(USE_DISTRIBUTED) && defined(USE_C10D)
#define CATCH_C10D_ERRORS(retstmnt) \
catch (const c10d::TimeoutError& e) { \
auto msg = torch::processErrorMsg(e.what()); \
PyErr_SetString(PyExc_TimeoutError, msg); \
retstmnt; \
} \
catch (const c10d::C10dError& e) { \
auto msg = torch::processErrorMsg(e.what()); \
PyErr_SetString(PyExc_RuntimeError, msg); \
retstmnt; \
}
#else
#define CATCH_C10D_ERRORS(retstmnt)
#endif
#define CATCH_TH_ERRORS(retstmnt) \
CATCH_CORE_ERRORS(retstmnt) \
CATCH_C10D_ERRORS(retstmnt)
#define CATCH_ALL_ERRORS(retstmnt) \
CATCH_TH_ERRORS(retstmnt) \
catch (const std::exception& e) { \
auto msg = torch::processErrorMsg(e.what()); \
PyErr_SetString(PyExc_RuntimeError, msg); \
retstmnt; \
}
#define END_HANDLE_TH_ERRORS_PYBIND \
} \
catch (...) { \
__enforce_warning_buffer.set_in_exception(); \
throw; \
} \
} \
catch (py::error_already_set & e) { \
throw; \
} \
catch (py::builtin_exception & e) { \
throw; \
} \
catch (torch::jit::JITException & e) { \
throw; \
} \
catch (const std::exception& e) { \
torch::translate_exception_to_python(std::current_exception()); \
throw py::error_already_set(); \
}
#define END_HANDLE_TH_ERRORS_RET(retval) \
} \
catch (...) { \
__enforce_warning_buffer.set_in_exception(); \
throw; \
} \
} \
catch (const std::exception& e) { \
torch::translate_exception_to_python(std::current_exception()); \
return retval; \
}
#define END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS_RET(nullptr)
extern PyObject *THPException_FatalError, *THPException_LinAlgError,
*THPException_OutOfMemoryError;
// Throwing this exception means that the python error flags have been already
// set and control should be immediately returned to the interpreter.
struct python_error : public std::exception {
python_error() : type(nullptr), value(nullptr), traceback(nullptr) {}
python_error(const python_error& other)
: type(other.type),
value(other.value),
traceback(other.traceback),
message(other.message) {
pybind11::gil_scoped_acquire gil;
Py_XINCREF(type);
Py_XINCREF(value);
Py_XINCREF(traceback);
}
python_error(python_error&& other) {
type = other.type;
value = other.value;
traceback = other.traceback;
message = std::move(other.message);
other.type = nullptr;
other.value = nullptr;
other.traceback = nullptr;
}
~python_error() override {
if (type || value || traceback) {
pybind11::gil_scoped_acquire gil;
Py_XDECREF(type);
Py_XDECREF(value);
Py_XDECREF(traceback);
}
}
const char* what() const noexcept override {
return message.c_str();
}
void build_message() {
// Ensure we have the GIL.
pybind11::gil_scoped_acquire gil;
// No errors should be set when we enter the function since PyErr_Fetch
// clears the error indicator.
TORCH_INTERNAL_ASSERT(!PyErr_Occurred());
// Default message.
message = "python_error";
// Try to retrieve the error message from the value.
if (value != nullptr) {
// Reference count should not be zero.
TORCH_INTERNAL_ASSERT(Py_REFCNT(value) > 0);
PyObject* pyStr = PyObject_Str(value);
if (pyStr != nullptr) {
PyObject* encodedString =
PyUnicode_AsEncodedString(pyStr, "utf-8", "strict");
if (encodedString != nullptr) {
char* bytes = PyBytes_AS_STRING(encodedString);
if (bytes != nullptr) {
// Set the message.
message = std::string(bytes);
}
Py_XDECREF(encodedString);
}
Py_XDECREF(pyStr);
}
}
// Clear any errors since we don't want to propagate errors for functions
// that are trying to build a string for the error message.
PyErr_Clear();
}
/** Saves the exception so that it can be re-thrown on a different thread */
inline void persist() {
if (type)
return; // Don't overwrite exceptions
// PyErr_Fetch overwrites the pointers
pybind11::gil_scoped_acquire gil;
Py_XDECREF(type);
Py_XDECREF(value);
Py_XDECREF(traceback);
PyErr_Fetch(&type, &value, &traceback);
build_message();
}
/** Sets the current Python error from this exception */
inline void restore() {
if (!type)
return;
// PyErr_Restore steals references
pybind11::gil_scoped_acquire gil;
Py_XINCREF(type);
Py_XINCREF(value);
Py_XINCREF(traceback);
PyErr_Restore(type, value, traceback);
}
PyObject* type;
PyObject* value;
PyObject* traceback;
// Message to return to the user when 'what()' is invoked.
std::string message;
};
bool THPException_init(PyObject* module);
namespace torch {
// Set python current exception from a C++ exception
TORCH_PYTHON_API void translate_exception_to_python(const std::exception_ptr&);
TORCH_PYTHON_API std::string processErrorMsg(std::string str);
// Abstract base class for exceptions which translate to specific Python types
struct PyTorchError : public std::exception {
// NOLINTNEXTLINE(modernize-pass-by-value)
PyTorchError(const std::string& msg_ = std::string()) : msg(msg_) {}
virtual PyObject* python_type() = 0;
const char* what() const noexcept override {
return msg.c_str();
}
std::string msg;
};
// Declare a printf-like function on gcc & clang
// The compiler can then warn on invalid format specifiers
#ifdef __GNUC__
#define TORCH_FORMAT_FUNC(FORMAT_INDEX, VA_ARGS_INDEX) \
__attribute__((format(printf, FORMAT_INDEX, VA_ARGS_INDEX)))
#else
#define TORCH_FORMAT_FUNC(FORMAT_INDEX, VA_ARGS_INDEX)
#endif
// Translates to Python IndexError
struct IndexError : public PyTorchError {
using PyTorchError::PyTorchError;
IndexError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3);
PyObject* python_type() override {
return PyExc_IndexError;
}
};
// Translates to Python TypeError
struct TypeError : public PyTorchError {
using PyTorchError::PyTorchError;
TORCH_API TypeError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3);
PyObject* python_type() override {
return PyExc_TypeError;
}
};
// Translates to Python ValueError
struct ValueError : public PyTorchError {
using PyTorchError::PyTorchError;
ValueError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3);
PyObject* python_type() override {
return PyExc_ValueError;
}
};
// Translates to Python NotImplementedError
struct NotImplementedError : public PyTorchError {
NotImplementedError() = default;
PyObject* python_type() override {
return PyExc_NotImplementedError;
}
};
// Translates to Python AttributeError
struct AttributeError : public PyTorchError {
AttributeError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3);
PyObject* python_type() override {
return PyExc_AttributeError;
}
};
// Translates to Python LinAlgError
struct LinAlgError : public PyTorchError {
LinAlgError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3);
PyObject* python_type() override {
return THPException_LinAlgError;
}
};
struct WarningMeta {
WarningMeta(
const c10::SourceLocation& _source_location,
// NOLINTNEXTLINE(modernize-pass-by-value)
const std::string& _msg,
const bool _verbatim)
: source_location_{_source_location}, msg_{_msg}, verbatim_{_verbatim} {}
c10::SourceLocation source_location_;
std::string msg_;
bool verbatim_;
};
// ATen warning handler for Python
struct PyWarningHandler {
// Move actual handler into a separate class with a noexcept
// destructor. Otherwise, we need to force all WarningHandler
// subclasses to have a noexcept(false) destructor.
struct InternalHandler : at::WarningHandler {
~InternalHandler() override = default;
void process(
const at::SourceLocation& source_location,
const std::string& msg,
const bool verbatim) override;
std::vector<WarningMeta> warning_buffer_;
};
public:
/// See NOTE [ Conversion Cpp Python Warning ] for noexcept justification
TORCH_API PyWarningHandler() noexcept(true);
// NOLINTNEXTLINE(bugprone-exception-escape)
TORCH_API ~PyWarningHandler() noexcept(false);
/** Call if an exception has been thrown
* Necessary to determine if it is safe to throw from the desctructor since
* std::uncaught_exception is buggy on some platforms and generally
* unreliable across dynamic library calls.
*/
void set_in_exception() {
in_exception_ = true;
}
private:
InternalHandler internal_handler_;
at::WarningHandler* prev_handler_;
bool in_exception_;
};
namespace detail {
template <typename Func, size_t i>
using Arg = typename function_traits<Func>::template arg<i>::type;
template <typename Func, size_t... Is>
auto wrap_pybind_function_impl_(Func&& f, std::index_sequence<Is...>) {
using traits = function_traits<Func>;
namespace py = pybind11;
// f=f is needed to handle function references on older compilers
return [f = f](Arg<Func, Is>... args) -> typename traits::result_type {
HANDLE_TH_ERRORS
return f(std::forward<Arg<Func, Is>>(args)...);
END_HANDLE_TH_ERRORS_PYBIND
};
}
} // namespace detail
// Wrap a function with TH error and warning handling.
// Returns a function object suitable for registering with pybind11.
template <typename Func>
auto wrap_pybind_function(Func&& f) {
using traits = function_traits<Func>;
return torch::detail::wrap_pybind_function_impl_(
std::forward<Func>(f), std::make_index_sequence<traits::arity>{});
}
} // namespace torch
|