1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
|
#include <torch/csrc/lazy/core/ir_util.h>
#include <c10/util/Logging.h>
namespace torch {
namespace lazy {
std::vector<Node*> Util::ComputePostOrder(const Node* node, EmissionMap* emap) {
std::vector<Node*> post_order;
std::vector<Node*> queue;
// std::vector<const T> to c10::ArrayRef<T> conversion is not supported,
// so we need to drop const in the return vector and use const_cast here.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
queue.push_back(const_cast<Node*>(node));
while (!queue.empty()) {
node = queue.back();
auto it = emap->find(node);
if (it == emap->end()) {
(*emap)[node] = kEmitting;
for (auto& output : node->operands()) {
auto oit = emap->find(output.node);
if (oit == emap->end()) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
queue.push_back(const_cast<Node*>(output.node));
} else {
TORCH_CHECK(
oit->second != kEmitting,
"Graph loop found at ",
output.node->ToString());
}
}
} else if (it->second == kEmitting) {
for (auto& output : node->operands()) {
auto oit = emap->find(output.node);
TORCH_CHECK(
oit != emap->end() && oit->second == kEmitted,
"Graph loop found at ",
output.node->ToString());
}
(*emap)[node] = kEmitted;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
post_order.push_back(const_cast<Node*>(node));
queue.pop_back();
} else {
TORCH_CHECK(it->second == kEmitted);
queue.pop_back();
}
}
return post_order;
}
std::vector<Node*> Util::ComputePostOrder(
c10::ArrayRef<Node*> nodes,
EmissionMap* emap) {
std::vector<Node*> post_order;
for (auto node : nodes) {
auto node_post_order = ComputePostOrder(node, emap);
post_order.insert(
post_order.end(), node_post_order.begin(), node_post_order.end());
}
return post_order;
}
std::vector<Node*> Util::ComputePostOrder(c10::ArrayRef<Node*> nodes) {
EmissionMap emap;
return ComputePostOrder(nodes, &emap);
}
size_t Util::GetGraphSize(c10::ArrayRef<Node*> nodes) {
return ComputePostOrder(nodes).size();
}
} // namespace lazy
} // namespace torch
|