File: kernel_cache.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (92 lines) | stat: -rw-r--r-- 2,756 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#include <torch/csrc/jit/codegen/fuser/kernel_cache.h>
#include <torch/csrc/jit/passes/canonicalize.h>
#include <torch/csrc/jit/passes/shape_analysis.h>

#include <cstdint>
#include <mutex>
#include <unordered_map>

namespace torch {
namespace jit {
namespace fuser {

struct KernelCacheImpl {
  // Note: std::unordered_map does not invalidate references even if rehashing
  // occurs. This is a critical property for thread-safety.
  std::mutex mutex_;
  int64_t kernel_counter{0};

  // Map of fusion key to KernelSpec
  std::unordered_map<int64_t, KernelSpec> specMap_;

  // Map of pretty-printed graph string to fusion key
  // Used to check if a graph has already been cached in specMap_
  std::unordered_map<std::string, int64_t> graphToKey_;
};

static KernelCacheImpl& getKernelCache() {
  static KernelCacheImpl cache;
  return cache;
}

int64_t debugNumCachedKernelSpecs() {
  auto& cache = getKernelCache();
  std::lock_guard<std::mutex> guard{cache.mutex_};
  return cache.specMap_.size();
}

std::shared_ptr<Graph> normalizeGraphForCache(
    const std::shared_ptr<Graph>& graph) {
  auto result = Canonicalize(graph, /*keep_unique_names=*/false);
  EraseShapeInformation(result);
  return result;
}

// TODO: lookup by historic string key to start, then issue key
// as appropriate for faster lookup in the future
// precondition: graph has been normalized via normalizeGraphForCache
int64_t store(std::shared_ptr<Graph> graph) {
  auto& cache = getKernelCache();
  std::string repr = graph->toString(false);

  std::lock_guard<std::mutex> guard{cache.mutex_};
  const auto key = cache.kernel_counter++;
  cache.specMap_.emplace(
      std::piecewise_construct,
      std::forward_as_tuple(key),
      std::forward_as_tuple(key, graph));
  cache.graphToKey_.emplace(std::make_pair(std::move(repr), key));
  return key;
}

// XXX: Does not grab mutex
static at::optional<KernelSpec*> nolock_retrieve(
    KernelCacheImpl& cache,
    const int64_t key) {
  auto it = cache.specMap_.find(key);
  if (it == cache.specMap_.end())
    return at::nullopt;
  return &(it->second);
}

at::optional<KernelSpec*> retrieve(const int64_t key) {
  auto& cache = getKernelCache();
  std::lock_guard<std::mutex> guard{cache.mutex_};
  return nolock_retrieve(cache, key);
}

// precondition: graph has been normalized via normalizeGraphForCache
at::optional<KernelSpec*> lookupGraph(std::shared_ptr<Graph> graph) {
  auto& cache = getKernelCache();
  std::string repr = graph->toString(false);

  std::lock_guard<std::mutex> guard{cache.mutex_};
  auto it = cache.graphToKey_.find(repr);
  if (it == cache.graphToKey_.end())
    return at::nullopt;
  return nolock_retrieve(cache, it->second);
}

} // namespace fuser
} // namespace jit
} // namespace torch