1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
|
#include <algorithm>
#include <cmath>
#include <immintrin.h>
namespace dnnlowp {
namespace internal {
float L2MinimizationKernelAVX2(
int precision,
float* bins,
int nbins,
float bin_width,
float dst_bin_width,
int start_bin) {
float norm = 0;
constexpr int VLEN = 8;
float norm_delta_default = dst_bin_width * dst_bin_width * dst_bin_width / 12;
__m256i identity_v = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
__m256 bin_width_v = _mm256_set1_ps(bin_width);
__m256 bin_width_inverse_v = _mm256_set1_ps(1.0f / bin_width);
__m256 dst_bin_width_v = _mm256_set1_ps(dst_bin_width);
__m256 dst_bin_width_inverse_v = _mm256_set1_ps(1.0f / dst_bin_width);
__m256 norm_v = _mm256_setzero_ps();
int src_bin = 0;
for (; src_bin < nbins / VLEN * VLEN; src_bin += VLEN) {
// distances from the beginning of first dst_bin to the beginning and
// end of src_bin
__m256i src_bin_v =
_mm256_add_epi32(_mm256_set1_epi32(src_bin), identity_v);
__m256 src_bin_begin_v = _mm256_mul_ps(
_mm256_cvtepi32_ps(
_mm256_sub_epi32(src_bin_v, _mm256_set1_epi32(start_bin))),
bin_width_v);
__m256 src_bin_end_v = _mm256_add_ps(src_bin_begin_v, bin_width_v);
// which dst_bins the beginning and end of src_bin belong to?
__m256i dst_bin_of_begin_v = _mm256_cvtps_epi32(_mm256_max_ps(
_mm256_setzero_ps(),
_mm256_min_ps(
_mm256_floor_ps(
_mm256_mul_ps(src_bin_begin_v, dst_bin_width_inverse_v)),
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
_mm256_set1_ps((1 << precision) - 1.0f))));
__m256i dst_bin_of_end_v = _mm256_cvtps_epi32(_mm256_max_ps(
_mm256_setzero_ps(),
_mm256_min_ps(
_mm256_floor_ps(
_mm256_mul_ps(src_bin_end_v, dst_bin_width_inverse_v)),
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
_mm256_set1_ps((1 << precision) - 1.0f))));
__m256 dst_bin_of_begin_center_v = _mm256_fmadd_ps(
_mm256_cvtepi32_ps(dst_bin_of_begin_v),
dst_bin_width_v,
_mm256_set1_ps(dst_bin_width / 2));
// Using sizeof(float) instead of 4 generates compilation error in dbg mode.
__m256 density_v = _mm256_mul_ps(
_mm256_i32gather_ps(bins, src_bin_v, 4), bin_width_inverse_v);
__m256 delta_begin_v =
_mm256_sub_ps(src_bin_begin_v, dst_bin_of_begin_center_v);
__m256 norm_delta_v = _mm256_mul_ps(
_mm256_mul_ps(
_mm256_mul_ps(delta_begin_v, delta_begin_v), delta_begin_v),
_mm256_set1_ps(-1.0f / 3));
__m256i mask_v = _mm256_cmpeq_epi32(dst_bin_of_begin_v, dst_bin_of_end_v);
__m256 delta_end0_v =
_mm256_sub_ps(src_bin_end_v, dst_bin_of_begin_center_v);
__m256 dst_bin_of_end_center_v = _mm256_fmadd_ps(
_mm256_cvtepi32_ps(dst_bin_of_end_v),
dst_bin_width_v,
_mm256_set1_ps(dst_bin_width / 2));
__m256 delta_end1_v = _mm256_sub_ps(src_bin_end_v, dst_bin_of_end_center_v);
__m256 delta_end_v = _mm256_blendv_ps(
delta_end1_v, delta_end0_v, _mm256_castsi256_ps(mask_v));
norm_delta_v = _mm256_fmadd_ps(
_mm256_mul_ps(_mm256_mul_ps(delta_end_v, delta_end_v), delta_end_v),
_mm256_set1_ps(1.0f / 3),
norm_delta_v);
norm_delta_v = _mm256_fmadd_ps(
_mm256_cvtepi32_ps(
_mm256_sub_epi32(dst_bin_of_end_v, dst_bin_of_begin_v)),
_mm256_set1_ps(norm_delta_default),
norm_delta_v);
norm_v = _mm256_fmadd_ps(density_v, norm_delta_v, norm_v);
} // src_bin loop vectorized
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
float norm_buf[VLEN];
_mm256_storeu_ps(norm_buf, norm_v);
// NOLINTNEXTLINE(modernize-loop-convert)
for (int i = 0; i < VLEN; ++i) {
norm += norm_buf[i];
}
for (; src_bin < nbins; ++src_bin) {
// distances from the beginning of first dst_bin to the beginning and
// end of src_bin
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
float src_bin_begin = (src_bin - start_bin) * bin_width;
float src_bin_end = src_bin_begin + bin_width;
// which dst_bins the beginning and end of src_bin belong to?
int dst_bin_of_begin = std::min(
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
(1 << precision) - 1.0f,
std::max(0.0f, floorf(src_bin_begin / dst_bin_width)));
int dst_bin_of_end = std::min(
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
(1 << precision) - 1.0f,
std::max(0.0f, floorf(src_bin_end / dst_bin_width)));
float dst_bin_of_begin_center =
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
dst_bin_of_begin * dst_bin_width + dst_bin_width / 2;
float density = bins[src_bin] / bin_width;
float delta_begin = src_bin_begin - dst_bin_of_begin_center;
float norm_delta = -(delta_begin * delta_begin * delta_begin) / 3;
if (dst_bin_of_begin == dst_bin_of_end) {
// if src_bin is entirely within 1 dst_bin
float delta_end = src_bin_end - dst_bin_of_begin_center;
norm_delta += (delta_end * delta_end * delta_end) / 3;
} else {
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
norm_delta += (dst_bin_of_end - dst_bin_of_begin) * norm_delta_default;
float dst_bin_of_end_center =
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
dst_bin_of_end * dst_bin_width + dst_bin_width / 2;
float delta_end = src_bin_end - dst_bin_of_end_center;
norm_delta += (delta_end * delta_end * delta_end) / 3;
}
norm += density * norm_delta;
} // src_bin loop remainder
return norm;
}
} // namespace internal
} // namespace dnnlowp
|