1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
|
#include <torch/csrc/jit/python/python_custom_class.h>
#include <torch/csrc/jit/frontend/sugared_value.h>
#include <fmt/format.h>
namespace torch {
namespace jit {
struct CustomMethodProxy;
struct CustomObjectProxy;
py::object ScriptClass::__call__(py::args args, py::kwargs kwargs) {
auto instance =
Object(at::ivalue::Object::create(class_type_, /*numSlots=*/1));
Function* init_fn = instance.type()->findMethod("__init__");
TORCH_CHECK(
init_fn,
fmt::format(
"Custom C++ class: '{}' does not have an '__init__' method bound. "
"Did you forget to add '.def(torch::init<...>)' to its registration?",
instance.type()->repr_str()));
Method init_method(instance._ivalue(), init_fn);
// NOLINTNEXTLINE(performance-move-const-arg)
invokeScriptMethodFromPython(init_method, std::move(args), std::move(kwargs));
return py::cast(instance);
}
/// Variant of StrongFunctionPtr, but for static methods of custom classes.
/// They do not belong to compilation units (the custom class method registry
/// serves that purpose in this case), so StrongFunctionPtr cannot be used here.
/// While it is usually unsafe to carry a raw pointer like this, the custom
/// class method registry that owns the pointer is never destroyed.
struct ScriptClassFunctionPtr {
ScriptClassFunctionPtr(Function* function) : function_(function) {
TORCH_INTERNAL_ASSERT(function_);
}
Function* function_;
};
void initPythonCustomClassBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
py::class_<ScriptClassFunctionPtr>(
m, "ScriptClassFunction", py::dynamic_attr())
.def("__call__", [](py::args args, const py::kwargs& kwargs) {
auto strongPtr = py::cast<ScriptClassFunctionPtr>(args[0]);
Function& callee = *strongPtr.function_;
py::object result = invokeScriptFunctionFromPython(
callee, tuple_slice(std::move(args), 1), kwargs);
return result;
});
py::class_<ScriptClass>(m, "ScriptClass")
.def("__call__", &ScriptClass::__call__)
.def(
"__getattr__",
[](ScriptClass& self, const std::string& name) {
// Define __getattr__ so that static functions of custom classes can
// be used in regular Python.
auto type = self.class_type_.type_->castRaw<ClassType>();
TORCH_INTERNAL_ASSERT(type);
auto* fn = type->findStaticMethod(name);
if (fn) {
return ScriptClassFunctionPtr(fn);
}
throw AttributeError("%s does not exist", name.c_str());
})
.def_property_readonly("__doc__", [](const ScriptClass& self) {
return self.class_type_.type_->expectRef<ClassType>().doc_string();
});
// This function returns a ScriptClass that wraps the constructor
// of the given class, specified by the qualified name passed in.
//
// This is to emulate the behavior in python where instantiation
// of a class is a call to a code object for the class, where that
// code object in turn calls __init__. Rather than calling __init__
// directly, we need a wrapper that at least returns the instance
// rather than the None return value from __init__
m.def(
"_get_custom_class_python_wrapper",
[](const std::string& ns, const std::string& qualname) {
std::string full_qualname =
"__torch__.torch.classes." + ns + "." + qualname;
auto named_type = getCustomClass(full_qualname);
TORCH_CHECK(
named_type,
fmt::format(
"Tried to instantiate class '{}.{}', but it does not exist! "
"Ensure that it is registered via torch::class_",
ns,
qualname));
c10::ClassTypePtr class_type = named_type->cast<ClassType>();
return ScriptClass(c10::StrongTypePtr(
std::shared_ptr<CompilationUnit>(), std::move(class_type)));
});
}
} // namespace jit
} // namespace torch
|