1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
|
#ifndef CAFFE2_NET_ASYNC_TASK_FUTURE_H
#define CAFFE2_NET_ASYNC_TASK_FUTURE_H
#include <atomic>
#include <condition_variable>
#include <functional>
#include <memory>
#include <mutex>
#include <string>
#include <vector>
namespace caffe2 {
// Represents the state of AsyncTask execution, that can be queried with
// IsCompleted/IsFailed. Callbacks are supported through SetCallback and
// are called upon future's completion.
class AsyncTaskFuture {
public:
AsyncTaskFuture();
// Creates a future completed when all given futures are completed
explicit AsyncTaskFuture(const std::vector<AsyncTaskFuture*>& futures);
~AsyncTaskFuture();
AsyncTaskFuture(const AsyncTaskFuture&) = delete;
AsyncTaskFuture& operator=(const AsyncTaskFuture&) = delete;
bool IsCompleted() const;
bool IsFailed() const;
std::string ErrorMessage() const;
void Wait() const;
void SetCallback(std::function<void(const AsyncTaskFuture*)> callback);
void SetCompleted(const char* err_msg = nullptr);
void ResetState();
private:
mutable std::mutex mutex_;
mutable std::condition_variable cv_completed_;
std::atomic<bool> completed_;
std::atomic<bool> failed_;
std::string err_msg_;
std::vector<std::function<void(const AsyncTaskFuture*)>> callbacks_;
struct ParentCounter {
explicit ParentCounter(int init_parent_count)
: init_parent_count_(init_parent_count),
parent_count(init_parent_count),
parent_failed(false) {}
void Reset() {
std::unique_lock<std::mutex> lock(err_mutex);
parent_count = init_parent_count_;
parent_failed = false;
err_msg = "";
}
const int init_parent_count_;
std::atomic<int> parent_count;
std::mutex err_mutex;
std::atomic<bool> parent_failed;
std::string err_msg;
};
std::unique_ptr<ParentCounter> parent_counter_;
};
} // namespace caffe2
#endif // CAFFE2_NET_ASYNC_TASK_FUTURE_H
|