File: SafePyObject.h

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (120 lines) | stat: -rw-r--r-- 3,741 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#pragma once

#include <c10/core/impl/PyInterpreter.h>
#include <c10/macros/Export.h>
#include <c10/util/python_stub.h>
#include <utility>

namespace c10 {

// This is an safe owning holder for a PyObject, akin to pybind11's
// py::object, with two major differences:
//
//  - It is in c10/core; i.e., you can use this type in contexts where
//    you do not have a libpython dependency
//
//  - It is multi-interpreter safe (ala torchdeploy); when you fetch
//    the underlying PyObject* you are required to specify what the current
//    interpreter context is and we will check that you match it.
//
// It is INVALID to store a reference to a Tensor object in this way;
// you should just use TensorImpl directly in that case!
struct C10_API SafePyObject {
  // Steals a reference to data
  SafePyObject(PyObject* data, c10::impl::PyInterpreter* pyinterpreter)
      : data_(data), pyinterpreter_(pyinterpreter) {}
  SafePyObject(SafePyObject&& other) noexcept
      : data_(std::exchange(other.data_, nullptr)),
        pyinterpreter_(other.pyinterpreter_) {}
  // For now it's not used, so we just disallow it.
  SafePyObject& operator=(SafePyObject&&) = delete;

  SafePyObject(SafePyObject const& other)
      : data_(other.data_), pyinterpreter_(other.pyinterpreter_) {
    if (data_ != nullptr) {
      (*pyinterpreter_)->incref(data_);
    }
  }

  SafePyObject& operator=(SafePyObject const& other) {
    if (this == &other) {
      return *this; // Handle self-assignment
    }
    if (other.data_ != nullptr) {
      (*other.pyinterpreter_)->incref(other.data_);
    }
    if (data_ != nullptr) {
      (*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false);
    }
    data_ = other.data_;
    pyinterpreter_ = other.pyinterpreter_;
    return *this;
  }

  ~SafePyObject() {
    if (data_ != nullptr) {
      (*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false);
    }
  }

  c10::impl::PyInterpreter& pyinterpreter() const {
    return *pyinterpreter_;
  }
  PyObject* ptr(const c10::impl::PyInterpreter*) const;

  // stop tracking the current object, and return it
  PyObject* release() {
    auto rv = data_;
    data_ = nullptr;
    return rv;
  }

 private:
  PyObject* data_;
  c10::impl::PyInterpreter* pyinterpreter_;
};

// A newtype wrapper around SafePyObject for type safety when a python object
// represents a specific type. Note that `T` is only used as a tag and isn't
// actually used for any true purpose.
template <typename T>
struct SafePyObjectT : private SafePyObject {
  SafePyObjectT(PyObject* data, c10::impl::PyInterpreter* pyinterpreter)
      : SafePyObject(data, pyinterpreter) {}
  ~SafePyObjectT() = default;
  SafePyObjectT(SafePyObjectT&& other) noexcept : SafePyObject(other) {}
  SafePyObjectT(SafePyObjectT const&) = delete;
  SafePyObjectT& operator=(SafePyObjectT const&) = delete;
  SafePyObjectT& operator=(SafePyObjectT&&) = delete;

  using SafePyObject::ptr;
  using SafePyObject::pyinterpreter;
  using SafePyObject::release;
};

// Like SafePyObject, but non-owning.  Good for references to global PyObjects
// that will be leaked on interpreter exit.  You get a copy constructor/assign
// this way.
struct C10_API SafePyHandle {
  SafePyHandle() : data_(nullptr), pyinterpreter_(nullptr) {}
  SafePyHandle(PyObject* data, c10::impl::PyInterpreter* pyinterpreter)
      : data_(data), pyinterpreter_(pyinterpreter) {}

  c10::impl::PyInterpreter& pyinterpreter() const {
    return *pyinterpreter_;
  }
  PyObject* ptr(const c10::impl::PyInterpreter*) const;
  void reset() {
    data_ = nullptr;
    pyinterpreter_ = nullptr;
  }
  operator bool() {
    return data_;
  }

 private:
  PyObject* data_;
  c10::impl::PyInterpreter* pyinterpreter_;
};

} // namespace c10