File: trie.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (79 lines) | stat: -rw-r--r-- 2,217 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#pragma once

#include <atomic>
#include <list>

#include <c10/core/ScalarType.h>
#include <torch/csrc/lazy/core/ir.h>
#include <torch/csrc/lazy/core/metrics.h>

namespace torch {
namespace lazy {

struct TORCH_API TrieNode {
  static size_t GetNextUniqueId() {
    static thread_local size_t id_generator = 0;
    return id_generator++;
  }

  size_t unique_id;
  size_t hit_counter;
  NodePtr ir_node;
  std::list<std::shared_ptr<TrieNode>> successors;

  TrieNode() : unique_id(GetNextUniqueId()), hit_counter(0), ir_node(nullptr) {}
  explicit TrieNode(NodePtr node)
      : unique_id(GetNextUniqueId()),
        hit_counter(0),
        ir_node(std::move(node)) {}
};

class TORCH_API TrieCache {
 public:
  static TrieCache* Get();

  TrieNode* Current() const;
  // Take an iterator as the input because we want to move the corresponding
  // node in the successor list to achieve a LRU caching effect
  void SetCurrent(std::list<std::shared_ptr<TrieNode>>::iterator& iter);
  // Used in MarkStep to indicate the end of one tracing
  void ResetCurrent();

  // Create a new TrieNode for ir_node and insert into the TrieCache
  void Insert(NodePtr ir_node);

  // Clear all TrieCache nodes
  // TODO: Because we don't expect user to explicitly call this function via
  // a Python API, we may need to introduce a threshold on the size of the cache
  // to avoid holding tensors for too long.
  void Clear();

  void DumpToDotFile(const std::string& file_name);

 private:
  TrieCache();

  std::shared_ptr<TrieNode> root_;
  TrieNode* current_;
};

template <typename T, typename... Args>
NodePtr LookupNodeFromTrieCache(Args&&... args) {
  auto& successors = TrieCache::Get()->Current()->successors;
  for (auto it = successors.begin(); it != successors.end(); it++) {
    NodePtr ir_node = (*it)->ir_node;
    const T* concrete_node = NodeCast<T>(ir_node.get());
    if (concrete_node &&
        concrete_node->CanBeReused(std::forward<Args>(args)...)) {
      TORCH_LAZY_COUNTER(
          "IrNodeReused_" + c10::demangle((typeid(T).name())), 1);
      (*it)->hit_counter++;
      TrieCache::Get()->SetCurrent(it);
      return ir_node;
    }
  }
  return nullptr;
}

} // namespace lazy
} // namespace torch