1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349
|
#include <torch/csrc/jit/frontend/convert_to_ssa.h>
#include <torch/csrc/jit/frontend/exit_transforms.h>
#include <torch/csrc/jit/frontend/inline_loop_condition.h>
#include <torch/csrc/jit/frontend/ir_emitter.h>
#include <torch/csrc/jit/frontend/mini_environment.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/ir_views.h>
namespace torch {
namespace jit {
// At the beginning of the pass the Graph has already undergone type checking,
// and writes or reads to a variable are emitted as Loads and Stores in the
// graph.
// a = 1
// print(a)
// is represented as:
// %a.1 : int = prim::Constant[value=1]()
// prim::Store[name="a"](%a.1)
// %a : int = prim::Load[name="a"]()
// prim::Print(%a)
//
// First, this pass recursively adds the Loads & Stores to control flow nodes
// Then the graph is converted to SSA form.
using ValueEnvironment = MiniEnvironment<Value*>;
using TypeEnvironment = MiniEnvironment<TypePtr>;
// Adds Loads & Stores to Loops & Ifs
struct ControlFlowLoadStores {
static void addBlockInput(
Block* b,
const TypePtr& type,
const std::string& name) {
auto g = b->owningGraph();
g->createStore(name, b->addInput(name)->setType(type))
->insertAfter(b->param_node());
}
static void addBlockOutput(
Block* exit_block,
const TypePtr& type,
const std::string& name) {
WithInsertPoint insert(exit_block);
auto g = exit_block->owningGraph();
auto block_exit = g->insertNode(g->createLoad(name, type))->output();
exit_block->registerOutput(block_exit);
}
static void addNodeOutput(
Node* n,
const TypePtr& type,
const std::string& name) {
auto out = n->addOutput()->setType(type);
if (meaningfulName(name)) {
out->setDebugName(name);
}
auto g = n->owningGraph();
g->createStore(name, out)->insertAfter(n);
}
static void addNodeInput(
Node* n,
const TypePtr& type,
const std::string& name) {
auto g = n->owningGraph();
auto inp = g->createLoad(name, type)->insertBefore(n)->output();
n->addInput(inp);
}
void addIfLoadStores(Node* n) {
auto true_block = n->blocks().at(0);
auto false_block = n->blocks().at(1);
auto true_vars = addControlFlowLoadStores(true_block);
auto false_vars = addControlFlowLoadStores(false_block);
std::set<std::string> mutated_variables;
for (auto& v : true_vars->definedVariables()) {
if (false_vars->findInAnyFrame(v)) {
mutated_variables.insert(v);
}
}
for (auto& v : false_vars->definedVariables()) {
if (true_vars->findInAnyFrame(v)) {
mutated_variables.insert(v);
}
}
// Following the same logic as emitIfElseBlocks in ir_emitter.cpp,
// we emit a node output if the variable is defined in each block
// and the types of each block can be unified
for (const auto& x : mutated_variables) {
auto true_type = true_vars->findInAnyFrame(x);
auto false_type = false_vars->findInAnyFrame(x);
auto unified =
unifyTypes(true_type, false_type, /*default_to_union=*/true);
addBlockOutput(true_block, true_type, x);
addBlockOutput(false_block, false_type, x);
addNodeOutput(n, *unified, x);
}
}
// loop_carried_outputs* = Loop(max_trip_count, start_condition,
// loop_carried_inputs*)
// block0(loop_counter, loop_carried_block*) {
// <body>
// -> (continue_condition, loop_carried_block_outputs*)
// }
// all loop_carried_... lists are the same length and represent the value of
// loop-carried variables whose definitions are updated as the loop executes
// in a way that ensure single static assignment.
void addLoopLoadStores(Node* n) {
auto body_block = n->blocks().at(0);
auto loop_vars = addControlFlowLoadStores(body_block);
for (const auto& name : loop_vars->definedVariables()) {
// if the variable local to the loop body, then
// we do not need a loop carried variable for it
auto parent_type = environment_stack->findInAnyFrame(name);
if (!parent_type) {
continue;
}
// since the loop may execute 0 or many times, the output types
// of the loop and the input loop carried dependencies are conservatively
// the union of the output of the body and the input to the loop
auto block_type = loop_vars->findInThisFrame(name);
auto unified_type = unifyTypes(parent_type, block_type).value();
// Insert a store at the beginning of the loop block, so that all
// loads of the variable will use the loop carried value
addNodeInput(n, parent_type, name);
addBlockInput(body_block, unified_type, name);
addBlockOutput(body_block, block_type, name);
addNodeOutput(n, unified_type, name);
}
}
std::shared_ptr<TypeEnvironment> addControlFlowLoadStores(Block* block) {
pushFrame(block);
for (Node* n : block->nodes()) {
switch (n->kind()) {
case prim::If: {
addIfLoadStores(n);
} break;
case prim::Loop: {
addLoopLoadStores(n);
} break;
case prim::Closure: {
for (auto b : n->blocks()) {
addControlFlowLoadStores(b);
}
} break;
case prim::Store: {
environment_stack->setVar(n->s(attr::name), n->input()->type());
} break;
case prim::ComprehensionScope: {
addControlFlowLoadStores(n->blocks().at(0));
} break;
}
}
return popFrame();
}
void pushFrame(Block* b) {
environment_stack = std::make_shared<TypeEnvironment>(b, environment_stack);
}
std::shared_ptr<TypeEnvironment> popFrame() {
auto old_frame = environment_stack;
environment_stack = environment_stack->next;
return old_frame;
}
void run(std::shared_ptr<Graph>& graph) {
addControlFlowLoadStores(graph->block());
}
std::shared_ptr<TypeEnvironment> environment_stack = nullptr;
};
// Given a graph where 1) outputs have been added to control flow nodes and
// 2) loads and stores are represented in the graph, erase the Loads & Stores.
struct EraseLoadStores {
void eraseBlockLoadStores(Block* block) {
pushFrame(block);
for (auto it = block->nodes().begin(); it != block->nodes().end();) {
auto n = *it;
it++;
switch (n->kind()) {
case prim::Store: {
environment_stack->setVar(n->s(attr::name), n->input());
n->destroy();
} break;
case prim::Load: {
auto name = n->s(attr::name);
auto var = environment_stack->findInAnyFrame(name);
TORCH_INTERNAL_ASSERT(
var, "Typechecking should ensure the variable name is set");
n->output()->replaceAllUsesWith(var);
n->destroy();
} break;
case prim::ComprehensionScope: {
// writes within a local variable scope do not leak into
// the rest of the graph
auto body = n->blocks().at(0);
eraseBlockLoadStores(body);
// inline the local variable scope into the graph
for (auto it_cmpr = body->nodes().begin();
it_cmpr != body->nodes().end();) {
Node* body_node = *it_cmpr;
it_cmpr++;
body_node->moveBefore(n);
}
n->destroy();
} break;
default: {
for (auto b : n->blocks()) {
eraseBlockLoadStores(b);
}
} break;
}
}
popFrame();
}
void pushFrame(Block* b) {
environment_stack =
std::make_shared<ValueEnvironment>(b, environment_stack);
}
std::shared_ptr<ValueEnvironment> popFrame() {
auto old_frame = environment_stack;
environment_stack = environment_stack->next;
return old_frame;
}
void run(std::shared_ptr<Graph>& graph) {
eraseBlockLoadStores(graph->block());
}
std::shared_ptr<ValueEnvironment> environment_stack = nullptr;
};
// This pass transforms Breaks & Continues to be LoopContinuations,
// of the form LoopContinuations(%loop_continue_condition, *loop_carried_vars)
// Break Statements have the condition set to false, and Continue statements
// inline the loop condition as the first input.
struct LoopContinuations {
public:
void run(std::shared_ptr<Graph>& graph) {
run(graph->block());
}
private:
void addLoopCarriedOutputs(Node* n) {
auto g = n->owningGraph();
WithInsertPoint insert(n);
// NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
auto continuation = curr_loop_->blocks().at(0)->return_node();
for (auto out : continuation->inputs()) {
auto load_node = out->node();
TORCH_INTERNAL_ASSERT(load_node->kind() == prim::Load);
auto new_load =
g->insertNode(g->createClone(load_node, [](Value* v) { return v; }));
n->addInput(new_load->output());
}
}
void assignExitContinuations(Block* block) {
for (auto it = block->nodes().begin(); it != block->nodes().end();) {
Node* n = *it;
it++;
switch (n->kind()) {
case prim::If: {
assignExitContinuations(n->blocks().at(0));
assignExitContinuations(n->blocks().at(1));
} break;
case prim::Closure: {
LoopContinuations closure_block;
closure_block.run(n->blocks().at(0));
} break;
case prim::Loop: {
Node* prev_loop = curr_loop_;
curr_loop_ = n;
assignExitContinuations(n->blocks().at(0));
curr_loop_ = prev_loop;
} break;
case prim::ContinueStmt: {
auto loop_continuation =
graph_->create(prim::LoopContinuation, 0)->insertAfter(n);
auto header_block = loop_continuation->addBlock();
// NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
auto pre_header = curr_loop_->blocks().at(1);
header_block->cloneFrom(pre_header, [](Value* v) { return v; });
InlineBlockBeforeNode(n, header_block);
loop_continuation->addInput(header_block->outputs().at(0));
loop_continuation->eraseBlock(0);
addLoopCarriedOutputs(loop_continuation);
n->destroy();
} break;
case prim::BreakStmt: {
auto loop_exit =
graph_->create(prim::LoopContinuation, 0)->insertAfter(n);
// first input is the loop continue condition - break sets false
loop_exit->addInput(false_val_);
addLoopCarriedOutputs(loop_exit);
n->destroy();
} break;
}
}
}
void run(Block* b) {
{
graph_ = b->owningGraph();
WithInsertPoint guard(b->nodes().front());
false_val_ = graph_->insertConstant(false);
}
assignExitContinuations(b);
}
Graph* graph_ = nullptr;
Value* false_val_ = nullptr;
Node* curr_loop_ = nullptr;
};
// Converting to SSA works in multiple parts. First, we add control flow
// loads and stores to the graph. Now that control flow outputs are set,
// we can set remove Break & Continue to have the correct continuations to the
// end of the block (LoopContinuation). Then we inline the loop condition into
// the graph. Then, we erase Loads & Stores. Finally, we remove
// LoopContinuations from the graph.
void ConvertToSSA(std::shared_ptr<Graph>& graph) {
ControlFlowLoadStores ctrl;
ctrl.run(graph);
LoopContinuations exit_vars;
exit_vars.run(graph);
InlineLoopCondition(graph);
EraseLoadStores erase_loads_stores;
erase_loads_stores.run(graph);
TransformExits(graph);
}
} // namespace jit
} // namespace torch
|