File: recurrent_op_miopen.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (140 lines) | stat: -rw-r--r-- 3,878 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#ifndef CAFFE2_OPERATORS_RECURRENT_OP_MIOPEN_H_
#define CAFFE2_OPERATORS_RECURRENT_OP_MIOPEN_H_

#include "caffe2/core/context.h"
#include "caffe2/core/hip/context_gpu.h"
#include "caffe2/core/hip/miopen_wrapper.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"

namespace caffe2 {
namespace detail {

template <typename T>
class TensorDescriptors {
 public:
  TensorDescriptors(
      size_t n,
      // dim and stride are not declared as const as opposed to cuDNN
      // since miopenSetTensorDescriptor doesn't take const arguments
      std::vector<int>& dim,
      std::vector<int>& stride);
  ~TensorDescriptors();
  const miopenTensorDescriptor_t* descs() const {
    return descs_.data();
  }

 private:
  std::vector<miopenTensorDescriptor_t> descs_;
};

} // namespace detail

template <typename T>
class RecurrentBaseOp : public Operator<HIPContext> {
 public:
  USE_OPERATOR_FUNCTIONS(HIPContext);
  RecurrentBaseOp(const OperatorDef& operator_def, Workspace* ws);
  virtual ~RecurrentBaseOp();

 protected:
  void initialize(
      const Tensor& input,
      // If passed, reshapes to the appropriate size
      Tensor* output = nullptr,
      Tensor* hiddenOutput = nullptr,
      Tensor* cellOutput = nullptr);

  MIOPENWrapper miopen_wrapper_;
  miopenRNNDescriptor_t rnnDesc_;
  miopenTensorDescriptor_t wDesc_;
  miopenTensorDescriptor_t hxDesc_;
  miopenTensorDescriptor_t cxDesc_;
  miopenTensorDescriptor_t hyDesc_;
  miopenTensorDescriptor_t cyDesc_;

  std::unique_ptr<detail::TensorDescriptors<T>> xDesc_;
  std::unique_ptr<detail::TensorDescriptors<T>> yDesc_;

  std::vector<int64_t> cachedInputDims_;
  size_t reserveNbytes_;
  size_t miopenWsNbytes_;

 private:
};

#define USE_RECURRENT_BASE_FUNCTIONS          \
  USE_OPERATOR_FUNCTIONS(HIPContext);        \
  using RecurrentBaseOp<T>::miopen_wrapper_;   \
  using RecurrentBaseOp<T>::rnnDesc_;         \
  using RecurrentBaseOp<T>::wDesc_;           \
  using RecurrentBaseOp<T>::hxDesc_;          \
  using RecurrentBaseOp<T>::cxDesc_;          \
  using RecurrentBaseOp<T>::hyDesc_;          \
  using RecurrentBaseOp<T>::cyDesc_;          \
  using RecurrentBaseOp<T>::xDesc_;           \
  using RecurrentBaseOp<T>::yDesc_;           \
  using RecurrentBaseOp<T>::cachedInputDims_; \
  using RecurrentBaseOp<T>::reserveNbytes_;   \
  using RecurrentBaseOp<T>::miopenWsNbytes_;   \
  using RecurrentBaseOp<T>::initialize;

template <typename T>
class RecurrentOp : public RecurrentBaseOp<T> {
 public:
  USE_RECURRENT_BASE_FUNCTIONS
  RecurrentOp(const OperatorDef& operator_def, Workspace* ws)
      : RecurrentBaseOp<T>(operator_def, ws) {}

  bool RunOnDevice() override;

 protected:
  INPUT_TAGS(INPUT, HIDDEN_INPUT, CELL_INPUT, WEIGHT);
  OUTPUT_TAGS(OUTPUT, HIDDEN_OUTPUT, CELL_OUTPUT, RNN_SCRATCH, DROPOUT_STATES);
};

enum RecurrentParamOpMode { SET_PARAM, GET_PARAM };

template <typename T, RecurrentParamOpMode mode>
class RecurrentParamAccessOp : public RecurrentBaseOp<T> {
 public:
  USE_RECURRENT_BASE_FUNCTIONS
  RecurrentParamAccessOp(const OperatorDef& operator_def, Workspace* ws)
      : RecurrentBaseOp<T>(operator_def, ws) {}

  bool RunOnDevice() override;
};

template <typename T>
class RecurrentGradientOp : public RecurrentBaseOp<T> {
 public:
  USE_RECURRENT_BASE_FUNCTIONS
  RecurrentGradientOp(const OperatorDef& operator_def, Workspace* ws)
      : RecurrentBaseOp<T>(operator_def, ws) {}

  bool RunOnDevice() override;

 protected:
  INPUT_TAGS(
      INPUT,
      HIDDEN_INPUT,
      CELL_INPUT,
      WEIGHT,
      RNN_SCRATCH,
      OUTPUT,
      GRAD_OUTPUT,
      GRAD_HIDDEN_OUTPUT,
      GRAD_CELL_OUTPUT);
  OUTPUT_TAGS(
      GRAD_INPUT,
      GRAD_HIDDEN_INPUT,
      GRAD_CELL_INPUT,
      GRAD_WEIGHT,
      DROPOUT_STATES,
      RNN_SCRATCH_OUT);
};


} // namespace caffe2

#endif // CAFFE2_OPERATORS_RECURRENT_OP_MIOPEN_H_