File: registry.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (239 lines) | stat: -rw-r--r-- 8,355 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
#pragma once
#include <torch/csrc/jit/codegen/cuda/executor_kernel_arg.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/compile_time_info.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/heuristic.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/pointwise_heuristic.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/reduction_heuristic.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
#include <torch/csrc/jit/codegen/cuda/utils.h>

namespace torch {
namespace jit {
namespace fuser {
namespace cuda {

class SegmentedGroup;
class ExpressionEvaluator;

//!  SchedulerRuntimeInfo is the abstraction introduced in
//! this PR for passing runtime input dependent information
//! to the schedulers and kernel caches.
//!
//! Note:
//!  if any additional info needed,  or maybe just the inputs themselves it
//!    could just be added to this class, and they will be distributed to the
//!    segmenter and schedulers.
//!  It is important that input id encoding should be up to date with any change
//!   of this class to avoid launching compiled kernels with illegal inputs.
class TORCH_CUDA_CU_API SchedulerRuntimeInfo : public NonCopyable {
 public:
  // Max vector size we will consider, in bytes,
  //  currently set to 16B = 128b
  static constexpr size_t max_alignment_size_in_byte = 16;

  //! Create runtime info for given fusion and input. Creating and binding
  //!  evaluator is optional. The evaluator is used to manage intermediate
  //!  integers in the fusion. We need them for segmenter and schedulers,
  //!  but we don't need them when we are just using this class to provide
  //!  additional encoding for kernel cache lookup.
  SchedulerRuntimeInfo(
      Fusion* complete_fusion,
      const KernelArgumentHolder& inputs,
      bool create_expr_evaluator = false);

  // TODO: Remove this guy below. Everything needs to go into the other ctor
  SchedulerRuntimeInfo(
      Fusion* complete_fusion,
      const at::ArrayRef<at::IValue>& aten_inputs,
      bool create_expr_evaluator = false);

  //! Lookup for the alignment sizes of the given tv. Currently only returns
  //!  actual alignment info for input tensors to the complete fusion,
  //!  and for other intermediate/fuser-allocated tensors will
  //!  return max_alignment_size_in_byte.
  size_t getAlignmentSize(TensorView* tv);

  // Gets maximum vectorizable width of tv, assumes we can merge across all
  // iteration domains if contiguous. Cannot permute the dimensions to fix
  // contiguity. Ignores dimensions that are broadcast or reduction.
  size_t getMaxVectorizableWidth(TensorView* tv);

  // Gets the vectorizable width of the inner most dimension of tv if it's
  // contiguous. Ignores inner most dimensions that are broadcast or reduction.
  size_t getInnerDimVectorizableWidth(TensorView* tv);

  // Computes alignment size in bytes for provided ptr address
  static size_t computeAlignmentSize(size_t ptr_address);

  // Return the runtime pointer value for provided tensor view
  size_t ptrOf(TensorView* tv);

  KernelIndexMode getIndexMode() {
    return index_mode_;
  }

  Fusion* fusion() {
    return complete_fusion_;
  }

  ExpressionEvaluator& expressionEvaluator() {
    TORCH_INTERNAL_ASSERT(expression_evaluator_ != nullptr);
    return *expression_evaluator_;
  }

 private:
  // Bind full fusion inputs to the internal expression evaluator
  void initializeExpressionEvaluator(const KernelArgumentHolder& inputs);

  // Initialize SchedulerRuntimeInfo
  void initialize(const KernelArgumentHolder& args, bool create_expr_evaluator);

  bool isInputTv(TensorView* tv) {
    return std::find(
               complete_fusion_->inputs().begin(),
               complete_fusion_->inputs().end(),
               tv) != complete_fusion_->inputs().end();
  }

 private:
  // Returns the offset of tv in the inputs ignoring non tensor views. Used to
  // access input_sizes, input_strides, input_ptr
  int offsetTensorPos(TensorView* tv);

  // Expression evaluator used to probe sizes in the fusion IR
  std::unique_ptr<ExpressionEvaluator> expression_evaluator_ = nullptr;

  // Fusion reference that this runtime info is associated with
  Fusion* complete_fusion_ = nullptr;

  // Copy of aten input pointer addresses
  // TODO: Support output tensor pointers
  std::unordered_map<Val*, size_t> input_ptrs_;

