1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
|
#pragma once
#include "caffe2/core/operator.h"
#include "c10/util/irange.h"
#include <cmath>
#include <limits>
namespace caffe2 {
template <typename Context>
class QuantileOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
QuantileOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
quantile_(this->template GetSingleArgument<float>("quantile", -1.0)),
abs_(this->template GetSingleArgument<bool>("abs", true)),
tol_(this->template GetSingleArgument<float>("tol", 1e-3)) {
CAFFE_ENFORCE_GE(
quantile_,
0,
"input quantile should be ",
"no less than 0, got ",
quantile_);
CAFFE_ENFORCE_GE(
1.0f,
quantile_,
"input quantile should be ",
"no larger than 1, got ",
quantile_);
CAFFE_ENFORCE_GT(
tol_, 0, "tolerance should be ", "no less than 0, got ", tol_);
}
bool RunOnDevice() override {
return DispatchHelper<TensorTypes<float, double>>::call(this, Input(0));
}
template <typename T>
bool DoRunWithType() {
Output(QUANTILE_VAL)->Resize(1);
auto* quantile_val = Output(QUANTILE_VAL)->template mutable_data<T>();
auto& input_zero = Input(0);
int64_t numel = input_zero.numel();
for (const auto i : c10::irange(1, InputSize())) {
CAFFE_ENFORCE_EQ(
Input(i).dtype(),
input_zero.dtype(),
"All inputs must have the same type, expected: ",
input_zero.dtype().name(),
" but got: ",
Input(i).dtype().name(),
" for input: ",
i);
numel += Input(i).numel();
}
CAFFE_ENFORCE_GT(
numel,
0,
"number of total element in input tensor should be ",
"larger than 0, got ",
numel);
// the expected number of elements lessEq to the target value
const int64_t target_cnt =
static_cast<int64_t>(std::ceil(numel * quantile_));
T hi = 0.0;
T lo = 0.0;
GetRangeFromInputs(&lo, &hi);
if (target_cnt == 0) {
// lowest possible value
quantile_val[0] = lo;
return true;
}
if (target_cnt == numel) {
// highest possible value
quantile_val[0] = hi;
return true;
}
int64_t lo_cnt = CountLowerEq(lo);
if (lo_cnt >= target_cnt) {
// the target is one of the lowest value
quantile_val[0] = lo;
return true;
}
while (std::abs(hi - lo) > tol_ * (std::abs(hi) + std::abs(lo))) {
// keep hi_cnt > target_idx and lo_cnt <= target_idx
const T mid = lo + (hi - lo) / 2.0;
const int64_t mid_cnt = CountLowerEq(mid);
if (mid_cnt > target_cnt) {
CAFFE_ENFORCE_NE(
hi, mid, "numeric precision at limit, unable to continue bisect");
hi = mid;
} else if (mid_cnt < target_cnt) {
CAFFE_ENFORCE_NE(
lo, mid, "numeric precision at limit, unable to continue bisect");
lo = mid;
} else {
// mid_cnt == target_cnt
quantile_val[0] = mid;
return true;
}
}
quantile_val[0] = hi;
return true;
}
protected:
float quantile_;
bool abs_;
float tol_;
OUTPUT_TAGS(QUANTILE_VAL);
template <typename T>
void GetRangeFromInputs(T* lo, T* hi) {
*hi = std::numeric_limits<T>::lowest();
*lo = std::numeric_limits<T>::max();
for (const auto i : c10::irange(InputSize())) {
const auto* input = Input(i).template data<T>();
for (const auto j : c10::irange(Input(i).numel())) {
const T val = abs_ ? std::abs(input[j]) : input[j];
if (*hi < val) {
*hi = val;
}
if (*lo > val) {
*lo = val;
}
}
}
}
template <typename T>
int64_t CountLowerEq(const T& thd) {
int64_t count = 0;
for (const auto i : c10::irange(InputSize())) {
const auto* input = Input(i).template data<T>();
for (const auto j : c10::irange(Input(i).numel())) {
const T val = abs_ ? std::abs(input[j]) : input[j];
if (val <= thd) {
count++;
}
}
}
return count;
}
};
} // namespace caffe2
|