1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
|
#include <torch/csrc/jit/passes/lower_graph.h>
#include <torch/csrc/jit/api/object.h>
#include <torch/csrc/jit/frontend/error_report.h>
#include <torch/csrc/jit/passes/inliner.h>
#include <torch/custom_class.h>
#include <unordered_map>
namespace torch {
namespace jit {
struct Slot {
c10::intrusive_ptr<c10::ivalue::Object> obj;
size_t offset;
bool operator==(const Slot& other) const {
return (this->obj == other.obj && this->offset == other.offset);
}
};
// remove the first module argument, replacing any access of its
// parameters/attributes with extra_ivalue input Slots that hold what value to
// pass into the graph. Used for ONNX export to remove first-class modules
// so it can deal purely with parameters and inputs
std::pair<std::shared_ptr<Graph>, std::vector<Slot>> lower_graph(
const ModulePtr& self,
Graph& g_,
size_t self_offset = 0) {
std::shared_ptr<Graph> g = g_.copy();
// Inline to remove method/function calls
Inline(*g);
std::vector<Slot> extra_ivalues;
struct SlotHash {
std::size_t operator()(const Slot& slot) const {
auto obj_hash = std::hash<c10::ivalue::Object*>{}(slot.obj.get());
auto offset_hash = std::hash<size_t>{}(slot.offset);
return c10::hash_combine(obj_hash, offset_hash);
}
};
std::unordered_map<Slot, size_t, SlotHash> slot_to_offset;
struct ToScan {
ModulePtr mod;
Node* n;
size_t offset;
};
std::vector<ToScan> to_scan;
std::vector<Node*> to_clean; // nodes that should be dead at the end
auto getOrAddSlot = [&](const Slot& slot) -> Value* {
auto it = slot_to_offset.find(slot);
if (it != slot_to_offset.end()) {
size_t ivalues_start = g->inputs().size() - extra_ivalues.size();
return g->inputs().at(ivalues_start + it->second);
}
extra_ivalues.emplace_back(slot);
slot_to_offset[slot] = extra_ivalues.size() - 1;
return g->addInput()->setType(slot.obj->getSlot(slot.offset).type());
};
auto self_value = g->inputs().at(self_offset);
for (Use use : self_value->uses()) {
to_scan.emplace_back(ToScan{self, use.user, use.offset});
}
while (to_scan.size() > 0) {
auto e = to_scan.back();
to_scan.pop_back();
// when we lambda lift forks, first-class modules may be passed across
// forks. This code recursively lowers the module in the fork call.
if (e.n->kind() == prim::fork) {
auto subgraph = e.n->g(attr::Subgraph);
std::vector<Slot> new_slots;
std::tie(subgraph, new_slots) = lower_graph(e.mod, *subgraph, e.offset);
e.n->g_(attr::Subgraph, subgraph);
for (const Slot& slot : new_slots) {
e.n->addInput(getOrAddSlot(slot));
}
e.n->removeInput(e.offset);
continue;
}
if (e.n->kind() == prim::PythonOp) {
throw ErrorReport(e.n->sourceRange()) << "Couldn't export Python method.";
}
if (e.n->kind() != prim::GetAttr) {
throw ErrorReport(e.n->sourceRange())
<< "temporary: the only valid use of a module is looking up an "
"attribute but found "
<< *e.n;
}
size_t slot_idx = e.mod->type()->getAttributeSlot(e.n->s(attr::name));
auto iv = e.mod->getSlot(slot_idx);
if (ClassTypePtr c = e.n->output()->type()->cast<ClassType>()) {
if (c->is_module()) {
for (Use use : e.n->output()->uses()) {
to_scan.emplace_back(ToScan{iv.toObject(), use.user, use.offset});
}
to_clean.emplace_back(e.n);
continue;
}
}
e.n->output()->replaceAllUsesWith(getOrAddSlot({e.mod, slot_idx}));
e.n->destroy();
}
while (to_clean.size() > 0) {
Node* n = to_clean.back();
AT_ASSERT(!n->hasUses());
n->destroy();
to_clean.pop_back();
}
AT_ASSERT(!self_value->hasUses());
g->eraseInput(self_offset);
return std::make_pair(std::move(g), std::move(extra_ivalues));
}
static std::vector<IValue> loadTensors(const std::vector<Slot>& slots) {
std::vector<IValue> result;
result.reserve(slots.size());
for (const Slot& slot : slots) {
auto obj = slot.obj->getSlot(slot.offset);
if (obj.isTensor()) {
result.emplace_back(obj.toTensor());
} else {
// Unpack quantization packed tensor
auto type = obj.type();
TORCH_CHECK(
(type ==
getCustomClass(
"__torch__.torch.classes.quantized.Conv2dPackedParamsBase")) ||
(type ==
getCustomClass(
"__torch__.torch.classes.quantized.Conv3dPackedParamsBase")) ||
(type ==
getCustomClass(
"__torch__.torch.classes.quantized.LinearPackedParamsBase")),
"Unknown type ",
type->repr_str(),
" encountered in graph lowering. This type is not supported in ONNX export.");
result.emplace_back(
script::Object(obj.toObject()).run_method("__getstate__"));
}
}
return result;
}
std::pair<std::shared_ptr<Graph>, std::vector<IValue>> LowerGraph(
Graph& graph,
const ModulePtr& self) {
auto result = lower_graph(self, graph);
return std::make_pair(result.first, loadTensors(result.second));
}
} // namespace jit
} // namespace torch
|