1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
|
#ifndef CAFFE2_OPERATORS_RECURRENT_OP_CUDNN_H_
#define CAFFE2_OPERATORS_RECURRENT_OP_CUDNN_H_
#include "caffe2/core/context.h"
#include "caffe2/core/context_gpu.h"
#include "caffe2/core/cudnn_wrappers.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
namespace caffe2 {
namespace detail {
template <typename T>
class TensorDescriptors {
public:
TensorDescriptors(
size_t n,
const std::vector<int>& dim,
const std::vector<int>& stride);
~TensorDescriptors();
const cudnnTensorDescriptor_t* descs() const {
return descs_.data();
}
private:
std::vector<cudnnTensorDescriptor_t> descs_;
};
} // namespace detail
template <typename T>
class RecurrentBaseOp : public Operator<CUDAContext> {
public:
USE_OPERATOR_FUNCTIONS(CUDAContext);
template<class... Args> explicit RecurrentBaseOp(Args&&... args)
: Operator<CUDAContext>(std::forward<Args>(args)...), cudnn_wrapper_(&context_) {
CUDNN_ENFORCE(cudnnCreateDropoutDescriptor(&dropoutDesc_));
CUDNN_ENFORCE(cudnnCreateRNNDescriptor(&rnnDesc_));
CUDNN_ENFORCE(cudnnCreateFilterDescriptor(&wDesc_));
CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&hxDesc_));
}
virtual ~RecurrentBaseOp();
protected:
void initialize(
const Tensor& input,
Tensor* dropoutStates = nullptr,
// If passed, reshapes to the appropriate size
Tensor* output = nullptr,
Tensor* hiddenOutput = nullptr,
Tensor* cellOutput = nullptr);
CuDNNWrapper cudnn_wrapper_;
cudnnDropoutDescriptor_t dropoutDesc_;
cudnnRNNDescriptor_t rnnDesc_;
cudnnFilterDescriptor_t wDesc_;
cudnnTensorDescriptor_t hxDesc_;
cudnnTensorDescriptor_t cxDesc_;
cudnnTensorDescriptor_t hyDesc_;
cudnnTensorDescriptor_t cyDesc_;
std::unique_ptr<detail::TensorDescriptors<T>> xDesc_;
std::unique_ptr<detail::TensorDescriptors<T>> yDesc_;
std::vector<int64_t> cachedInputDims_;
size_t reserveNbytes_;
size_t cudnnWsNbytes_;
private:
};
#define USE_RECURRENT_BASE_FUNCTIONS \
USE_OPERATOR_FUNCTIONS(CUDAContext); \
using RecurrentBaseOp<T>::cudnn_wrapper_; \
using RecurrentBaseOp<T>::dropoutDesc_; \
using RecurrentBaseOp<T>::rnnDesc_; \
using RecurrentBaseOp<T>::wDesc_; \
using RecurrentBaseOp<T>::hxDesc_; \
using RecurrentBaseOp<T>::cxDesc_; \
using RecurrentBaseOp<T>::hyDesc_; \
using RecurrentBaseOp<T>::cyDesc_; \
using RecurrentBaseOp<T>::xDesc_; \
using RecurrentBaseOp<T>::yDesc_; \
using RecurrentBaseOp<T>::cachedInputDims_; \
using RecurrentBaseOp<T>::reserveNbytes_; \
using RecurrentBaseOp<T>::cudnnWsNbytes_; \
using RecurrentBaseOp<T>::initialize;
template <typename T>
class RecurrentOp : public RecurrentBaseOp<T> {
public:
USE_RECURRENT_BASE_FUNCTIONS
template <class... Args>
explicit RecurrentOp(Args&&... args)
: RecurrentBaseOp<T>(std::forward<Args>(args)...) {}
bool RunOnDevice() override;
protected:
INPUT_TAGS(INPUT, HIDDEN_INPUT, CELL_INPUT, WEIGHT);
OUTPUT_TAGS(OUTPUT, HIDDEN_OUTPUT, CELL_OUTPUT, RNN_SCRATCH, DROPOUT_STATES);
};
enum RecurrentParamOpMode { SET_PARAM, GET_PARAM };
template <typename T, RecurrentParamOpMode mode>
class RecurrentParamAccessOp : public RecurrentBaseOp<T> {
public:
USE_RECURRENT_BASE_FUNCTIONS
template <class... Args>
explicit RecurrentParamAccessOp(Args&&... args)
: RecurrentBaseOp<T>(std::forward<Args>(args)...) {}
bool RunOnDevice() override;
};
template <typename T>
class RecurrentGradientOp : public RecurrentBaseOp<T> {
public:
USE_RECURRENT_BASE_FUNCTIONS
template <class... Args>
explicit RecurrentGradientOp(Args&&... args)
: RecurrentBaseOp<T>(std::forward<Args>(args)...) {}
bool RunOnDevice() override;
protected:
INPUT_TAGS(
INPUT,
HIDDEN_INPUT,
CELL_INPUT,
WEIGHT,
RNN_SCRATCH,
OUTPUT,
GRAD_OUTPUT,
GRAD_HIDDEN_OUTPUT,
GRAD_CELL_OUTPUT);
OUTPUT_TAGS(
GRAD_INPUT,
GRAD_HIDDEN_INPUT,
GRAD_CELL_INPUT,
GRAD_WEIGHT,
DROPOUT_STATES,
RNN_SCRATCH_OUT);
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_RECURRENT_OP_CUDNN_H_
|