1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
|
#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/reduction_ops.h"
#include "caffe2/utils/conversions.h"
#include "caffe2/utils/cub_namespace.cuh"
namespace caffe2 {
REGISTER_CUDA_OPERATOR(SumElements, SumElementsOp<float, CUDAContext>);
REGISTER_CUDA_OPERATOR(SumElementsInt, SumElementsIntOp<int, CUDAContext>);
REGISTER_CUDA_OPERATOR(SumSqrElements, SumSqrElementsOp<CUDAContext>);
REGISTER_CUDA_OPERATOR(RowwiseMax, MaxReductionOp<float, CUDAContext, true>);
REGISTER_CUDA_OPERATOR(ColwiseMax, MaxReductionOp<float, CUDAContext, false>);
REGISTER_CUDA_OPERATOR(
RowwiseMaxGradient,
MaxReductionGradientOp<float, CUDAContext, true>)
REGISTER_CUDA_OPERATOR(
ColwiseMaxGradient,
MaxReductionGradientOp<float, CUDAContext, false>)
REGISTER_CUDA_OPERATOR(
SumElementsGradient,
SumElementsGradientOp<float, CUDAContext>);
template <typename T>
__global__ void
SumElementsGradientKernel(bool average, const int N, const T* dY, T* dX) {
const T value = average ? (*dY) / N : *dY;
CUDA_1D_KERNEL_LOOP(i, N) {
dX[i] = value;
}
}
__global__ void rowwise_max_gradient_kernel(
const int batch_size,
const int M,
const int N,
const float* X,
const float* Y,
const float* dY,
float* dX) {
const int input_size = M * N;
CUDA_1D_KERNEL_LOOP(i, batch_size * M * N) {
const int b_i = i / input_size;
const int b_n = i / input_size / N;
const int y_index = b_i * M + b_n;
if (X[i] == Y[y_index]) {
dX[i] = dY[y_index];
} else {
dX[i] = 0.0;
}
}
}
template <>
bool SumSqrElementsOp<CUDAContext>::RunOnDevice() {
return DispatchHelper<TensorTypes<float, at::Half>>::call(this, Input(0));
}
__global__ void colwise_max_gradient_kernel(
const int batch_size,
const int M,
const int N,
const float* X,
const float* Y,
const float* dY,
float* dX) {
const int input_size = M * N;
CUDA_1D_KERNEL_LOOP(i, batch_size * M * N) {
const int b_i = i / input_size;
const int b_n = i % input_size % N;
const int y_index = b_i * N + b_n;
if (X[i] == Y[y_index]) {
dX[i] = dY[y_index];
} else {
dX[i] = 0.0;
}
}
}
template <>
bool SumElementsGradientOp<float, CUDAContext>::RunOnDevice() {
auto& X = Input(0);
auto& dY = Input(1);
TORCH_DCHECK_EQ(dY.numel(), 1);
auto* dX = Output(0, X.sizes(), at::dtype<float>());
SumElementsGradientKernel<float>
<<<CAFFE_GET_BLOCKS(X.numel()),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(
average_,
X.numel(),
dY.data<float>(),
dX->template mutable_data<float>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
return true;
}
template <typename T, class Context, bool ROWWISE>
bool MaxReductionGradientOp<T, Context, ROWWISE>::RunOnDevice() {
auto& X = Input(0);
auto& Y = Input(1);
auto& dY = Input(2);
auto* dX = Output(0, X.sizes(), at::dtype<T>());
CAFFE_ENFORCE_EQ(X.dim(), 3);
const int batch_size = X.dim32(0);
const int M = X.dim32(1);
const int N = X.dim32(2);
const T* Xdata = X.template data<T>();
const T* Ydata = Y.template data<T>();
const T* dYdata = dY.template data<T>();
T* dXdata = dX->template mutable_data<T>();
const int input_size = M * N;
if (ROWWISE) {
rowwise_max_gradient_kernel<<<
CAFFE_GET_BLOCKS(batch_size * input_size),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(
batch_size, M, N, Xdata, Ydata, dYdata, dXdata);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
colwise_max_gradient_kernel<<<
CAFFE_GET_BLOCKS(batch_size * input_size),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(
batch_size, M, N, Xdata, Ydata, dYdata, dXdata);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
return true;
}
} // namespace caffe2
|