File: python_custom_class.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (102 lines) | stat: -rw-r--r-- 4,031 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#include <torch/csrc/jit/python/python_custom_class.h>

#include <torch/csrc/jit/frontend/sugared_value.h>

#include <fmt/format.h>

namespace torch {
namespace jit {

struct CustomMethodProxy;
struct CustomObjectProxy;

py::object ScriptClass::__call__(py::args args, py::kwargs kwargs) {
  auto instance =
      Object(at::ivalue::Object::create(class_type_, /*numSlots=*/1));
  Function* init_fn = instance.type()->findMethod("__init__");
  TORCH_CHECK(
      init_fn,
      fmt::format(
          "Custom C++ class: '{}' does not have an '__init__' method bound. "
          "Did you forget to add '.def(torch::init<...>)' to its registration?",
          instance.type()->repr_str()));
  Method init_method(instance._ivalue(), init_fn);
  // NOLINTNEXTLINE(performance-move-const-arg)
  invokeScriptMethodFromPython(init_method, std::move(args), std::move(kwargs));
  return py::cast(instance);
}

/// Variant of StrongFunctionPtr, but for static methods of custom classes.
/// They do not belong to compilation units (the custom class method registry
/// serves that purpose in this case), so StrongFunctionPtr cannot be used here.
/// While it is usually unsafe to carry a raw pointer like this, the custom
/// class method registry that owns the pointer is never destroyed.
struct ScriptClassFunctionPtr {
  ScriptClassFunctionPtr(Function* function) : function_(function) {
    TORCH_INTERNAL_ASSERT(function_);
  }
  Function* function_;
};

void initPythonCustomClassBindings(PyObject* module) {
  auto m = py::handle(module).cast<py::module>();

  py::class_<ScriptClassFunctionPtr>(
      m, "ScriptClassFunction", py::dynamic_attr())
      .def("__call__", [](py::args args, const py::kwargs& kwargs) {
        auto strongPtr = py::cast<ScriptClassFunctionPtr>(args[0]);
        Function& callee = *strongPtr.function_;
        py::object result = invokeScriptFunctionFromPython(
            callee, tuple_slice(std::move(args), 1), kwargs);
        return result;
      });

  py::class_<ScriptClass>(m, "ScriptClass")
      .def("__call__", &ScriptClass::__call__)
      .def(
          "__getattr__",
          [](ScriptClass& self, const std::string& name) {
            // Define __getattr__ so that static functions of custom classes can
            // be used in regular Python.
            auto type = self.class_type_.type_->castRaw<ClassType>();
            TORCH_INTERNAL_ASSERT(type);
            auto* fn = type->findStaticMethod(name);
            if (fn) {
              return ScriptClassFunctionPtr(fn);
            }

            throw AttributeError("%s does not exist", name.c_str());
          })
      .def_property_readonly("__doc__", [](const ScriptClass& self) {
        return self.class_type_.type_->expectRef<ClassType>().doc_string();
      });

  // This function returns a ScriptClass that wraps the constructor
  // of the given class, specified by the qualified name passed in.
  //
  // This is to emulate the behavior in python where instantiation
  // of a class is a call to a code object for the class, where that
  // code object in turn calls __init__. Rather than calling __init__
  // directly, we need a wrapper that at least returns the instance
  // rather than the None return value from __init__
  m.def(
      "_get_custom_class_python_wrapper",
      [](const std::string& ns, const std::string& qualname) {
        std::string full_qualname =
            "__torch__.torch.classes." + ns + "." + qualname;
        auto named_type = getCustomClass(full_qualname);
        TORCH_CHECK(
            named_type,
            fmt::format(
                "Tried to instantiate class '{}.{}', but it does not exist! "
                "Ensure that it is registered via torch::class_",
                ns,
                qualname));
        c10::ClassTypePtr class_type = named_type->cast<ClassType>();
        return ScriptClass(c10::StrongTypePtr(
            std::shared_ptr<CompilationUnit>(), std::move(class_type)));
      });
}

} // namespace jit
} // namespace torch