1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
|
#include <torch/csrc/lazy/core/ir_util.h>
#include <stack>
#include <c10/util/Logging.h>
namespace torch::lazy {
std::vector<const Node*> Util::ComputePostOrder(
const Node* node,
EmissionMap* emap) {
std::vector<const Node*> post_order;
std::stack<const Node*> node_stack;
node_stack.push(node);
while (!node_stack.empty()) {
node = node_stack.top();
auto it = emap->find(node);
if (it == emap->end()) {
(*emap)[node] = kEmitting;
for (auto& output : node->operands()) {
auto oit = emap->find(output.node);
if (oit == emap->end()) {
node_stack.push(output.node);
} else {
TORCH_CHECK(
oit->second != kEmitting,
"Graph loop found at ",
output.node->ToString());
}
}
} else if (it->second == kEmitting) {
for (auto& output : node->operands()) {
auto oit = emap->find(output.node);
TORCH_CHECK(
oit != emap->end() && oit->second == kEmitted,
"Graph loop found at ",
output.node->ToString());
}
(*emap)[node] = kEmitted;
post_order.push_back(node);
node_stack.pop();
} else {
TORCH_CHECK(it->second == kEmitted);
node_stack.pop();
}
}
return post_order;
}
std::vector<const Node*> Util::ComputePostOrder(
c10::ArrayRef<const Node*> nodes,
EmissionMap* emap) {
std::vector<const Node*> post_order;
for (auto node : nodes) {
auto node_post_order = ComputePostOrder(node, emap);
post_order.insert(
post_order.end(), node_post_order.begin(), node_post_order.end());
}
return post_order;
}
std::vector<const Node*> Util::ComputePostOrder(
c10::ArrayRef<const Node*> nodes) {
EmissionMap emap;
return ComputePostOrder(nodes, &emap);
}
size_t Util::GetGraphSize(c10::ArrayRef<const Node*> nodes) {
return ComputePostOrder(nodes).size();
}
} // namespace torch::lazy
|