File: sequence_ops.cu

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (378 lines) | stat: -rw-r--r-- 11,023 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
#include <algorithm>

#include <cub/cub.cuh>
#include "caffe2/utils/cub_namespace.cuh"

#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/sequence_ops.h"

#include "caffe2/core/operator.h"
#include "caffe2/core/tensor.h"

namespace caffe2 {

namespace {
template <typename T>
__global__ void AddPaddingKernel(
    const T* in,
    int block_size,
    int lengths_size,
    int outer_size,
    const int32_t* lengths_prefix_sum,
    const T* padding_start_ptr,
    int start_padding_width_blocks,
    const T* padding_end_ptr,
    int end_padding_width_blocks,
    T* out,
    int32_t* lengths_out) {
  int element_idx = blockIdx.x;
  int prior_padding =
      element_idx * (start_padding_width_blocks + end_padding_width_blocks);
  int out_start_idx = element_idx == 0
      ? 0
      : lengths_prefix_sum[element_idx - 1] + prior_padding;
  int len_blocks;
  int in_start_idx;
  if (lengths_prefix_sum) {
    len_blocks = lengths_prefix_sum[element_idx] -
        (element_idx == 0 ? 0 : lengths_prefix_sum[element_idx - 1]);
    in_start_idx = lengths_prefix_sum[element_idx] - len_blocks;
  } else {
    // Only one element, use the outer size
    CUDA_KERNEL_ASSERT(lengths_size == 1);
    len_blocks = outer_size;
    in_start_idx = 0;
  }

  out_start_idx *= block_size;
  in_start_idx *= block_size;

  int len = len_blocks * block_size;
  int start_padding_width = start_padding_width_blocks * block_size;
  int end_padding_width = end_padding_width_blocks * block_size;

  // start pad
  T* out_ptr = out + out_start_idx;
  for (int i = threadIdx.x; i < start_padding_width; i += blockDim.x) {
    T fill = padding_start_ptr ? padding_start_ptr[i % block_size] : T(0);
    out_ptr[i] = fill;
  }

  // payload
  for (int i = threadIdx.x; i < len; i += blockDim.x) {
    out_ptr[i + start_padding_width] = in[in_start_idx + i];
  }

  // end pad
  for (int i = threadIdx.x; i < end_padding_width; i += blockDim.x) {
    T fill = padding_end_ptr ? padding_end_ptr[i % block_size] : T(0);
    out_ptr[i + start_padding_width + len] = fill;
  }

  // update the lengths
  if (threadIdx.x == 0 && lengths_out != nullptr) {
    lengths_out[element_idx] =
        len_blocks + start_padding_width_blocks + end_padding_width_blocks;
  }
}

template <typename T>
__global__ void RemovePaddingKernel(
    const T* in,
    int block_size,
    int lengths_size,
    int outer_size,
    const int32_t* lengths_prefix_sum,
    int start_padding_width_blocks,
    int end_padding_width_blocks,
    T* out,
    int32_t* lengths_out) {
  int element_idx = blockIdx.x;
  int prior_padding =
      element_idx * (start_padding_width_blocks + end_padding_width_blocks);
  int out_start_idx = element_idx == 0
      ? 0
      : lengths_prefix_sum[element_idx - 1] - prior_padding;
  int len_blocks;
  int in_start_idx;
  if (lengths_prefix_sum) {
    len_blocks = lengths_prefix_sum[element_idx] -
        (element_idx == 0 ? 0 : lengths_prefix_sum[element_idx - 1]);
    in_start_idx = lengths_prefix_sum[element_idx] - len_blocks;
  } else {
    // Only one element, use the outer size
    CUDA_KERNEL_ASSERT(lengths_size == 1);
    len_blocks = outer_size;
    in_start_idx = 0;
  }

  out_start_idx *= block_size;
  in_start_idx *= block_size;

  int len = len_blocks * block_size;
  int start_padding_width = start_padding_width_blocks * block_size;

  // payload
  T* out_ptr = out + out_start_idx;
  for (int i = threadIdx.x; i < len; i += blockDim.x) {
    out_ptr[in_start_idx + i] = in[i + start_padding_width];
  }

  // update the lengths
  if (threadIdx.x == 0 && lengths_out != nullptr) {
    lengths_out[element_idx] =
        len_blocks - (start_padding_width_blocks + end_padding_width_blocks);
  }
}

template <bool Inclusive = true>
void lengths_prefix_sum(
    const int32_t* lengths,
    int32_t num_items,
    Tensor* prefix_buffer,
    Tensor* prefix_sum,
    CUDAContext* context) {
  // Retrieve buffer size
  size_t temp_storage_bytes = 0;
  prefix_sum->Resize(num_items);
  if (Inclusive) {
    cub::DeviceScan::InclusiveSum(
        NULL,
        temp_storage_bytes,
        lengths,
        prefix_sum->template mutable_data<int32_t>(),
        num_items,
        context->cuda_stream());
  } else {
    cub::DeviceScan::ExclusiveSum(
        NULL,
        temp_storage_bytes,
        lengths,
        prefix_sum->template mutable_data<int32_t>(),
        num_items,
        context->cuda_stream());
  }

  // Allocate temporary storage
  auto buffer_size = (temp_storage_bytes + sizeof(int32_t)) / sizeof(int32_t);
  prefix_buffer->Resize(buffer_size);
  void* d_temp_storage =
      static_cast<void*>(prefix_buffer->template mutable_data<int32_t>());

  if (Inclusive) {
    cub::DeviceScan::InclusiveSum(
        d_temp_storage,
        temp_storage_bytes,
        lengths,
        prefix_sum->template mutable_data<int32_t>(),
        num_items,
        context->cuda_stream());
  } else {
    cub::DeviceScan::ExclusiveSum(
        d_temp_storage,
        temp_storage_bytes,
        lengths,
        prefix_sum->template mutable_data<int32_t>(),
        num_items,
        context->cuda_stream());
  }
}
} // namespace

template <>
template <typename T>
bool AddPaddingOp<CUDAContext>::MakePadding(
    const T* in_ptr,
    T* out_ptr,
    const int32_t* lengths_ptr,
    int32_t lengths_size,
    int32_t outer_size,
    const T* padding_start_ptr,
    const T* padding_end_ptr,
    int64_t block_size) {
  // Step 1: compute prefix sum over the lengths -- unless
  // there were no lengths given, i.e there is only one segment
  const int32_t* lengths_prefix_sum_ptr = nullptr;
  if (lengths_ptr != nullptr) {
    lengths_prefix_sum(
        lengths_ptr,
        lengths_size,
        &lengths_prefix_sum_buffer_,
        &lengths_prefix_sum_,
        &context_);
    lengths_prefix_sum_ptr = lengths_prefix_sum_.data<int32_t>();
  }

  int32_t* lengths_out_ptr = nullptr;
  if (OutputSize() > 1) {
    auto* lengths_out = Output(1, {lengths_size}, at::dtype<int32_t>());
    lengths_out_ptr = lengths_out->template mutable_data<int32_t>();
  }

  if (lengths_size == 0) {
    return true;
  }

  // Compute the padding using the accumulated lengths
  AddPaddingKernel<T>
      <<<lengths_size, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
          in_ptr,
          block_size,
          lengths_size,
          outer_size,
          lengths_prefix_sum_ptr,
          padding_start_ptr,
          startPaddingWidth_,
          padding_end_ptr,
          endPaddingWidth_,
          out_ptr,
          lengths_out_ptr);
  C10_CUDA_KERNEL_LAUNCH_CHECK();

