File: trie.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (87 lines) | stat: -rw-r--r-- 2,261 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#include <torch/csrc/lazy/core/trie.h>

#include <torch/csrc/lazy/core/hash.h>
#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
#include <torch/csrc/lazy/core/ir_metadata.h>
#include <torch/csrc/lazy/core/metrics.h>
#include <fstream>
#include <sstream>

namespace torch {
namespace lazy {
namespace {

void TraverseTrie(TrieNode* node, std::stringstream& ss) {
  if (!node) {
    return;
  }
  if (node->ir_node) {
    ss << node->unique_id << "[label=\"" << node->ir_node->op().ToString()
       << ", " << node->hit_counter << " hits\"]\n";
  }
  for (auto& successor : node->successors) {
    ss << node->unique_id << " -> " << successor->unique_id << "\n";
    TraverseTrie(successor.get(), ss);
  }
}
} // namespace

TrieCache* TrieCache::Get() {
  static thread_local TrieCache* trie = new TrieCache();
  return trie;
}

TrieCache::TrieCache()
    : root_(std::make_shared<TrieNode>()), current_(root_.get()) {}

TrieNode* TrieCache::Current() const {
  return current_;
}

void TrieCache::SetCurrent(
    std::list<std::shared_ptr<TrieNode>>::iterator& iter) {
  auto& successors = current_->successors;
  // Update current_ before iter gets destroyed
  current_ = (*iter).get();

  // Insert this node to the front of its parent's successor list
  if (iter != successors.begin()) {
    successors.push_front(std::move(*iter));
    successors.erase(iter);
  }
}

void TrieCache::ResetCurrent() {
  current_ = root_.get();
}

void TrieCache::Insert(NodePtr ir_node) {
  TORCH_CHECK(current_);
  if (!current_->successors.empty()) {
    TORCH_LAZY_COUNTER("TrieForked", 1);
  }
  auto new_node = std::make_shared<TrieNode>(std::move(ir_node));
  current_->successors.push_front(std::move(new_node));
  // Update current_ to the newly inserted node
  current_ = current_->successors.front().get();
}

void TrieCache::Clear() {
  ResetCurrent();
  // Clear at the root level should be sufficient because all the nodes
  // are created as shared_ptr.
  root_->successors.clear();
}

void TrieCache::DumpToDotFile(const std::string& file_name) {
  std::stringstream ss;
  ss << "digraph G {\n";
  TraverseTrie(root_.get(), ss);
  ss << "}\n";

  std::ofstream graph_file(file_name);
  graph_file << ss.str();
}

} // namespace lazy
} // namespace torch