1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
|
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
// This class facilitates depth-first iteration over all nodes in a graph.
class DepthFirstGraphNodeIterator {
Node* current_;
public:
// Constructor.
explicit DepthFirstGraphNodeIterator(std::shared_ptr<Graph>& graph)
: current_(*(graph->block()->nodes().begin())) {}
// Moves up and to the next node (may move up recursively).
void move_up() {
if (current_ == nullptr) {
return;
}
// Basically we start from the child block (which is current_)
// and we try to find the block that owns it. Now we need to check
// if that block is the graph root block, or if it is an If/Loop/etc
// block.
//
// If it's the graph root block we can stop because there is no "up"
// but if it is a node (e.g. If/Loop/etc) we need to apply logic
// based on where we are coming from to move to the next block.
// This might mean that we need to traverse up again (e.g. if we've
// reached the end of the else clause in an if block we need to go)
// up to the parent block that contains the if.
//
// Similarly if we've reached the end of the parent block containing
// the else clause we might need to go up again so this is a recursive
// function.
//
// BlockNode (if/loop/with)
// |
// [Block1] ... [Block2]
// |
// [ Node1, Node2, Node3, FromNode]
//
auto parent_block = current_->owningBlock();
TORCH_INTERNAL_ASSERT(parent_block, "Every node must be owned by a block");
// Get the node that owns the parent block. This node has to be an if,
// loop, or with.
auto parent_node = parent_block->owningNode();
if (parent_node == nullptr) {
// If there's no node that owns this current block then we're at the
// top of the graph and since we're trying to move up we have reached
// the end of the traversal.
current_ = nullptr;
return;
}
// Check the type of node this root is.
if (parent_node->kind() == prim::If) {
// Need to check if we came from the `then` branch or the `else` branch.
auto* then_block = parent_node->blocks().at(0);
auto* else_block = parent_node->blocks().at(1);
if (parent_block == else_block) {
// If else block then we move to the next node in the parent block.
current_ = parent_node->next();
if (current_->kind() == prim::Return) {
move_up();
}
} else {
// If then block then move to the else block if it is not empty.
TORCH_INTERNAL_ASSERT(parent_block == then_block);
bool else_block_empty =
else_block->nodes().begin() == else_block->nodes().end();
if (!else_block_empty) {
current_ = *(else_block->nodes().begin());
} else {
// Since it's empty we move to the next node.
current_ = parent_node->next();
if (current_->kind() == prim::Return) {
move_up();
}
}
}
} else if (
parent_node->kind() == prim::Loop ||
parent_node->kind() == prim::With) {
current_ = parent_node->next();
if (current_->kind() == prim::Return) {
move_up();
}
} else {
TORCH_INTERNAL_ASSERT(
false, "Only if/loop/with nodes should have child blocks");
}
}
// Moves to the next adjacent node or up in to the parent if that is not
// possible.
void move_next() {
if (current_ == nullptr) {
return;
}
// Increment to the next node in the current block.
current_ = current_->next();
// Check if we're at the end of the block. If so we need
// to move upwards (if it makes sense to).
if (current_->kind() == prim::Return) {
move_up();
}
}
// Moves to the next node in the graph into children if it can.
void move_into() {
if (current_ == nullptr) {
return;
}
// Check if we're currently on a node that contains sub-nodes.
if (current_->kind() == prim::If || current_->kind() == prim::Loop ||
current_->kind() == prim::With) {
auto* first_block = current_->blocks().at(0);
current_ = first_block->param_node();
// Move next will move up and out of the current node if the block is
// empty. `move_up` which is called by `move_next` will handle the
// difference between If, Loop, and With blocks appropriately.
move_next();
} else {
move_next();
}
}
// Get the next Node in the graph. \returns nullptr if there are no nodes
// left.
Node* next() {
auto result = current_;
// Try move into the existing node to set the next node to be returned.
// This will move to the next node if not possible, or move upwards and
// to the next.
move_into();
return result;
}
};
} // namespace jit
} // namespace torch
|