1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
|
#include <torch/csrc/jit/passes/check_strict_fusion.h>
#include <c10/util/Exception.h>
#include <torch/csrc/jit/frontend/error_report.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/quantization/helper.h>
#include <torch/csrc/jit/runtime/graph_iterator.h>
#include <unordered_map>
namespace torch {
namespace jit {
namespace {
bool isStrictFusion(Value* value) {
const auto class_name = getModuleName(value);
return class_name.has_value() &&
(*class_name == "__torch__.torch.jit.strict_fusion");
}
} // namespace
bool fusionGuardCheck(Symbol k) {
return k == Symbol::prim("TensorExprDynamicGuard") || k == prim::TypeCheck ||
k == prim::CudaFusionGuard || k == prim::RequiresGradCheck;
}
std::unordered_set<Node*> collectValuesUsedInGuard(
Node* guarding_if,
Node* enter_node) {
// DFS to collect
std::unordered_set<Node*> visited_nodes;
std::vector<Node*> queue = {guarding_if};
while (!queue.empty()) {
Node* curr = queue[queue.size() - 1];
queue.pop_back();
visited_nodes.insert(curr);
// these nodes directly test Tensor inputs, and are not part of additional
// guards inserted
if (fusionGuardCheck(curr->kind())) {
continue;
}
for (Value* v : curr->inputs()) {
Node* inp_node = v->node();
if (inp_node->isBefore(enter_node) ||
inp_node->owningBlock() != enter_node->owningBlock()) {
continue;
}
if (visited_nodes.count(inp_node)) {
continue;
}
queue.push_back(inp_node);
}
}
return visited_nodes;
}
void checkForUnfusedOps(Node* enter_node) {
std::vector<Node*> unsupported_nodes;
std::vector<Node*> guarding_ifs; // if multiple, we will throw
for (Node* node = enter_node->next(); node->kind() != prim::Exit;
node = node->next()) {
if (node->kind() == prim::If &&
fusionGuardCheck(node->input()->node()->kind())) {
guarding_ifs.push_back(node);
continue;
}
unsupported_nodes.push_back(node);
}
if (guarding_ifs.size() > 1) {
std::stringstream ss;
ss << "Found multiple fusions: \n";
for (Node* n : guarding_ifs) {
ss << *n << "\n";
}
throw ErrorReport(enter_node->input()->node()->sourceRange()) << ss.str();
}
// NVFuser/autodiff/nnc all insert a number of guards, see
// `CudaFusionViewGuard Example Graph`
// to check for unfused nodes, look at node's whose outputs
// are not depended on by the fusion guard
// restrict search for all values after the first
// node in the prim::Enter block
std::unordered_set<Node*> guarding_check_nodes;
if (guarding_ifs.size() == 1) {
guarding_check_nodes =
collectValuesUsedInGuard(guarding_ifs[0], enter_node);
}
std::vector<Node*> unfused_nodes_not_used_in_guard;
for (Node* unfused : unsupported_nodes) {
if (!guarding_check_nodes.count(unfused)) {
unfused_nodes_not_used_in_guard.push_back(unfused);
}
}
if (unfused_nodes_not_used_in_guard.size()) {
std::stringstream ss;
ss << "Found unfused operators: \n";
for (Node* unfused : unfused_nodes_not_used_in_guard) {
ss << "\t";
if (unfused->maybeSchema()) {
ss << unfused->schema();
} else {
unfused->kind().toDisplayString();
}
ss << "\n";
}
auto range = enter_node->input()->node()->sourceRange();
throw ErrorReport(enter_node->input()->node()->sourceRange()) << ss.str();
}
}
void CheckStrictFusion(std::shared_ptr<Graph>& graph) {
DepthFirstGraphNodeIterator it(graph);
Node* n = nullptr;
while ((n = it.next()) != nullptr) {
if (n->kind() == prim::Enter && isStrictFusion(n->input())) {
checkForUnfusedOps(n);
}
}
// TODO: remove context manager after checks
// TODO: improve control flow not taken, right now always errors
}
} // namespace jit
} // namespace torch
|