File: trie.h

package info (click to toggle)
pytorch 2.6.0%2Bdfsg-8
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 161,672 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (77 lines) | stat: -rw-r--r-- 2,192 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#pragma once

#include <atomic>
#include <list>

#include <c10/core/ScalarType.h>
#include <torch/csrc/lazy/core/ir.h>
#include <torch/csrc/lazy/core/metrics.h>

namespace torch::lazy {

struct TORCH_API TrieNode {
  static size_t GetNextUniqueId() {
    static thread_local size_t id_generator = 0;
    return id_generator++;
  }

  size_t unique_id;
  size_t hit_counter;
  NodePtr ir_node;
  std::list<std::shared_ptr<TrieNode>> successors;

  TrieNode() : unique_id(GetNextUniqueId()), hit_counter(0), ir_node(nullptr) {}
  explicit TrieNode(NodePtr node)
      : unique_id(GetNextUniqueId()),
        hit_counter(0),
        ir_node(std::move(node)) {}
};

class TORCH_API TrieCache {
 public:
  static TrieCache* Get();

  TrieNode* Current() const;
  // Take an iterator as the input because we want to move the corresponding
  // node in the successor list to achieve a LRU caching effect
  void SetCurrent(std::list<std::shared_ptr<TrieNode>>::iterator& iter);
  // Used in MarkStep to indicate the end of one tracing
  void ResetCurrent();

  // Create a new TrieNode for ir_node and insert into the TrieCache
  void Insert(NodePtr ir_node);

  // Clear all TrieCache nodes
  // TODO: Because we don't expect user to explicitly call this function via
  // a Python API, we may need to introduce a threshold on the size of the cache
  // to avoid holding tensors for too long.
  void Clear();

  void DumpToDotFile(const std::string& file_name);

 private:
  TrieCache();

  std::shared_ptr<TrieNode> root_;
  TrieNode* current_;
};

template <typename T, typename... Args>
NodePtr LookupNodeFromTrieCache(Args&&... args) {
  auto& successors = TrieCache::Get()->Current()->successors;
  for (auto it = successors.begin(); it != successors.end(); it++) {
    NodePtr ir_node = (*it)->ir_node;
    const T* concrete_node = NodeCast<T>(ir_node.get());
    if (concrete_node &&
        concrete_node->CanBeReused(std::forward<Args>(args)...)) {
      TORCH_LAZY_COUNTER(
          "IrNodeReused_" + c10::demangle((typeid(T).name())), 1);
      (*it)->hit_counter++;
      TrieCache::Get()->SetCurrent(it);
      return ir_node;
    }
  }
  return nullptr;
}

} // namespace torch::lazy