1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328
|
#include <pybind11/pybind11.h>
#include <torch/csrc/Device.h>
#include <torch/csrc/Event.h>
#include <torch/csrc/Stream.h>
#include <torch/csrc/THP.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/pycfunction_helpers.h>
#include <torch/csrc/utils/python_arg_parser.h>
#include <c10/core/Event.h>
#include <c10/core/Stream.h>
#include <c10/core/DeviceType.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <structmember.h>
#include <string>
PyTypeObject* THPEventClass = nullptr;
static PyObject* THPEvent_pynew(
PyTypeObject* type,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
unsigned char enable_timing = 0;
unsigned char blocking = 0;
unsigned char interprocess = 0;
static torch::PythonArgParser parser({
"Event(Device device=None, *, bool enable_timing=True, bool blocking=False, bool interprocess=False)",
});
torch::ParsedArgs<4> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
auto device = r.deviceOptional(0);
if (!device.has_value()) {
device = at::Device(at::getAccelerator(false).value_or(at::kCPU));
}
enable_timing = r.toBoolWithDefault(1, true);
blocking = r.toBoolWithDefault(2, false);
interprocess = r.toBoolWithDefault(3, false);
THPObjectPtr ptr(type->tp_alloc(type, 0));
if (!ptr) {
TORCH_CHECK(ptr, "Failed to allocate memory for Event");
}
THPEvent* self = (THPEvent*)ptr.get();
// TODO: blocking and interprocess are not supported yet. To support them, the
// flag system of c10::Event needs to be refactored. C10::Event should also
// provide a generic constructor to support blocking and interprocess events.
(void)blocking;
(void)interprocess;
new (&self->event) c10::Event(
device->type(),
// See note [Flags defining the behavior of events]
// BACKEND_DEFAULT is a enable-timing flag, and
// PYTORCH_DEFAULT is a disable-timing flag.
(enable_timing ? c10::EventFlag::BACKEND_DEFAULT
: c10::EventFlag::PYTORCH_DEFAULT));
return (PyObject*)ptr.release();
END_HANDLE_TH_ERRORS
}
PyObject* THPEvent_new(c10::DeviceType device_type, c10::EventFlag flag) {
auto type = (PyTypeObject*)&THPEventType;
auto self = THPObjectPtr{type->tp_alloc(type, 0)};
TORCH_CHECK(self, "Failed to allocate memory for Event");
auto self_ = reinterpret_cast<THPEvent*>(self.get());
new (&self_->event) c10::Event(device_type, flag);
return self.release();
}
static void THPEvent_dealloc(THPEvent* self) {
{
pybind11::gil_scoped_release no_gil{};
self->event.~Event();
}
Py_TYPE(self)->tp_free((PyObject*)self);
}
static PyObject* THPEvent_get_device(THPEvent* self, void* unused) {
HANDLE_TH_ERRORS
return THPDevice_New(self->event.device());
END_HANDLE_TH_ERRORS
}
static PyObject* THPEvent_record(
PyObject* _self,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
auto self = (THPEvent*)_self;
PyObject* _stream = Py_None;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
constexpr const char* accepted_args[] = {"stream", nullptr};
if (!PyArg_ParseTupleAndKeywords(
args,
kwargs,
"|O",
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<char**>(accepted_args),
&_stream)) {
TORCH_WARN("Parsing THPEvent_record arg fails");
return nullptr;
}
if (_stream != Py_None) {
auto stream = (THPStream*)_stream;
self->event.record(c10::Stream::unpack3(
stream->stream_id,
static_cast<c10::DeviceIndex>(stream->device_index),
static_cast<c10::DeviceType>(stream->device_type)));
} else {
c10::impl::VirtualGuardImpl impl{
static_cast<c10::DeviceType>(self->event.device_type())};
self->event.record(impl.getStream(impl.getDevice()));
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject* THPEvent_from_ipc_handle(
PyObject* _type,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
auto type = (PyTypeObject*)_type;
static torch::PythonArgParser parser({
"from_ipc_handle(Device device, std::string ipc_handle)",
});
torch::ParsedArgs<2> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
at::Device device = r.device(0);
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"torch.Event ipc is not supported yet, please open an issue if you need this!");
THPObjectPtr ptr(type->tp_alloc(type, 0));
if (!ptr) {
return nullptr;
}
THPEvent* self = (THPEvent*)ptr.get();
// TODO: for constructing event from ipc handle, the c10::Event needs to have
// more general constructor to achieve that.
new (&self->event) c10::Event(device.type(), c10::EventFlag::PYTORCH_DEFAULT);
return (PyObject*)ptr.release();
END_HANDLE_TH_ERRORS
}
static PyObject* THPEvent_ipc_handle(
PyObject* _self [[maybe_unused]],
PyObject* noargs) {
HANDLE_TH_ERRORS
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"torch.Event ipc is not supported yet, please open an issue if you need this!");
constexpr const char* handle = "0";
return PyBytes_FromStringAndSize(
handle, std::char_traits<char>::length(handle));
END_HANDLE_TH_ERRORS
}
static PyObject* THPEvent_wait(
PyObject* _self,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS {
auto self = (THPEvent*)_self;
PyObject* _stream = Py_None;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
constexpr const char* accepted_args[] = {"stream", nullptr};
if (!PyArg_ParseTupleAndKeywords(
args,
kwargs,
"|O",
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<char**>(accepted_args),
&_stream)) {
TORCH_WARN("Parsing THPEvent_wait arg fails");
return nullptr;
}
if (_stream != Py_None) {
auto stream = (THPStream*)_stream;
self->event.block(c10::Stream::unpack3(
stream->stream_id,
static_cast<c10::DeviceIndex>(stream->device_index),
static_cast<c10::DeviceType>(stream->device_type)));
} else {
c10::impl::VirtualGuardImpl impl{
static_cast<c10::DeviceType>(self->event.device_type())};
self->event.block(impl.getStream(impl.getDevice()));
}
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject* THPEvent_query(PyObject* _self, PyObject* noargs) {
HANDLE_TH_ERRORS
auto self = (THPEvent*)_self;
return PyBool_FromLong(self->event.query());
END_HANDLE_TH_ERRORS
}
static PyObject* THPEvent_elapsed_time(PyObject* _self, PyObject* _other) {
HANDLE_TH_ERRORS
auto self = (THPEvent*)_self;
auto other = (THPEvent*)_other;
return PyFloat_FromDouble(self->event.elapsedTime(other->event));
END_HANDLE_TH_ERRORS
}
static PyObject* THPEvent_synchronize(PyObject* _self, PyObject* noargs) {
HANDLE_TH_ERRORS {
pybind11::gil_scoped_release no_gil{};
auto self = (THPEvent*)_self;
self->event.synchronize();
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject* THPEvent_evend_id(PyObject* _self, PyObject* noargs) {
HANDLE_TH_ERRORS
auto self = (THPEvent*)_self;
return PyLong_FromVoidPtr(self->event.eventId());
END_HANDLE_TH_ERRORS
}
static PyObject* THPEvent_repr(THPEvent* self) {
HANDLE_TH_ERRORS
return THPUtils_packString(
"torch.Event device_type=" +
c10::DeviceTypeName(
static_cast<c10::DeviceType>(self->event.device_type()), true) +
", device_index=" + std::to_string(self->event.device_index()) +
", event_flag=" +
std::to_string(static_cast<int64_t>(self->event.flag())) + ", event_id=" +
std::to_string(reinterpret_cast<int64_t>(self->event.eventId())));
END_HANDLE_TH_ERRORS
}
// NOLINTNEXTLINE(*c-arrays*, *global-variables)
static struct PyGetSetDef THPEvent_properties[] = {
{"device", (getter)THPEvent_get_device, nullptr, nullptr, nullptr},
{"event_id", (getter)THPEvent_evend_id, nullptr, nullptr, nullptr},
{nullptr}};
// NOLINTNEXTLINE(*c-arrays*, *global-variables)
static PyMethodDef THPEvent_methods[] = {
{(char*)"from_ipc_handle",
castPyCFunctionWithKeywords(THPEvent_from_ipc_handle),
METH_CLASS | METH_VARARGS | METH_KEYWORDS,
nullptr},
{(char*)"record",
castPyCFunctionWithKeywords(THPEvent_record),
METH_VARARGS | METH_KEYWORDS,
nullptr},
{(char*)"wait",
castPyCFunctionWithKeywords(THPEvent_wait),
METH_VARARGS | METH_KEYWORDS,
nullptr},
{(char*)"query", THPEvent_query, METH_NOARGS, nullptr},
{(char*)"elapsed_time", THPEvent_elapsed_time, METH_O, nullptr},
{(char*)"synchronize", THPEvent_synchronize, METH_NOARGS, nullptr},
{(char*)"ipc_handle", THPEvent_ipc_handle, METH_NOARGS, nullptr},
{nullptr}};
PyTypeObject THPEventType = {
PyVarObject_HEAD_INIT(nullptr, 0)
"torch.Event", /* tp_name */
sizeof(THPEvent), /* tp_basicsize */
0, /* tp_itemsize */
(destructor)THPEvent_dealloc, /* tp_dealloc */
0, /* tp_vectorcall_offset */
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
(reprfunc)THPEvent_repr, /* tp_repr */
nullptr, /* tp_as_number */
nullptr, /* tp_as_sequence */
nullptr, /* tp_as_mapping */
nullptr, /* tp_hash */
nullptr, /* tp_call */
nullptr, /* tp_str */
nullptr, /* tp_getattro */
nullptr, /* tp_setattro */
nullptr, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
nullptr, /* tp_doc */
nullptr, /* tp_traverse */
nullptr, /* tp_clear */
nullptr, /* tp_richcompare */
0, /* tp_weaklistoffset */
nullptr, /* tp_iter */
nullptr, /* tp_iternext */
THPEvent_methods, /* tp_methods */
nullptr, /* tp_members */
THPEvent_properties, /* tp_getset */
nullptr, /* tp_base */
nullptr, /* tp_dict */
nullptr, /* tp_descr_get */
nullptr, /* tp_descr_set */
0, /* tp_dictoffset */
nullptr, /* tp_init */
nullptr, /* tp_alloc */
THPEvent_pynew, /* tp_new */
};
void THPEvent_init(PyObject* module) {
THPEventClass = &THPEventType;
if (PyType_Ready(&THPEventType) < 0) {
throw python_error();
}
Py_INCREF(&THPEventType);
if (PyModule_AddObject(module, "Event", (PyObject*)&THPEventType) < 0) {
throw python_error();
}
}
|