1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
|
#include <cub/block/block_reduce.cuh>
#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/reduce_front_back_max_ops.h"
#include "caffe2/utils/cub_namespace.cuh"
#if defined(USE_ROCM)
#include <cfloat>
#endif
namespace caffe2 {
/***
Max Ops
***/
namespace {
__global__ void columnwise_max_kernel(
const int rows,
const int cols,
const float* data,
const int* lengths,
float* out) {
typedef cub::BlockReduce<float, CAFFE_CUDA_NUM_THREADS> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
for (int colIndex = blockIdx.x; colIndex < cols; colIndex += gridDim.x) {
float mx = FLT_MIN;
const int length = lengths == nullptr ? rows : lengths[colIndex];
for (int rowIndex = threadIdx.x; rowIndex < length;
rowIndex += blockDim.x) {
mx = fmaxf(mx, data[rowIndex * cols + colIndex]);
}
mx = BlockReduce(temp_storage).Reduce(mx, cub::Max());
if (threadIdx.x == 0) {
out[colIndex] = mx;
}
__syncthreads();
}
}
__global__ void rowwise_max_kernel(
const int rows,
const int cols,
const float* data,
const int* lengths,
float* out) {
typedef cub::BlockReduce<float, CAFFE_CUDA_NUM_THREADS> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
for (int rowIndex = blockIdx.x; rowIndex < rows; rowIndex += gridDim.x) {
float mx = FLT_MIN;
const int length = lengths == nullptr ? cols : lengths[rowIndex];
for (int colIndex = threadIdx.x; colIndex < length;
colIndex += blockDim.x) {
mx = fmaxf(mx, data[rowIndex * cols + colIndex]);
}
mx = BlockReduce(temp_storage).Reduce(mx, cub::Max());
if (threadIdx.x == 0) {
out[rowIndex] = mx;
}
__syncthreads();
}
}
__global__ void columnwise_max_grad_kernel(
const int rows,
const int cols,
const float* dYdata,
const float* Xdata,
const float* Ydata,
const int* lengths,
float* dXdata) {
CUDA_1D_KERNEL_LOOP(i, rows * cols) {
int col = i % cols;
int row = i / cols;
if (lengths != nullptr && row >= lengths[col]) {
dXdata[i] = 0.0f;
} else {
dXdata[i] = (Xdata[i] == Ydata[col]) * dYdata[col];
}
}
}
__global__ void rowwise_max_grad_kernel(
const int rows,
const int cols,
const float* dYdata,
const float* Xdata,
const float* Ydata,
const int* lengths,
float* dXdata) {
CUDA_1D_KERNEL_LOOP(i, rows * cols) {
int col = i % cols;
int row = i / cols;
if (lengths != nullptr && col >= lengths[row]) {
dXdata[i] = 0.0f;
} else {
dXdata[i] = (Xdata[i] == Ydata[row]) * dYdata[row];
}
}
}
} // anonymous namespace
// ReduceFrontmax
template <>
void MaxReduceDimsOp<float, CUDAContext, true>::Compute(
int rows,
int cols,
const float* data,
const int32_t* lengths_data,
float* out_data) {
columnwise_max_kernel<<<
std::min(cols, CAFFE_MAXIMUM_NUM_BLOCKS),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(rows, cols, data, lengths_data, out_data);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
// ReduceBackMax
template <>
void MaxReduceDimsOp<float, CUDAContext, false>::Compute(
int rows,
int cols,
const float* data,
const int32_t* lengths_data,
float* out_data) {
rowwise_max_kernel<<<
std::min(rows, CAFFE_MAXIMUM_NUM_BLOCKS),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(rows, cols, data, lengths_data, out_data);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
// ReduceFrontMaxGradient
template <>
void MaxReduceDimsGradientOp<float, CUDAContext, true>::Compute(
int rows,
int cols,
const float* dYdata,
const float* Xdata,
const float* Ydata,
const int32_t* lengths_data,
float* dXdata) {
columnwise_max_grad_kernel<<<
CAFFE_GET_BLOCKS(rows * cols),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(
rows, cols, dYdata, Xdata, Ydata, lengths_data, dXdata);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
// ReduceBackMaxGradient
template <>
void MaxReduceDimsGradientOp<float, CUDAContext, false>::Compute(
int rows,
int cols,
const float* dYdata,
const float* Xdata,
const float* Ydata,
const int* lengths_data,
float* dXdata) {
rowwise_max_grad_kernel<<<
CAFFE_GET_BLOCKS(rows * cols),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(
rows, cols, dYdata, Xdata, Ydata, lengths_data, dXdata);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
REGISTER_CUDA_OPERATOR(
ReduceFrontMax,
MaxReduceDimsOp<float, CUDAContext, true>);
REGISTER_CUDA_OPERATOR(
ReduceFrontMaxGradient,
MaxReduceDimsGradientOp<float, CUDAContext, true>);
REGISTER_CUDA_OPERATOR(
ReduceBackMax,
MaxReduceDimsOp<float, CUDAContext, false>);
REGISTER_CUDA_OPERATOR(
ReduceBackMaxGradient,
MaxReduceDimsGradientOp<float, CUDAContext, false>);
} // namespace caffe2
|