File: local_response_normalization_op.cu

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (345 lines) | stat: -rw-r--r-- 11,756 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/local_response_normalization_op.h"

namespace caffe2 {

namespace {
template <typename T>
__global__ void LRNFillScaleNCHW(const int nthreads, const T* in,
    const int channels, const int height,
    const int width, const int size, const T alpha_over_size,
    const T bias, T* scale) {
  CUDA_1D_KERNEL_LOOP(index, nthreads) {
    // find out the local offset
    const int w = index % width;
    const int h = (index / width) % height;
    const int n = index / width / height;
    const int offset = (n * channels * height + h) * width + w;
    const int step = height * width;
    in += offset;
    scale += offset;
    int head = 0;
    const int pre_pad = (size - 1) / 2;
    const int post_pad = size - pre_pad - 1;
    T accum_scale = 0;
    // fill the scale at [n, :, h, w]
    // accumulate values
    while (head < post_pad) {
      accum_scale += in[head * step] * in[head * step];
      ++head;
    }
    // until we reach size, nothing needs to be subtracted
    while (head < size) {
      accum_scale += in[head * step] * in[head * step];
      scale[(head - post_pad) * step] = bias + accum_scale * alpha_over_size;
      ++head;
    }
    // both add and subtract
    while (head < channels) {
      accum_scale += in[head * step] * in[head * step];
      accum_scale -= in[(head - size) * step] * in[(head - size) * step];
      scale[(head - post_pad) * step] = bias + accum_scale * alpha_over_size;
      ++head;
    }
    // subtract only
    while (head < channels + post_pad) {
      accum_scale -= in[(head - size) * step] * in[(head - size) * step];
      scale[(head - post_pad) * step] = bias + accum_scale * alpha_over_size;
      ++head;
    }
    // recover the pointers for the next loop.
    in -= offset;
    scale -= offset;
  }
}

template <typename T>
__global__ void LRNFillScaleNHWC(const int nthreads, const T *const in,
    const int channels, const int size, const T alpha_over_size,
    const T bias, T* scale) {
  CUDA_1D_KERNEL_LOOP(index, nthreads) {
    const int c = index % channels;
    const int pre_pad = (size - 1) / 2;
    scale[index] = 0;
    for (int i = 0; i < size; ++i) {
      const int raw_idx = c + i - pre_pad;
      if (raw_idx >= 0 && raw_idx < channels) {
        scale[index] += in[index + i - pre_pad] * in[index + i - pre_pad];
      }
    }
    scale[index] = bias + scale[index] * alpha_over_size;
  }
}

// TODO(Yangqing): check if it would be faster to just put it into the previous
// kernel.
template <typename T>
__global__ void LRNComputeOutput(const int nthreads, const T* in,
    const T* scale, const T negative_beta, T* out) {
  CUDA_1D_KERNEL_LOOP(index, nthreads) {
    out[index] = in[index] * pow(scale[index], negative_beta);
  }
}

template <typename T>
__global__ void LRNComputeDiffNCHW(const int nthreads, const T* bottom_data,
    const T* top_data, const T* scale, const T* top_diff,
    const int channels, const int height,
    const int width, const int size, const T negative_beta,
    const T cache_ratio,
    T* bottom_diff) {
  CUDA_1D_KERNEL_LOOP(index, nthreads) {
    // find out the local offset
    const int w = index % width;
    const int h = (index / width) % height;
    const int n = index / width / height;
    const int offset = (n * channels * height + h) * width + w;
    const int step = height * width;
    bottom_data += offset;
    top_data += offset;
    scale += offset;
    top_diff += offset;
    bottom_diff += offset;
    int head = 0;
    int pre_pad = size - (size + 1) / 2;
    int post_pad = size - pre_pad - 1;
    T accum_ratio = 0;
    // accumulate values
    while (head < post_pad) {
      accum_ratio += top_diff[head * step] * top_data[head * step] /
          scale[head * step];
      ++head;
    }
    // until we reach size, nothing needs to be subtracted
    while (head < size) {
      accum_ratio += top_diff[head * step] * top_data[head * step] /
          scale[head * step];
      bottom_diff[(head - post_pad) * step] = top_diff[(head - post_pad) * step]
          * pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio *
          bottom_data[(head - post_pad) * step] * accum_ratio;
      ++head;
    }
    // both add and subtract
    while (head < channels) {
      accum_ratio += top_diff[head * step] * top_data[head * step] /
          scale[head * step];
      accum_ratio -= top_diff[(head - size) * step] *
          top_data[(head - size) * step] / scale[(head - size) * step];
      bottom_diff[(head - post_pad) * step] = top_diff[(head - post_pad) * step]
          * pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio *
          bottom_data[(head - post_pad) * step] * accum_ratio;
      ++head;
    }
    // subtract only
    while (head < channels + post_pad) {
      accum_ratio -= top_diff[(head - size) * step] *
          top_data[(head - size) * step] / scale[(head - size) * step];
      bottom_diff[(head - post_pad) * step] = top_diff[(head - post_pad) * step]
          * pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio *
          bottom_data[(head - post_pad) * step] * accum_ratio;
      ++head;
    }
    // recover pointer for next iteration.
    bottom_data -= offset;
    top_data -= offset;
    scale -= offset;
    top_diff -= offset;
    bottom_diff -= offset;
  }
}

// This local response normalization gradient does one sum per output location
// and does not use the running trick for 1-d convolution: thus it might not be
// the fastest implementation.
template <typename T>
__global__ void LRNComputeDiffNHWC(const int nthreads, const T* bottom_data,
    const T* top_data, const T* scale, const T* top_diff,
    const int channels, const int size, const T negative_beta, const T cache_ratio,
    T* bottom_diff) {
  CUDA_1D_KERNEL_LOOP(index, nthreads) {
    // find out the local channel offset
    const int c = index % channels;
    const int pre_pad = size / 2;
    T accum_ratio = 0;
    for (int i = -pre_pad; i < size - pre_pad; ++i) {
      if (c + i >= 0 && c + i < channels) {
        accum_ratio += top_diff[index + i] * top_data[index + i] /
            scale[index + i];
      }
    }
    bottom_diff[index] = top_diff[index] * pow(scale[index], negative_beta) -
                         cache_ratio * bottom_data[index] * accum_ratio;
  }
}
}  // namespace

template<>
bool LRNOp<float, CUDAContext>::RunOnDeviceWithOrderNCHW() {
  auto& X = Input(0);

  TORCH_DCHECK_EQ(X.dim(), 4);
  const int N = X.dim32(0);
  const int C = X.dim32(1);
  const int H = X.dim32(2);
  const int W = X.dim32(3);
  const float* Xdata = X.data<float>();
  auto* Y = Output(0, X.sizes(), at::dtype<float>());
  float* Ydata = Y->template mutable_data<float>();
  if (OutputSize() > 1) {
    scale_ = Output(1);
  } else {
    if (!scale_) {
      scale_ = &local_scale_tensor_;
    }
  }
  scale_->ResizeLike(X);
  float* scale_data = scale_->template mutable_data<float>();

  int n_threads = N * H * W;
  LRNFillScaleNCHW<float><<<CAFFE_GET_BLOCKS(n_threads), CAFFE_CUDA_NUM_THREADS,
                        0, context_.cuda_stream()>>>(
      n_threads, Xdata, C, H, W, size_, alpha_ / size_, bias_, scale_data);
  C10_CUDA_KERNEL_LAUNCH_CHECK();

  n_threads = X.numel();
  LRNComputeOutput<float><<<CAFFE_GET_BLOCKS(n_threads), CAFFE_CUDA_NUM_THREADS,
                            0, context_.cuda_stream()>>>(
      n_threads, Xdata, scale_data, -beta_, Ydata);
  C10_CUDA_KERNEL_LAUNCH_CHECK();

  return true;
}

template<>
bool LRNOp<float, CUDAContext>::RunOnDeviceWithOrderNHWC() {
  auto& X = Input(0);

  TORCH_DCHECK_EQ(X.dim(), 4);
  const int N = X.dim32(0);
  const int H = X.dim32(1);
  const int W = X.dim32(2);
  const int C = X.dim32(3);
  const float* Xdata = X.data<float>();
  auto* Y = Output(0, X.sizes(), at::dtype<float>());
  float* Ydata = Y->template mutable_data<float>();
  if (OutputSize() > 1) {
    scale_ = Output(1);
  } else {
    if (!scale_) {
      scale_ = &local_scale_tensor_;
    }
  }
  scale_->ResizeLike(X);
  float* scale_data = scale_->template mutable_data<float>();

  int n_threads = X.numel();
  LRNFillScaleNHWC<float><<<CAFFE_GET_BLOCKS(n_threads), CAFFE_CUDA_NUM_THREADS,
                        0, context_.cuda_stream()>>>(
      n_threads, Xdata, C, size_, alpha_ / size_, bias_, scale_data);
  C10_CUDA_KERNEL_LAUNCH_CHECK();

  LRNComputeOutput<float><<<CAFFE_GET_BLOCKS(n_threads), CAFFE_CUDA_NUM_THREADS,
                            0, context_.cuda_stream()>>>(
      n_threads, Xdata, scale_data, -beta_, Ydata);
  C10_CUDA_KERNEL_LAUNCH_CHECK();

  return true;
}

template <>
bool LRNGradientOp<float, CUDAContext>::RunOnDeviceWithOrderNCHW() {
  auto& X = Input(0);
  auto& Y = Input(1);
  auto& dY = Input(2);

  TORCH_DCHECK_EQ(X.dim(), 4);
  const int N = X.dim32(0);
  const int C = X.dim32(1);
  const int H = X.dim32(2);
  const int W = X.dim32(3);
  // Loosely checking the size, assuming that the shapes will be the same as
  // long as the sizes check out.
  TORCH_DCHECK_EQ(X.numel(), Y.numel());
  TORCH_DCHECK_EQ(X.numel(), dY.numel());
  auto* dX = Output(0, X.sizes(), at::dtype<float>());

  const float* Xdata = X.data<float>();
  const float* Ydata = Y.data<float>();
  if (!scale_) {
    scale_ = &local_scale_tensor_;
  }
  scale_->ResizeLike(X);
  float *const scale_data = scale_->template mutable_data<float>();
  const int n_threads = N * H * W;
  LRNFillScaleNCHW<float><<<CAFFE_GET_BLOCKS(n_threads), CAFFE_CUDA_NUM_THREADS,
                        0, context_.cuda_stream()>>>(
      n_threads, Xdata, C, H, W, size_, alpha_ / size_, bias_, scale_data);
  C10_CUDA_KERNEL_LAUNCH_CHECK();

  const float *const dYdata = dY.data<float>();
  float *const dXdata = dX->template mutable_data<float>();

  LRNComputeDiffNCHW<float><<<CAFFE_GET_BLOCKS(n_threads),
                              CAFFE_CUDA_NUM_THREADS,
                              0, context_.cuda_stream()>>>(
      n_threads, Xdata, Ydata, scale_data, dYdata, C, H, W, size_, -beta_,
      2.f * alpha_ * beta_ / size_, dXdata);
  C10_CUDA_KERNEL_LAUNCH_CHECK();

  return true;
}

template <>
bool LRNGradientOp<float, CUDAContext>::RunOnDeviceWithOrderNHWC() {
  auto& X = Input(0);
  auto& Y = Input(1);
  auto& dY = Input(2);

  TORCH_DCHECK_EQ(X.dim(), 4);
  const int N = X.dim32(0);
  const int H = X.dim32(1);
  const int W = X.dim32(2);
  const int C = X.dim32(3);
  const float* Xdata = X.data<float>();
  // Loosely checking the size, assuming that the shapes will be the same as
  // long as the sizes check out.
  TORCH_DCHECK_EQ(X.numel(), Y.numel());
  TORCH_DCHECK_EQ(X.numel(), dY.numel());
  auto* dX = Output(0, X.sizes(), at::dtype<float>());
  if (!scale_) {
    scale_ = &local_scale_tensor_;
  }
  scale_->ResizeLike(X);

  float* scale_data = scale_->template mutable_data<float>();
  int n_threads = X.numel();
  LRNFillScaleNHWC<float><<<CAFFE_GET_BLOCKS(n_threads), CAFFE_CUDA_NUM_THREADS,
                        0, context_.cuda_stream()>>>(
      n_threads, Xdata, C, size_, alpha_ / size_, bias_, scale_data);
  C10_CUDA_KERNEL_LAUNCH_CHECK();

  LRNComputeDiffNHWC<float>
      <<<CAFFE_GET_BLOCKS(X.numel()),
         CAFFE_CUDA_NUM_THREADS,
         0,
         context_.cuda_stream()>>>(
          X.numel(),
          X.data<float>(),
          Y.data<float>(),
          scale_data,
          dY.data<float>(),
          X.dim32(3),
          size_,
          -beta_,
          2.f * alpha_ * beta_ / size_,
          dX->template mutable_data<float>());
  C10_CUDA_KERNEL_LAUNCH_CHECK();

  return true;
}


REGISTER_CUDA_OPERATOR(LRN, LRNOp<float, CUDAContext>);
REGISTER_CUDA_OPERATOR(LRNGradient, LRNGradientOp<float, CUDAContext>);

}  // namespace caffe2