1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274
|
#include <algorithm>
#include <cmath>
#include <vector>
#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/lstm_unit_op.h"
namespace caffe2 {
namespace detail {
template <typename Dtype>
__device__ Dtype cuda_sigmoid(const Dtype x) {
return Dtype(1) / (Dtype(1) + exp(-x));
}
template <typename T, typename MATH>
__global__ void LSTMUnitKernel(
const int nthreads,
const int dim,
const int t,
const T* H_prev,
const T* C_prev,
const T* X,
const int32_t* seqLengths,
bool drop_states,
T* C,
T* H,
const MATH forget_bias) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
const int n = index / dim;
const int d = index % dim;
const bool valid = seqLengths == nullptr || t < seqLengths[n];
if (!valid) {
H[index] = convert::To<MATH, T>(convert::To<T, MATH>(H_prev[index]) * !drop_states);
C[index] = convert::To<MATH, T>(convert::To<T, MATH>(C_prev[index]) * !drop_states);
} else {
const T* X_offset = X + 4 * dim * n;
const MATH i = cuda_sigmoid(convert::To<T, MATH>(X_offset[d]));
const MATH f = cuda_sigmoid(convert::To<T, MATH>(X_offset[1 * dim + d]) + forget_bias);
const MATH o = cuda_sigmoid(convert::To<T, MATH>(X_offset[2 * dim + d]));
const MATH g = tanh(convert::To<T, MATH>(X_offset[3 * dim + d]));
const MATH c_prev = convert::To<T, MATH>(C_prev[index]);
const MATH c = f * c_prev + i * g;
C[index] = convert::To<MATH, T>(c);
const MATH tanh_c = tanh(c);
H[index] = convert::To<MATH, T>(o * tanh_c);
}
}
}
template <typename T, typename MATH>
__global__ void LSTMUnitGradientKernel(
const int nthreads,
const int dim,
const int t,
const T* C_prev,
const T* X,
const T* C,
const T* H,
const int32_t* seqLengths,
const T* C_diff,
const T* H_diff,
bool drop_states,
T* H_prev_diff,
T* C_prev_diff,
T* X_diff,
const MATH forget_bias) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
const int n = index / dim;
const bool valid = seqLengths == nullptr || t < seqLengths[n];
const int d = index % dim;
const T* X_offset = X + 4 * dim * n;
T* c_prev_diff = C_prev_diff + index;
T* h_prev_diff = H_prev_diff + index;
T* X_diff_offset = X_diff + 4 * dim * n;
T* i_diff = X_diff_offset + d;
T* f_diff = X_diff_offset + 1 * dim + d;
T* o_diff = X_diff_offset + 2 * dim + d;
T* g_diff = X_diff_offset + 3 * dim + d;
if (!valid) {
*h_prev_diff = convert::To<MATH, T>(convert::To<T, MATH>(H_diff[index]) *
!drop_states);
*c_prev_diff = convert::To<MATH, T>(convert::To<T, MATH>(C_diff[index]) *
!drop_states);
*i_diff = convert::To<MATH, T>(0);
*f_diff = convert::To<MATH, T>(0);
*o_diff = convert::To<MATH, T>(0);
*g_diff = convert::To<MATH, T>(0);
} else {
const MATH i = cuda_sigmoid(convert::To<T, MATH>(X_offset[d]));
const MATH f = cuda_sigmoid(convert::To<T, MATH>(X_offset[1 * dim + d]) + forget_bias);
const MATH o = cuda_sigmoid(convert::To<T, MATH>(X_offset[2 * dim + d]));
const MATH g = tanh(convert::To<T, MATH>(X_offset[3 * dim + d]));
const MATH c_prev = convert::To<T, MATH>(C_prev[index]);
const MATH c = convert::To<T, MATH>(C[index]);
const MATH tanh_c = tanh(c);
const MATH c_term_diff =
convert::To<T, MATH>(C_diff[index]) +
convert::To<T, MATH>(H_diff[index]) * o * (1 - tanh_c * tanh_c);
*c_prev_diff = convert::To<MATH, T>(c_term_diff * f);
*h_prev_diff = convert::To<MATH, T>(0);
*i_diff = convert::To<MATH, T>(c_term_diff * g * i * (1 - i));
*f_diff = convert::To<MATH, T>(c_term_diff * c_prev * f * (1 - f));
*o_diff = convert::To<MATH, T>(
convert::To<T, MATH>(H_diff[index]) * tanh_c * o * (1 - o));
*g_diff = convert::To<MATH, T>(c_term_diff * i * (1 - g * g));
}
}
}
template <>
void LSTMUnit<float, CUDAContext>(
int N,
int D,
int t,
const float* H_prev,
const float* C_prev,
const float* X,
const int32_t* seqLengths,
bool drop_states,
float* C,
float* H,
const float forget_bias,
CUDAContext* context) {
LSTMUnitKernel<float, float><<<
CAFFE_GET_BLOCKS(N * D),
CAFFE_CUDA_NUM_THREADS,
0,
context->cuda_stream()>>>(
N * D,
D,
t,
H_prev,
C_prev,
X,
seqLengths,
drop_states,
C,
H,
forget_bias);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
template <>
void LSTMUnit<at::Half, CUDAContext>(
int N,
int D,
int t,
const at::Half* H_prev,
const at::Half* C_prev,
const at::Half* X,
const int32_t* seqLengths,
bool drop_states,
at::Half* C,
at::Half* H,
const float forget_bias,
CUDAContext* context) {
LSTMUnitKernel<at::Half, float><<<
CAFFE_GET_BLOCKS(N * D),
CAFFE_CUDA_NUM_THREADS,
0,
context->cuda_stream()>>>(
N * D,
D,
t,
H_prev,
C_prev,
X,
seqLengths,
drop_states,
C,
H,
forget_bias);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
template <>
void LSTMUnitGradient<float, CUDAContext>(
int N,
int D,
int t,
const float* C_prev,
const float* X,
const int32_t* seqLengths,
const float* C,
const float* H,
const float* C_diff,
const float* H_diff,
bool drop_states,
float* H_prev_diff,
float* C_prev_diff,
float* X_diff,
const float forget_bias,
CUDAContext* context) {
LSTMUnitGradientKernel<float, float><<<
CAFFE_GET_BLOCKS(N * D),
CAFFE_CUDA_NUM_THREADS,
0,
context->cuda_stream()>>>(
N * D,
D,
t,
C_prev,
X,
C,
H,
seqLengths,
C_diff,
H_diff,
drop_states,
H_prev_diff,
C_prev_diff,
X_diff,
forget_bias);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
template <>
void LSTMUnitGradient<at::Half, CUDAContext>(
int N,
int D,
int t,
const at::Half* C_prev,
const at::Half* X,
const int32_t* seqLengths,
const at::Half* C,
const at::Half* H,
const at::Half* C_diff,
const at::Half* H_diff,
bool drop_states,
at::Half* H_prev_diff,
at::Half* C_prev_diff,
at::Half* X_diff,
const float forget_bias,
CUDAContext* context) {
LSTMUnitGradientKernel<at::Half, float><<<
CAFFE_GET_BLOCKS(N * D),
CAFFE_CUDA_NUM_THREADS,
0,
context->cuda_stream()>>>(
N * D,
D,
t,
C_prev,
X,
C,
H,
seqLengths,
C_diff,
H_diff,
drop_states,
H_prev_diff,
C_prev_diff,
X_diff,
forget_bias);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
template <>
bool LSTMUnitOp<CUDAContext>::RunOnDevice() {
return DispatchHelper<TensorTypes<float, at::Half>>::call(this, Input(0));
}
template <>
bool LSTMUnitGradientOp<CUDAContext>::RunOnDevice() {
return DispatchHelper<TensorTypes<float, at::Half>>::call(this, Input(0));
}
REGISTER_CUDA_OPERATOR(LSTMUnit, LSTMUnitOp<CUDAContext>);
REGISTER_CUDA_OPERATOR(
LSTMUnitGradient,
LSTMUnitGradientOp<CUDAContext>);
}
|