1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
|
#include <torch/csrc/jit/backends/backend_init.h>
#include <pybind11/iostream.h>
#include <torch/csrc/jit/backends/backend_detail.h>
#include <torch/csrc/jit/backends/backend_resolver.h>
#include <torch/csrc/jit/python/module_python.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/utils/pybind.h>
namespace torch {
namespace jit {
// Get all types that are shared in the module hierarchy rooted at \p mod.
std::unordered_set<TypePtr> getSharedModuleTypes(Module& mod) {
// Maintain a set of all TypePtrs.
std::unordered_set<TypePtr> types;
// Maintain another set of TypePtrs that have been encountered more than once.
std::unordered_set<TypePtr> duplicate_types;
// Iterate over all modules in the hierarchy, including the root.
for (auto module : mod.modules()) {
auto module_type = module.type();
if (types.count(module_type) > 0) {
duplicate_types.insert(module_type);
}
types.insert(module_type);
}
return duplicate_types;
}
// Selectively lower \p mod to a backend. \p to_backend
// is called to lower modules. \p modules_to_lower contains
// qualified names of submodules of \p mod that should be lowered.
void toBackendSelectiveImpl(
Module& mod,
const py::function& to_backend,
const std::vector<std::string>& modules_to_lower,
const std::unordered_set<TypePtr>& duplicate_types) {
// This map will be used later to remap types in ancestor module graphs for
// all lowered submodules.
std::unordered_map<TypePtr, TypePtr> type_remap;
// For each module that should be lowered:
for (const auto& module_to_lower : modules_to_lower) {
// Use QualifiedName to parse the qualified module names.
c10::QualifiedName qual_module_name(module_to_lower);
auto& atoms = qual_module_name.atoms();
// Search through the module hierarchy using the atoms of
// qual_module_name until current points to the module to
// be lowered and parent points to its parent.
Module current = mod;
Module parent;
for (size_t i = 0, e = atoms.size(); i < e; ++i) {
IValue submodule = current.attr(atoms[i]);
if (submodule.isModule()) {
if (i == e - 1) {
parent = current;
}
current = submodule.toModule();
} else {
std::stringstream err;
err << "Attribute named " << atoms[i] << " is not a Module";
throw std::runtime_error(err.str());
}
}
// Check that the parent type is not shared and therefore can be edited.
if (duplicate_types.count(parent.type()) > 0) {
throw py::cast_error(c10::str(
"Selective lowering is only supported for module hierarchies with unique types for selected modules; ",
parent.type()->repr_str(),
" is shared"));
}
// Call to_backend on the module that needs to be lowered. It needs to be
// wrapped before doing so because _to_jit_backend accepts wrapped modules.
// The result needs to be unwrapped in order to access its type below.
auto lowered_submodule =
py::cast<Module>(to_backend(py::module::import("torch.jit._recursive")
.attr("wrap_cpp_module")(current))
.attr("_c"));
// Adjust the parent's type so that the type of the submodule matches
// the type of lowered_submodule.
auto parent_type = parent.type();
parent_type->unsafeChangeAttributeType(
atoms.back(), lowered_submodule.type());
parent.setattr(atoms.back(), lowered_submodule._ivalue());
// Record the type mapping from old type -> lowered type.
type_remap[current.type()] = lowered_submodule.type();
}
// Having lowered all of the modules that needed to be lowered, remap types in
// all graphs in the hierarchy so that the graphs all use the new lowered
// type.
auto type_remap_fn = [&type_remap](TypePtr in) {
auto it = type_remap.find(in);
if (it == type_remap.end())
return in;
return it->second;
};
// modules() iterates over all modules in the hierarchy including the root.
for (auto module : mod.modules()) {
auto module_type = module.type();
for (auto& fn : module_type->methods()) {
auto method = module.get_method(fn->name());
auto graph = method.graph();
graph->remapTypes(type_remap_fn);
auto new_schema = fn->getSchema().cloneWithRemappedTypes(type_remap_fn);
fn->setSchema(new_schema);
}
}
}
Module codegen_func(
const std::string& backend_name,
const Module& orig_module,
const py::dict& method_compile_spec) {
// Represents of a Type of Dict[str, Any].
auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
return detail::codegen_backend_module(
backend_name,
orig_module,
toIValue(method_compile_spec, any_dict_ty).toGenericDict(),
any_dict_ty);
}
void initJitBackendBindings(PyObject* module) {
// Bind a function for lowering to each JIT backend. The name of the backend
// must be the first argument. For example, to lower a Module to
// "example_backend", declared as
//
// static auto cls = torch::jit::backend<ExampleBackend>("example_backend");
//
// this function must be called like
//
// torch._C._jit_to_backend("example_backend", module, spec)
auto m = py::handle(module).cast<py::module>();
m.def(
"_jit_to_backend",
[=](const std::string& backend_name,
py::handle orig_module,
const py::dict& method_compile_spec) {
py::scoped_ostream_redirect cerr(
std::cerr, py::module_::import("sys").attr("stderr"));
py::scoped_ostream_redirect cout(
std::cout, py::module_::import("sys").attr("stdout"));
return py::module::import("torch.jit._recursive")
.attr("wrap_cpp_module")(codegen_func(
backend_name,
py::cast<Module>(orig_module.attr("_c")),
method_compile_spec));
});
m.def(
"_jit_to_backend_selective",
[=](py::handle orig_module,
const py::function& to_backend,
const std::vector<std::string>& modules_to_lower) {
py::scoped_ostream_redirect cerr(
std::cerr, py::module_::import("sys").attr("stderr"));
py::scoped_ostream_redirect cout(
std::cout, py::module_::import("sys").attr("stdout"));
if (auto original_module =
as_module(py::cast<py::object>(orig_module))) {
// Clone the Module to avoid editing types that are shared with
// Modules in other instances outside this hierarchy.
Module& mod = original_module.value();
auto cloned_mod = mod.clone();
// Get all shared module types. Type sharing is only a problem if the
// parent modules of the ones to lower are in this set.
auto shared_types = getSharedModuleTypes(cloned_mod);
toBackendSelectiveImpl(
cloned_mod, to_backend, modules_to_lower, shared_types);
// Wrap the result in a RecursiveScriptModule because that's what
// the caller passed in.
return py::module::import("torch.jit._recursive")
.attr("wrap_cpp_module")(cloned_mod);
}
throw py::cast_error(c10::str(
"Object ", py::str(orig_module), " is not a ScriptModule"));
});
}
} // namespace jit
} // namespace torch
|