1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
|
#pragma once
#include <unordered_map>
#include <vector>
#include <torch/csrc/lazy/core/ir.h>
namespace torch {
namespace lazy {
class TORCH_API Util {
public:
// Tracks the emission status of the nodes during the post-order generation.
// It helps tracking loops within the computation graphs.
enum EmitStatus {
kNotEmitted,
kEmitting,
kEmitted,
};
using EmissionMap = std::unordered_map<const Node*, EmitStatus>;
// Computes the post order from the given node, without using recursion. The
// emission map can be used as saved state, for multiple separate calls to
// this API. The returned post-order can be empty if the node has already been
// emitted inside the emission map. An error is generated if a loop is
// detected.
static std::vector<Node*> ComputePostOrder(
const Node* node,
EmissionMap* emap);
static std::vector<Node*> ComputePostOrder(
c10::ArrayRef<Node*> nodes,
EmissionMap* emap);
// Same as above, but computes the post order on the set of nodes specified as
// argument.
static std::vector<Node*> ComputePostOrder(c10::ArrayRef<Node*> nodes);
// Retrieves the number of nodes within the graph whose sink are passed in the
// nodes argument.
static size_t GetGraphSize(c10::ArrayRef<Node*> nodes);
};
} // namespace lazy
} // namespace torch
|