1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
|
#ifndef CAFFE2_OPERATORS_RECURRENT_OP_MIOPEN_H_
#define CAFFE2_OPERATORS_RECURRENT_OP_MIOPEN_H_
#include "caffe2/core/context.h"
#include "caffe2/core/hip/context_gpu.h"
#include "caffe2/core/hip/miopen_wrapper.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
namespace caffe2 {
namespace detail {
template <typename T>
class TensorDescriptors {
public:
TensorDescriptors(
size_t n,
// dim and stride are not declared as const as opposed to cuDNN
// since miopenSetTensorDescriptor doesn't take const arguments
std::vector<int>& dim,
std::vector<int>& stride);
~TensorDescriptors();
const miopenTensorDescriptor_t* descs() const {
return descs_.data();
}
private:
std::vector<miopenTensorDescriptor_t> descs_;
};
} // namespace detail
template <typename T>
class RecurrentBaseOp : public Operator<HIPContext> {
public:
USE_OPERATOR_FUNCTIONS(HIPContext);
RecurrentBaseOp(const OperatorDef& operator_def, Workspace* ws);
virtual ~RecurrentBaseOp();
protected:
void initialize(
const Tensor& input,
// If passed, reshapes to the appropriate size
Tensor* output = nullptr,
Tensor* hiddenOutput = nullptr,
Tensor* cellOutput = nullptr);
MIOPENWrapper miopen_wrapper_;
miopenRNNDescriptor_t rnnDesc_;
miopenTensorDescriptor_t wDesc_;
miopenTensorDescriptor_t hxDesc_;
miopenTensorDescriptor_t cxDesc_;
miopenTensorDescriptor_t hyDesc_;
miopenTensorDescriptor_t cyDesc_;
std::unique_ptr<detail::TensorDescriptors<T>> xDesc_;
std::unique_ptr<detail::TensorDescriptors<T>> yDesc_;
std::vector<int64_t> cachedInputDims_;
size_t reserveNbytes_;
size_t miopenWsNbytes_;
private:
};
#define USE_RECURRENT_BASE_FUNCTIONS \
USE_OPERATOR_FUNCTIONS(HIPContext); \
using RecurrentBaseOp<T>::miopen_wrapper_; \
using RecurrentBaseOp<T>::rnnDesc_; \
using RecurrentBaseOp<T>::wDesc_; \
using RecurrentBaseOp<T>::hxDesc_; \
using RecurrentBaseOp<T>::cxDesc_; \
using RecurrentBaseOp<T>::hyDesc_; \
using RecurrentBaseOp<T>::cyDesc_; \
using RecurrentBaseOp<T>::xDesc_; \
using RecurrentBaseOp<T>::yDesc_; \
using RecurrentBaseOp<T>::cachedInputDims_; \
using RecurrentBaseOp<T>::reserveNbytes_; \
using RecurrentBaseOp<T>::miopenWsNbytes_; \
using RecurrentBaseOp<T>::initialize;
template <typename T>
class RecurrentOp : public RecurrentBaseOp<T> {
public:
USE_RECURRENT_BASE_FUNCTIONS
RecurrentOp(const OperatorDef& operator_def, Workspace* ws)
: RecurrentBaseOp<T>(operator_def, ws) {}
bool RunOnDevice() override;
protected:
INPUT_TAGS(INPUT, HIDDEN_INPUT, CELL_INPUT, WEIGHT);
OUTPUT_TAGS(OUTPUT, HIDDEN_OUTPUT, CELL_OUTPUT, RNN_SCRATCH, DROPOUT_STATES);
};
enum RecurrentParamOpMode { SET_PARAM, GET_PARAM };
template <typename T, RecurrentParamOpMode mode>
class RecurrentParamAccessOp : public RecurrentBaseOp<T> {
public:
USE_RECURRENT_BASE_FUNCTIONS
RecurrentParamAccessOp(const OperatorDef& operator_def, Workspace* ws)
: RecurrentBaseOp<T>(operator_def, ws) {}
bool RunOnDevice() override;
};
template <typename T>
class RecurrentGradientOp : public RecurrentBaseOp<T> {
public:
USE_RECURRENT_BASE_FUNCTIONS
RecurrentGradientOp(const OperatorDef& operator_def, Workspace* ws)
: RecurrentBaseOp<T>(operator_def, ws) {}
bool RunOnDevice() override;
protected:
INPUT_TAGS(
INPUT,
HIDDEN_INPUT,
CELL_INPUT,
WEIGHT,
RNN_SCRATCH,
OUTPUT,
GRAD_OUTPUT,
GRAD_HIDDEN_OUTPUT,
GRAD_CELL_OUTPUT);
OUTPUT_TAGS(
GRAD_INPUT,
GRAD_HIDDEN_INPUT,
GRAD_CELL_INPUT,
GRAD_WEIGHT,
DROPOUT_STATES,
RNN_SCRATCH_OUT);
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_RECURRENT_OP_MIOPEN_H_
|