1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
|
/**
* Cache utils in this file is adapted from PyTorch/XLA
* https://github.com/pytorch/xla/blob/master/third_party/xla_client/cache.h
*/
#pragma once
#include <functional>
#include <list>
#include <memory>
#include <mutex>
#include <unordered_map>
#include <utility>
namespace torch {
namespace lazy {
// Generic key and object cache with LRU expiration policy. The objects of type
// T will be stored as std::shared_ptr<T> and taken and returned as such, by the
// cache API.
template <
typename K,
typename T,
typename H = std::hash<K>,
typename E = std::equal_to<K>>
class Cache {
public:
using TypePtr = std::shared_ptr<T>;
using Element = std::pair<K, TypePtr>;
explicit Cache(size_t max_size) : max_size_(max_size) {}
// Adds an object to the cache, unless it already exists. If the cache grows
// beyond the limit set during construction, the oldest used object will be
// removed from the cache.
TypePtr Add(K key, TypePtr object) {
std::lock_guard<std::mutex> slock(lock_);
element_list_.emplace_front(Element(std::move(key), std::move(object)));
auto it = element_list_.begin();
auto emplace_result = element_map_.emplace(&it->first, it);
if (!emplace_result.second) {
element_list_.erase(it);
DoLRU(emplace_result.first->second);
} else if (element_list_.size() > max_size_) {
Element* last = &element_list_.back();
element_map_.erase(&last->first);
element_list_.pop_back();
}
return emplace_result.first->second->second;
}
// Retrieves the existing object if it exists. If it does, its position in
// the LRU list gets moved to the head of the list.
// Returns nullptr if no object with the specified key is found within the
// cache.
TypePtr Get(const K& key) {
std::lock_guard<std::mutex> slock(lock_);
auto it = element_map_.find(&key);
if (it == element_map_.end()) {
return nullptr;
}
DoLRU(it->second);
return it->second->second;
}
TypePtr GetLatest() {
std::lock_guard<std::mutex> g(lock_);
TORCH_CHECK(element_list_.size() > 0);
return element_list_.front().second;
}
bool Erase(const K& key) {
std::lock_guard<std::mutex> slock(lock_);
auto it = element_map_.find(&key);
if (it == element_map_.end()) {
return false;
}
auto lit = it->second;
element_map_.erase(it);
element_list_.erase(lit);
return true;
}
void Clear() {
std::lock_guard<std::mutex> slock(lock_);
element_map_.clear();
element_list_.clear();
}
int Numel() const {
std::lock_guard<std::mutex> g(lock_);
TORCH_CHECK(element_map_.size() == element_list_.size());
return element_map_.size();
}
private:
using ElementList = std::list<Element>;
struct Hasher {
size_t operator()(const K* key) const {
return hasher(*key);
}
H hasher;
};
struct Equaler {
bool operator()(const K* k1, const K* k2) const {
return equaler(*k1, *k2);
}
E equaler;
};
using ElementMap = std::
unordered_map<const K*, typename ElementList::iterator, Hasher, Equaler>;
void DoLRU(typename ElementList::iterator it) {
element_list_.splice(element_list_.begin(), element_list_, it);
}
mutable std::mutex lock_;
size_t max_size_ = 0;
ElementList element_list_;
ElementMap element_map_;
};
} // namespace lazy
} // namespace torch
|