File: recurrent_op_cudnn.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (150 lines) | stat: -rw-r--r-- 4,233 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#ifndef CAFFE2_OPERATORS_RECURRENT_OP_CUDNN_H_
#define CAFFE2_OPERATORS_RECURRENT_OP_CUDNN_H_

#include "caffe2/core/context.h"
#include "caffe2/core/context_gpu.h"
#include "caffe2/core/cudnn_wrappers.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"

namespace caffe2 {
namespace detail {

template <typename T>
class TensorDescriptors {
 public:
  TensorDescriptors(
      size_t n,
      const std::vector<int>& dim,
      const std::vector<int>& stride);
  ~TensorDescriptors();
  const cudnnTensorDescriptor_t* descs() const {
    return descs_.data();
  }

 private:
  std::vector<cudnnTensorDescriptor_t> descs_;
};

} // namespace detail

template <typename T>
class RecurrentBaseOp : public Operator<CUDAContext> {
 public:
  USE_OPERATOR_FUNCTIONS(CUDAContext);
  template<class... Args> explicit RecurrentBaseOp(Args&&... args)
  : Operator<CUDAContext>(std::forward<Args>(args)...), cudnn_wrapper_(&context_) {
      CUDNN_ENFORCE(cudnnCreateDropoutDescriptor(&dropoutDesc_));
      CUDNN_ENFORCE(cudnnCreateRNNDescriptor(&rnnDesc_));
      CUDNN_ENFORCE(cudnnCreateFilterDescriptor(&wDesc_));
      CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&hxDesc_));
  }
  virtual ~RecurrentBaseOp();

 protected:
  void initialize(
      const Tensor& input,
      Tensor* dropoutStates = nullptr,
      // If passed, reshapes to the appropriate size
      Tensor* output = nullptr,
      Tensor* hiddenOutput = nullptr,
      Tensor* cellOutput = nullptr);

  CuDNNWrapper cudnn_wrapper_;
  cudnnDropoutDescriptor_t dropoutDesc_;
  cudnnRNNDescriptor_t rnnDesc_;
  cudnnFilterDescriptor_t wDesc_;
  cudnnTensorDescriptor_t hxDesc_;
  cudnnTensorDescriptor_t cxDesc_;
  cudnnTensorDescriptor_t hyDesc_;
  cudnnTensorDescriptor_t cyDesc_;

  std::unique_ptr<detail::TensorDescriptors<T>> xDesc_;
  std::unique_ptr<detail::TensorDescriptors<T>> yDesc_;

  std::vector<int64_t> cachedInputDims_;
  size_t reserveNbytes_;
  size_t cudnnWsNbytes_;

 private:
};

#define USE_RECURRENT_BASE_FUNCTIONS          \
  USE_OPERATOR_FUNCTIONS(CUDAContext);        \
  using RecurrentBaseOp<T>::cudnn_wrapper_;   \
  using RecurrentBaseOp<T>::dropoutDesc_;     \
  using RecurrentBaseOp<T>::rnnDesc_;         \
  using RecurrentBaseOp<T>::wDesc_;           \
  using RecurrentBaseOp<T>::hxDesc_;          \
  using RecurrentBaseOp<T>::cxDesc_;          \
  using RecurrentBaseOp<T>::hyDesc_;          \
  using RecurrentBaseOp<T>::cyDesc_;          \
  using RecurrentBaseOp<T>::xDesc_;           \
  using RecurrentBaseOp<T>::yDesc_;           \
  using RecurrentBaseOp<T>::cachedInputDims_; \
  using RecurrentBaseOp<T>::reserveNbytes_;   \
  using RecurrentBaseOp<T>::cudnnWsNbytes_;   \
  using RecurrentBaseOp<T>::initialize;

template <typename T>
class RecurrentOp : public RecurrentBaseOp<T> {
 public:
  USE_RECURRENT_BASE_FUNCTIONS
  template <class... Args>
  explicit RecurrentOp(Args&&... args)
      : RecurrentBaseOp<T>(std::forward<Args>(args)...) {}

  bool RunOnDevice() override;

 protected:
  INPUT_TAGS(INPUT, HIDDEN_INPUT, CELL_INPUT, WEIGHT);
  OUTPUT_TAGS(OUTPUT, HIDDEN_OUTPUT, CELL_OUTPUT, RNN_SCRATCH, DROPOUT_STATES);
};

enum RecurrentParamOpMode { SET_PARAM, GET_PARAM };

template <typename T, RecurrentParamOpMode mode>
class RecurrentParamAccessOp : public RecurrentBaseOp<T> {
 public:
  USE_RECURRENT_BASE_FUNCTIONS
  template <class... Args>
  explicit RecurrentParamAccessOp(Args&&... args)
      : RecurrentBaseOp<T>(std::forward<Args>(args)...) {}

  bool RunOnDevice() override;
};

template <typename T>
class RecurrentGradientOp : public RecurrentBaseOp<T> {
 public:
  USE_RECURRENT_BASE_FUNCTIONS
  template <class... Args>
  explicit RecurrentGradientOp(Args&&... args)
      : RecurrentBaseOp<T>(std::forward<Args>(args)...) {}

  bool RunOnDevice() override;

 protected:
  INPUT_TAGS(
      INPUT,
      HIDDEN_INPUT,
      CELL_INPUT,
      WEIGHT,
      RNN_SCRATCH,
      OUTPUT,
      GRAD_OUTPUT,
      GRAD_HIDDEN_OUTPUT,
      GRAD_CELL_OUTPUT);
  OUTPUT_TAGS(
      GRAD_INPUT,
      GRAD_HIDDEN_INPUT,
      GRAD_CELL_INPUT,
      GRAD_WEIGHT,
      DROPOUT_STATES,
      RNN_SCRATCH_OUT);
};


} // namespace caffe2

#endif // CAFFE2_OPERATORS_RECURRENT_OP_CUDNN_H_