File: softmax_op_cudnn.cc

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (148 lines) | stat: -rw-r--r-- 4,054 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#include "caffe2/core/context_gpu.h"
#include "caffe2/core/cudnn_wrappers.h"
#include "caffe2/core/types.h"
#include "caffe2/operators/softmax_op.h"

namespace caffe2 {

namespace {
constexpr int NUM_DESCRIPTORS = 2;
constexpr int GRADIENT_NUM_DESCRIPTORS = 3;
constexpr int BOTTOM_DESC_ID = 0;
constexpr int TOP_DESC_ID = 1;
constexpr int TOP_GRADIENT_DESC_ID = 2;
} // namespace

class CuDNNSoftmaxOp final : public Operator<CUDAContext> {
 public:
  template <class... Args>
  explicit CuDNNSoftmaxOp(Args&&... args)
      : Operator<CUDAContext>(std::forward<Args>(args)...),
        cudnn_wrapper_(&context_),
        axis_(OperatorBase::GetSingleArgument<int>("axis", 1)) {
    CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&desc_));
  }

  ~CuDNNSoftmaxOp() override {
    CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(desc_));
  }

  template <typename T>
  bool DoRunWithType() {
    auto& X = Input(0);

    const auto canonical_axis = X.canonical_axis_index(axis_);
    const int N = X.size_to_dim(canonical_axis);
    const int D = X.size_from_dim(canonical_axis);

    auto* Y = Output(0, X.sizes(), at::dtype<T>());
    auto* Y_data = Y->template mutable_data<T>();
    if (N == 0 || D == 0) {
      return true;
    }
    if (dims_ != X.sizes()) {
      CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
          desc_,
          GetCudnnTensorFormat(StorageOrder::NCHW),
          cudnnTypeWrapper<T>::type,
          N,
          D,
          1,
          1));
      dims_ = X.sizes().vec();
    }
    CUDNN_ENFORCE(cudnnSoftmaxForward(
        cudnn_wrapper_.inline_cudnn_handle(),
        CUDNN_SOFTMAX_ACCURATE,
        CUDNN_SOFTMAX_MODE_INSTANCE,
        cudnnTypeWrapper<T>::kOne(),
        desc_,
        X.template data<T>(),
        cudnnTypeWrapper<T>::kZero(),
        desc_,
        Y_data));
    return true;
  }

  bool RunOnDevice() override {
    return DispatchHelper<TensorTypes<float, at::Half>>::call(this, Input(0));
  }

 protected:
  CuDNNWrapper cudnn_wrapper_;
  int axis_;
  cudnnTensorDescriptor_t desc_;
  vector<int64_t> dims_;
};

class CuDNNSoftmaxGradientOp final : public Operator<CUDAContext> {
 public:
  template <class... Args>
  explicit CuDNNSoftmaxGradientOp(Args&&... args)
      : Operator<CUDAContext>(std::forward<Args>(args)...),
        cudnn_wrapper_(&context_),
        axis_(OperatorBase::GetSingleArgument<int>("axis", 1)) {
    CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&desc_));
  }

  ~CuDNNSoftmaxGradientOp() override {
    CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(desc_));
  }

  template <typename T>
  bool DoRunWithType() {
    auto& Y = Input(0);
    auto& dY = Input(1);

    const auto canonical_axis = Y.canonical_axis_index(axis_);
    const int N = Y.size_to_dim(canonical_axis);
    const int D = Y.size_from_dim(canonical_axis);

    TORCH_CHECK_EQ(Y.sizes(), dY.sizes());
    auto* dX = Output(0, Y.sizes(), at::dtype<T>());
    auto* dX_data = dX->template mutable_data<T>();
    if (N == 0 || D == 0) {
      return true;
    }
    if (dims_ != Y.sizes()) {
      CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
          desc_,
          GetCudnnTensorFormat(StorageOrder::NCHW),
          cudnnTypeWrapper<T>::type,
          N,
          D,
          1,
          1));
      dims_ = Y.sizes().vec();
    }
    CUDNN_ENFORCE(cudnnSoftmaxBackward(
        cudnn_wrapper_.inline_cudnn_handle(),
        CUDNN_SOFTMAX_ACCURATE,
        CUDNN_SOFTMAX_MODE_INSTANCE,
        cudnnTypeWrapper<T>::kOne(),
        desc_,
        Y.template data<T>(),
        desc_,
        dY.template data<T>(),
        cudnnTypeWrapper<T>::kZero(),
        desc_,
        dX_data));
    return true;
  }

  bool RunOnDevice() override {
    return DispatchHelper<TensorTypes<float, at::Half>>::call(this, Input(0));
  }

 protected:
  CuDNNWrapper cudnn_wrapper_;
  int axis_;
  cudnnTensorDescriptor_t desc_;
  vector<int64_t> dims_;
};

namespace {
REGISTER_CUDNN_OPERATOR(Softmax, CuDNNSoftmaxOp);
REGISTER_CUDNN_OPERATOR(SoftmaxGradient, CuDNNSoftmaxGradientOp);
} // namespace
} // namespace caffe2