1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
|
#pragma once
#include <mutex>
#include <unordered_map>
#include <unordered_set>
#include "caffe2/core/logging.h"
namespace caffe2 {
/**
* thread_local pointer in C++ is a per thread pointer. However, sometimes
* we want to have a thread local state that is per thread and also per
* instance. e.g. we have the following class:
* class A {
* ThreadLocalPtr<int> x;
* }
* We would like to have a copy of x per thread and also per instance of class A
* This can be applied to storing per instance thread local state of some class,
* when we could have multiple instances of the class in the same thread.
* We implemented a subset of functions in folly::ThreadLocalPtr that's enough
* to support BlackBoxPredictor.
*/
class ThreadLocalPtrImpl;
class ThreadLocalHelper;
/**
* Map of object pointer to instance in each thread
* to achieve per thread(using thread_local) per object(using the map)
* thread local pointer
*/
typedef std::unordered_map<ThreadLocalPtrImpl*, std::shared_ptr<void>>
UnsafeThreadLocalMap;
ThreadLocalHelper* getThreadLocalHelper();
typedef std::vector<ThreadLocalHelper*> UnsafeAllThreadLocalHelperVector;
/**
* A thread safe vector of all ThreadLocalHelper, this will be used
* to encapuslate the locking in the APIs for the changes to the global
* AllThreadLocalHelperVector instance.
*/
class AllThreadLocalHelperVector {
public:
AllThreadLocalHelperVector() {}
// Add a new ThreadLocalHelper to the vector
void push_back(ThreadLocalHelper* helper);
// Erase a ThreadLocalHelper to the vector
void erase(ThreadLocalHelper* helper);
// Erase object in all the helpers stored in vector
// Called during destructor of a ThreadLocalPtrImpl
void erase_tlp(ThreadLocalPtrImpl* ptr);
private:
UnsafeAllThreadLocalHelperVector vector_;
std::mutex mutex_;
};
/**
* ThreadLocalHelper is per thread
*/
class ThreadLocalHelper {
public:
ThreadLocalHelper();
// When the thread dies, we want to clean up *this*
// in AllThreadLocalHelperVector
~ThreadLocalHelper();
// Insert a (object, ptr) pair into the thread local map
void insert(ThreadLocalPtrImpl* tl_ptr, std::shared_ptr<void> ptr);
// Get the ptr by object
void* get(ThreadLocalPtrImpl* key);
// Erase the ptr associated with the object in the map
void erase(ThreadLocalPtrImpl* key);
private:
// mapping of object -> ptr in each thread
UnsafeThreadLocalMap mapping_;
std::mutex mutex_;
}; // ThreadLocalHelper
/** ThreadLocalPtrImpl is per object
*/
class ThreadLocalPtrImpl {
public:
ThreadLocalPtrImpl() {}
// Delete copy and move constructors
ThreadLocalPtrImpl(const ThreadLocalPtrImpl&) = delete;
ThreadLocalPtrImpl(ThreadLocalPtrImpl&&) = delete;
ThreadLocalPtrImpl& operator=(const ThreadLocalPtrImpl&) = delete;
ThreadLocalPtrImpl& operator=(const ThreadLocalPtrImpl&&) = delete;
// In the case when object dies first, we want to
// clean up the states in all child threads
~ThreadLocalPtrImpl();
template <typename T>
T* get() {
return static_cast<T*>(getThreadLocalHelper()->get(this));
}
template <typename T>
void reset(T* newPtr = nullptr) {
VLOG(2) << "In Reset(" << newPtr << ")";
auto* wrapper = getThreadLocalHelper();
// Cleaning up the objects(T) stored in the ThreadLocalPtrImpl in the thread
wrapper->erase(this);
if (newPtr != nullptr) {
std::shared_ptr<void> sharedPtr(newPtr);
// Deletion of newPtr is handled by shared_ptr
// as it implements type erasure
wrapper->insert(this, std::move(sharedPtr));
}
}
}; // ThreadLocalPtrImpl
template <typename T>
class ThreadLocalPtr {
public:
auto* operator->() {
return get();
}
auto& operator*() {
return *get();
}
auto* get() {
return impl_.get<T>();
}
auto* operator->() const {
return get();
}
auto& operator*() const {
return *get();
}
auto* get() const {
return impl_.get<T>();
}
void reset(unique_ptr<T> ptr = nullptr) {
impl_.reset<T>(ptr.release());
}
private:
ThreadLocalPtrImpl impl_;
};
} // namespace caffe2
|