File: python_saved_variable_hooks.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (89 lines) | stat: -rw-r--r-- 3,179 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#include <ATen/SavedTensorHooks.h>
#include <torch/csrc/autograd/python_saved_variable_hooks.h>

#include <torch/csrc/THP.h>

namespace py = pybind11;

namespace torch {
namespace autograd {
PySavedVariableHooks::PySavedVariableHooks(
    py::function& pack_hook,
    py::function& unpack_hook)
    : // steals the reference (we will decref ourselves)
      pack_hook_(pack_hook.release().ptr()),
      unpack_hook_(unpack_hook.release().ptr()) {}

void PySavedVariableHooks::call_pack_hook(const at::Tensor& tensor) {
  py::gil_scoped_acquire acquire;
  auto pack_hook = py::reinterpret_borrow<py::function>(pack_hook_);
  auto wrapped = THPVariable_Wrap(tensor);
  py::object obj = py::reinterpret_steal<py::object>(wrapped);
  py::object packed = pack_hook(obj);
  data_ = packed.release().ptr();
  // pack_hook, obj are decrefed on exit
  // wrapped and packed had their references stolen
  // pack_hook_ and data_ will be manually decrefed when the saved variable is
  // released
}

at::Tensor PySavedVariableHooks::call_unpack_hook() {
  py::gil_scoped_acquire acquire;
  auto unpack_hook = py::reinterpret_borrow<py::function>(unpack_hook_);
  py::object obj = py::cast<py::object>(data_);
  py::object res = unpack_hook(obj);
  PyObject* ptr = res.ptr();
  TORCH_CHECK_TYPE(
      THPVariable_Check(ptr),
      "Output of saved tensor unpack_hook expected to be a Tensor but got result of type ",
      THPUtils_typename(ptr));
  return THPVariable_Unpack(ptr);
  // unpack_hook, obj and res are decrefed on exit
  // ptr is only alive as long as res is
  // unpack_hook_ will be manually decrefed when the saved variable is released
}

PySavedVariableHooks::~PySavedVariableHooks() {
  // If python is already dead, leak the wrapped python objects
  if (Py_IsInitialized()) {
    py::gil_scoped_acquire gil;
    Py_XDECREF(pack_hook_);
    Py_XDECREF(unpack_hook_);
    Py_XDECREF(data_);
  }
}

void PyDefaultSavedVariableHooks::push_hooks(
    py::function& pack_hook,
    py::function& unpack_hook) {
  at::SavedTensorDefaultHooks::lazy_initialize();
  at::SavedTensorDefaultHooks::push_hooks(
      pack_hook.release().ptr(), unpack_hook.release().ptr());
}

void PyDefaultSavedVariableHooks::pop_hooks() {
  PyObject *pack_hook(nullptr), *unpack_hook(nullptr);
  std::tie(pack_hook, unpack_hook) = at::SavedTensorDefaultHooks::get_hooks();
  TORCH_INTERNAL_ASSERT(pack_hook != nullptr && unpack_hook != nullptr);
  if (Py_IsInitialized()) {
    py::gil_scoped_acquire gil;
    Py_XDECREF(pack_hook);
    Py_XDECREF(unpack_hook);
  }
  at::SavedTensorDefaultHooks::pop_hooks();
}

std::unique_ptr<SavedVariableHooks> PyDefaultSavedVariableHooks::get_hooks() {
  PyObject *pack_hook(nullptr), *unpack_hook(nullptr);
  std::tie(pack_hook, unpack_hook) = at::SavedTensorDefaultHooks::get_hooks();
  if (!pack_hook || !unpack_hook) {
    return nullptr;
  }
  py::gil_scoped_acquire gil;
  py::function pack_hook_ = py::reinterpret_borrow<py::function>(pack_hook);
  py::function unpack_hook_ = py::reinterpret_borrow<py::function>(unpack_hook);
  return std::make_unique<PySavedVariableHooks>(pack_hook_, unpack_hook_);
}

} // namespace autograd
} // namespace torch