File: segment_reduction_op_gpu.cuh

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (119 lines) | stat: -rw-r--r-- 2,831 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#include "caffe2/utils/cub_namespace.cuh"
#include <cub/block/block_reduce.cuh>
#include <cub/device/device_reduce.cuh>
#include <cub/device/device_scan.cuh>
#include "caffe2/core/context_gpu.h"


#if defined(USE_ROCM)
#define SEGREDUCE_MINBLOCKS 8
#else
#define SEGREDUCE_MINBLOCKS 16
#endif

namespace caffe2{

template <typename T>
struct SharedMemory;

template <>
struct SharedMemory<double> {
  __device__ double* getPointer() {
    extern __shared__ double s_double[];
    return s_double;
  }
};

template <>
struct SharedMemory<float> {
  __device__ float* getPointer() {
    extern __shared__ float s_float[];
    return s_float;
  }
};

template <>
struct SharedMemory<at::Half> {
  __device__ at::Half* getPointer() {
    extern __shared__ at::Half s_half[];
    return s_half;
  }
};


template <typename InType, typename OutType>
__device__ inline OutType convert_type(const InType in) {
  return in;
}

template <>
__device__ inline float convert_type<at::Half, float>(const at::Half in) {
  return __half2float(in);
}

template <
    typename InType,
    typename OutType,
    typename IndexType,
    bool ExactBlock = false,
    bool Average = false>
#if defined(USE_ROCM)
C10_LAUNCH_BOUNDS_2(1024,SEGREDUCE_MINBLOCKS)
#endif
__global__ void sparse_length_sum_kernel(
    const InType* __restrict__ in,
    OutType* __restrict__ out,
    const int* __restrict__ prefix_sum_length_data,
    const IndexType* __restrict__ indices,
    int N,
    int post,
    int len_length,
    int len_indices) {
  // len_length blocks
  int group = blockIdx.x;

  int start = group == 0 ? 0 : prefix_sum_length_data[group - 1];
  int end = prefix_sum_length_data[group];
  CUDA_KERNEL_ASSERT(start <= len_indices);
  CUDA_KERNEL_ASSERT(end <= len_indices);

  struct SharedMemory<OutType> smem;
  OutType* reduceVals = smem.getPointer();

  if (ExactBlock) {
    OutType sum = (OutType)0;

    in += threadIdx.x;
    for (int line = start + threadIdx.y; line < end; line += blockDim.y) {
      sum += convert_type<InType, OutType>(in[indices[line] * post]);
    }

    reduceVals[threadIdx.y * blockDim.x + threadIdx.x] = sum;
    __syncthreads();

    if (threadIdx.y == 0) {
      sum = (OutType)0;
      for (int i = 0; i < blockDim.y; ++i) {
        sum += reduceVals[i * blockDim.x + threadIdx.x];
      }
      if (Average && (end - start) > 1) {
        sum /= (end - start);
      }

      out[group * post + threadIdx.x] = sum;
    }
  } else {
    for (int i = threadIdx.x; i < post; i += blockDim.x) {
      OutType sum = (OutType)0;
      for (int line = start; line < end; ++line) {
        sum += convert_type<InType, OutType>(in[indices[line] * post + i]);
      }
      if (Average && (end - start) > 1) {
        sum /= (end - start);
      }
      out[group * post + i] = sum;
    }
  }
}

} //namespace caffe2