1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
|
#include <torch/csrc/jit/codegen/fuser/kernel_cache.h>
#include <torch/csrc/jit/passes/canonicalize.h>
#include <torch/csrc/jit/passes/shape_analysis.h>
#include <cstdint>
#include <mutex>
#include <unordered_map>
namespace torch {
namespace jit {
namespace fuser {
struct KernelCacheImpl {
// Note: std::unordered_map does not invalidate references even if rehashing
// occurs. This is a critical property for thread-safety.
std::mutex mutex_;
int64_t kernel_counter{0};
// Map of fusion key to KernelSpec
std::unordered_map<int64_t, KernelSpec> specMap_;
// Map of pretty-printed graph string to fusion key
// Used to check if a graph has already been cached in specMap_
std::unordered_map<std::string, int64_t> graphToKey_;
};
static KernelCacheImpl& getKernelCache() {
static KernelCacheImpl cache;
return cache;
}
int64_t debugNumCachedKernelSpecs() {
auto& cache = getKernelCache();
std::lock_guard<std::mutex> guard{cache.mutex_};
return cache.specMap_.size();
}
std::shared_ptr<Graph> normalizeGraphForCache(
const std::shared_ptr<Graph>& graph) {
auto result = Canonicalize(graph, /*keep_unique_names=*/false);
EraseShapeInformation(result);
return result;
}
// TODO: lookup by historic string key to start, then issue key
// as appropriate for faster lookup in the future
// precondition: graph has been normalized via normalizeGraphForCache
int64_t store(std::shared_ptr<Graph> graph) {
auto& cache = getKernelCache();
std::string repr = graph->toString(false);
std::lock_guard<std::mutex> guard{cache.mutex_};
const auto key = cache.kernel_counter++;
cache.specMap_.emplace(
std::piecewise_construct,
std::forward_as_tuple(key),
std::forward_as_tuple(key, graph));
cache.graphToKey_.emplace(std::make_pair(std::move(repr), key));
return key;
}
// XXX: Does not grab mutex
static at::optional<KernelSpec*> nolock_retrieve(
KernelCacheImpl& cache,
const int64_t key) {
auto it = cache.specMap_.find(key);
if (it == cache.specMap_.end())
return at::nullopt;
return &(it->second);
}
at::optional<KernelSpec*> retrieve(const int64_t key) {
auto& cache = getKernelCache();
std::lock_guard<std::mutex> guard{cache.mutex_};
return nolock_retrieve(cache, key);
}
// precondition: graph has been normalized via normalizeGraphForCache
at::optional<KernelSpec*> lookupGraph(std::shared_ptr<Graph> graph) {
auto& cache = getKernelCache();
std::string repr = graph->toString(false);
std::lock_guard<std::mutex> guard{cache.mutex_};
auto it = cache.graphToKey_.find(repr);
if (it == cache.graphToKey_.end())
return at::nullopt;
return nolock_retrieve(cache, it->second);
}
} // namespace fuser
} // namespace jit
} // namespace torch
|