File: fp32_momentum_sgd_op.cu

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (141 lines) | stat: -rw-r--r-- 3,945 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#include "caffe2/core/common_gpu.h"
#include "caffe2/core/context_gpu.h"

#include "caffe2/sgd/fp32_momentum_sgd_op.h"

namespace caffe2 {
namespace {

__global__ void FP32MomentumSGDKernel(
    int N,
    const float2* g,
    const float2* m,
    float2* ng,
    float2* nm,
    const float* lr,
    const float mom,
    bool nesterov,
    const float wd,
    float2* param) {
#if __CUDA_ARCH__ >= 530
  const float lr2 = lr[0];
  const float LR = lr2;
  const float momentum = mom;
  const float weight_decay = wd;

  int n = N / 2;
  if (!nesterov) {
    CUDA_1D_KERNEL_LOOP(i, n) {
      ng[i].x = __fmaf_rn(weight_decay, param[i].x, g[i].x);
      ng[i].y = __fmaf_rn(weight_decay, param[i].y, g[i].y);

      float2 mi_float2 = m[i];
      float2 adjusted_gradient_float2;
      adjusted_gradient_float2.x =
          __fmaf_rn(LR, ng[i].x, __fmul_rn(momentum, mi_float2.x));
      adjusted_gradient_float2.y =
          __fmaf_rn(LR, ng[i].y, __fmul_rn(momentum, mi_float2.y));

      nm[i] = adjusted_gradient_float2;
      ng[i] = adjusted_gradient_float2;

      if (param) {
        param[i].x = __fsub_rn(param[i].x, adjusted_gradient_float2.x);
        param[i].y = __fsub_rn(param[i].y, adjusted_gradient_float2.y);
      }
    }
  } else {
    CUDA_1D_KERNEL_LOOP(i, n) {
      // computing the term (grad + lambda*weight)
      // might need to change in case of denormalization

      ng[i].x = __fmaf_rn(weight_decay, param[i].x, g[i].x);
      ng[i].y = __fmaf_rn(weight_decay, param[i].y, g[i].y);

      const float2 mi_float2 = m[i];
      float2 mom_mi_float2;
      mom_mi_float2.x = __fmul_rn(momentum, mi_float2.x);
      mom_mi_float2.y = __fmul_rn(momentum, mi_float2.y);
      float2 mi_new_float2;
      mi_new_float2.x = __fmaf_rn(LR, ng[i].x, mom_mi_float2.x);
      mi_new_float2.y = __fmaf_rn(LR, ng[i].y, mom_mi_float2.y);

      nm[i] = mi_new_float2;
      ng[i].x = __fsub_rn(
          __fmaf_rn(mi_new_float2.x, momentum, mi_new_float2.x),
          mom_mi_float2.x);
      ng[i].y = __fsub_rn(
          __fmaf_rn(mi_new_float2.y, momentum, mi_new_float2.y),
          mom_mi_float2.y);

      if (param) {
        param[i].x = __fsub_rn(param[i].x, ng[i].x);
        param[i].y = __fsub_rn(param[i].y, ng[i].y);
      }
    }
  }
#else
   CUDA_KERNEL_ASSERT(false);
#endif // CAFFE_HAS_CUDA_FP16
}
}

template <>
void fp32_momentum_sgd_update<CUDAContext>(
    int N,
    const float* g,
    const float* m,
    float* ng,
    float* nm,
    const float* lr,
    float momentum,
    bool nesterov,
    float weight_decay,
    float* param,
    CUDAContext* context) {
  FP32MomentumSGDKernel<<<
      CAFFE_GET_BLOCKS(N / 2),
      CAFFE_CUDA_NUM_THREADS,
      0,
      context->cuda_stream()>>>(
      N,
      reinterpret_cast<const float2*>(g),
      reinterpret_cast<const float2*>(m),
      reinterpret_cast<float2*>(ng),
      reinterpret_cast<float2*>(nm),
      lr,
      momentum,
      nesterov,
      weight_decay,
      reinterpret_cast<float2*>(param));
  C10_CUDA_KERNEL_LAUNCH_CHECK();
  // not setting N to N/2
  // TODO_ check float performance vs float2
}

REGISTER_CUDA_OPERATOR(
    FP32MomentumSGDUpdate,
    FP32MomentumSGDUpdateOp<float, CUDAContext>);
OPERATOR_SCHEMA(FP32MomentumSGDUpdate)
    .NumInputs(4)
    .NumOutputs(3)
    .AllowInplace({{0, 0}, {1, 1}, {3, 2}})
    .TensorInferenceFunction([](const OperatorDef& /* unused */,
                                const vector<TensorShape>& in) {
      vector<TensorShape> out(3);
      out[0] = in[0];
      out[1] = in[1];
      out[2] = in[3];
      return out;
    })
    .SetDoc(R"DOC(

Computes the momentum SGD update similarly to the MomentumSGDUpdateOp,
however this op also performs the weight decay update at the same time, thus
making it more efficient.

This op is also functionally equivalent to the FP16MomentumSGDUpdateOp, however
it expects FP32 data and performs its updates in FP32 precision.

)DOC");
}