1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
|
#pragma once
#include <c10/core/impl/TorchDispatchModeTLS.h>
namespace torch {
namespace torch_dispatch_mode {
struct StashTorchDispatchModeGuard {
public:
StashTorchDispatchModeGuard() {
c10::impl::TorchDispatchModeTLS::swap_mode(saved_mode_);
}
~StashTorchDispatchModeGuard() {
c10::impl::TorchDispatchModeTLS::set_mode(std::move(saved_mode_));
}
private:
std::shared_ptr<at::SafePyObject> saved_mode_;
};
} // namespace torch_dispatch_mode
} // namespace torch
|