File: order_switch_ops_cudnn.cc

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (169 lines) | stat: -rw-r--r-- 5,871 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
#include "caffe2/operators/order_switch_ops.h"

#include <algorithm>
#include <functional>
#include <vector>

#include "caffe2/core/context_gpu.h"
#include "caffe2/core/cudnn_wrappers.h"
#include "caffe2/core/types.h"

namespace caffe2 {

namespace {

class CuDNNOrderSwithOpBase : public Operator<CUDAContext> {
 public:
  USE_OPERATOR_FUNCTIONS(CUDAContext);

  template <class... Args>
  explicit CuDNNOrderSwithOpBase(Args&&... args)
      : Operator<CUDAContext>(std::forward<Args>(args)...),
        cudnn_wrapper_(&context_) {
    CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&X_desc_));
    CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&Y_desc_));
  }

  ~CuDNNOrderSwithOpBase() override {
    CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(X_desc_));
    CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(Y_desc_));
  }

 protected:
  // TODO: std::vector<int> -> std::vector<int64_t>
  void SetTensorDescriptor(
      const cudnnDataType_t data_type,
      const StorageOrder order,
      const std::vector<int>& data_dims,
      cudnnTensorDescriptor_t data_desc) const {
    const int ndim = data_dims.size();
    const int N = data_dims[0];
    const int C = order == StorageOrder::NCHW ? data_dims[1] : data_dims.back();
    if (ndim == 3) {
      const int H = 1;
      const int W = order == StorageOrder::NCHW ? data_dims[2] : data_dims[1];
      CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
          data_desc, GetCudnnTensorFormat(order), data_type, N, C, H, W));
    } else if (ndim == 4) {
      const int H = order == StorageOrder::NCHW ? data_dims[2] : data_dims[1];
      const int W = order == StorageOrder::NCHW ? data_dims[3] : data_dims[2];
      CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
          data_desc, GetCudnnTensorFormat(order), data_type, N, C, H, W));
    } else {
      const int H = order == StorageOrder::NCHW ? data_dims[2] : data_dims[1];
      const int W = order == StorageOrder::NCHW ? data_dims[3] : data_dims[2];
      const auto l_iter = order == StorageOrder::NCHW ? data_dims.cbegin() + 4
                                                      : data_dims.cbegin() + 3;
      const auto r_iter =
          order == StorageOrder::NCHW ? data_dims.cend() : data_dims.cend() - 1;
      const int D = std::accumulate(l_iter, r_iter, 1, std::multiplies<int>());
      const std::array<int, 5> dims = {N, C, H, W, D};
      const std::array<int, 5> strides = order == StorageOrder::NCHW
          ? std::array<int, 5>{C * H * W * D, H * W * D, W * D, D, 1}
          : std::array<int, 5>{C * H * W * D, 1, W * D * C, D * C, C};
      CUDNN_ENFORCE(cudnnSetTensorNdDescriptor(
          data_desc, data_type, 5, dims.data(), strides.data()));
    }
  }

  CuDNNWrapper cudnn_wrapper_;
  cudnnTensorDescriptor_t X_desc_;
  cudnnTensorDescriptor_t Y_desc_;

  std::vector<int> cached_X_dims_;
};

class CuDNNNHWC2NCHWOp final : public CuDNNOrderSwithOpBase {
 public:
  template <class... Args>
  explicit CuDNNNHWC2NCHWOp(Args&&... args)
      : CuDNNOrderSwithOpBase(std::forward<Args>(args)...) {}

  bool RunOnDevice() override {
    return DispatchHelper<TensorTypes<float, at::Half>>::call(this, Input(0));
  }

  template <typename T>
  bool DoRunWithType() {
    const auto& X = Input(0);

    const int ndim = X.dim();
    const int N = X.dim32(0);
    const int C = X.dim32(ndim - 1);
    const std::vector<int> X_dims(X.sizes().cbegin(), X.sizes().cend());
    std::vector<int> Y_dims(ndim);
    Y_dims[0] = N;
    Y_dims[1] = C;
    std::copy(X_dims.cbegin() + 1, X_dims.cend() - 1, Y_dims.begin() + 2);
    std::vector<int64_t> Y_dims_64;
    std::copy(Y_dims.cbegin(), Y_dims.cend(), std::back_inserter(Y_dims_64));
    auto* Y = Output(0, Y_dims_64, at::dtype<T>());
    if (cached_X_dims_ != X_dims) {
      cached_X_dims_ = X_dims;
      SetTensorDescriptor(
          cudnnTypeWrapper<T>::type, StorageOrder::NHWC, X_dims, X_desc_);
      SetTensorDescriptor(
          cudnnTypeWrapper<T>::type, StorageOrder::NCHW, Y_dims, Y_desc_);
    }
    CUDNN_ENFORCE(cudnnTransformTensor(
        cudnn_wrapper_.inline_cudnn_handle(),
        cudnnTypeWrapper<T>::kOne(),
        X_desc_,
        X.template data<T>(),
        cudnnTypeWrapper<T>::kZero(),
        Y_desc_,
        Y->template mutable_data<T>()));
    return true;
  }
};

class CuDNNNCHW2NHWCOp final : public CuDNNOrderSwithOpBase {
 public:
  template <class... Args>
  explicit CuDNNNCHW2NHWCOp(Args&&... args)
      : CuDNNOrderSwithOpBase(std::forward<Args>(args)...) {}

  bool RunOnDevice() override {
    return DispatchHelper<TensorTypes<float, at::Half>>::call(this, Input(0));
  }

  template <typename T>
  bool DoRunWithType() {
    const auto& X = Input(0);

    const int ndim = X.dim();
    const int N = X.dim32(0);
    const int C = X.dim32(1);
    const std::vector<int> X_dims(X.sizes().cbegin(), X.sizes().cend());
    std::vector<int> Y_dims(ndim);
    Y_dims[0] = N;
    Y_dims[ndim - 1] = C;
    std::copy(X_dims.cbegin() + 2, X_dims.cend(), Y_dims.begin() + 1);
    std::vector<int64_t> Y_dims_64;
    std::copy(Y_dims.cbegin(), Y_dims.cend(), std::back_inserter(Y_dims_64));
    auto* Y = Output(0, Y_dims_64, at::dtype<T>());
    if (cached_X_dims_ != X_dims) {
      cached_X_dims_ = X_dims;
      SetTensorDescriptor(
          cudnnTypeWrapper<T>::type, StorageOrder::NCHW, X_dims, X_desc_);
      SetTensorDescriptor(
          cudnnTypeWrapper<T>::type, StorageOrder::NHWC, Y_dims, Y_desc_);
    }
    CUDNN_ENFORCE(cudnnTransformTensor(
        cudnn_wrapper_.inline_cudnn_handle(),
        cudnnTypeWrapper<T>::kOne(),
        X_desc_,
        X.template data<T>(),
        cudnnTypeWrapper<T>::kZero(),
        Y_desc_,
        Y->template mutable_data<T>()));
    return true;
  }
};

} // namespace

REGISTER_CUDNN_OPERATOR(NHWC2NCHW, CuDNNNHWC2NCHWOp);
REGISTER_CUDNN_OPERATOR(NCHW2NHWC, CuDNNNCHW2NHWCOp);

} // namespace caffe2