1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
|
#include <torch/csrc/dynamo/extra_state.h>
#include <torch/csrc/dynamo/cache_entry.h>
#include <torch/csrc/dynamo/cpython_defs.h>
#include <torch/csrc/dynamo/debug_macros.h>
#include <torch/csrc/dynamo/framelocals_mapping.h>
#include <torch/csrc/dynamo/guards.h>
#include <torch/csrc/utils/python_compat.h>
#if IS_PYTHON_3_12_PLUS
#define _PyCode_GetExtra PyUnstable_Code_GetExtra
#define _PyCode_SetExtra PyUnstable_Code_SetExtra
#endif
Py_ssize_t extra_index = -1;
CacheEntry* ExtraState::get_first_entry() {
if (this->cache_entry_list.empty()) {
return nullptr;
}
return &this->cache_entry_list.front();
}
ExtraState::ExtraState(PyCodeObject* orig_code_arg)
: orig_code(orig_code_arg) {}
void ExtraState::move_to_front(CacheEntry* cache_entry) {
CHECK(cache_entry->_owner == this);
CHECK(!this->cache_entry_list.empty());
CHECK(cache_entry == &*cache_entry->_owner_loc);
this->cache_entry_list.splice(
this->cache_entry_list.begin(),
this->cache_entry_list,
cache_entry->_owner_loc);
}
void ExtraState::move_to_back(CacheEntry* cache_entry) {
CHECK(cache_entry->_owner == this);
CHECK(!this->cache_entry_list.empty());
CHECK(cache_entry == &*cache_entry->_owner_loc);
this->cache_entry_list.splice(
this->cache_entry_list.end(),
this->cache_entry_list,
cache_entry->_owner_loc);
}
void ExtraState::invalidate(
CacheEntry* cache_entry,
py::object deleted_guard_manager) {
// Sometimes setting the cache_entry->code to None causes the orig_code to be
// freed. This calls destroy_extra_state, which deletes the extra_state and
// all the cache_entries. This causes the `this` pointer to be a dangling
// pointer, causing a segfault. So, we manually inc/dec ref the original code
// pointer to prevent triggering of destroy_extra_state while the invalidate
// function is running.
Py_INCREF(this->orig_code);
CHECK(cache_entry->_owner == this);
CHECK(!this->cache_entry_list.empty());
CHECK(cache_entry == &*cache_entry->_owner_loc);
cache_entry->invalidate(std::move(deleted_guard_manager));
// Move the cache entry to the end of the list because these will always
// return False.
cache_entry->_owner->move_to_back(cache_entry);
Py_DECREF(this->orig_code);
}
static bool is_extra_state_unset(ExtraState* extra_state) {
return extra_state == nullptr || extra_state == SKIP_CODE ||
extra_state == SKIP_CODE_RECURSIVE;
}
CacheEntry* extract_cache_entry(ExtraState* extra_state) {
if (is_extra_state_unset(extra_state)) {
return nullptr;
}
return extra_state->get_first_entry();
}
FrameState* extract_frame_state(ExtraState* extra_state) {
if (is_extra_state_unset(extra_state)) {
return nullptr;
}
return (FrameState*)extra_state->frame_state.ptr();
}
bool extra_state_cache_limit_hit(ExtraState* extra_state) {
return extra_state->cache_limit_hit;
}
void set_extra_state_cache_limit_hit(ExtraState* extra_state, bool value) {
extra_state->cache_limit_hit = value;
}
ExtraState* get_extra_state(PyCodeObject* code) {
ExtraState* extra = nullptr;
_PyCode_GetExtra((PyObject*)code, extra_index, (void**)&extra);
return extra;
}
void destroy_extra_state(void* obj) {
ExtraState* extra = (ExtraState*)obj;
if (!is_extra_state_unset(extra)) {
delete extra;
}
}
void set_extra_state(PyCodeObject* code, ExtraState* extra_state) {
ExtraState* old_extra_state = get_extra_state(code);
CHECK(is_extra_state_unset(extra_state) || old_extra_state != extra_state);
_PyCode_SetExtra((PyObject*)code, extra_index, extra_state);
}
ExtraState* init_and_set_extra_state(PyCodeObject* code) {
// Invariant - Extra state should not have been set before, therefore it
// should be nullptr.
CHECK(get_extra_state(code) == nullptr);
ExtraState* extra_state = new ExtraState(code);
NULL_CHECK(extra_state);
set_extra_state(code, extra_state);
// freed by destroy_extra_state (since we need to pass these objects to C)
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
return extra_state;
}
static bool backend_match(PyObject* saved_backend, PyObject* backend) {
// Pointer equality check for common case
if (saved_backend != backend) {
// The Py_TYPE check should not be required but there is a pre-existing
// issue where backend is possibly deallocated (or nullptr) and causes
// segfaults. Check test - test_inplace_custom_op_intermediate
return (
Py_TYPE(saved_backend) == Py_TYPE(backend) &&
PyObject_RichCompareBool(saved_backend, backend, Py_EQ));
}
return true;
}
void lookup(
ExtraState* extra_state,
PyObject* f_locals,
PyObject* backend,
PyObject** maybe_cached_code,
const char** trace_annotation,
bool is_skip_guard_eval_unsafe) {
size_t index = 0;
CacheEntry* found = nullptr;
py::handle locals(f_locals);
for (CacheEntry& cache_entry : extra_state->cache_entry_list) {
// Check backend. Py_False means run only mode.
bool valid =
backend == Py_False || backend_match(cache_entry.backend, backend);
if (valid) {
try {
if (is_skip_guard_eval_unsafe) {
valid = torch::dynamo::run_root_guard_manager(
cache_entry.diff_guard_root_mgr, f_locals);
} else {
valid = torch::dynamo::run_root_guard_manager(
cache_entry.root_mgr, f_locals);
}
} catch (py::error_already_set& e) {
if (guard_error_hook) {
py::handle guard_error_hook_handle(guard_error_hook);
guard_error_hook_handle(
cache_entry.guard_manager,
cache_entry.code,
locals,
index,
index == extra_state->cache_entry_list.size() - 1);
}
// this function is called from C, so we cannot repropagate
// the exception
e.restore();
*maybe_cached_code = nullptr;
return;
}
}
if (valid) {
found = &cache_entry;
break;
}
++index;
}
if (found) {
extra_state->move_to_front(found);
*maybe_cached_code = found->code.ptr();
*trace_annotation = found->trace_annotation.c_str();
return;
}
*maybe_cached_code = py::none().ptr();
}
CacheEntry* create_cache_entry(
ExtraState* extra_state,
PyObject* guarded_code,
PyObject* backend) {
extra_state->cache_entry_list.emplace_front(guarded_code, backend);
auto new_iter = extra_state->cache_entry_list.begin();
new_iter->_owner = extra_state;
new_iter->_owner_loc = new_iter;
// Set guard_manager references to extra_state and CacheEntry
// Warning: lifetime is controlled by C++!
py::handle guard_manager = py::handle(guarded_code).attr("guard_manager");
guard_manager.attr("cache_entry") =
py::cast(*new_iter, py::return_value_policy::reference);
guard_manager.attr("extra_state") =
py::cast(extra_state, py::return_value_policy::reference);
return &*new_iter;
}
py::list _debug_get_cache_entry_list(const py::handle& code_obj) {
if (!py::isinstance(code_obj, py::module::import("types").attr("CodeType"))) {
throw py::type_error("expected a code object!");
}
PyCodeObject* code = (PyCodeObject*)code_obj.ptr();
ExtraState* extra = get_extra_state(code);
py::list result;
if (!is_extra_state_unset(extra)) {
for (CacheEntry& e : extra->cache_entry_list) {
result.append(py::cast(e, py::return_value_policy::reference));
}
}
return result;
}
|