1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
|
#include <c10/core/DispatchKeySet.h>
#include <c10/core/SafePyObject.h>
#include <c10/core/impl/LocalDispatchKeySet.h>
#include <c10/core/impl/TorchDispatchModeTLS.h>
namespace c10 {
namespace impl {
thread_local TorchDispatchModeTLS torchDispatchModeState;
// MODE
void TorchDispatchModeTLS::set_mode(std::shared_ptr<SafePyObject> mode) {
if (mode) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonTLSSnapshot, true);
} else {
TorchDispatchModeTLS::reset_mode();
}
torchDispatchModeState.mode_ = std::move(mode);
}
const std::shared_ptr<SafePyObject>& TorchDispatchModeTLS::get_mode() {
return torchDispatchModeState.mode_;
}
void TorchDispatchModeTLS::reset_mode() {
torchDispatchModeState.mode_.reset();
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonTLSSnapshot, false);
}
void TorchDispatchModeTLS::swap_mode(std::shared_ptr<SafePyObject>& mode) {
if (mode) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonTLSSnapshot, true);
} else {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonTLSSnapshot, false);
}
torchDispatchModeState.mode_.swap(mode);
}
// STACK
void TorchDispatchModeTLS::push_onto_stack(std::shared_ptr<SafePyObject> mode) {
torchDispatchModeState.stack_.push_back(std::move(mode));
}
const std::shared_ptr<SafePyObject> TorchDispatchModeTLS::pop_stack() {
TORCH_CHECK(
torchDispatchModeState.stack_.size() > 0,
"trying to pop from empty mode stack");
const std::shared_ptr<SafePyObject> out =
torchDispatchModeState.stack_.back();
torchDispatchModeState.stack_.pop_back();
return out;
}
const std::shared_ptr<SafePyObject>& TorchDispatchModeTLS::get_stack_at(
int64_t idx) {
TORCH_CHECK(
idx < static_cast<int64_t>(torchDispatchModeState.stack_.size()),
"Tried to get stack at idx that's too big");
return torchDispatchModeState.stack_[idx];
}
int64_t TorchDispatchModeTLS::stack_len() {
return torchDispatchModeState.stack_.size();
}
// STATE
const TorchDispatchModeTLS& TorchDispatchModeTLS::get_state() {
return torchDispatchModeState;
}
void TorchDispatchModeTLS::set_state(const TorchDispatchModeTLS& state) {
torchDispatchModeState = state;
}
// UTIL
bool dispatch_mode_enabled() {
return static_cast<bool>(c10::impl::TorchDispatchModeTLS::get_mode());
}
} // namespace impl
} // namespace c10
|