1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376
|
#pragma once
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/frontend/concrete_module_type.h>
#include <torch/csrc/jit/frontend/sugared_value.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
namespace torch {
namespace jit {
std::string typeString(py::handle h);
inline std::shared_ptr<SugaredValue> toSimple(Value* v) {
return std::make_shared<SimpleValue>(v);
}
// NB: This should be the single entry-point for instantiating a SugaredValue
// from a Python object. If you are adding support for converting a new Python
// type, *add it in this function's implementation*.
std::shared_ptr<SugaredValue> toSugaredValue(
py::object obj,
GraphFunction& m,
const SourceRange& loc,
bool is_constant = false);
c10::optional<StrongFunctionPtr> as_function(const py::object& obj);
struct VISIBILITY_HIDDEN PythonValue : public SugaredValue {
PythonValue(
py::object the_self,
c10::optional<py::object> rcb = c10::nullopt,
Value* module_self = nullptr)
: self(std::move(the_self)),
rcb(std::move(rcb)),
moduleSelf_(module_self) {}
FunctionSchema getSchema(
const size_t n_args,
const size_t n_binders,
const SourceRange& loc);
// call it like a function, e.g. `outputs = this(inputs)`
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
GraphFunction& m,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) override;
std::string kind() const override;
std::vector<std::shared_ptr<SugaredValue>> asTuple(
const SourceRange& loc,
GraphFunction& m,
const c10::optional<size_t>& size_hint = {}) override;
std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
GraphFunction& m,
const std::string& field) override;
Value* asValue(const SourceRange& loc, GraphFunction& m) override {
throw ErrorReport(loc)
<< kind() << " cannot be used as a value. "
<< "Perhaps it is a closed over global variable? If so, please "
<< "consider passing it in as an argument or use a local varible "
<< "instead.";
}
protected:
py::object getattr(const SourceRange& loc, const std::string& name);
void checkForAddToConstantsError(std::stringstream& ss);
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
py::object self;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
c10::optional<py::object> rcb;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
Value* moduleSelf_ = nullptr;
};
struct VISIBILITY_HIDDEN PythonModuleValue : public PythonValue {
explicit PythonModuleValue(py::object mod) : PythonValue(std::move(mod)) {}
std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
GraphFunction& m,
const std::string& field) override;
};
// Used for desugaring uses of the torch.cuda module. All the CUDA APIs with
// torch.cuda.* are resolved using CUDAPythonModuleValue.
struct VISIBILITY_HIDDEN CUDAPythonModuleValue : public PythonValue {
explicit CUDAPythonModuleValue(py::object mod)
: PythonValue(std::move(mod)) {}
std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
GraphFunction& m,
const std::string& field) override;
};
// Represents all the parameters of a module as a List[Tensor]
struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue {
ConstantParameterList(Value* the_list) : the_list_(the_list) {}
std::string kind() const override {
return "constant parameter list";
}
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
GraphFunction& caller,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) override {
return toSimple(the_list_);
}
private:
Value* the_list_;
};
struct VISIBILITY_HIDDEN ModuleDictMethod : public SugaredValue {
explicit ModuleDictMethod(SugaredValuePtr iterable, std::string name)
: iterable_(std::move(iterable)), name_(std::move(name)){};
std::string kind() const override {
return name_;
}
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
GraphFunction& f,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) override {
if (args.size() || kwargs.size()) {
throw ErrorReport(loc)
<< name_ << " method does not accept any arguments";
}
return iterable_;
}
SugaredValuePtr iterable_;
const std::string name_;
};
struct SugaredDict;
// defines how modules/methods behave inside the script subset.
// for now this does not have any interaction with python.
// in the future, we will add the ability to resolve `self.foo` to python
// {functions, modules, constants} so this SugaredValue is defined here
// anticipating we will eventually need to replace Module with a py::object
// holding the actual nn.Module class.
struct VISIBILITY_HIDDEN ModuleValue : public SugaredValue {
ModuleValue(Value* self, std::shared_ptr<ConcreteModuleType> concreteType)
: self_(self), concreteType_(std::move(concreteType)) {}
std::string kind() const override {
return "module";
}
Value* asValue(const SourceRange& loc, GraphFunction& m) override;
SugaredValuePtr asTupleValue(const SourceRange& loc, GraphFunction& m)
override;
// select an attribute on it, e.g. `this.field`
std::shared_ptr<SugaredValue> tryGetAttr(
const SourceRange& loc,
GraphFunction& m,
const std::string& field);
// select an attribute on it, e.g. `this.field`
std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
GraphFunction& m,
const std::string& field) override;
// select an attribute on it, e.g. `this.field`
bool hasAttr(
const SourceRange& loc,
GraphFunction& m,
const std::string& field) override;
// call module.forward with pre_hooks and hooks
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
GraphFunction& caller,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) override;
std::shared_ptr<SugaredDict> getSugaredDict(
const SourceRange& loc,
GraphFunction& m);
std::shared_ptr<SugaredDict> getSugaredNamedBufferDict(
const SourceRange& loc,
GraphFunction& m);
std::shared_ptr<SugaredDict> getSugaredNamedParameterList(
const SourceRange& loc,
GraphFunction& m);
std::shared_ptr<SugaredDict> getSugaredNamedParameterDict(
const SourceRange& loc,
GraphFunction& m);
void setAttr(
const SourceRange& loc,
GraphFunction& m,
const std::string& field,
Value* newValue) override;
SugaredValuePtr iter(const SourceRange& loc, GraphFunction& m) override;
std::shared_ptr<SugaredValue> getitem(
const SourceRange& loc,
GraphFunction& m,
Value* idx,
TypePtr type_hint) override;
private:
// Check that the type of all submodules is a subtype of ty. If the function
// returns false, more information about why it returns false (e.g. which
// submodule's type is not a subtype of ty) is printed it why_not if it is not
// null.
bool areAllSubmodulesSubtypeOf(
const TypePtr& ty,
std::ostream* why_not = nullptr) const;
Value* self_;
std::shared_ptr<ConcreteModuleType> concreteType_;
};
bool isNamedTupleClass(const py::object& obj);
TypePtr registerNamedTuple(const py::object& obj, const SourceRange& loc);
void recurseThroughNestedModules(
const SourceRange& loc,
GraphFunction& m,
std::vector<SugaredValuePtr>& keys,
std::vector<SugaredValuePtr>& values,
std::shared_ptr<ModuleValue>& self,
const std::string& prefix,
const std::string& field);
// Used to support named_modules()
struct VISIBILITY_HIDDEN SugaredDict : public SugaredValue {
explicit SugaredDict(
std::shared_ptr<ModuleValue> self,
std::shared_ptr<SugaredTupleValue> keys,
std::shared_ptr<SugaredTupleValue> modules) {
self_ = std::move(self);
keys_ = std::move(keys);
modules_ = std::move(modules);
}
std::string kind() const override {
return "ModuleDict";
}
std::shared_ptr<SugaredTupleValue> getKeys() {
return keys_;
}
std::shared_ptr<SugaredTupleValue> getModules() {
return modules_;
}
std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
GraphFunction& m,
const std::string& field) override;
SugaredValuePtr iter(const SourceRange& loc, GraphFunction& m) override {
return keys_;
};
std::shared_ptr<ModuleValue> self_;
std::shared_ptr<SugaredTupleValue> keys_;
std::shared_ptr<SugaredTupleValue> modules_;
};
struct VISIBILITY_HIDDEN BooleanDispatchValue : public SugaredValue {
BooleanDispatchValue(py::dict dispatched_fn)
: dispatched_fn_(std::move(dispatched_fn)) {}
std::string kind() const override {
return "boolean dispatch";
}
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
GraphFunction& caller,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) override;
private:
py::dict dispatched_fn_;
};
struct VISIBILITY_HIDDEN PythonClassValue : public ClassValue {
PythonClassValue(ClassTypePtr type, py::object py_type)
: ClassValue(std::move(type)), py_type_(std::move(py_type)) {}
std::string kind() const override {
return "Python type";
}
std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
GraphFunction& m,
const std::string& field) override;
bool hasAttr(
const SourceRange& loc,
GraphFunction& m,
const std::string& field) override;
private:
py::object py_type_;
};
struct VISIBILITY_HIDDEN PythonExceptionValue : public ExceptionValue {
explicit PythonExceptionValue(const py::object& exception_class)
: ExceptionValue(
py::str(py::getattr(exception_class, "__name__", py::str("")))),
exception_class_qualified_name_(
py::str(py::module::import("torch._jit_internal")
.attr("_qualified_name")(
exception_class,
/*mangle_name=*/false))) {}
std::string kind() const override {
return "Python exception";
}
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
GraphFunction& caller,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) override;
private:
std::string exception_class_qualified_name_;
};
// Python Slice class.
struct VISIBILITY_HIDDEN PythonSliceClass : public SugaredValue {
explicit PythonSliceClass() = default;
std::string kind() const override {
return "Python slice class";
}
std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
GraphFunction& caller,
at::ArrayRef<NamedValue> args,
at::ArrayRef<NamedValue> kwargs,
size_t n_binders) override;
};
} // namespace jit
} // namespace torch
|