1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463
|
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <c10/util/irange.h>
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/ir/ir_views.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/utils/memory.h>
#include <unordered_map>
namespace torch {
namespace jit {
namespace prim {
using namespace ::c10::prim;
}
class DeadCodeEliminator {
public:
explicit DeadCodeEliminator(
std::shared_ptr<Graph> graph,
DCESideEffectPolicy sideEffectPolicy)
: sideEffectPolicy_(sideEffectPolicy),
graph_(std::move(graph)),
useAliasDb_(true) {}
DeadCodeEliminator(DCESideEffectPolicy sideEffectPolicy)
: sideEffectPolicy_(sideEffectPolicy) {}
// The algorithm is an inverse mark-and-sweep. Starting from the return node,
// we mark "live" nodes that are necessary for the output. Nodes that have
// side effects are also marked.
void run(Block* block, bool recurse) {
// clean up unused fork inputs before starting the main algorithm
eliminateDeadForkInputs(block, recurse);
// Initialize by marking the return node and all its consumed values as live
mark(block->return_node());
mark(block);
deleteCallback_(liveValues_);
sweep(block, recurse);
}
void setDeleteCallback(
std::function<void(const std::unordered_set<const Value*>&)>
deleteCallback) {
deleteCallback_ = std::move(deleteCallback);
}
private:
void eliminateDeadForkInputs(Block* block, bool recurse) {
for (Node* node : block->nodes()) {
if (recurse) {
for (Block* sb : node->blocks()) {
eliminateDeadForkInputs(sb, recurse);
}
}
if (node->kind() != prim::fork) {
continue;
}
Graph& g = *node->g(attr::Subgraph);
// WARNING: Do not use a ranged loop. The loop bounds are changed by the
// loop body.
for (size_t i = 0; i < g.inputs().size(); ++i) {
if (!g.inputs().at(i)->hasUses()) {
GRAPH_UPDATE(
"Dead ",
i,
"-th input ",
node->inputs().at(i)->debugName(),
"(",
g.inputs().at(i)->debugName(),
" in a subgraph) will be removed");
g.eraseInput(i);
node->removeInput(i);
}
}
}
}
// Special handling for block return nodes. Unlike other nodes, the block
// return node doesn't really "use" its inputs. Consider:
//
// %a0 = aten::foo()
// %b = aten::foo()
// %a2, %b2 = prim::If(%cond) {
// block0() {
// %a1 = aten::foo(%.0)
// %b1 = aten::foo(%b)
// } -> (%a1, %b1)
// }
// return (%a2)
//
// We want to be able to DCE all the %b stuff. So when processing block
// returns, we only mark producers for values that "live" (i.e. used outside
// the block).
//
// Returns true iff this marked something we haven't marked before.
bool markReturnNode(Node* node) {
if (marked_.count(node)) {
return false;
}
AT_ASSERT(node->owningBlock()->return_node() == node);
auto outerNode = node->owningBlock()->owningNode();
if (outerNode == nullptr || outerNode->kind() == prim::Reverse) {
// If there's no outer node, we're looking at the graph's top-level
// return block. We consider all graph outputs to be "used", so just mark
// this node normally.
return mark(node);
}
// Collect all inputs that are actually live
if (outerNode->kind() == prim::Loop ||
outerNode->kind() == c10::onnx::Loop) {
// Special handling to deal with loop carried dependencies.
auto loop = LoopView(outerNode);
for (const auto i : c10::irange(loop.carriedOutputs().size())) {
if (outerNode->kind() == c10::onnx::Loop) {
// Special handling for onnx loop.
// The number of body carried inputs and outputs are different.
// They cannot be mapped to each other easily by the same index.
liveValues_.insert(loop.bodyCarriedOutputs().at(i));
continue;
}
auto innerInput = loop.bodyCarriedInputs().at(i);
auto innerOutput = loop.bodyCarriedOutputs().at(i);
auto outerOutput = loop.carriedOutputs().at(i);
if (liveValues_.count(outerOutput) || innerInput->hasUses()) {
liveValues_.insert(innerOutput);
}
}
// Also mark the loop next condition as live, since it will be used inside
// the loop body.
liveValues_.insert(loop.nextCond());
} else {
AT_ASSERT(outerNode->outputs().size() == node->inputs().size());
for (const auto i : c10::irange(outerNode->outputs().size())) {
auto innerOutput = node->inputs()[i];
auto outerOutput = outerNode->outputs()[i];
if (liveValues_.count(outerOutput)) {
liveValues_.insert(innerOutput);
}
}
}
marked_.insert(node);
return true;
}
// Loops are special, because we need to run them to convergence.
// Consider the following loop:
// for i in range(3):
// tot += a[0][0]
// b = a[0]
// b[0] += 1
// print(tot)
//
// If we only process the loop block once, we will conclude that `b[0]` and
// `b` are dead, even though `b[0] += 1` mutates a live memory location (since
// `b[0]` is an alias of `a`). i.e. `a` is used to compute `tot` in the next
// iteration
//
// We need to mark the loop again with the information that `a` is live, and
// repeat until we're not marking new stuff anymore.
//
// Returns true iff this marked something we haven't marked before.
bool markLoop(Node* node) {
TORCH_INTERNAL_ASSERT(node->kind() == prim::Loop);
// Did a single iteration over the loop block mark anything new?
// If this is false, we've converged.
bool marked = false;
// Did we ever mark anything new?
bool anyMarked = false;
do {
marked = mark(node->blocks().at(0));
anyMarked |= marked;
} while (marked);
return anyMarked;
}
// Returns true iff this marked something we haven't marked before.
bool mark(Block* block) {
bool anyMarked = false;
// Mark all nodes with side effects.
for (auto node : block->nodes()) {
if (sideEffectPolicy_ ==
DCESideEffectPolicy::DONT_DELETE_NODES_WITH_SIDE_EFFECTS &&
hasSideEffects(node)) {
anyMarked |= mark(node);
}
}
// Initialize by marking the return node
anyMarked |= markReturnNode(block->return_node());
for (auto it = block->nodes().rbegin(); it != block->nodes().rend(); ++it) {
auto node = *it;
if (node->kind() == prim::Loop) {
// Special casing for loops, see comment in markLoop.
anyMarked |= markLoop(node);
} else {
// Other nodes with sub-blocks get marked normally.
for (auto subBlock : node->blocks()) {
anyMarked |= mark(subBlock);
}
}
anyMarked |= markIfLive(node);
}
return anyMarked;
}
// If we output or write to a live memory location, mark this node
// Returns true iff this marked something we haven't marked before.
bool markIfLive(Node* node) {
for (const auto output : node->outputs()) {
if (liveValues_.count(output)) {
return mark(node);
}
}
if (useAliasDb_) {
if (getOrCreateAliasDb()->writesToAlias(node, liveValues_)) {
return mark(node);
}
}
return false;
}
// Mark this node as live and add this node's inputs and aliases to the live
// value sets.
// Returns true iff this marked something we haven't marked before.
bool mark(Node* node) {
if (marked_.count(node)) {
return false;
}
marked_.insert(node);
// Mark all nodes in this node's blockchain (since owning nodes are
// considered live if they contain a live node)
auto curNode = node;
while (curNode) {
if (!curNode->owningBlock()) {
break;
}
mark(curNode);
curNode = curNode->owningBlock()->owningNode();
}
// NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
for (const auto input : node->inputs()) {
if (liveValues_.count(input)) {
continue;
}
liveValues_.insert(input);
}
return true;
}
// Delete all unmarked nodes.
void sweep(Block* block, bool recurse) {
auto nodes = block->nodes().reverse();
for (auto it = nodes.begin(); it != nodes.end(); it++) {
auto node = *it;
// note these occur before the recursion because we want to uncover
// dead code in the blocks used to calculate the output
removeDeadBlockOutputs(node);
removeDeadLoopOutputs(node);
if (recurse) {
for (Block* block : node->blocks()) {
sweep(block, true);
}
}
// NB: Checking hasUses() is required. AD graphs are not perfectly
// valid, as a node in grad_desc.f might be used in reverse_block.
// Reverse_block is inlined in grad_desc.f before it's separated
// to grad_desc.df.
if (!(marked_.count(node) || node->hasUses())) {
GRAPH_UPDATE(
"Node ",
it->kind().toQualString(),
" which outputs ",
(node->outputs().size() > 0 ? node->outputs().at(0)->debugName()
: "n/a"),
" will be removed");
it.destroyCurrent();
}
}
}
bool hasUntrackedMutation(Node* node) {
if (!useAliasDb_) {
// If we don't have alias information, all mutable ops have unknown
// effects and can't be considered for elimination.
if (node->kind() == prim::SetAttr) {
// SetAttr is a special case: it doesn't have a schema, but does
// have untracked mutations
return true;
}
// onnx export calls EliminateDeadCode but sometimes passes invalid
// aten operators. So we call maybeSchema so we handle the cases when
// there is no valid schema for a node
auto schema = node->maybeSchema();
return schema && schema->is_mutable();
} else {
return getOrCreateAliasDb()->writesToWildcard(node);
}
}
bool hasSideEffects(Node* node) {
auto it = memo_.find(node);
if (it != memo_.end())
return it->second;
bool has_side_effects = node->hasSideEffects() ||
std::any_of(node->blocks().begin(),
node->blocks().end(),
[&](Block* b) {
return std::any_of(
b->nodes().begin(), b->nodes().end(), [&](Node* n) {
return hasSideEffects(n);
});
}) ||
hasUntrackedMutation(node);
memo_.emplace(node, has_side_effects);
return has_side_effects;
}
void removeDeadBlockOutputs(Node* node) {
if (node->kind() != prim::If && node->kind() != prim::GradOf) {
return;
}
for (size_t i_1 = node->outputs().size(); i_1 > 0; --i_1) {
size_t i = i_1 - 1;
if (!node->outputs().at(i)->hasUses()) {
GRAPH_UPDATE(
"Dead ",
i,
"-th output ",
node->outputs().at(i)->debugName(),
" of node ",
node->kind().toQualString(),
" will be removed");
node->eraseOutput(i);
for (Block* b : node->blocks()) {
GRAPH_UPDATE(
"\tCorresponding block output ",
b->outputs().at(i)->debugName(),
" will be removed");
b->eraseOutput(i);
}
}
}
}
void removeDeadLoopOutputs(Node* node) {
if (node->kind() != prim::Loop)
return;
auto loop_body = node->blocks().at(0);
auto loop_input_offset = 2; // offset of loop carried deps in input list
auto loop_body_offset =
1; // offset to the loop carried dependencies in block inputs/outputs
for (size_t i_1 = node->outputs().size(); i_1 > 0; --i_1) {
size_t i = i_1 - 1;
if (!node->outputs().at(i)->hasUses() &&
!loop_body->inputs().at(loop_body_offset + i)->hasUses()) {
logDeadLoopOutputs(node, i, loop_input_offset, loop_body_offset);
node->eraseOutput(i);
node->removeInput(loop_input_offset + i);
loop_body->eraseInput(loop_body_offset + i);
loop_body->eraseOutput(loop_body_offset + i);
}
}
}
void logDeadLoopOutputs(
Node* node,
size_t i,
size_t loop_input_offset,
size_t loop_body_offset) {
auto loop_body = node->blocks().at(0);
GRAPH_UPDATE(
"Dead ",
loop_input_offset + i,
"-th input ",
node->inputs().at(i)->debugName(),
" will be removed");
GRAPH_UPDATE(
"Dead ",
i,
"-th output ",
node->outputs().at(i)->debugName(),
" will be removed");
GRAPH_UPDATE(
"\tDead block input ",
loop_body->inputs().at(loop_body_offset + i)->debugName(),
"at offset ",
loop_body_offset + i,
" will be removed");
GRAPH_UPDATE(
"\tDead block output ",
loop_body->outputs().at(loop_body_offset + i)->debugName(),
"at offset ",
loop_body_offset + i,
" will be removed");
}
AliasDb* getOrCreateAliasDb() {
if (!aliasDb_) {
aliasDb_ = std::make_unique<AliasDb>(graph_);
}
return aliasDb_.get();
}
DCESideEffectPolicy sideEffectPolicy_;
std::shared_ptr<Graph> graph_;
bool useAliasDb_ = false;
// lazily initialized
std::unique_ptr<AliasDb> aliasDb_ = nullptr;
std::unordered_map<Node*, bool> memo_;
std::unordered_set<Node*> marked_;
std::unordered_set<const Value*> liveValues_;
std::function<void(const std::unordered_set<const Value*>&)> deleteCallback_ =
[](const std::unordered_set<const Value*>&) {};
};
void EliminateDeadCode(
const std::shared_ptr<Graph>& graph,
DCESideEffectPolicy sideEffectPolicy) {
DeadCodeEliminator(graph, sideEffectPolicy)
.run(graph->block(), /*recurse=*/true);
GRAPH_DUMP("After EliminateDeadCode: ", graph);
}
void EliminateDeadCode(
Block* block,
bool recurse,
DCESideEffectPolicy sideEffectPolicy) {
DeadCodeEliminator(sideEffectPolicy).run(block, recurse);
}
void EliminateDeadCode(
Block* block,
std::function<void(const std::unordered_set<const Value*>&)> cb,
DCESideEffectPolicy sideEffectPolicy) {
DeadCodeEliminator eliminator(sideEffectPolicy);
eliminator.setDeleteCallback(std::move(cb));
eliminator.run(block, /*recurse=*/true);
}
} // namespace jit
} // namespace torch
|