File: check_strict_fusion.cpp

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (128 lines) | stat: -rw-r--r-- 3,820 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128

#include <torch/csrc/jit/passes/check_strict_fusion.h>

#include <c10/util/Exception.h>
#include <torch/csrc/jit/frontend/error_report.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/quantization/helper.h>
#include <torch/csrc/jit/runtime/graph_iterator.h>

namespace torch::jit {

namespace {

bool isStrictFusion(Value* value) {
  const auto class_name = getModuleName(value);
  return class_name.has_value() &&
      (*class_name == "__torch__.torch.jit.strict_fusion");
}

} // namespace

static bool fusionGuardCheck(Symbol k) {
  return k == Symbol::prim("TensorExprDynamicGuard") || k == prim::TypeCheck ||
      k == prim::CudaFusionGuard || k == prim::RequiresGradCheck;
}

static std::unordered_set<Node*> collectValuesUsedInGuard(
    Node* guarding_if,
    Node* enter_node) {
  // DFS to collect
  std::unordered_set<Node*> visited_nodes;
  std::vector<Node*> queue = {guarding_if};

  while (!queue.empty()) {
    Node* curr = queue[queue.size() - 1];
    queue.pop_back();
    visited_nodes.insert(curr);
    // these nodes directly test Tensor inputs, and are not part of additional
    // guards inserted
    if (fusionGuardCheck(curr->kind())) {
      continue;
    }
    for (Value* v : curr->inputs()) {
      Node* inp_node = v->node();
      if (inp_node->isBefore(enter_node) ||
          inp_node->owningBlock() != enter_node->owningBlock()) {
        continue;
      }
      if (visited_nodes.count(inp_node)) {
        continue;
      }
      queue.push_back(inp_node);
    }
  }
  return visited_nodes;
}

static void checkForUnfusedOps(Node* enter_node) {
  std::vector<Node*> unsupported_nodes;
  std::vector<Node*> guarding_ifs; // if multiple, we will throw
  for (Node* node = enter_node->next(); node->kind() != prim::Exit;
       node = node->next()) {
    if (node->kind() == prim::If &&
        fusionGuardCheck(node->input()->node()->kind())) {
      guarding_ifs.push_back(node);
      continue;
    }
    unsupported_nodes.push_back(node);
  }

  if (guarding_ifs.size() > 1) {
    std::stringstream ss;
    ss << "Found multiple fusions: \n";
    for (Node* n : guarding_ifs) {
      ss << *n << "\n";
    }
    throw(ErrorReport(enter_node->input()->node()->sourceRange()) << ss.str());
  }

  // autodiff/nnc both insert a number of guards, see
  // `CudaFusionViewGuard Example Graph`
  // to check for unfused nodes, look at node's whose outputs
  // are not depended on by the fusion guard
  // restrict search for all values after the first
  // node in the prim::Enter block

  std::unordered_set<Node*> guarding_check_nodes;
  if (guarding_ifs.size() == 1) {
    guarding_check_nodes =
        collectValuesUsedInGuard(guarding_ifs[0], enter_node);
  }
  std::vector<Node*> unfused_nodes_not_used_in_guard;
  for (Node* unfused : unsupported_nodes) {
    if (!guarding_check_nodes.count(unfused)) {
      unfused_nodes_not_used_in_guard.push_back(unfused);
    }
  }
  if (!unfused_nodes_not_used_in_guard.empty()) {
    std::stringstream ss;
    ss << "Found unfused operators: \n";
    for (Node* unfused : unfused_nodes_not_used_in_guard) {
      ss << "\t";
      if (unfused->maybeSchema()) {
        ss << unfused->schema();
      } else {
        unfused->kind().toDisplayString();
      }
      ss << "\n";
    }
    throw(ErrorReport(enter_node->input()->node()->sourceRange()) << ss.str());
  }
}

void CheckStrictFusion(std::shared_ptr<Graph>& graph) {
  DepthFirstGraphNodeIterator it(graph);
  Node* n = nullptr;
  while ((n = it.next()) != nullptr) {
    if (n->kind() == prim::Enter && isStrictFusion(n->input())) {
      checkForUnfusedOps(n);
    }
  }

  // TODO: remove context manager after checks
  // TODO: improve control flow not taken, right now always errors
}

} // namespace torch::jit