1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
|
#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/rnn/recurrent_network_op.h"
namespace caffe2 {
namespace detail {
template <typename T, typename Context>
void initializeRecurrentInput(
const RecurrentInput& rc,
int32_t seqLen,
int32_t batchSize,
Workspace* ws,
Context* context);
namespace {
template <typename T>
__global__
void initRecurrentInput_kernel(
size_t stateSize,
const T* input,
T* state) {
// index into appropriate target buffer
const int block_id = blockIdx.x;
T* state_local = state + block_id*stateSize;
// copy
for (int idx=threadIdx.x; idx < stateSize; idx+=blockDim.x) {
state_local[idx] = input[idx];
}
}
}; // namespace
template <>
void repeatCopy(
size_t repeat_n,
size_t n,
const float* src,
float* dst,
CUDAContext* context) {
initRecurrentInput_kernel<float><<<repeat_n, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
n, src, dst);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
template <>
void repeatCopy(
size_t repeat_n,
size_t n,
const at::Half* src,
at::Half* dst,
CUDAContext* context) {
initRecurrentInput_kernel<at::Half><<<repeat_n, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
n, src, dst);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}; // namespace detail
template <>
bool RecurrentNetworkOp<CUDAContext>::RunOnDevice() {
return DispatchHelper<TensorTypes<float, at::Half>>::call(this, Input(0));
}
template <>
bool RecurrentNetworkGradientOp<CUDAContext>::RunOnDevice() {
return DispatchHelper<TensorTypes<float, at::Half>>::call(this, Input(0));
}
template <>
bool AccumulateInputGradientOp<CUDAContext>::RunOnDevice() {
return DispatchHelper<TensorTypes<float, at::Half>>::call(this, Input(1));
}
template <>
bool RNNApplyLinkOp<CUDAContext>::RunOnDevice() {
return DispatchHelper<TensorTypes<float, at::Half>>::call(this, Input(1));
}
REGISTER_CUDA_OPERATOR(
RecurrentNetwork,
RecurrentNetworkOp<CUDAContext>);
REGISTER_CUDA_OPERATOR(
RecurrentNetworkGradient,
RecurrentNetworkGradientOp<CUDAContext>);
REGISTER_CUDA_OPERATOR(
rnn_internal_accumulate_gradient_input,
AccumulateInputGradientOp<CUDAContext>);
REGISTER_CUDA_OPERATOR(
rnn_internal_apply_link,
RNNApplyLinkOp<CUDAContext>);
} // namespace caffe2
|