1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
|
#pragma once
#include <c10/core/impl/TorchDispatchModeTLS.h>
namespace torch::torch_dispatch_mode {
struct StashTorchDispatchModeGuard {
public:
StashTorchDispatchModeGuard() {
if (c10::impl::TorchDispatchModeTLS::any_modes_set(
/*skip_infra_modes=*/true)) {
saved_mode_ = c10::impl::TorchDispatchModeTLS::pop_stack();
} else {
auto mode_and_key =
c10::impl::TorchDispatchModeTLS::pop_highest_infra_mode();
saved_mode_ = std::move(std::get<0>(mode_and_key));
saved_mode_key_ = std::get<1>(mode_and_key);
}
}
~StashTorchDispatchModeGuard() {
if (saved_mode_key_.has_value()) {
c10::impl::TorchDispatchModeTLS::set_mode(
saved_mode_, saved_mode_key_.value());
} else {
c10::impl::TorchDispatchModeTLS::push_non_infra_mode_onto_stack(
std::move(saved_mode_));
}
}
StashTorchDispatchModeGuard(const StashTorchDispatchModeGuard&) = delete;
StashTorchDispatchModeGuard(StashTorchDispatchModeGuard&&) = delete;
StashTorchDispatchModeGuard& operator=(const StashTorchDispatchModeGuard&) =
delete;
StashTorchDispatchModeGuard& operator=(StashTorchDispatchModeGuard&&) =
delete;
const std::shared_ptr<c10::impl::PyObject_TorchDispatchMode>& get_cur_mode() {
return saved_mode_;
}
private:
std::shared_ptr<c10::impl::PyObject_TorchDispatchMode> saved_mode_;
std::optional<c10::impl::TorchDispatchModeKey> saved_mode_key_;
};
struct StashTorchDispatchStackGuard {
public:
StashTorchDispatchStackGuard() {
auto old = c10::impl::TorchDispatchModeTLS::get_state();
c10::impl::TorchDispatchModeTLS::set_state(std::move(saved_state_));
saved_state_ = std::move(old);
}
StashTorchDispatchStackGuard(const StashTorchDispatchStackGuard&) = delete;
StashTorchDispatchStackGuard(StashTorchDispatchStackGuard&&) = delete;
StashTorchDispatchStackGuard& operator=(const StashTorchDispatchStackGuard&) =
delete;
StashTorchDispatchStackGuard& operator=(StashTorchDispatchStackGuard&&) =
delete;
~StashTorchDispatchStackGuard() {
c10::impl::TorchDispatchModeTLS::set_state(std::move(saved_state_));
}
private:
c10::impl::TorchDispatchModeTLS saved_state_;
};
} // namespace torch::torch_dispatch_mode
|