1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
|
#include <torch/csrc/utils/disable_torch_function.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/Exceptions.h>
namespace torch {
static thread_local bool enable_torch_function = true;
PyObject* disabled_torch_function = nullptr;
bool torch_function_enabled() {
return enable_torch_function;
}
PyObject* disabled_torch_function_impl() {
return disabled_torch_function;
}
void set_disabled_torch_function_impl(PyObject* value) {
disabled_torch_function = value;
}
}
typedef struct {
PyObject_HEAD
/* Type-specific fields go here. */
bool old_state;
} DisableTorchFunction;
PyObject* DisableTorchFunction__enter(PyObject* self, PyObject *unused) {
((DisableTorchFunction*)self)->old_state = torch::enable_torch_function;
torch::enable_torch_function = false;
Py_RETURN_NONE;
}
PyObject* DisableTorchFunction__exit(PyObject* self, PyObject *unused) {
torch::enable_torch_function = ((DisableTorchFunction*)self)->old_state;
Py_RETURN_NONE;
}
PyObject* THPModule_isEnabledTorchFunction(PyObject* self, PyObject *unused) {
if (torch::enable_torch_function) {
Py_RETURN_TRUE;
} else
{
Py_RETURN_FALSE;
}
}
static PyMethodDef DisableTorchFunction_methods[] = { // NOLINT
{"__enter__", (PyCFunction)DisableTorchFunction__enter, METH_NOARGS, nullptr},
{"__exit__", (PyCFunction)DisableTorchFunction__exit, METH_VARARGS, nullptr},
{nullptr, nullptr, 0, nullptr}
};
PyTypeObject DisableTorchFunctionType = {
PyVarObject_HEAD_INIT(nullptr, 0)
"torch._C.DisableTorchFunction", /* tp_name */
sizeof(DisableTorchFunction), /* tp_basicsize */
0, /* tp_itemsize */
nullptr, /* tp_dealloc */
0, /* tp_vectorcall_offset */
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
nullptr, /* tp_repr */
nullptr, /* tp_as_number */
nullptr, /* tp_as_sequence */
nullptr, /* tp_as_mapping */
nullptr, /* tp_hash */
nullptr, /* tp_call */
nullptr, /* tp_str */
nullptr, /* tp_getattro */
nullptr, /* tp_setattro */
nullptr, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT, /* tp_flags */
nullptr, /* tp_doc */
nullptr, /* tp_traverse */
nullptr, /* tp_clear */
nullptr, /* tp_richcompare */
0, /* tp_weaklistoffset */
nullptr, /* tp_iter */
nullptr, /* tp_iternext */
DisableTorchFunction_methods, /* tp_methods */
nullptr, /* tp_members */
nullptr, /* tp_getset */
nullptr, /* tp_base */
nullptr, /* tp_dict */
nullptr, /* tp_descr_get */
nullptr, /* tp_descr_set */
0, /* tp_dictoffset */
nullptr, /* tp_init */
PyType_GenericAlloc, /* tp_alloc */
PyType_GenericNew, /* tp_new */
};
PyObject* THPModule_DisableTorchFunctionType() {
if (PyType_Ready(&DisableTorchFunctionType) < 0) {
return nullptr;
}
return (PyObject *)(&DisableTorchFunctionType);
}
PyObject* THPModule_disable_torch_function(PyObject *self, PyObject *a) {
HANDLE_TH_ERRORS
PyObject *func=nullptr, *types=nullptr, *args=nullptr, *kwargs=nullptr;
if (!PyArg_ParseTuple(a, "OO|OO", &func, &types, &args, &kwargs)) {
return nullptr;
}
py::tuple py_args;
if (args == nullptr) {
py_args = py::make_tuple();
}
else {
py_args = py::reinterpret_borrow<py::tuple>(args);
}
// These are all C-API calls so no exceptions will be raised
// and therefore no need for RAII approach to storing
// the old value.
bool old_value = torch::enable_torch_function;
torch::enable_torch_function = false;
// kwargs can safely be nullptr here.
PyObject *result = PyObject_Call(func, py_args.ptr(), kwargs);
torch::enable_torch_function = old_value;
return result;
END_HANDLE_TH_ERRORS
}
|