File: adagrad_avx2.cc

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (125 lines) | stat: -rw-r--r-- 3,885 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#include "caffe2/perfkernels/adagrad.h"
#include "caffe2/perfkernels/cvtsh_ss_bugfix.h"

#include <emmintrin.h>
#include <immintrin.h>

namespace caffe2 {

// version without prefetching
void adagrad_update__avx2_fma(
    int N,
    const float* w,
    const float* g,
    const float* h,
    float* nw,
    float* nh,
    float epsilon,
    float decay,
    float lr,
    float weight_decay = 0.f) {
  constexpr int kSize = 8;
  auto i = 0;
  for (; i + kSize <= N; i += kSize) {
    __m256 gi = _mm256_loadu_ps(g + i);
    __m256 hi = _mm256_loadu_ps(h + i);
    __m256 wi = _mm256_loadu_ps(w + i);
    gi = _mm256_fmadd_ps(_mm256_set1_ps(weight_decay), wi, gi);

    __m256 nhi = _mm256_add_ps(
        _mm256_mul_ps(_mm256_set1_ps(decay), hi), _mm256_mul_ps(gi, gi));
    _mm256_storeu_ps(nh + i, nhi);
    __m256 vtmp = _mm256_div_ps(
        _mm256_mul_ps(_mm256_set1_ps(lr), gi),
        _mm256_add_ps(_mm256_sqrt_ps(nhi), _mm256_set1_ps(epsilon)));
    _mm256_storeu_ps(nw + i, _mm256_add_ps(wi, vtmp));
  }

  for (; i < N; ++i) {
    float gi = std::fma(weight_decay, w[i], g[i]);
    float hi = nh[i] = decay * h[i] + gi * gi;
    nw[i] = w[i] + lr * gi / (std::sqrt(hi) + epsilon);
  }
}

void adagrad_update_prefetch__avx2_fma(
    int N,
    const float* w,
    const float* w_n, // prefetch ptr

    const float* g,

    const float* h,
    const float* h_n, // prefetch ptr

    float* nw,
    float* nw_n, // prefetch ptr

    float* nh,
    float* nh_n, // prefetch ptr

    float epsilon,
    float lr,
    float weight_decay = 0.f) {
  internal::adagrad_update_prefetch_inlined(
      N, w, w_n, g, h, h_n, nw, nw_n, nh, nh_n, epsilon, lr, weight_decay);
}

// Compute adagrad sparse, assumes embedding and momentum are at::Half
void adagrad_fp16_update_prefetch__avx2_fma(
    int N,
    const at::Half* w,
    const at::Half* w_n, // prefetch ptr
    const float* g,
    const at::Half* h,
    const at::Half* h_n, // prefetch ptr
    at::Half* nw,
    at::Half* nw_n, // prefetch ptr
    at::Half* nh,
    at::Half* nh_n, // prefetch ptr
    float epsilon,
    float lr,
    float weight_decay = 0.f) {
  constexpr int kSize = 8;
  auto i = 0;
  for (; i + kSize <= N; i += kSize) {
    _mm_prefetch(reinterpret_cast<const char*>(&w_n[i]), _MM_HINT_T0);
    _mm_prefetch(reinterpret_cast<const char*>(&h_n[i]), _MM_HINT_T0);
    _mm_prefetch(reinterpret_cast<const char*>(&nw_n[i]), _MM_HINT_T0);
    _mm_prefetch(reinterpret_cast<const char*>(&nh_n[i]), _MM_HINT_T0);

    // only convert momentum and embedding, gradient is fp32
    __m256 gi = _mm256_loadu_ps(g + i);
    __m128i hhi = _mm_loadu_si128(reinterpret_cast<const __m128i*>(h + i));
    __m256 hi = _mm256_cvtph_ps(hhi);
    __m128i whi = _mm_loadu_si128(reinterpret_cast<const __m128i*>(w + i));
    __m256 wi = _mm256_cvtph_ps(whi);
    gi = _mm256_fmadd_ps(_mm256_set1_ps(weight_decay), wi, gi);

    __m256 nhi = _mm256_add_ps(hi, _mm256_mul_ps(gi, gi));
    __m128i nhhi = _mm256_cvtps_ph(nhi, 0);
    _mm_storeu_si128(reinterpret_cast<__m128i*>(nh + i), nhhi);

    __m256 vtmp = _mm256_div_ps(
        _mm256_mul_ps(_mm256_set1_ps(lr), gi),
        _mm256_add_ps(_mm256_sqrt_ps(nhi), _mm256_set1_ps(epsilon)));
    __m256 nwi = _mm256_add_ps(wi, vtmp);
    __m128i nhwi = _mm256_cvtps_ph(nwi, 0);
    _mm_storeu_si128(reinterpret_cast<__m128i*>(nw + i), nhwi);
  }

  for (; i < N; ++i) {
    float gi = std::fma(
        weight_decay,
        _cvtsh_ss(reinterpret_cast<const unsigned short*>(w)[i]),
        g[i]);
    float nhi =
        _cvtsh_ss(reinterpret_cast<const unsigned short*>(h)[i]) + gi * gi;
    reinterpret_cast<unsigned short*>(nh)[i] = _cvtss_sh(nhi, 0);
    float nwi = _cvtsh_ss(reinterpret_cast<const unsigned short*>(w)[i]) +
        lr * gi / (std::sqrt(nhi) + epsilon);
    reinterpret_cast<unsigned short*>(nw)[i] = _cvtss_sh(nwi, 0);
  }
}

} // namespace caffe2