File: transpose.cu

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (233 lines) | stat: -rw-r--r-- 10,557 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
#include "caffe2/utils/math/transpose.h"

#include <algorithm>
#include <functional>
#include <numeric>

#include "caffe2/core/common_gpu.h"
#include "caffe2/core/context_gpu.h"
#include "caffe2/utils/math/utils.h"

namespace caffe2 {
namespace math {

namespace {

constexpr int kTileDim = 32;
constexpr int kBlockRows = 8;

// Splits the original matrix into submatrices with size 32 * 32.
// Each block transposes one submatrix by loading it into shared memory.
// Reference https://devblogs.nvidia.com/efficient-matrix-transpose-cuda-cc/
template <typename TIndex, typename TData>
__global__ void BatchTranspose2DCUDAKernel(
    const TIndex H,
    const TIndex W,
    const TIndex dh,
    const TIndex dw,
    const TData* X,
    TData* Y) {
  __shared__ TData tile[kTileDim][kTileDim + 1];
  const TIndex n = blockIdx.x / (dh * dw);
  const TIndex k = blockIdx.x % (dh * dw);
  const TIndex r = k / dw;
  const TIndex c = k % dw;
  const TIndex offset = n * H * W;
  int x = c * kTileDim + threadIdx.x;
  int y = r * kTileDim + threadIdx.y;
  if (x < W) {
    for (int i = 0; threadIdx.y + i < kTileDim && y + i < H; i += kBlockRows) {
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
      tile[threadIdx.y + i][threadIdx.x] = __ldg(X + offset + (y + i) * W + x);
#else
      tile[threadIdx.y + i][threadIdx.x] = X[offset + (y + i) * W + x];
#endif
    }
  }
  __syncthreads();
  x = r * kTileDim + threadIdx.x;
  y = c * kTileDim + threadIdx.y;
  if (x < H) {
    for (int i = 0; threadIdx.y + i < kTileDim && y + i < W; i += kBlockRows) {
      Y[offset + (y + i) * H + x] = tile[threadIdx.x][threadIdx.y + i];
    }
  }
}

template <typename TIndex, typename TData>
void BatchTranspose2DCUDAImpl(
    const TIndex N,
    const TIndex H,
    const TIndex W,
    const TData* X,
    TData* Y,
    CUDAContext* context) {
  const TIndex dh = DivUp<TIndex>(H, kTileDim);
  const TIndex dw = DivUp<TIndex>(W, kTileDim);
  BatchTranspose2DCUDAKernel<TIndex, TData>
      <<<N * dh * dw, dim3(kTileDim, kBlockRows), 0, context->cuda_stream()>>>(
          H, W, dh, dw, X, Y);
  C10_CUDA_KERNEL_LAUNCH_CHECK();
}

#define DELEGATE_TRANSPOSE_2D_CUDA_IMPL(TIndex, TData, CuBLASFunc) \
  template <>                                                      \
  void BatchTranspose2DCUDAImpl<TIndex, TData>(                    \
      const TIndex N,                                              \
      const TIndex H,                                              \
      const TIndex W,                                              \
      const TData* X,                                              \
      TData* Y,                                                    \
      CUDAContext* context) {                                      \
    if (N == 1) {                                                  \
      const TData kAlpha = TData(1);                               \
      const TData kBeta = TData(0);                                \
      CUBLAS_ENFORCE(cublasSetPointerMode(                         \
          context->cublas_handle(), CUBLAS_POINTER_MODE_HOST));    \
      CUBLAS_ENFORCE(CuBLASFunc(                                   \
          context->cublas_handle(),                                \
          CUBLAS_OP_T,                                             \
          CUBLAS_OP_N,                                             \
          H,                                                       \
          W,                                                       \
          &kAlpha,                                                 \
          X,                                                       \
          W,                                                       \
          &kBeta,                                                  \
          Y,                                                       \
          H,                                                       \
          Y,                                                       \
          H));                                                     \
    } else {                                                       \
      const TIndex dh = DivUp<TIndex>(H, kTileDim);                \
      const TIndex dw = DivUp<TIndex>(W, kTileDim);                \
      BatchTranspose2DCUDAKernel<TIndex, TData>                    \
          <<<N * dh * dw,                                          \
             dim3(kTileDim, kBlockRows),                           \
             0,                                                    \
             context->cuda_stream()>>>(H, W, dh, dw, X, Y);        \
      C10_CUDA_KERNEL_LAUNCH_CHECK();                              \
    }                                                              \
  }
DELEGATE_TRANSPOSE_2D_CUDA_IMPL(std::int32_t, float, cublasSgeam)
DELEGATE_TRANSPOSE_2D_CUDA_IMPL(std::int64_t, float, cublasSgeam)
DELEGATE_TRANSPOSE_2D_CUDA_IMPL(std::int32_t, double, cublasDgeam)
DELEGATE_TRANSPOSE_2D_CUDA_IMPL(std::int64_t, double, cublasDgeam)
#undef DELEGATE_TRANSPOSE_2D_CUDA_IMPL

template <typename TIndex, typename TData, int D>
__global__ void TransposeCUDAKernel(
    const TIndex size,
    const SimpleArray<TIndex, D> X_strides,
    const SimpleArray<TIndex, D> Y_dims,
    const TData* X,
    TData* Y) {
  const int Y_index = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
  if (Y_index < size) {
    TIndex X_index = 0;
    TIndex v = Y_index;
#pragma unroll
    for (int i = D - 1; i >= 0; --i) {
      X_index += v % Y_dims.data[i] * X_strides.data[i];
      v /= Y_dims.data[i];
    }
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
    Y[Y_index] = __ldg(X + X_index);
#else
    Y[Y_index] = X[X_index];
#endif
  }
}

template <typename TIndex, typename TData, int D>
void TransposeCUDAImpl(
    const TIndex* dims,
    const int* axes,
    const TData* X,
    TData* Y,
    CUDAContext* context) {
  SimpleArray<TIndex, D> X_strides;
  SimpleArray<TIndex, D> Y_dims;
  utils::ComputeTransposedStrides<TIndex>(D, dims, axes, X_strides.data);
  TIndex size = 1;
  for (int i = 0; i < D; ++i) {
    Y_dims.data[i] = dims[axes[i]];
    size *= dims[i];
  }
  const TIndex M = DivUp<TIndex>(size, CAFFE_CUDA_NUM_THREADS);
  TransposeCUDAKernel<TIndex, TData, D>
      <<<M, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
          size, X_strides, Y_dims, X, Y);
  C10_CUDA_KERNEL_LAUNCH_CHECK();
}

} // namespace

#define CAFFE2_SPECIALIZED_CUDA_TRANSPOSE(TIndex, TData)                    \
  template <>                                                               \
  CAFFE2_CUDA_EXPORT void Transpose<TIndex, TData, CUDAContext>(            \
      const int ndim,                                                       \
      const TIndex* dims,                                                   \
      const int* axes,                                                      \
      const TData* X,                                                       \
      TData* Y,                                                             \
      CUDAContext* context) {                                               \
    const TIndex size = std::accumulate(                                    \
        dims, dims + ndim, TIndex(1), std::multiplies<TIndex>());           \
    if (size == 0) {                                                        \
      return;                                                               \
    }                                                                       \
    if (utils::IsIdentityPermutation(ndim, axes)) {                         \
      context->template CopySameDevice<TData>(size, X, Y);                  \
      return;                                                               \
    }                                                                       \
    if (utils::IsBatchTranspose2D(ndim, axes)) {                            \
      const int H = dims[ndim - 2];                                         \
      const int W = dims[ndim - 1];                                         \
      const int N = size / (H * W);                                         \
      BatchTranspose2DCUDAImpl<TIndex, TData>(N, H, W, X, Y, context);      \
      return;                                                               \
    }                                                                       \
    DISPATCH_FUNCTION_BY_VALUE_WITH_TYPE_2(                                 \
        ndim, TransposeCUDAImpl, TIndex, TData, dims, axes, X, Y, context); \
  }
CAFFE2_SPECIALIZED_CUDA_TRANSPOSE(std::int32_t, float)
CAFFE2_SPECIALIZED_CUDA_TRANSPOSE(std::int64_t, float)
CAFFE2_SPECIALIZED_CUDA_TRANSPOSE(std::int32_t, double)
CAFFE2_SPECIALIZED_CUDA_TRANSPOSE(std::int64_t, double)
CAFFE2_SPECIALIZED_CUDA_TRANSPOSE(std::int32_t, std::int32_t)
CAFFE2_SPECIALIZED_CUDA_TRANSPOSE(std::int64_t, std::int32_t)
CAFFE2_SPECIALIZED_CUDA_TRANSPOSE(std::int32_t, std::int64_t)
CAFFE2_SPECIALIZED_CUDA_TRANSPOSE(std::int64_t, std::int64_t)
#undef CAFFE2_SPECIALIZED_CUDA_TRANSPOSE

#define CAFFE2_SPECIALIZED_CUDA_NCHW2NHWC(T)                    \
  template <>                                                   \
  CAFFE2_CUDA_EXPORT void NCHW2NHWC<T, CUDAContext>(            \
      const int N,                                              \
      const int C,                                              \
      const int HxW,                                            \
      const T* X,                                               \
      T* Y,                                                     \
      CUDAContext* context) {                                   \
    BatchTranspose2DCUDAImpl<int, T>(N, C, HxW, X, Y, context); \
  }
CAFFE2_SPECIALIZED_CUDA_NCHW2NHWC(float)
#undef CAFFE2_SPECIALIZED_CUDA_NCHW2NHWC

#define CAFFE2_SPECIALIZED_CUDA_NHWC2NCHW(T)                    \
  template <>                                                   \
  CAFFE2_CUDA_EXPORT void NHWC2NCHW<T, CUDAContext>(            \
      const int N,                                              \
      const int C,                                              \
      const int HxW,                                            \
      const T* X,                                               \
      T* Y,                                                     \
      CUDAContext* context) {                                   \
    BatchTranspose2DCUDAImpl<int, T>(N, HxW, C, X, Y, context); \
  }
CAFFE2_SPECIALIZED_CUDA_NHWC2NCHW(float)
#undef CAFFE2_SPECIALIZED_CUDA_NHWC2NCHW

} // namespace math
} // namespace caffe2