  // Cache for getAlignmentSize
  std::unordered_map<TensorView*, size_t> alignment_map_;
  // Cache for getMaxVectorizableWidth
  std::unordered_map<TensorView*, size_t> max_vectorword_map_;
  // Cache for getInnerDimVectorizableWidth
  std::unordered_map<TensorView*, size_t> inner_vectorword_map_;

  // Found index mode kernel needs to be run in
  KernelIndexMode index_mode_ = KernelIndexMode::INT64;

  // TODO: Remove
  std::unordered_map<TensorView*, size_t> vectorword_map_;
};

class HeuristicSummary;

//! Virtual base class for schedule heuristics
//!   heuristic implementations derive from this
//!   class and implement a schedule(Fusion*)
//!   and a bool canSchedule(Fusion*) interface
class TORCH_CUDA_CU_API SchedulerEntry {
 public:
  //! Fusion runtime facing API,
  //!   builds a new entry with the given heuristics
  //!   corresponding to the given fusion
  static std::unique_ptr<SchedulerEntry> makeEntry(
      ScheduleHeuristic sh,
      Fusion* fusion,
      SchedulerRuntimeInfo& runtime_info,
      HeuristicSummary* data_cache = nullptr);

  virtual ~SchedulerEntry() = default;

  //! External access for canSchedule utilities through SchedulerEntry
  //!  to avoid exposing a single function to the namespace
  static bool canSchedule(
      ScheduleHeuristic sh,
      Fusion* fusion,
      SchedulerRuntimeInfo& runtime_info,
      HeuristicSummary* data_cache = nullptr);

  //! Fusion segmenter facing API,
  //!   returns a schedule that applies in the given fusion, returns a nullopt
  //!   if no schedule in the registry can handle.
  static c10::optional<ScheduleHeuristic> proposeHeuristics(
      Fusion* fusion,
      SchedulerRuntimeInfo& runtime_info);

  //! Fusion runtime facing API,
  //!   schedule the given fusion with heuristics owned
  //!   by this entry, for actual heuristics to override
  virtual void schedule(Fusion* fusion) = 0;

  //! Heuristic comparison
  bool sameAs(const SchedulerEntry* other);

  ScheduleHeuristic heuristic() const {
    return heuristc_;
  }

  KernelIndexMode indexMode() const {
    return index_mode_;
  }

  const std::shared_ptr<HeuristicParams>& params() const {
    return params_;
  }

  const ReductionParams& reductionParams() const {
    auto rparams = std::dynamic_pointer_cast<ReductionParams>(params_);
    TORCH_INTERNAL_ASSERT(
        rparams != nullptr, "Heuristic parameter is not a reduction parameter");
    return *rparams;
  }

  const PointwiseParams& pointwiseParams() const {
    auto pparams = std::dynamic_pointer_cast<PointwiseParams>(params_);
    TORCH_INTERNAL_ASSERT(
        pparams != nullptr, "Heuristic parameter is not a pointwise parameter");
    return *pparams;
  }

  const TransposeParams& transposeParams() const {
    auto tparams = std::dynamic_pointer_cast<TransposeParams>(params_);
    TORCH_INTERNAL_ASSERT(
        tparams != nullptr, "Heuristic parameter is not a transpose parameter");
    return *tparams;
  }

  void updateLaunchConstraint(const LaunchParams& launch_params) {
    params_->lparams = launch_params;
  }

 protected:
  explicit SchedulerEntry(ScheduleHeuristic heuristic) : heuristc_(heuristic) {}

  //! Heuristic parameters if applicable
  std::shared_ptr<HeuristicParams> params_ = nullptr;

 private:
  //! What kind of heuristics does this entry have?
  const ScheduleHeuristic heuristc_;

  //! Kernel Index Mode
  KernelIndexMode index_mode_ = KernelIndexMode::INT64;
};

//! Hash function for a scheduler entry
class TORCH_CUDA_CU_API SchedulerEntryHash {
 public:
  size_t operator()(const SchedulerEntry& se) const;
};

//! Debug print function for heuristics
TORCH_CUDA_CU_API std::string toString(ScheduleHeuristic sh);

//! Debug print function for heuristics
TORCH_CUDA_CU_API std::ostream& operator<<(
    std::ostream& os,
    ScheduleHeuristic sh);

} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch