1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
|
#include <torch/csrc/dynamo/cache_entry.h>
#include <torch/csrc/dynamo/guards.h>
#include <torch/csrc/dynamo/debug_macros.h>
#include <torch/csrc/dynamo/extra_state.h>
CacheEntry::CacheEntry(const py::handle& guarded_code, PyObject* backend)
: backend{backend} {
this->guard_manager = guarded_code.attr("guard_manager");
this->code = guarded_code.attr("code");
this->compile_id = guarded_code.attr("compile_id");
py::object trace_annotation = guarded_code.attr("trace_annotation");
const char* trace_annotation_str = PyUnicode_AsUTF8(trace_annotation.ptr());
if (trace_annotation) {
this->trace_annotation = std::string(trace_annotation_str);
} else {
this->trace_annotation = "Unknown";
}
this->root_mgr = torch::dynamo::convert_to_root_guard_manager(
this->guard_manager.attr("root"));
this->diff_guard_root_mgr = torch::dynamo::convert_to_root_guard_manager(
this->guard_manager.attr("diff_guard_root"));
}
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED(
"-Wdeprecated-copy-with-user-provided-dtor")
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated-copy-dtor")
// NOLINTNEXTLINE(bugprone-exception-escape)
CacheEntry::~CacheEntry() {
// prevent guard_manager from use-after-free when invalidating
this->guard_manager.attr("cache_entry") = py::none();
this->guard_manager.attr("extra_state") = py::none();
}
C10_DIAGNOSTIC_POP()
C10_DIAGNOSTIC_POP()
py::object CacheEntry::next() {
NULL_CHECK(this->_owner);
auto it = this->_owner_loc;
++it;
if (it == this->_owner->cache_entry_list.end()) {
return py::none();
}
return py::cast(*it, py::return_value_policy::reference);
}
void CacheEntry::invalidate(py::object deleted_guard_manager) {
// Keep the current pointer alive but make the fields as if no-op
this->guard_manager.attr("cache_entry") = py::none();
this->guard_manager.attr("extra_state") = py::none();
this->code = py::none();
this->guard_manager = std::move(deleted_guard_manager);
this->root_mgr = nullptr;
this->trace_annotation = "Invalidated";
}
void CacheEntry::update_diff_guard_root_manager() {
this->diff_guard_root_mgr = torch::dynamo::convert_to_root_guard_manager(
this->guard_manager.attr("diff_guard_root"));
}
PyCodeObject* CacheEntry_get_code(CacheEntry* e) {
return (PyCodeObject*)e->code.ptr();
}
const char* CacheEntry_get_trace_annotation(CacheEntry* e) {
return e->trace_annotation.c_str();
}
PyObject* CacheEntry_to_obj(CacheEntry* e) {
if (!e) {
return py::none().release().ptr();
}
return py::cast(e, py::return_value_policy::reference).release().ptr();
}
PyObject* get_backend(PyObject* callback) {
py::handle handle = py::handle(callback);
while (py::hasattr(handle, "_torchdynamo_orig_callable")) {
handle = handle.attr("_torchdynamo_orig_callable");
}
return handle.ptr();
}
|