File: rms_norm_op.cu

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (194 lines) | stat: -rw-r--r-- 5,349 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
#include "caffe2/operators/rms_norm_op.h"

#include <vector>

#include <thrust/tuple.h>

#include "c10/cuda/CUDAMathCompat.h"
#include "caffe2/core/context_gpu.h"
#include "caffe2/utils/math.h"
#include "caffe2/utils/math/reduce.cuh"
#include "caffe2/utils/math/utils.h"

namespace caffe2 {

namespace {

template <typename T>
__global__ void RowwiseRMSCUDAKernel(int64_t N, T eps, const T* X, T* rrms) {
  __shared__ typename BlockReduce<T>::TempStorage rms_storage;
  const int64_t i = blockIdx.x;
  T sum = 0;
  for (int64_t j = threadIdx.x; j < N; j += blockDim.x) {
    const int64_t index = i * N + j;
    sum += X[index] * X[index];
  }
  sum = BlockReduce<T>(rms_storage).Sum(sum);
  if (threadIdx.x == 0) {
    rrms[i] =
        c10::cuda::compat::rsqrt(sum / static_cast<T>(N) + static_cast<T>(eps));
  }
}

template <typename T>
__global__ void RMSNormForwardCUDAKernel(
    int64_t N,
    const T* X,
    const T* gamma,
    const T* beta,
    const T* rrms,
    T* Y) {
  const int64_t i = blockIdx.x;
  for (int64_t j = threadIdx.x; j < N; j += blockDim.x) {
    const int64_t index = i * N + j;
    Y[index] = rrms[i] * X[index] * gamma[j] + beta[j];
  }
}

template <typename T>
__global__ void ComputeInternalGradientsCUDAKernel(
    int64_t N,
    const T* dY,
    const T* X,
    const T* gamma,
    const T* rrms,
    T* c2) {
  __shared__ typename BlockReduce<T>::TempStorage ds_storage;
  const int64_t i = blockIdx.x;
  T ds = 0;
  for (int64_t j = threadIdx.x; j < N; j += blockDim.x) {
    const int index = i * N + j;
    ds += dY[index] * X[index] * gamma[j];
  }
  ds = BlockReduce<T>(ds_storage).Sum(ds);
  if (threadIdx.x == 0) {
    c2[i] = -ds * math::utils::Cube<T>(rrms[i]) / static_cast<T>(N);
  }
}

template <typename T>
__global__ void RMSNormBackwardCUDAKernel(
    int64_t N,
    const T* dY,
    const T* X,
    const T* gamma,
    const T* c1,
    const T* c2,
    T* dX) {
  const int64_t i = blockIdx.x;
  for (int64_t j = threadIdx.x; j < N; j += blockDim.x) {
    const int64_t index = i * N + j;
    dX[index] = c1[i] * dY[index] * gamma[j] + c2[i] * X[index];
  }
}

// Assume the batch size will not be very large, direct implementation is the
// most efficient one.
template <typename T>
__global__ void GammaBetaBackwardCUDAKernel(
    int64_t M,
    int64_t N,
    const T* dY,
    const T* X,
    const T* rrms,
    T* dg,
    T* db) {
  const int64_t j = blockIdx.x * blockDim.x + threadIdx.x;
  if (j < N) {
    T sum1 = 0;
    T sum2 = 0;
    for (int64_t i = 0; i < M; ++i) {
      const int64_t index = i * N + j;
      sum1 += dY[index] * X[index] * rrms[i];
      sum2 += dY[index];
    }
    dg[j] = sum1;
    db[j] = sum2;
  }
}

} // namespace

template <>
template <typename T>
bool RMSNormOp<CUDAContext>::DoRunWithType() {
  const auto& X = Input(0);
  const auto& gamma = Input(1);
  const auto& beta = Input(2);
  auto* Y = Output(0, X.sizes(), at::dtype<T>());
  CAFFE_ENFORCE_GE(X.dim(), 2, "RMSNorm requires input dim >= 2.");
  const int canonical_axis = X.canonical_axis_index(axis_);
  const std::vector<int64_t> rms_dims(
      X.sizes().cbegin(), X.sizes().cbegin() + canonical_axis);
  auto* rrms = Output(1, rms_dims, at::dtype<T>());
  const int64_t M = X.size_to_dim(canonical_axis);
  const int64_t N = X.size_from_dim(canonical_axis);
  CAFFE_ENFORCE_EQ(gamma.numel(), N);
  CAFFE_ENFORCE_EQ(beta.numel(), N);

  const T* X_data = X.template data<T>();
  const T* gamma_data = gamma.template data<T>();
  const T* beta_data = beta.template data<T>();
  T* Y_data = Y->template data<T>();
  T* rrms_data = rrms->template data<T>();

  if (M > 0) {
    RowwiseRMSCUDAKernel<T>
        <<<M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
            N, static_cast<T>(eps_), X_data, rrms_data);
    C10_CUDA_KERNEL_LAUNCH_CHECK();

    RMSNormForwardCUDAKernel<T>
        <<<M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
            N, X_data, gamma_data, beta_data, rrms_data, Y_data);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
  }

  return true;
}

template <>
template <typename T>
void RMSNormGradientOp<CUDAContext>::RMSNormBackward(
    int64_t M,
    int64_t N,
    const T* dY,
    const T* X,
    const T* gamma,
    const T* rrms,
    T* dX) {
  ReinitializeTensor(
      &c2_, {M}, at::dtype<T>().device(CUDAContext::GetDeviceType()));
  T* c2_data = c2_.mutable_data<T>();
  ComputeInternalGradientsCUDAKernel<T>
      <<<M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
          N, dY, X, gamma, rrms, c2_data);
  C10_CUDA_KERNEL_LAUNCH_CHECK();

  RMSNormBackwardCUDAKernel<T>
      <<<M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
          N, dY, X, gamma, rrms, c2_data, dX);
  C10_CUDA_KERNEL_LAUNCH_CHECK();
}

template <>
template <typename T>
void RMSNormGradientOp<CUDAContext>::GammaBetaBackward(
    int64_t M,
    int64_t N,
    const T* dY,
    const T* X,
    const T* rrms,
    T* dgamma,
    T* dbeta) {
  const int64_t B = math::DivUp<int64_t>(N, CAFFE_CUDA_NUM_THREADS);
  GammaBetaBackwardCUDAKernel<T>
      <<<B, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
          M, N, dY, X, rrms, dgamma, dbeta);
  C10_CUDA_KERNEL_LAUNCH_CHECK();
}

REGISTER_CUDA_OPERATOR(RMSNorm, RMSNormOp<CUDAContext>);
REGISTER_CUDA_OPERATOR(RMSNormGradient, RMSNormGradientOp<CUDAContext>);

} // namespace caffe2