File: insert_guards.cpp

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (50 lines) | stat: -rw-r--r-- 1,347 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#include <torch/csrc/jit/passes/insert_guards.h>
#include <torch/csrc/jit/runtime/profiling_record.h>
#include <memory>
#include <unordered_set>

namespace torch::jit {

struct GuardInserter {
  GuardInserter(std::shared_ptr<Graph> graph) : graph_(std::move(graph)) {}

  void run() {
    insertGuards(graph_->block());
    ProfilingRecord::removeProfilingNodes(graph_->block());
  }

 private:
  void insertGuards(Block* b) {
    for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
      auto n = *it;
      if (n->kind() == prim::profile) {
        auto pttp = n->ty(attr::profiled_type)->cast<TensorType>();
        if (pttp) {
          auto guard = graph_->create(prim::Guard, {n->input()}, 1);
          auto go = guard->output();
          go->setType(pttp);
          guard->insertBefore(n);
          n->output()->replaceAllUsesWith(go);
        } else {
          // we didn't go down this path i.e
          // no profiling information is available
          n->output()->replaceAllUsesWith(n->input());
        }
        it.destroyCurrent();
      } else {
        for (Block* ib : n->blocks()) {
          insertGuards(ib);
        }
      }
    }
  }

  std::shared_ptr<Graph> graph_;
};

void InsertGuards(std::shared_ptr<Graph> graph) {
  GuardInserter gi(std::move(graph));
  gi.run();
}

} // namespace torch::jit