1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
|
#pragma once
#include <ATen/core/functional.h>
#include <ATen/core/ivalue.h>
#include <c10/util/Optional.h>
#include <torch/csrc/jit/api/method.h>
namespace torch {
namespace jit {
struct Resolver;
using ResolverPtr = std::shared_ptr<Resolver>;
using ObjectPtr = c10::intrusive_ptr<c10::ivalue::Object>;
// Throw this in C++ land if `attr` fails. This will be converted to a Python
// AttributeError by the Python binding code
class ObjectAttributeError : public std::runtime_error {
public:
ObjectAttributeError(const std::string& what) : std::runtime_error(what) {}
};
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
struct TORCH_API Object {
Object() = default;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
Object(ObjectPtr _ivalue) : _ivalue_(std::move(_ivalue)) {}
Object(std::shared_ptr<CompilationUnit> cu, const c10::ClassTypePtr& type);
Object(
c10::QualifiedName,
std::shared_ptr<CompilationUnit> cu,
bool shouldMangle = false);
ObjectPtr _ivalue() const {
TORCH_INTERNAL_ASSERT(_ivalue_);
return _ivalue_;
}
c10::ClassTypePtr type() const {
return _ivalue()->type();
}
struct Property {
std::string name;
Method getter_func;
c10::optional<Method> setter_func;
};
void setattr(const std::string& name, c10::IValue v) {
if (_ivalue()->type()->hasConstant(name)) {
TORCH_CHECK(
false,
"Can't set constant '",
name,
"' which has value:",
_ivalue()->type()->getConstant(name));
} else if (auto slot = _ivalue()->type()->findAttributeSlot(name)) {
const c10::TypePtr& expected = _ivalue()->type()->getAttribute(*slot);
TORCH_CHECK(
v.type()->isSubtypeOf(*expected),
"Expected a value of type '",
expected->repr_str(),
"' for field '",
name,
"', but found '",
v.type()->repr_str(),
"'");
_ivalue()->setSlot(*slot, std::move(v));
} else {
TORCH_CHECK(false, "Module has no attribute '", name, "'");
}
}
c10::IValue attr(const std::string& name) const {
if (auto r = _ivalue()->type()->findAttributeSlot(name)) {
return _ivalue()->getSlot(*r);
}
if (auto r = _ivalue()->type()->findConstantSlot(name)) {
return _ivalue()->type()->getConstant(*r);
}
std::stringstream err;
err << _ivalue()->type()->repr_str() << " does not have a field with name '"
<< name.c_str() << "'";
throw ObjectAttributeError(err.str());
}
c10::IValue attr(const std::string& name, c10::IValue or_else) const {
if (auto r = _ivalue()->type()->findAttributeSlot(name)) {
return _ivalue()->getSlot(*r);
}
if (auto r = _ivalue()->type()->findConstantSlot(name)) {
return _ivalue()->type()->getConstant(*r);
}
return or_else;
}
bool hasattr(const std::string& name) const {
return _ivalue()->type()->hasAttribute(name) ||
_ivalue()->type()->hasConstant(name);
}
// each object owns its methods. The reference returned here
// is guaranteed to stay valid until this module has been destroyed
Method get_method(const std::string& name) const {
if (auto method = find_method(name)) {
return *method;
}
AT_ERROR("Method '", name, "' is not defined.");
}
const std::vector<Method> get_methods() const {
return c10::fmap(type()->methods(), [&](Function* func) {
return Method(_ivalue(), func);
});
}
bool has_property(const std::string& name) const {
for (const auto& prop : type()->properties()) {
if (prop.name == name) {
return true;
}
}
return false;
}
const Property get_property(const std::string& name) const {
for (const auto& prop : type()->properties()) {
if (prop.name == name) {
c10::optional<Method> setter = c10::nullopt;
if (prop.setter) {
setter = Method(_ivalue(), prop.setter);
}
return Property{prop.name, Method(_ivalue(), prop.getter), setter};
}
}
AT_ERROR("Property '", name, "' is not defined.");
}
const std::vector<Property> get_properties() const {
return c10::fmap(type()->properties(), [&](ClassType::Property prop) {
c10::optional<Method> setter = c10::nullopt;
if (prop.setter) {
setter = Method(_ivalue(), prop.setter);
}
return Property{prop.name, Method(_ivalue(), prop.getter), setter};
});
}
c10::optional<Method> find_method(const std::string& basename) const;
/// Run a method from this module.
///
/// For example:
/// @code
/// IValue output = module->run("relu_script", a, b);
/// @endcode
///
/// To get a compile a module from a source string, see torch::jit::compile
///
/// @param method_name The name of the method to run
/// @param args Arguments to be passed to the method
/// @return An IValue containing the return value (or values if it is a tuple)
/// from the method
template <typename... Types>
IValue run_method(const std::string& method_name, Types&&... args) {
return get_method(method_name)({IValue(std::forward<Types>(args))...});
}
// so that C++ users can easily add methods
void define(const std::string& src, const ResolverPtr& resolver = nullptr);
size_t num_slots() const {
return _ivalue()->slots().size();
}
// shallow copy the object
Object copy() const;
// Copies all the attributes of the object recursively without creating new
// `ClassType`, including deepcopy of Tensors
Object deepcopy() const;
private:
// mutable be we lazily initialize in module_object.
mutable ObjectPtr _ivalue_;
};
namespace script {
// We once had a `script::` namespace that was deleted. This is for backcompat
// of the public API; new code should not use this type alias.
using Object = ::torch::jit::Object;
} // namespace script
} // namespace jit
} // namespace torch
|