File: scale_blobs_op.cu

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (175 lines) | stat: -rw-r--r-- 5,381 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
#include <algorithm>

#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/scale_blobs_op.h"

namespace caffe2 {

template <typename T>
__global__ void ScaleBlobsCUDAKernel(
    const float scale,
    const int numBlobs,
    const int* sizeArr,
    T** X,
    T** Y) {
  for (size_t i = 0; i < numBlobs; ++i) {
    CUDA_1D_KERNEL_LOOP(j, sizeArr[i]) {
      Y[i][j] = X[i][j] * scale;
    }
  }
}

template <typename T>
__global__ void ScaleBlobsCUDAKernelManyTensors(
    const float scale,
    const int* sizeArr,
    T** X,
    T** Y) {
  for (size_t i = threadIdx.x; i < sizeArr[blockIdx.x]; i += blockDim.x) {
    Y[blockIdx.x][i] = X[blockIdx.x][i] * scale;
  }
}

template <>
template <typename T>
bool ScaleBlobsOp<CUDAContext>::DoRunWithType() {
  const int numBlobs = InputSize();

  ReinitializeTensor(&hostBlobSizes_, {numBlobs}, at::dtype<int>().device(CPU));
  int* hostBlobSizesData = hostBlobSizes_.mutable_data<int>();

  ReinitializeTensor(&hostInputs_, {numBlobs}, at::dtype<T*>().device(CPU));
  T** hostInputsData = hostInputs_.mutable_data<T*>();

  ReinitializeTensor(&hostOutputs_, {numBlobs}, at::dtype<T*>().device(CPU));
  T** hostOutputsData = hostOutputs_.mutable_data<T*>();

  int totalSize = 0;
  int maxSize = 0;
  for (int i = 0; i < numBlobs; ++i) {
    hostBlobSizesData[i] = Input(i).numel();
    totalSize += hostBlobSizesData[i];
    maxSize = std::max(maxSize, hostBlobSizesData[i]);
    hostInputsData[i] = Input(i).template data<T>();
    hostOutputsData[i] = Output(i)->template mutable_data<T>();
  }

  ReinitializeTensor(&inputs_, {numBlobs}, at::dtype<T*>().device(CUDA));
  ReinitializeTensor(&outputs_, {numBlobs}, at::dtype<T*>().device(CUDA));
  ReinitializeTensor(&blobSizes_, {numBlobs}, at::dtype<T*>().device(CUDA));

  blobSizes_.CopyFrom(hostBlobSizes_);
  inputs_.CopyFrom(hostInputs_);
  outputs_.CopyFrom(hostOutputs_);

  // Select which kernel to launch based on the length of the tensors
  // The first one performs better when there are many tensors of short length
  // The second one is better when there are small number of long tensors
  if (numBlobs > CAFFE_GET_BLOCKS(maxSize)) {
    // Note: number of blocks has to be equal to the numBlobs
    ScaleBlobsCUDAKernelManyTensors<T>
        <<<numBlobs, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
            scale_,
            blobSizes_.data<int>(),
            inputs_.mutable_data<T*>(),
            outputs_.mutable_data<T*>());
    C10_CUDA_KERNEL_LAUNCH_CHECK();
  } else {
    ScaleBlobsCUDAKernel<T>
        <<<CAFFE_GET_BLOCKS(maxSize),
           CAFFE_CUDA_NUM_THREADS,
           0,
           context_.cuda_stream()>>>(
            scale_,
            numBlobs,
            blobSizes_.data<int>(),
            inputs_.mutable_data<T*>(),
            outputs_.mutable_data<T*>());
    C10_CUDA_KERNEL_LAUNCH_CHECK();
  }
  return true;
}

template <>
bool ScaleBlobsOp<CUDAContext>::RunOnDevice() {
  for (int i = 0; i < InputSize(); ++i) {
    auto& input = this->template Input<Tensor>(i, CUDA);
    auto* output = this->template Output<Tensor>(i, CUDA);
    output->ResizeLike(input);
  }
  return DispatchHelper<TensorTypes<at::Half, float>>::call(this, Input(0));
}

REGISTER_CUDA_OPERATOR(ScaleBlobs, ScaleBlobsOp<CUDAContext>);

/*
 * Implementation of a different version of the kernel
 * This balances the work per thread and could be useful
 * when there is a high imbalance between tensors
 * However the memory requirement is very high so it does
 * not perform well for common scenarios
 *
 *
 * Additional storage for the start pointers is required
 * for ScaleBlobsCUDAKernelBalanced setup
 *
    int threadsPerBlock = CAFFE_CUDA_NUM_THREADS;
    int coorArrSize = 2 * ((totalSize - 1) / threadsPerBlock + 1);
    int startCoorArr[coorArrSize];
    int* dStartCoorArr;

    int j = 0, cur = 0, elemsLeftInRow = 0;
    for (int i = 0; i < numBlobs; ++i) {
      if (i == 0) {
        startCoorArr[cur++] = i;
        startCoorArr[cur++] = j;
        elemsLeftInRow = 0;
      }
      while (j < sizeArr[i]) {
        j += threadsPerBlock - elemsLeftInRow;
        if (j < sizeArr[i]) {
          startCoorArr[cur++] = i;
          startCoorArr[cur++] = j;
          elemsLeftInRow = 0;
        } else {
          elemsLeftInRow = sizeArr[i] - j + threadsPerBlock;
          j = 0;
          break;
        }
      }
    }
    cudaMalloc(&dStartCoorArr, sizeof(int) * coorArrSize);
    cudaMemcpy(dStartCoorArr, startCoorArr, sizeof(int) * coorArrSize,
    cudaMemcpyHostToDevice);

  // ScaleBlobsCUDAKernelBalanced kernel launch
  ScaleBlobsCUDAKernelBalanced<T>
   <<<(totalSize-1)/CAFFE_CUDA_NUM_THREADS+1, CAFFE_CUDA_NUM_THREADS, 0,
   context_.cuda_stream()>>>(
     scale_, numBlobs, coorArrSize, dStartCoorArr, dSizeArr, dInputArr,
     dOutputArr);
  C10_CUDA_KERNEL_LAUNCH_CHECK();

  cudaFree(dStartCoorArr);
*/

template <typename T>
__global__ void ScaleBlobsCUDAKernelBalanced(
    const float scale,
    const int numBlobs,
    const int coorArrSize,
    const int* coorArr,
    const int* sizeArr,
    T** X,
    T** Y) {
  int i = coorArr[2 * blockIdx.x + 1] + threadIdx.x;
  int curTen = coorArr[2 * blockIdx.x];
  while (curTen < numBlobs && i >= sizeArr[curTen]) {
    i -= sizeArr[curTen++];
  }
  if (curTen < numBlobs) {
    Y[curTen][i] = X[curTen][i] * scale;
  }
}

} // namespace caffe2