File: kernel_meta_info.h

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (142 lines) | stat: -rw-r--r-- 5,712 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#if !defined(C10_MOBILE) && !defined(ANDROID)
#pragma once

#include <ATen/ATen.h>
#include <c10/core/SymIntArrayRef.h>
#include <torch/csrc/dynamo/guards.h>

#include <string>

namespace torch::inductor {

// Regarding a aten operation implemented by AOTI, the metadata of the input
// tensors will be cached on the disk to accelerate next run. TensorMetada
// structure is to represent the metadata of each input tensor. It includes
// whether the tensor is symbolic, the dtype, the device, the sizes and the
// strides of the tensor. When the metadata of the input tensors is the same as
// the cached metadata, the cached kernel library will be loaded and executed.
// Otherwise, the AOT Inductor will be called again to generate the kernel
// library.
// Beyond the TensorMetadata, we build guard/TensorCheck for each input tensor
// as well to support symbolic shape. We intend to utilize TensorCheck to find
// out the proper kernel rather than TensorMetada comparison. Suppose an
// operation with a single input tensor and two kernels:
//   kernel1: TensorMetadata(is_symbolic=false, dtype=Float, device=CPU,
//   sizes=[s0, s1, s2], strides=[s1 * s2, s2, 1]) kernel2:
//   TensorMetadata(is_symbolic=false, dtype=Float, device=CPU, sizes=[3, s1,
//   s2], strides=[s1 * s2, s2, 1])
// If a tensor with sizes=[3, 4, 5] is passed to the operation, both kernel1 and
// kernel2 support the tensor shape. In this case, we need to use TensorCheck
// plus some heruistic rules to find out the proper kernel.
struct TensorMetadata {
  // Indicate whether the tensor is symbolic and it may be concluded by sizes_
  // and strides_ in the future.
  bool is_symbolic_;
  // Dtype of a tensor(For scalar, we will wrap it as a scalar tensor)
  c10::ScalarType dtype_ = c10::ScalarType::Undefined;
  // Device of a tensor.
  c10::Device device_;
  // Dispatch key set of a tensor
  c10::DispatchKeySet dispatch_key_set_;
  // Sizes of a tensor. Currently, we only support static shape and use int64_t
  // to represent the sizes. In the future, we will create symbolic size and use
  // SymInt to represent it to support symbolic shape.
  std::vector<int64_t> sizes_;
  // Strides of a tensor. For symbolic shape support, it is the same as sizes_
  std::vector<int64_t> strides_;
  // requires grad
  bool requires_grad_ = false;
  // TensorCheck for the tensor
  std::optional<dynamo::TensorCheck> tensor_check_;

  TensorMetadata()
      : is_symbolic_(false),
        device_(c10::DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES),
        sizes_({}),
        strides_({}) {}
  TensorMetadata(const at::Tensor& src_tensor);
  TensorMetadata(
      bool is_symbolic,
      c10::ScalarType dtype,
      c10::Device device,
      c10::DispatchKeySet dispatch_key_set,
      std::vector<int64_t> sizes,
      std::vector<int64_t> strides,
      bool requires_grad = false);

  // Build TensorCheck for the tensor by using the data fields in TensorMetadata
  void build_guard(const dynamo::LocalState& local_state);

  // Compare two TensorMetadata objects
  bool operator==(const TensorMetadata& other) const;
};

// ParameterTag is to represent the type of the input parameters of a aten
// operation. Currently, we support the following types:
//   1. TENSOR: a single tensor
//   2. TENSOR_OPTIONAL: a single optional tensor
//   3. TENSOR_LIST: a list of tensors
//   4. TENSOR_LIST_OPTIONAL: a list of optional tensors
//   5. SCALAR: a scalar value
// If we need to support more types in the future, we will add more types in the
// ParameterTag enum. For example, we will extend the enum to support string,
// Dimname and so on to support more types of input parameters of aten
// operations.
enum ParameterTag {
  TENSOR,
  TENSOR_OPTIONAL,
  TENSOR_LIST,
  TENSOR_LIST_OPTIONAL,
  SCALAR,
  STRING,
  DEVICE,
  INVALID,
};

// ParameterMetadataValue is to represent the value of the input parameters of a
// aten operation.
using ParameterMetadataValue = std::variant<
    TensorMetadata,
    std::vector<TensorMetadata>,
    c10::Scalar,
    std::string,
    c10::Device>;

// ParameterMetadata is to represent the metadata of the input parameters of a
// aten operation. It includes the tag of the parameter, the value of the
// parameter and the order of the parameter.
struct ParameterMetadata {
  // The tag of the parameter. It indicates the type of the parameter.
  ParameterTag tag_;
  // The value of the parameter. It can be a tensor, a list of tensors or a
  // scalar.
  ParameterMetadataValue value_;
  // The order of the parameter is used to distinguish the parameters with the
  // same tag. For example, an operation with two input tensors, the first
  // tensor is a optional tensor and the second tensor is a tensor. The first
  // tensor will have the order 0 and the second tensor will have the order 1.
  uint64_t order_{};

  ParameterMetadata() : tag_(INVALID) {}
  ParameterMetadata(TensorMetadata tensor_metadata, uint64_t input_order);
  ParameterMetadata(const at::Tensor& tensor, uint64_t input_order);
  ParameterMetadata(
      const std::vector<at::Tensor>& tensor_list,
      uint64_t input_order);
  ParameterMetadata(
      const std::vector<TensorMetadata>& tensor_metadata_list,
      uint64_t input_order);
  ParameterMetadata(const c10::Scalar& scalar, uint64_t input_order);
  ParameterMetadata(const std::string& string_value, uint64_t input_order);
  ParameterMetadata(const c10::Device& device, uint64_t input_order);

  bool operator==(const ParameterMetadata& other) const;

 private:
  // Helper function to compare two ParameterMetadata objects with the same
  // SCALAR tag.
  bool equal_to(const c10::Scalar& scalar) const;
};

} // namespace torch::inductor
#endif