1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373
|
#pragma once
#include <atomic>
#include <condition_variable>
#include <thread>
#include "c10/util/thread_name.h"
#include <c10/util/irange.h>
#include "caffe2/core/common.h"
#include "caffe2/core/logging.h"
#if defined(_MSC_VER)
#include <intrin.h>
#endif
namespace caffe2 {
// Uses code derived from gemmlowp,
// https://github.com/google/gemmlowp/blob/6c91e1ed0c2eff1182d804310b92911fe9c18019/internal/multi_thread_gemm.h
// Changes:
// - allocation-free execute()
// - Use RAII where possible.
// - Run the first task on the main thread (since that is the largest task).
// - removed custom allocator.
// - Removed some ifdef's
// - cache-line align Worker.
// - use std::atomic instead of volatile and custom barriers.
// - use std::mutex/std::condition_variable instead of raw pthreads.
constexpr size_t kGEMMLOWPCacheLineSize = 64;
template <typename T>
struct AllocAligned {
// Allocate a T aligned at an `align` byte address
template <typename... Args>
static T* alloc(Args&&... args) {
void* p = nullptr;
#if defined(__ANDROID__)
p = memalign(kGEMMLOWPCacheLineSize, sizeof(T));
#elif defined(_MSC_VER)
p = _aligned_malloc(sizeof(T), kGEMMLOWPCacheLineSize);
#else
posix_memalign((void**)&p, kGEMMLOWPCacheLineSize, sizeof(T));
#endif
if (p) {
return new (p) T(std::forward<Args>(args)...);
}
return nullptr;
}
// Free a T previously allocated via AllocAligned<T>::alloc()
static void release(T* p) {
if (p) {
p->~T();
#if defined(_MSC_VER)
_aligned_free((void*)p);
#else
free((void*)p);
#endif
}
}
};
// Deleter object for unique_ptr for an aligned object
template <typename T>
struct AlignedDeleter {
void operator()(T* p) const { AllocAligned<T>::release(p); }
};
// make_unique that guarantees alignment
template <typename T>
struct MakeAligned {
template <typename... Args>
static std::unique_ptr<T, AlignedDeleter<T>> make(Args&&... args) {
return std::unique_ptr<T, AlignedDeleter<T>>(
AllocAligned<T>::alloc(std::forward<Args>(args)...));
}
};
const int kMaxBusyWaitNOPs = 32 * 1000 * 1000;
#if defined(_MSC_VER)
#define GEMMLOWP_NOP __nop();
#else
#define GEMMLOWP_NOP "nop\n"
#endif
#define GEMMLOWP_STRING_CONCAT_4(X) X X X X
#define GEMMLOWP_NOP4 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP)
#define GEMMLOWP_NOP16 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP4)
#define GEMMLOWP_NOP64 GEMMLOWP_STRING_CONCAT_4(GEMMLOWP_NOP16)
inline int Do256NOPs() {
#if defined(_MSC_VER)
GEMMLOWP_NOP64;
#else
asm volatile(GEMMLOWP_NOP64);
#endif
return 64;
}
#undef GEMMLOWP_STRING_CONCAT_4
#undef GEMMLOWP_NOP256
#undef GEMMLOWP_NOP64
#undef GEMMLOWP_NOP16
#undef GEMMLOWP_NOP4
#undef GEMMLOWP_NOP
// Waits until *var != initial_value.
//
// Returns the new value of *var. The guarantee here is that
// the return value is different from initial_value, and that that
// new value has been taken by *var at some point during the
// execution of this function. There is no guarantee that this is
// still the value of *var when this function returns, since *var is
// not assumed to be guarded by any lock.
//
// First does some busy-waiting for a fixed number of no-op cycles,
// then falls back to passive waiting for the given condvar, guarded
// by the given mutex.
//
// The idea of doing some initial busy-waiting is to help get
// better and more consistent multithreading benefits for small GEMM sizes.
// Busy-waiting help ensuring that if we need to wake up soon after having
// started waiting, then we can wake up quickly (as opposed to, say,
// having to wait to be scheduled again by the OS). On the other hand,
// we must still eventually revert to passive waiting for longer waits
// (e.g. worker threads having finished a GEMM and waiting until the next GEMM)
// so as to avoid permanently spinning.
//
template <typename T>
T WaitForVariableChange(std::atomic<T>* var,
T initial_value,
std::condition_variable* cond,
std::mutex* mutex) {
// If we are on a platform that supports it, spin for some time.
{
int nops = 0;
// First, trivial case where the variable already changed value.
T new_value = var->load(std::memory_order_relaxed);
if (new_value != initial_value) {
std::atomic_thread_fence(std::memory_order_acquire);
return new_value;
}
// Then try busy-waiting.
while (nops < kMaxBusyWaitNOPs) {
nops += Do256NOPs();
new_value = var->load(std::memory_order_relaxed);
if (new_value != initial_value) {
std::atomic_thread_fence(std::memory_order_acquire);
return new_value;
}
}
}
// Finally, do real passive waiting.
{
std::unique_lock<std::mutex> g(*mutex);
T new_value = var->load(std::memory_order_relaxed);
// Handle spurious wakeups.
cond->wait(g, [&]() {
new_value = var->load(std::memory_order_relaxed);
return new_value != initial_value;
});
TORCH_DCHECK_NE(static_cast<size_t>(new_value), static_cast<size_t>(initial_value));
return new_value;
}
}
// A BlockingCounter lets one thread to wait for N events to occur.
// This is how the master thread waits for all the worker threads
// to have finished working.
class BlockingCounter {
public:
// Sets/resets the counter; initial_count is the number of
// decrementing events that the Wait() call will be waiting for.
void Reset(std::size_t initial_count) {
std::lock_guard<std::mutex> g(mutex_);
TORCH_DCHECK_EQ(count_, 0);
count_ = initial_count;
}
// Decrements the counter; if the counter hits zero, signals
// the thread that was waiting for that, and returns true.
// Otherwise (if the decremented count is still nonzero),
// returns false.
bool DecrementCount() {
const auto count_value = count_.fetch_sub(1, std::memory_order_relaxed) - 1;
TORCH_DCHECK_GE(count_value, 0);
if (count_value == 0) {
std::lock_guard<std::mutex> g(mutex_);
cond_.notify_one();
}
bool retval = count_value == 0;
return retval;
}
// Waits for the N other threads (N having been set by Reset())
// to hit the BlockingCounter.
void Wait() {
while (size_t count_value = count_.load(std::memory_order_relaxed)) {
WaitForVariableChange(&count_, count_value, &cond_, &mutex_);
}
}
private:
std::condition_variable cond_;
std::mutex mutex_;
std::atomic<std::size_t> count_{0};
};
// A workload for a worker.
struct Task {
Task() {}
virtual ~Task() {}
virtual void Run() = 0;
};
// A worker thread.
class alignas(kGEMMLOWPCacheLineSize) Worker {
public:
enum class State : uint8_t {
ThreadStartup, // The initial state before the thread main loop runs.
Ready, // Is not working, has not yet received new work to do.
HasWork, // Has work to do.
ExitAsSoonAsPossible // Should exit at earliest convenience.
};
explicit Worker(BlockingCounter* counter_to_decrement_when_ready)
: task_(nullptr),
state_(State::ThreadStartup),
counter_to_decrement_when_ready_(counter_to_decrement_when_ready) {
thread_ = std::make_unique<std::thread>([this]() { this->ThreadFunc(); });
}
~Worker() {
ChangeState(State::ExitAsSoonAsPossible);
thread_->join();
}
// Changes State; may be called from either the worker thread
// or the master thread; however, not all state transitions are legal,
// which is guarded by assertions.
void ChangeState(State new_state) {
std::lock_guard<std::mutex> g(state_mutex_);
DCHECK(new_state != state_.load(std::memory_order_relaxed));
switch (state_.load(std::memory_order_relaxed)) {
case State::ThreadStartup:
DCHECK(new_state == State::Ready);
break;
case State::Ready:
DCHECK(new_state == State::HasWork || new_state == State::ExitAsSoonAsPossible);
break;
case State::HasWork:
DCHECK(new_state == State::Ready || new_state == State::ExitAsSoonAsPossible);
break;
default:
abort();
}
state_.store(new_state, std::memory_order_relaxed);
state_cond_.notify_one();
if (new_state == State::Ready) {
counter_to_decrement_when_ready_->DecrementCount();
}
}
// Thread entry point.
void ThreadFunc() {
c10::setThreadName("CaffeWorkersPool");
ChangeState(State::Ready);
// Thread main loop
while (true) {
// Get a state to act on
// In the 'Ready' state, we have nothing to do but to wait until
// we switch to another state.
State state_to_act_upon =
WaitForVariableChange(&state_, State::Ready, &state_cond_, &state_mutex_);
// We now have a state to act on, so act.
switch (state_to_act_upon) {
case State::HasWork:
// Got work to do! So do it, and then revert to 'Ready' state.
DCHECK(task_.load());
(*task_).Run();
task_ = nullptr;
ChangeState(State::Ready);
break;
case State::ExitAsSoonAsPossible:
return;
default:
abort();
}
}
}
static void* ThreadFunc(void* arg) {
static_cast<Worker*>(arg)->ThreadFunc();
return nullptr;
}
// Called by the master thread to give this worker work to do.
// It is only legal to call this if the worker
void StartWork(Task* task) {
DCHECK(!task_.load());
task_ = task;
DCHECK(state_.load(std::memory_order_acquire) == State::Ready);
ChangeState(State::HasWork);
}
private:
// The underlying thread.
std::unique_ptr<std::thread> thread_;
// The task to be worked on.
std::atomic<Task*> task_;
// The condition variable and mutex guarding state changes.
std::condition_variable state_cond_;
std::mutex state_mutex_;
// The state enum tells if we're currently working, waiting for work, etc.
std::atomic<State> state_;
// pointer to the master's thread BlockingCounter object, to notify the
// master thread of when this worker switches to the 'Ready' state.
BlockingCounter* const counter_to_decrement_when_ready_;
};
class WorkersPool {
public:
WorkersPool() {}
void Execute(const std::vector<std::shared_ptr<Task>>& tasks) {
CAFFE_ENFORCE_GE(tasks.size(), 1);
// One of the tasks will be run on the current thread.
int workers_count = tasks.size() - 1;
CreateWorkers(workers_count);
TORCH_DCHECK_LE(workers_count, (int)workers_.size());
counter_to_decrement_when_ready_.Reset(workers_count);
for (const auto task : c10::irange(1, tasks.size())) {
workers_[task - 1]->StartWork(tasks[task].get());
}
// Execute the remaining workload immediately on the current thread.
auto& task = tasks.front();
task->Run();
// Wait for the workers submitted above to finish.
counter_to_decrement_when_ready_.Wait();
}
private:
// Ensures that the pool has at least the given count of workers.
// If any new worker has to be created, this function waits for it to
// be ready.
void CreateWorkers(std::size_t workers_count) {
if (workers_.size() >= workers_count) {
return;
}
counter_to_decrement_when_ready_.Reset(workers_count - workers_.size());
while (workers_.size() < workers_count) {
workers_.push_back(MakeAligned<Worker>::make(&counter_to_decrement_when_ready_));
}
counter_to_decrement_when_ready_.Wait();
}
C10_DISABLE_COPY_AND_ASSIGN(WorkersPool);
std::vector<std::unique_ptr<Worker, AlignedDeleter<Worker>>> workers_;
// The BlockingCounter used to wait for the workers.
BlockingCounter counter_to_decrement_when_ready_;
};
} // namespace caffe2
|