1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
|
#include <torch/csrc/jit/passes/lift_closures.h>
#include <torch/csrc/jit/frontend/ir_emitter.h>
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
// Closures are initially emitted as prim::Closure nodes with a single block.
// Here, we convert the block to a subgraph, adding all closed over variables
// as a context tuple input to the closure node.
// At this point the closure has already undergone conversion to SSA,
// so closed over variables will just be value * that are not set in the
// closure block.
// Within the closure subgraph, the context tuple is unpacked and the unpacked
// values are used for closed over values.
void liftClosure(Node* closure) {
auto block = closure->blocks().at(0);
auto subgraph = std::make_shared<Graph>();
// closures/forks can be nested, so use closure owning graph
auto g = closure->owningGraph();
Node* pack_context =
g->create(prim::TupleConstruct, {}, 1)->insertAfter(closure);
Value* context = subgraph->addInput("context");
// cannot use createTupleUnpack because the type is not known yet
Node* unpack_context =
subgraph->insertNode(subgraph->create(prim::TupleUnpack, {context}, 0));
std::unordered_map<Value*, Value*> captures;
auto env = [&](Value* v) -> Value* {
auto it = captures.find(v);
if (it != captures.end()) {
return it->second;
}
pack_context->addInput(v);
Value* r = unpack_context->addOutput()->copyMetadata(v);
captures[v] = r;
return r;
};
subgraph->block()->cloneFrom(block, env);
auto context_type = TupleType::create(
fmap(pack_context->inputs(), [](Value* v) { return v->type(); }));
context->setType(context_type);
pack_context->output()->setType(context_type);
auto closure_tuple =
g->create(prim::TupleConstruct, {}, 1)->insertAfter(pack_context);
closure->output()->replaceAllUsesWith(closure_tuple->output());
closure_tuple->addInput(closure->output());
closure_tuple->addInput(pack_context->output());
closure_tuple->output()->setType(
TupleType::create({closure->output()->type(), context_type}));
closure->eraseBlock(0);
closure->g_(attr::Subgraph, std::move(subgraph));
runCleanupPasses(closure->g(attr::Subgraph));
}
void liftClosures(Block* block) {
for (auto it = block->nodes().begin(); it != block->nodes().end();) {
Node* n = *it;
it++;
switch (n->kind()) {
case prim::Closure: {
liftClosure(n);
} break;
default: {
for (Block* b : n->blocks()) {
liftClosures(b);
}
}
}
}
}
void liftClosures(const std::shared_ptr<Graph>& to_clean) {
liftClosures(to_clean->block());
}
} // namespace jit
} // namespace torch
|