1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
|
#include <torch/csrc/jit/passes/check_strict_fusion.h>
#include <c10/util/Exception.h>
#include <torch/csrc/jit/frontend/error_report.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/quantization/helper.h>
#include <torch/csrc/jit/runtime/graph_iterator.h>
namespace torch::jit {
namespace {
bool isStrictFusion(Value* value) {
const auto class_name = getModuleName(value);
return class_name.has_value() &&
(*class_name == "__torch__.torch.jit.strict_fusion");
}
} // namespace
static bool fusionGuardCheck(Symbol k) {
return k == Symbol::prim("TensorExprDynamicGuard") || k == prim::TypeCheck ||
k == prim::CudaFusionGuard || k == prim::RequiresGradCheck;
}
static std::unordered_set<Node*> collectValuesUsedInGuard(
Node* guarding_if,
Node* enter_node) {
// DFS to collect
std::unordered_set<Node*> visited_nodes;
std::vector<Node*> queue = {guarding_if};
while (!queue.empty()) {
Node* curr = queue[queue.size() - 1];
queue.pop_back();
visited_nodes.insert(curr);
// these nodes directly test Tensor inputs, and are not part of additional
// guards inserted
if (fusionGuardCheck(curr->kind())) {
continue;
}
for (Value* v : curr->inputs()) {
Node* inp_node = v->node();
if (inp_node->isBefore(enter_node) ||
inp_node->owningBlock() != enter_node->owningBlock()) {
continue;
}
if (visited_nodes.count(inp_node)) {
continue;
}
queue.push_back(inp_node);
}
}
return visited_nodes;
}
static void checkForUnfusedOps(Node* enter_node) {
std::vector<Node*> unsupported_nodes;
std::vector<Node*> guarding_ifs; // if multiple, we will throw
for (Node* node = enter_node->next(); node->kind() != prim::Exit;
node = node->next()) {
if (node->kind() == prim::If &&
fusionGuardCheck(node->input()->node()->kind())) {
guarding_ifs.push_back(node);
continue;
}
unsupported_nodes.push_back(node);
}
if (guarding_ifs.size() > 1) {
std::stringstream ss;
ss << "Found multiple fusions: \n";
for (Node* n : guarding_ifs) {
ss << *n << "\n";
}
throw(ErrorReport(enter_node->input()->node()->sourceRange()) << ss.str());
}
// autodiff/nnc both insert a number of guards, see
// `CudaFusionViewGuard Example Graph`
// to check for unfused nodes, look at node's whose outputs
// are not depended on by the fusion guard
// restrict search for all values after the first
// node in the prim::Enter block
std::unordered_set<Node*> guarding_check_nodes;
if (guarding_ifs.size() == 1) {
guarding_check_nodes =
collectValuesUsedInGuard(guarding_ifs[0], enter_node);
}
std::vector<Node*> unfused_nodes_not_used_in_guard;
for (Node* unfused : unsupported_nodes) {
if (!guarding_check_nodes.count(unfused)) {
unfused_nodes_not_used_in_guard.push_back(unfused);
}
}
if (!unfused_nodes_not_used_in_guard.empty()) {
std::stringstream ss;
ss << "Found unfused operators: \n";
for (Node* unfused : unfused_nodes_not_used_in_guard) {
ss << "\t";
if (unfused->maybeSchema()) {
ss << unfused->schema();
} else {
unfused->kind().toDisplayString();
}
ss << "\n";
}
throw(ErrorReport(enter_node->input()->node()->sourceRange()) << ss.str());
}
}
void CheckStrictFusion(std::shared_ptr<Graph>& graph) {
DepthFirstGraphNodeIterator it(graph);
Node* n = nullptr;
while ((n = it.next()) != nullptr) {
if (n->kind() == prim::Enter && isStrictFusion(n->input())) {
checkForUnfusedOps(n);
}
}
// TODO: remove context manager after checks
// TODO: improve control flow not taken, right now always errors
}
} // namespace torch::jit
|