1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
|
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
#include <caffe2/utils/threadpool/thread_pool_guard.h>
#include <c10/util/Exception.h>
#include <atomic>
namespace {
// After fork, the child process inherits the data-structures of the parent
// process' thread-pool, but since those threads don't exist, the thread-pool
// is corrupt. It's leaked in order to prevent segfaults.
// Ref: https://github.com/pytorch/pytorch/issues/54752#issuecomment-810315302
bool leak_corrupted_threadpool = false;
void child_atfork() {
leak_corrupted_threadpool = true;
}
} // namespace
namespace caffe2 {
PThreadPool::PThreadPool(const size_t thread_count)
: threadpool_(pthreadpool_create(thread_count), pthreadpool_destroy) {}
size_t PThreadPool::get_thread_count() const {
std::lock_guard<std::mutex> lock{mutex_};
TORCH_INTERNAL_ASSERT(threadpool_.get(), "Invalid threadpool!");
return pthreadpool_get_threads_count(threadpool_.get());
}
void PThreadPool::set_thread_count(const size_t thread_count) {
std::lock_guard<std::mutex> lock{mutex_};
// As it stands, pthreadpool is an entirely data parallel framework with no
// support for task parallelism. Hence, all functions are blocking, and no
// user-provided tasks can be in flight when the control is returned to the
// user of the API, which means re-initializing the library, without the
// need to wait on any pending tasks, is all one needs to do to re-adjust
// the thread count.
threadpool_.reset(pthreadpool_create(thread_count));
}
void PThreadPool::run(
const std::function<void(size_t)>& fn,
const size_t range) {
// Run on same thread if _NoPThreadPoolGuard guard is enabled
if (caffe2::_NoPThreadPoolGuard::is_enabled()) {
for (size_t i = 0; i < range; ++i) {
fn(i);
}
return;
}
std::lock_guard<std::mutex> lock{mutex_};
TORCH_INTERNAL_ASSERT(!caffe2::_NoPThreadPoolGuard::is_enabled(), "Inside a threadpool guard!");
TORCH_INTERNAL_ASSERT(threadpool_.get(), "Invalid threadpool!");
struct Context final {
const std::function<void(size_t)>& fn;
} context{
fn,
};
pthreadpool_parallelize_1d(
threadpool_.get(),
// Note: pthreadpool_parallelize_1d() is a blocking function. The
// function pointer to this lambda passed on to
// pthreadpool_parallelize_1d() cannot go out of scope until
// pthreadpool_parallelize_1d() returns.
[](void* const context, const size_t item) {
reinterpret_cast<Context*>(context)->fn(item);
},
&context,
range,
0u);
}
// Forward declaration
size_t getDefaultNumThreads();
PThreadPool* pthreadpool() {
static auto threadpool =
std::make_unique<PThreadPool>(getDefaultNumThreads());
#if !(defined(WIN32))
static std::once_flag flag;
std::call_once(flag, []() {
pthread_atfork(nullptr, nullptr, child_atfork);
});
#endif
if (C10_UNLIKELY(leak_corrupted_threadpool)) {
leak_corrupted_threadpool = false;
if (auto leaked = threadpool.release()) {
auto num_threads = leaked->get_thread_count();
// NOLINTNEXTLINE(modernize-make-unique)
threadpool.reset(new PThreadPool(num_threads));
}
}
return threadpool.get();
}
pthreadpool_t pthreadpool_() {
if (caffe2::_NoPThreadPoolGuard::is_enabled()) {
return nullptr;
}
PThreadPool* const threadpool = pthreadpool();
TORCH_INTERNAL_ASSERT(
threadpool, "Failed to acquire an instance of PThreadPool!");
return threadpool->threadpool_.get();
}
} // namespace caffe2
|