File: multi_wait.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites:
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (73 lines) | stat: -rw-r--r-- 1,685 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#include <torch/csrc/lazy/core/multi_wait.h>

#include <chrono>
#include <exception>

namespace torch {
namespace lazy {

void MultiWait::Done() {
  bool notify = false;
  {
    std::lock_guard<std::mutex> lock(mutex_);
    completed_count_ += 1;
    notify = completed_count_ == count_;
  }
  if (notify) {
    cv_.notify_all();
  }
}

void MultiWait::Wait() {
  std::unique_lock<std::mutex> lock(mutex_);
  cv_.wait(lock, [this] { return completed_count_ >= count_; });
  if (exptr_ != nullptr) {
    std::rethrow_exception(exptr_);
  }
}

void MultiWait::Wait(double wait_seconds) {
  std::unique_lock<std::mutex> lock(mutex_);
  if (!cv_.wait_for(lock, std::chrono::duration<double>(wait_seconds), [this] {
        return completed_count_ >= count_;
      })) {
    throw std::runtime_error("Timeout");
  }
  if (exptr_ != nullptr) {
    std::rethrow_exception(exptr_);
  }
}

void MultiWait::Reset(size_t count) {
  std::lock_guard<std::mutex> lock(mutex_);
  count_ = count;
  completed_count_ = 0;
  exptr_ = nullptr;
}

std::function<void()> MultiWait::Completer(std::function<void()> func) {
  auto completer = [this, func = std::move(func)]() { Complete(func); };
  return completer;
}

std::function<void()> MultiWait::Completer(
    std::shared_ptr<MultiWait> mwait,
    std::function<void()> func) {
  auto completer = [mwait = std::move(mwait), func = std::move(func)]() {
    mwait->Complete(func);
  };
  return completer;
}

void MultiWait::Complete(const std::function<void()>& func) {
  try {
    func();
  } catch (...) {
    std::lock_guard<std::mutex> lock(mutex_);
    exptr_ = std::current_exception();
  }
  Done();
}

} // namespace lazy
} // namespace torch