File: cudnn_extension.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (80 lines) | stat: -rw-r--r-- 3,381 bytes parent folder | download | duplicates (4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
/*
 * CuDNN ReLU extension. Simple function but contains the general structure of
 * most CuDNN extensions:
 * 1) Check arguments. torch::check* functions provide a standard way to
 * validate input and provide pretty errors. 2) Create descriptors. Most CuDNN
 * functions require creating and setting a variety of descriptors. 3) Apply the
 * CuDNN function. 4) Destroy your descriptors. 5) Return something (optional).
 */

#include <torch/extension.h>

#include <ATen/cuda/Exceptions.h> // for CUDNN_CHECK
#include <ATen/cudnn/Descriptors.h> // for TensorDescriptor
#include <ATen/cudnn/Handle.h> // for getCudnnHandle

// Name of function in python module and name used for error messages by
// torch::check* functions.
const char* cudnn_relu_name = "cudnn_relu";

// Check arguments to cudnn_relu
void cudnn_relu_check(
    const torch::Tensor& inputs,
    const torch::Tensor& outputs) {
  // Create TensorArgs. These record the names and positions of each tensor as a
  // parameter.
  torch::TensorArg arg_inputs(inputs, "inputs", 0);
  torch::TensorArg arg_outputs(outputs, "outputs", 1);
  // Check arguments. No need to return anything. These functions with throw an
  // error if they fail. Messages are populated using information from
  // TensorArgs.
  torch::checkContiguous(cudnn_relu_name, arg_inputs);
  torch::checkScalarType(cudnn_relu_name, arg_inputs, torch::kFloat);
  torch::checkBackend(cudnn_relu_name, arg_inputs.tensor, torch::Backend::CUDA);
  torch::checkContiguous(cudnn_relu_name, arg_outputs);
  torch::checkScalarType(cudnn_relu_name, arg_outputs, torch::kFloat);
  torch::checkBackend(
      cudnn_relu_name, arg_outputs.tensor, torch::Backend::CUDA);
  torch::checkSameSize(cudnn_relu_name, arg_inputs, arg_outputs);
}

void cudnn_relu(const torch::Tensor& inputs, const torch::Tensor& outputs) {
  // Most CuDNN extensions will follow a similar pattern.
  // Step 1: Check inputs. This will throw an error if inputs are invalid, so no
  // need to check return codes here.
  cudnn_relu_check(inputs, outputs);
  // Step 2: Create descriptors
  cudnnHandle_t cuDnn = torch::native::getCudnnHandle();
  // Note: 4 is minimum dim for a TensorDescriptor. Input and output are same
  // size and type and contiguous, so one descriptor is sufficient.
  torch::native::TensorDescriptor input_tensor_desc(inputs, 4);
  cudnnActivationDescriptor_t activationDesc;
  // Note: Always check return value of cudnn functions using CUDNN_CHECK
  AT_CUDNN_CHECK(cudnnCreateActivationDescriptor(&activationDesc));
  AT_CUDNN_CHECK(cudnnSetActivationDescriptor(
      activationDesc,
      /*mode=*/CUDNN_ACTIVATION_RELU,
      /*reluNanOpt=*/CUDNN_PROPAGATE_NAN,
      /*coef=*/1.));
  // Step 3: Apply CuDNN function
  float alpha = 1.;
  float beta = 0.;
  AT_CUDNN_CHECK(cudnnActivationForward(
      cuDnn,
      activationDesc,
      &alpha,
      input_tensor_desc.desc(),
      inputs.data_ptr(),
      &beta,
      input_tensor_desc.desc(), // output descriptor same as input
      outputs.data_ptr()));
  // Step 4: Destroy descriptors
  AT_CUDNN_CHECK(cudnnDestroyActivationDescriptor(activationDesc));
  // Step 5: Return something (optional)
}

// Create the pybind11 module
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  // Use the same name as the check functions so error messages make sense
  m.def(cudnn_relu_name, &cudnn_relu, "CuDNN ReLU");
}