  return true;
}

REGISTER_CUDA_OPERATOR(AddPadding, AddPaddingOp<CUDAContext>);

template <>
template <typename T>
bool RemovePaddingOp<CUDAContext>::DoRunWithType() {
  const auto& in = Input(0);
  CAFFE_ENFORCE_GE(in.dim(), 1);
  const int32_t outer_size = in.sizes()[0];
  const auto block_size = std::accumulate(
      in.sizes().begin() + 1, in.sizes().end(), 1, std::multiplies<int64_t>());

  // if no lengths is provided, assume it is a single full-span entry
  const int32_t* lengths_ptr = nullptr;
  int32_t lengths_size = 1;
  if (InputSize() > 1) {
    const auto& lengths = Input(1);
    lengths_ptr = lengths.data<int32_t>();
    lengths_size = lengths.numel();
  }

  auto out_dims = in.sizes().vec();
  out_dims[0] -= (startPaddingWidth_ + endPaddingWidth_) * lengths_size;
  auto* out = Output(0, out_dims, at::dtype<T>());
  const auto* in_ptr = in.template data<T>();
  auto* out_ptr = out->template mutable_data<T>();

  // Step 1: compute prefix sum over the (padded) lengths -- unless
  // there were no lengths given, i.e there is only one segment
  const int32_t* lengths_prefix_sum_ptr = nullptr;
  if (lengths_ptr != nullptr) {
    lengths_prefix_sum(
        lengths_ptr,
        lengths_size,
        &lengths_prefix_sum_buffer_,
        &lengths_prefix_sum_,
        &context_);
    lengths_prefix_sum_ptr = lengths_prefix_sum_.data<int32_t>();
  }

  int32_t* lengths_out_ptr = nullptr;
  if (OutputSize() > 1) {
    auto* lengths_out = Output(1, {lengths_size}, at::dtype<int32_t>());
    lengths_out_ptr = lengths_out->template mutable_data<int32_t>();
  }

  if (lengths_size == 0) {
    return true;
  }

  // Compute the padding using the accumulated lengths
  RemovePaddingKernel<T>
      <<<lengths_size, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
          in_ptr,
          block_size,
          lengths_size,
          outer_size,
          lengths_prefix_sum_ptr,
          startPaddingWidth_,
          endPaddingWidth_,
          out_ptr,
          lengths_out_ptr);
  C10_CUDA_KERNEL_LAUNCH_CHECK();

  return true;
}

