1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
|
// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/language_detection/core/quantization_utils.h"
#include <algorithm>
#include <cmath>
#include "base/check_op.h"
namespace language_detection {
namespace {
void Nudge(const float min,
const float max,
const unsigned int quant_min,
const unsigned int quant_max,
float* nudged_min,
float* nudged_max,
float* scale) {
const float quant_min_float = static_cast<float>(quant_min);
const float quant_max_float = static_cast<float>(quant_max);
*scale = (max - min) / (quant_max_float - quant_min_float);
const float zero_point_from_min = quant_min_float - min / *scale;
uint16_t nudged_zero_point;
if (zero_point_from_min < quant_min_float) {
nudged_zero_point = static_cast<uint16_t>(quant_min);
} else if (zero_point_from_min > quant_max_float) {
nudged_zero_point = static_cast<uint16_t>(quant_max);
} else {
nudged_zero_point = static_cast<uint16_t>(std::round(zero_point_from_min));
}
*nudged_min = (quant_min_float - nudged_zero_point) * (*scale);
*nudged_max = (quant_max_float - nudged_zero_point) * (*scale);
}
} // namespace
QuantizationParams GetQuantizationParams(float min_val,
float max_val,
int num_bits) {
DCHECK_GT(num_bits, 1);
DCHECK_LT(num_bits, 32);
QuantizationParams params;
float quant_min = 0.f;
params.quant_max_uint32 = (1 << num_bits) - 1;
float quant_max = static_cast<float>(params.quant_max_uint32);
Nudge(min_val, max_val, quant_min, quant_max, ¶ms.nudged_min,
¶ms.nudged_max, ¶ms.nudged_scale);
return params;
}
uint32_t FloatToQuantized(float x, float min_val, float max_val, int num_bits) {
QuantizationParams params = GetQuantizationParams(min_val, max_val, num_bits);
const float inv_nudged_scale = 1.0f / params.nudged_scale;
float clamped = std::clamp(x, params.nudged_min, params.nudged_max);
float clamped_shifted = clamped - params.nudged_min;
return std::min(
static_cast<uint32_t>(clamped_shifted * inv_nudged_scale + 0.5f),
params.quant_max_uint32);
}
float QuantizedToFloat(uint32_t x, float min_val, float max_val, int num_bits) {
const QuantizationParams params =
GetQuantizationParams(min_val, max_val, num_bits);
return QuantizedToFloatWithQuantParams(x, params);
}
} // namespace language_detection
|