1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
|
#include <torch/csrc/lazy/core/multi_wait.h>
#include <chrono>
#include <exception>
namespace torch {
namespace lazy {
void MultiWait::Done() {
bool notify = false;
{
std::lock_guard<std::mutex> lock(mutex_);
completed_count_ += 1;
notify = completed_count_ == count_;
}
if (notify) {
cv_.notify_all();
}
}
void MultiWait::Wait() {
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [this] { return completed_count_ >= count_; });
if (exptr_ != nullptr) {
std::rethrow_exception(exptr_);
}
}
void MultiWait::Wait(double wait_seconds) {
std::unique_lock<std::mutex> lock(mutex_);
if (!cv_.wait_for(lock, std::chrono::duration<double>(wait_seconds), [this] {
return completed_count_ >= count_;
})) {
throw std::runtime_error("Timeout");
}
if (exptr_ != nullptr) {
std::rethrow_exception(exptr_);
}
}
void MultiWait::Reset(size_t count) {
std::lock_guard<std::mutex> lock(mutex_);
count_ = count;
completed_count_ = 0;
exptr_ = nullptr;
}
std::function<void()> MultiWait::Completer(std::function<void()> func) {
auto completer = [this, func = std::move(func)]() { Complete(func); };
return completer;
}
std::function<void()> MultiWait::Completer(
std::shared_ptr<MultiWait> mwait,
std::function<void()> func) {
auto completer = [mwait = std::move(mwait), func = std::move(func)]() {
mwait->Complete(func);
};
return completer;
}
void MultiWait::Complete(const std::function<void()>& func) {
try {
func();
} catch (...) {
std::lock_guard<std::mutex> lock(mutex_);
exptr_ = std::current_exception();
}
Done();
}
} // namespace lazy
} // namespace torch
|