1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
|
#ifndef CAFFE2_OPERATORS_PREFETCH_OP_H_
#define CAFFE2_OPERATORS_PREFETCH_OP_H_
#include <condition_variable>
#include <mutex>
#include <thread> // NOLINT
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
namespace caffe2 {
// PrefetchOperator is an operator that prefetches the next batch. It should
// almost always be used to read things from disk, so I am setting the input to
// zero blobs.
//
// For any operator that is derived from PrefetchOperator, it should
// explicitly call the Finalize() function in its destructor, so that the
// prefetching thread is properly destructed.
// Note: We inherit from OperatorBase since we control the
// synchronization properties of this operator ourselves (we inform
// the waiting producer after we synchronize). This is a special-case
// - you should generally inherit from Operator<Context> directly.
template <class Context>
class PrefetchOperator : public OperatorBase {
public:
PrefetchOperator(const OperatorDef& operator_def, Workspace* ws)
: OperatorBase(operator_def, ws),
context_(operator_def.device_option()),
prefetched_(false),
prefetch_success_(true),
finalize_(false),
no_prefetch_(GetSingleArgument<bool>("no_prefetch", false)) {
context_.SwitchToDevice();
}
virtual ~PrefetchOperator() noexcept {
CHECK(finalize_ || !prefetch_thread_.get())
<< "YOU MADE A PROGRAMING ERROR: derived class of PrefetchOperator "
"should call Finalize() in its destructor so the prefetching "
"thread is joined. ";
}
void Finalize() {
if (prefetch_thread_.get()) {
{
std::unique_lock<std::mutex> lock(prefetch_access_mutex_);
while (!prefetched_)
consumer_.wait(lock);
finalize_ = true;
prefetched_ = false;
}
producer_.notify_one();
prefetch_thread_->join();
prefetch_thread_.reset();
} else {
// If we never initialized the prefetch thread, just set
// finalize anyway.
finalize_ = true;
}
}
bool Run(int /* unused */ /*stream_id*/) override {
if (no_prefetch_) {
context_.SwitchToDevice();
bool result = Prefetch() && CopyPrefetched();
context_.FinishDeviceComputation();
return result;
}
// Note(jiayq): We only start the prefetch_thread at the Run() function
// instead of in the constructor, because the prefetch_thread needs to start
// after all derived classes' constructors finish.
if (!prefetch_thread_) {
prefetch_thread_.reset(
new std::thread([this] { this->PrefetchWorker(); }));
}
context_.SwitchToDevice();
std::unique_lock<std::mutex> lock(prefetch_access_mutex_);
while (!prefetched_)
consumer_.wait(lock);
if (!prefetch_success_) {
LOG(ERROR) << "Prefetching failed.";
return false;
}
if (!CopyPrefetched()) {
LOG(ERROR) << "Error when copying prefetched data.";
return false;
}
prefetched_ = false;
context_.FinishDeviceComputation();
producer_.notify_one();
return true;
}
void PrefetchWorker() {
context_.SwitchToDevice();
std::unique_lock<std::mutex> lock(prefetch_access_mutex_);
while (prefetched_)
producer_.wait(lock);
while (!finalize_) {
// We will need to run a FinishDeviceComputation() call because the
// prefetcher thread and the main thread are potentially using different
// streams (like on GPU).
try {
prefetch_success_ = Prefetch();
context_.FinishDeviceComputation();
} catch (const std::exception& e) {
// TODO: propagate exception_ptr to the caller side
LOG(ERROR) << "Prefetching error " << e.what();
prefetch_success_ = false;
}
prefetched_ = true;
consumer_.notify_one();
while (prefetched_)
producer_.wait(lock);
}
}
// You will need to implement this instead of the Run function.
virtual bool Prefetch() = 0;
virtual bool CopyPrefetched() = 0;
protected:
Context context_;
std::mutex prefetch_access_mutex_;
std::condition_variable producer_, consumer_;
// prefetched_ is used to tell the operator that it is done.
std::atomic<bool> prefetched_;
// prefetch_success_ is used to see if prefetching failed or not.
std::atomic<bool> prefetch_success_;
// finalize_ is used to tell the prefetcher to quit.
std::atomic<bool> finalize_;
unique_ptr<std::thread> prefetch_thread_;
// Whether to do prefetching or run this as a normal operator
const bool no_prefetch_;
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_PREFETCH_OP_H_
|