File: future.h

package info (click to toggle)
pytorch 1.7.1-7
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 80,340 kB
  • sloc: cpp: 670,830; python: 343,991; ansic: 67,845; asm: 5,503; sh: 2,924; java: 2,888; xml: 266; makefile: 244; ruby: 148; yacc: 144; objc: 51; lex: 44
file content (187 lines) | stat: -rw-r--r-- 5,222 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
#pragma once

#include <ATen/core/ivalue.h>

namespace torch {

namespace utils {

// FutureError inherits from std::exception, it can return const char* or
// std::string error message
class TORCH_API FutureError final : public std::exception {
 public:
  FutureError(std::string errorMsg) : errorMsg_(std::move(errorMsg)) {}

  FutureError() = default;

  const char* what() const noexcept override {
    return errorMsg_.c_str();
  }

 private:
  std::string errorMsg_;
};

// This class holds a value of type T that will be ready in the future.
// Most implementation is copied from FutureMessage and
// c10::ivalue::Future
template <typename T>
class TORCH_PYTHON_API Future final {
 public:
  Future() = default;

  Future(T value) : completed_(true), value_(std::move(value)) {}

  const T& wait() {
    std::unique_lock<std::mutex> lock(mutex_);
    finished_cv_.wait(lock, [this] { return completed_.load(); });
    if (error_) {
      throw *error_;
    }
    return value_;
  }

  const T& waitNoThrow() {
    std::unique_lock<std::mutex> lock(mutex_);
    finished_cv_.wait(lock, [this] { return completed_.load(); });
    return value_;
  }

  // These constValue/moveValue accessors should only be used if
  // we know that the future is completed() with no error.
  const T& constValue() const {
    std::unique_lock<std::mutex> lock(mutex_);
    AT_ASSERT(completed_);
    return value_;
  }

  T&& moveValue() && {
    std::unique_lock<std::mutex> lock(mutex_);
    AT_ASSERT(completed_);
    return std::move(value_);
  }

  // Marks the future complete only if it hasn't been marked completed already.
  void markCompletedIfNeeded(T value) {
    std::unique_lock<std::mutex> lock(mutex_);
    if (completed_) {
      LOG(INFO) << "markCompletedIfNeeded skipped since future is already complete.";
      return;
    } else {
      markCompletedInternal(std::move(value), lock);
    }
  }

  void markCompleted(T value) {
    std::unique_lock<std::mutex> lock(mutex_);
    markCompletedInternal(std::move(value), lock);
  }

  // Sets error only if the future hasn't been marked completed already.
  // Useful in avoiding races where multiple threads try to setError
  // on a future.
  void setErrorIfNeeded(FutureError error) {
    std::unique_lock<std::mutex> lock(mutex_);
    if (completed_) {
      // This should be rare and shouldn't cause log spew. Its important to
      // log errors and thats why we have this log here.
      LOG (INFO) << "Skipping setting following error on the Future since " <<
        "it is already marked completed (this is not neccessarily an error): "
        << error.what();
      return;
    } else {
      setErrorInternal(std::move(error), lock);
    }
  }

  void setErrorIfNeeded(std::string errorMsg) {
    setErrorIfNeeded(FutureError(std::move(errorMsg)));
  }

  void setError(FutureError error) {
    std::unique_lock<std::mutex> lock(mutex_);
    setErrorInternal(std::move(error), lock);
  }

  void setError(std::string errorMsg) {
    setError(FutureError(std::move(errorMsg)));
  }

  bool completed() const {
    return completed_;
  }

  bool hasError() const {
    std::unique_lock<std::mutex> lock(mutex_);
    return error_ ? true : false;
  }

  c10::optional<FutureError> error() const {
    std::unique_lock<std::mutex> lock(mutex_);
    return error_;
  }

  // If completed() the callback will be invoked in-place.
  void addCallback(std::function<void(void)> cb) {
    std::unique_lock<std::mutex> lock(mutex_);
    if (completed_) {
      lock.unlock();
      cb();
      return;
    }
    callbacks_.emplace_back(std::move(cb));
  }

  void addCallback(std::function<void(const Future<T>& future)> cb) {
    addCallback([this, cb = std::move(cb)]() { cb(*this); });
  }

 private:
  void setErrorInternal(
      FutureError error,
      std::unique_lock<std::mutex>& lock) {
    TORCH_CHECK(!completed_);
    error_ = std::move(error);
    completed_ = true;

    // Move callbacks to a vector on the stack so we can access it without
    // holding a lock
    std::vector<std::function<void(void)>> cbs(std::move(callbacks_));
    lock.unlock();
    finished_cv_.notify_all();
    // There is no need to protect callbacks_ with the lock.
    // Once completed_ is set to true, no one can add new callback to the
    // list. pass value_, error_ for callback to easily check state.
    for (auto& callback : cbs) {
      callback();
    }
  }

  void markCompletedInternal(T value,
      std::unique_lock<std::mutex>& lock) {
    TORCH_CHECK(!completed_);
    value_ = std::move(value);
    completed_ = true;

    // Move callbacks to a vector on the stack so we can access it without
    // holding a lock
    std::vector<std::function<void(void)>> cbs;
    cbs.swap(callbacks_);
    lock.unlock();
    finished_cv_.notify_all();
    // There is no need to protect callbacks_ with the lock.
    for (auto& callback : cbs) {
      callback();
    }
  }

  mutable std::mutex mutex_;
  std::atomic_bool completed_{false}; // is this future complete
  std::condition_variable finished_cv_;
  std::vector<std::function<void(void)>> callbacks_;
  T value_;
  c10::optional<FutureError> error_;
};

} // namespace utils
} // namespace torch