1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
|
#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/resize_op.h"
#include "caffe2/utils/GpuAtomics.cuh"
#include "caffe2/utils/math.h"
namespace caffe2 {
namespace {
__global__ void NearestNeighborKernel(
const int size,
const int num_channels,
const int input_height,
const int input_width,
const int output_height,
const int output_width,
const float height_scale,
const float width_scale,
const float* X,
float* Y) {
CUDA_1D_KERNEL_LOOP(index, size) {
int indexTemp = index;
const int w = indexTemp % output_width;
indexTemp /= output_width;
const int h = indexTemp % output_height;
indexTemp /= output_height;
const int c = indexTemp % num_channels;
indexTemp /= num_channels;
const int n = indexTemp;
const int in_y = fminf(h / height_scale, input_height - 1);
const int in_x = fminf(w / width_scale, input_width - 1);
Y[index] =
X[((n * num_channels + c) * input_height + in_y) * input_width + in_x];
}
}
__global__ void NearestNeighborGradientKernel(
const int size,
const int num_channels,
const int input_height,
const int input_width,
const int output_height,
const int output_width,
const float height_scale,
const float width_scale,
const float* dY,
float* dX) {
CUDA_1D_KERNEL_LOOP(index, size) {
int indexTemp = index;
const int x = indexTemp % input_width;
indexTemp /= input_width;
const int y = indexTemp % input_height;
indexTemp /= input_height;
const int c = indexTemp % num_channels;
indexTemp /= num_channels;
const int n = indexTemp;
const int out_y = fminf(y / height_scale, output_height - 1);
const int out_x = fminf(x / width_scale, output_width - 1);
const int out_index =
((n * num_channels + c) * output_height + out_y) * output_width + out_x;
#if __CUDA_ARCH__ >= 350
gpu_atomic_add(dX + out_index, __ldg(dY + index));
#else
gpu_atomic_add(dX + out_index, *(dY + index));
#endif
}
}
} // namespace
template <>
bool ResizeNearestOp<float, CUDAContext>::RunOnDevice() {
const auto& X = Input(0);
const auto inputDims = X.sizes();
CAFFE_ENFORCE_EQ(4, inputDims.size());
const int batch_size = X.dim32(0), num_channels = X.dim32(1),
input_height = X.dim32(2), input_width = X.dim32(3);
if (InputSize() == 2) {
const auto& scales = Input(1);
CAFFE_ENFORCE_EQ(scales.dim(), 1);
CAFFE_ENFORCE_EQ(scales.numel(), 2);
float scales_data[2];
context_.CopyToCPU<float>(2, scales.data<float>(), scales_data);
height_scale_ = scales_data[0];
width_scale_ = scales_data[1];
}
int output_width = input_width * width_scale_;
int output_height = input_height * height_scale_;
auto* Y = Output(
0,
{batch_size, num_channels, output_height, output_width},
at::dtype<float>());
const auto size = Y->numel();
NearestNeighborKernel<<<
CAFFE_GET_BLOCKS(size),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(
size,
num_channels,
input_height,
input_width,
output_height,
output_width,
height_scale_,
width_scale_,
X.data<float>(),
Y->template mutable_data<float>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
return true;
}
template <>
bool ResizeNearestGradientOp<float, CUDAContext>::RunOnDevice() {
const auto& dY = Input(0);
const auto& X = Input(1);
const auto inputDims = dY.sizes();
CAFFE_ENFORCE_EQ(4, inputDims.size());
const int batch_size = dY.dim32(0), num_channels = dY.dim32(1),
input_height = dY.dim32(2), input_width = dY.dim32(3);
int output_height = X.dim32(2);
int output_width = X.dim32(3);
if (InputSize() == 3) {
const auto& scales = Input(2);
CAFFE_ENFORCE_EQ(scales.dim(), 1);
CAFFE_ENFORCE_EQ(scales.numel(), 2);
float scales_data[2];
context_.CopyToCPU<float>(2, scales.data<float>(), scales_data);
height_scale_ = scales_data[0];
width_scale_ = scales_data[1];
}
auto* dX = Output(
0,
{batch_size, num_channels, output_height, output_width},
at::dtype<float>());
math::Set<float, CUDAContext>(
dX->numel(), 0.0f, dX->template mutable_data<float>(), &context_);
const auto size = dY.numel();
NearestNeighborGradientKernel<<<
CAFFE_GET_BLOCKS(size),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(
size,
num_channels,
input_height,
input_width,
output_height,
output_width,
height_scale_,
width_scale_,
dY.data<float>(),
dX->template mutable_data<float>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
return true;
}
REGISTER_CUDA_OPERATOR(ResizeNearest, ResizeNearestOp<float, CUDAContext>);
REGISTER_CUDA_OPERATOR(
ResizeNearestGradient,
ResizeNearestGradientOp<float, CUDAContext>);
} // namespace caffe2
using ResizeNearestOpFloatCUDA =
caffe2::ResizeNearestOp<float, caffe2::CUDAContext>;
C10_EXPORT_CAFFE2_OP_TO_C10_CUDA(ResizeNearest, ResizeNearestOpFloatCUDA);
|