template <typename T>
__global__ void gather_padding_kernel(
    const int K,
    const int N,
    const int Y0Width,
    const int Y1Width,
    const T* X,
    const int* I,
    const int* L,
    T* Y0,
    T* Y1) {
  typedef cub::BlockReduce<float, CAFFE_CUDA_NUM_THREADS> BlockReduce;
  __shared__ typename BlockReduce::TempStorage y0_tmp;
  __shared__ typename BlockReduce::TempStorage y1_tmp;
  for (int i = blockIdx.x; i < N; i += gridDim.x) {
    T sum_1 = T(0);
    T sum_2 = T(0);
    for (int j = threadIdx.x; j < K * Y0Width; j += blockDim.x) {
      const int j1 = j / Y0Width;
      const int j2 = j % Y0Width;
      const int idx1 = N * (L[j1] + j2);
      sum_1 += X[idx1 + i];
    }
    for (int j = threadIdx.x; j < K * Y1Width; j += blockDim.x) {
      const int j1 = j / Y1Width;
      const int j2 = j % Y1Width;
      const int idx1 = N * L[j1];
      const int idx2 = idx1 + N * (I[j1] - Y1Width + j2);
      sum_2 += X[idx2 + i];
    }
    sum_1 = BlockReduce(y0_tmp).Reduce(sum_1, cub::Sum());
    sum_2 = BlockReduce(y1_tmp).Reduce(sum_2, cub::Sum());
    if (threadIdx.x == 0) {
      Y0[i] = sum_1;
      Y0 != Y1 ? Y1[i] = sum_2 : Y0[i] = sum_1 + sum_2;
    }
    __syncthreads();
  }
}

template <>
template <typename T>
void GatherPaddingOp<CUDAContext>::GatherPadding(
    const int outer_size,
    const int lengths_size,
    const int block_size,
    const int pad_width,
    const T* in_ptr,
    const int* lengths_ptr,
    T* padding_start_ptr,
    T* padding_end_ptr) {
  if (lengths_size > 0) {
    lengths_prefix_sum<false>(
        lengths_ptr,
        lengths_size,
        &lengths_prefix_sum_buffer_,
        &lengths_prefix_sum_,
        &context_);
    gather_padding_kernel<T>
        <<<std::min(block_size, CAFFE_MAXIMUM_NUM_BLOCKS),
           CAFFE_CUDA_NUM_THREADS,
           0,
           context_.cuda_stream()>>>(
            lengths_size,
            block_size,
            startPaddingWidth_,
            endPaddingWidth_,
            in_ptr,
            lengths_ptr,
            lengths_prefix_sum_.template data<int>(),
            padding_start_ptr,
            padding_end_ptr);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
  }
}
REGISTER_CUDA_OPERATOR(RemovePadding, RemovePaddingOp<CUDAContext>);
REGISTER_CUDA_OPERATOR(GatherPadding, GatherPaddingOp<CUDAContext>);
} // namespace caffe2