1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
|
#ifndef CAFFE2_OPERATORS_COUNTER_OPS_H
#define CAFFE2_OPERATORS_COUNTER_OPS_H
#include <atomic>
#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
namespace caffe2 {
template <typename T>
class TORCH_API Counter {
public:
explicit Counter(T count) : count_(count) {}
bool countDown() {
if (count_-- > 0) {
return false;
}
return true;
}
T countUp() {
return count_++;
}
T retrieve() const {
return count_.load();
}
T checkIfDone() const {
return (count_.load() <= 0);
}
T reset(T init_count) {
return count_.exchange(init_count);
}
private:
std::atomic<T> count_;
};
// TODO(jiayq): deprecate these ops & consolidate them with IterOp/AtomicIterOp
template <typename T, class Context>
class CreateCounterOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit CreateCounterOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
init_count_(this->template GetSingleArgument<T>("init_count", 0)) {
CAFFE_ENFORCE_LE(0, init_count_, "negative init_count is not permitted.");
}
bool RunOnDevice() override {
*this->template Output<std::unique_ptr<Counter<T>>>(0) =
std::unique_ptr<Counter<T>>(new Counter<T>(init_count_));
return true;
}
private:
T init_count_ = 0;
};
template <typename T, class Context>
class ResetCounterOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit ResetCounterOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
init_count_(this->template GetSingleArgument<T>("init_count", 0)) {
CAFFE_ENFORCE_LE(0, init_count_, "negative init_count is not permitted.");
}
bool RunOnDevice() override {
auto& counterPtr = this->template Input<std::unique_ptr<Counter<T>>>(0);
auto previous = counterPtr->reset(init_count_);
if (OutputSize() == 1) {
auto* output = Output(0);
output->Resize();
*output->template mutable_data<T>() = previous;
}
return true;
}
private:
T init_count_;
};
// Will always use TensorCPU regardless the Context
template <typename T, class Context>
class CountDownOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit CountDownOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...) {}
bool RunOnDevice() override {
auto& counterPtr = this->template Input<std::unique_ptr<Counter<T>>>(0);
auto* output = Output(0);
output->Resize(std::vector<int>{});
*output->template mutable_data<bool>() = counterPtr->countDown();
return true;
}
};
// Will always use TensorCPU regardless the Context
template <typename T, class Context>
class CheckCounterDoneOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit CheckCounterDoneOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...) {}
bool RunOnDevice() override {
auto& counterPtr = this->template Input<std::unique_ptr<Counter<T>>>(0);
auto* output = Output(0);
output->Resize(std::vector<int>{});
*output->template mutable_data<bool>() = counterPtr->checkIfDone();
return true;
}
};
// Will always use TensorCPU regardless the Context
template <typename T, class Context>
class CountUpOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit CountUpOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...) {}
bool RunOnDevice() override {
auto& counterPtr = this->template Input<std::unique_ptr<Counter<T>>>(0);
auto* output = Output(0);
output->Resize(std::vector<int>{});
*output->template mutable_data<T>() = counterPtr->countUp();
return true;
}
};
// Will always use TensorCPU regardless the Context
template <typename T, class Context>
class RetrieveCountOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit RetrieveCountOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...) {}
bool RunOnDevice() override {
auto& counterPtr = this->template Input<std::unique_ptr<Counter<T>>>(0);
auto* output = Output(0);
output->Resize(std::vector<int>{});
*output->template mutable_data<T>() = counterPtr->retrieve();
return true;
}
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_COUNTER_OPS_H_
|