File: norm_minimization_avx2.cc

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (147 lines) | stat: -rw-r--r-- 6,048 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
#include <algorithm>
#include <cmath>

#include <immintrin.h>

namespace dnnlowp {

namespace internal {

float L2MinimizationKernelAVX2(
    int precision,
    float* bins,
    int nbins,
    float bin_width,
    float dst_bin_width,
    int start_bin) {
  float norm = 0;
  constexpr int VLEN = 8;
  float norm_delta_default = dst_bin_width * dst_bin_width * dst_bin_width / 12;

  __m256i identity_v = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
  __m256 bin_width_v = _mm256_set1_ps(bin_width);
  __m256 bin_width_inverse_v = _mm256_set1_ps(1.0f / bin_width);
  __m256 dst_bin_width_v = _mm256_set1_ps(dst_bin_width);
  __m256 dst_bin_width_inverse_v = _mm256_set1_ps(1.0f / dst_bin_width);
  __m256 norm_v = _mm256_setzero_ps();

  int src_bin = 0;
  for (; src_bin < nbins / VLEN * VLEN; src_bin += VLEN) {
    // distances from the beginning of first dst_bin to the beginning and
    // end of src_bin
    __m256i src_bin_v =
        _mm256_add_epi32(_mm256_set1_epi32(src_bin), identity_v);
    __m256 src_bin_begin_v = _mm256_mul_ps(
        _mm256_cvtepi32_ps(
            _mm256_sub_epi32(src_bin_v, _mm256_set1_epi32(start_bin))),
        bin_width_v);
    __m256 src_bin_end_v = _mm256_add_ps(src_bin_begin_v, bin_width_v);

    // which dst_bins the beginning and end of src_bin belong to?
    __m256i dst_bin_of_begin_v = _mm256_cvtps_epi32(_mm256_max_ps(
        _mm256_setzero_ps(),
        _mm256_min_ps(
            _mm256_floor_ps(
                _mm256_mul_ps(src_bin_begin_v, dst_bin_width_inverse_v)),
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
            _mm256_set1_ps((1 << precision) - 1.0f))));
    __m256i dst_bin_of_end_v = _mm256_cvtps_epi32(_mm256_max_ps(
        _mm256_setzero_ps(),
        _mm256_min_ps(
            _mm256_floor_ps(
                _mm256_mul_ps(src_bin_end_v, dst_bin_width_inverse_v)),
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
            _mm256_set1_ps((1 << precision) - 1.0f))));

    __m256 dst_bin_of_begin_center_v = _mm256_fmadd_ps(
        _mm256_cvtepi32_ps(dst_bin_of_begin_v),
        dst_bin_width_v,
        _mm256_set1_ps(dst_bin_width / 2));
    // Using sizeof(float) instead of 4 generates compilation error in dbg mode.
    __m256 density_v = _mm256_mul_ps(
        _mm256_i32gather_ps(bins, src_bin_v, 4), bin_width_inverse_v);
    __m256 delta_begin_v =
        _mm256_sub_ps(src_bin_begin_v, dst_bin_of_begin_center_v);
    __m256 norm_delta_v = _mm256_mul_ps(
        _mm256_mul_ps(
            _mm256_mul_ps(delta_begin_v, delta_begin_v), delta_begin_v),
        _mm256_set1_ps(-1.0f / 3));
    __m256i mask_v = _mm256_cmpeq_epi32(dst_bin_of_begin_v, dst_bin_of_end_v);

    __m256 delta_end0_v =
        _mm256_sub_ps(src_bin_end_v, dst_bin_of_begin_center_v);

    __m256 dst_bin_of_end_center_v = _mm256_fmadd_ps(
        _mm256_cvtepi32_ps(dst_bin_of_end_v),
        dst_bin_width_v,
        _mm256_set1_ps(dst_bin_width / 2));
    __m256 delta_end1_v = _mm256_sub_ps(src_bin_end_v, dst_bin_of_end_center_v);
    __m256 delta_end_v = _mm256_blendv_ps(
        delta_end1_v, delta_end0_v, _mm256_castsi256_ps(mask_v));
    norm_delta_v = _mm256_fmadd_ps(
        _mm256_mul_ps(_mm256_mul_ps(delta_end_v, delta_end_v), delta_end_v),
        _mm256_set1_ps(1.0f / 3),
        norm_delta_v);

    norm_delta_v = _mm256_fmadd_ps(
        _mm256_cvtepi32_ps(
            _mm256_sub_epi32(dst_bin_of_end_v, dst_bin_of_begin_v)),
        _mm256_set1_ps(norm_delta_default),
        norm_delta_v);

    norm_v = _mm256_fmadd_ps(density_v, norm_delta_v, norm_v);
  } // src_bin loop vectorized
  // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
  float norm_buf[VLEN];
  _mm256_storeu_ps(norm_buf, norm_v);
  // NOLINTNEXTLINE(modernize-loop-convert)
  for (int i = 0; i < VLEN; ++i) {
    norm += norm_buf[i];
  }

  for (; src_bin < nbins; ++src_bin) {
    // distances from the beginning of first dst_bin to the beginning and
    // end of src_bin
    // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
    float src_bin_begin = (src_bin - start_bin) * bin_width;
    float src_bin_end = src_bin_begin + bin_width;

    // which dst_bins the beginning and end of src_bin belong to?
    int dst_bin_of_begin = std::min(
        // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
        (1 << precision) - 1.0f,
        std::max(0.0f, floorf(src_bin_begin / dst_bin_width)));
    int dst_bin_of_end = std::min(
        // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
        (1 << precision) - 1.0f,
        std::max(0.0f, floorf(src_bin_end / dst_bin_width)));

    float dst_bin_of_begin_center =
        // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
        dst_bin_of_begin * dst_bin_width + dst_bin_width / 2;
    float density = bins[src_bin] / bin_width;
    float delta_begin = src_bin_begin - dst_bin_of_begin_center;
    float norm_delta = -(delta_begin * delta_begin * delta_begin) / 3;
    if (dst_bin_of_begin == dst_bin_of_end) {
      // if src_bin is entirely within 1 dst_bin
      float delta_end = src_bin_end - dst_bin_of_begin_center;
      norm_delta += (delta_end * delta_end * delta_end) / 3;
    } else {
      // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
      norm_delta += (dst_bin_of_end - dst_bin_of_begin) * norm_delta_default;

      float dst_bin_of_end_center =
          // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
          dst_bin_of_end * dst_bin_width + dst_bin_width / 2;
      float delta_end = src_bin_end - dst_bin_of_end_center;
      norm_delta += (delta_end * delta_end * delta_end) / 3;
    }
    norm += density * norm_delta;
  } // src_bin loop remainder

  return norm;
}

} // namespace internal

} // namespace dnnlowp