1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
|
#include "kl_minimization.h"
#include "caffe2/core/logging.h"
using namespace std;
namespace dnnlowp {
TensorQuantizationParams KLDivergenceMinimization::ChooseQuantizationParams(
const Histogram& hist,
bool preserve_sparsity,
int precision) {
const vector<uint64_t> bins = *hist.GetHistogram();
int nbins = bins.size();
int dst_nbins = 1 << precision;
float min = hist.Min(), max = hist.Max();
assert(min <= 0.f);
assert(max >= 0.f);
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
double bin_width = (max - min) / nbins;
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
int zero_bin = round(-min / bin_width);
double total_sum = 0;
for (int i = 0; i < nbins; ++i) {
total_sum += bins[i];
}
vector<pair<int, double>> best_start_bins(nbins + 1);
// Look at mapping [start_bin, start_bin + nbins_selected) to
// [0, 1 << precision) for every (start_bin, nbins_selected) combination and
// pick the one with smallest KL divergence
#ifdef _OPENMP
#pragma omp parallel for
#endif
for (int nbins_selected = 1; nbins_selected <= nbins; ++nbins_selected) {
// if (nbins_selected % dst_nbins != 0) continue;
double kl_min = numeric_limits<double>::max();
int best_start_bin = 0;
int start_bin_begin = 0, start_bin_end = nbins - nbins_selected + 1;
if (preserve_sparsity) {
if (min == 0) {
start_bin_begin = 0;
start_bin_end = 1;
} else {
start_bin_begin = zero_bin - nbins_selected / 2;
start_bin_end = start_bin_begin + 1;
}
}
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int start_bin;
for (start_bin = start_bin_begin; start_bin < start_bin_end; ++start_bin) {
double kl = 0;
// sum outliers
uint64_t left_outliers = 0;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int src_bin;
for (src_bin = 0; src_bin < start_bin; ++src_bin) {
left_outliers += bins[src_bin];
}
uint64_t right_outliers = 0;
for (src_bin = start_bin + nbins_selected; src_bin < nbins; ++src_bin) {
right_outliers += bins[src_bin];
}
// each destination bin corresponds to a quantized value
for (int dst_bin = 0; dst_bin < dst_nbins; ++dst_bin) {
double non_zero_length = 0;
double sum = 0;
double src_bin_begin_not_rounded =
start_bin + (double)dst_bin * nbins_selected / dst_nbins;
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
int src_bin_begin = src_bin_begin_not_rounded;
double src_bin_end_not_rounded =
start_bin + (double)(dst_bin + 1) * nbins_selected / dst_nbins;
int src_bin_end = ceil(src_bin_end_not_rounded);
for (src_bin = src_bin_begin; src_bin < src_bin_end; ++src_bin) {
if (src_bin >= 0 && src_bin < nbins) {
double bin = bins[src_bin];
double fraction = 1;
if (src_bin == src_bin_begin && src_bin == src_bin_end - 1) {
fraction = src_bin_end_not_rounded - src_bin_begin_not_rounded;
} else if (src_bin == src_bin_begin) {
fraction = (src_bin_begin + 1) - src_bin_begin_not_rounded;
assert(fraction >= 0);
} else if (src_bin == src_bin_end - 1) {
fraction = src_bin_end_not_rounded - (src_bin_end - 1);
assert(fraction >= 0);
}
bin *= fraction;
sum += bin;
if (src_bin == std::max(start_bin, 0)) {
bin += left_outliers;
}
if (src_bin ==
std::min(start_bin + nbins_selected - 1, nbins - 1)) {
bin += right_outliers;
}
if (bin > 0) {
non_zero_length += fraction;
}
}
} // src_bin
for (src_bin = src_bin_begin; src_bin < src_bin_end; ++src_bin) {
if (src_bin >= 0 && src_bin < nbins) {
uint64_t bin = bins[src_bin];
double fraction = 1;
if (src_bin == src_bin_begin && src_bin == src_bin_end - 1) {
fraction = src_bin_end_not_rounded - src_bin_begin_not_rounded;
} else if (src_bin == src_bin_begin) {
fraction = (src_bin_begin + 1) - src_bin_begin_not_rounded;
} else if (src_bin == src_bin_end - 1) {
fraction = src_bin_end_not_rounded - (src_bin_end - 1);
}
if (src_bin == std::max(start_bin, 0)) {
bin += left_outliers;
}
if (src_bin ==
std::min(start_bin + nbins_selected - 1, nbins - 1)) {
bin += right_outliers;
}
bin *= fraction;
if (bin > 0) {
double p = (double)bin / total_sum;
double q = sum * fraction / non_zero_length / total_sum;
kl += p * log(p / q);
}
}
} // src_bin
} // dst_bin
assert(kl >= 0);
if (kl < kl_min) {
kl_min = kl;
best_start_bin = start_bin;
}
} // for each start_bin
best_start_bins[nbins_selected] = {best_start_bin, kl_min};
} // for each nbins_selected
double kl_min = numeric_limits<double>::max();
int best_nbins_selected = dst_nbins, best_start_bin = 0;
for (int nbins_selected = 1; nbins_selected <= nbins; ++nbins_selected) {
double kl = best_start_bins[nbins_selected].second;
if (kl < kl_min) {
kl_min = kl;
best_start_bin = best_start_bins[nbins_selected].first;
best_nbins_selected = nbins_selected;
}
}
double selected_sum = 0;
int i_begin = std::max(0, best_start_bin);
int i_end = std::min(nbins, best_start_bin + best_nbins_selected);
for (int i = i_begin; i < i_end; ++i) {
selected_sum += bins[i];
}
VLOG(2) << "best quantization range covers "
<< (double)selected_sum / total_sum * 100 << " %%";
VLOG(2) << "best start_bin " << best_start_bin << " nbins_selected "
<< best_nbins_selected;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
min = hist.Min() + bin_width * (best_start_bin + 0.5);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
max = hist.Min() + bin_width * (best_start_bin + best_nbins_selected + 0.5);
QuantizationFactory* qfactory = QuantizationFactory::GetDefaultInstance();
return qfactory->ChooseQuantizationParams(min, max);
} // ChooseQuantizationParams
} // namespace dnnlowp
|