File: evaluator_common.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (343 lines) | stat: -rw-r--r-- 10,726 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
#pragma once
#include <torch/csrc/jit/codegen/cuda/dynamic_type.h>
#include <torch/csrc/jit/codegen/cuda/executor_kernel_arg.h>
#include <torch/csrc/jit/codegen/cuda/executor_launch_params.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>

#include <c10/core/DeviceType.h>

namespace torch {
namespace jit {
namespace fuser {
namespace cuda {

//! This is the common space for expression evaluators in
//!  fusion IR and kernel IR context. Much of the evaluator
//!  optimizations and runtimes could share the same code
//!  path and they could be collected here.

class ExpressionEvaluator;

namespace kir {

class ExpressionEvaluator;

} // namespace kir

//! IR Contexts to be passed to generic evaluator optimizations
//!   and runtimes. Defines the essential interface for the
//!   generic logic to get necessary type and function info
//!   from the IR nodes. Generic optimizations will assume
//!   the same list of static definitions are provided
//!   in each of the contexts, just FusionIR and KernelIR
//!   currently.

//! Context for using generic logic on FusionIR
class FusionIRContext {
 public:
  using TV_TYPE = TensorView;
  using EVALUATOR_TYPE = ExpressionEvaluator;

  static BinaryOpType getOpType(BinaryOp* bop) {
    return bop->getBinaryOpType();
  }

  static UnaryOpType getOpType(UnaryOp* uop) {
    return uop->getUnaryOpType();
  }
};

//! Context for using generic logic on KernelIR
class KernelIRContext {
 public:
  using EVALUATOR_TYPE = kir::ExpressionEvaluator;

  static BinaryOpType getOpType(BinaryOp* bop) {
    return bop->getBinaryOpType();
  }

  static UnaryOpType getOpType(UnaryOp* uop) {
    return uop->getUnaryOpType();
  }
};

template <typename IRContext>
class PrecomputedValuesBase;

//! NaiveValueMachine:
//!  This is an un-optimized runtime for evaluating a
//!   set of values in one run. The runtime contains
//!   a vector of instructions inferred from IR at compile-time
//!   and it currently must be associated with an instance of
//!   PrecomputedValuesBase that will provide the workspace
//!   containing the concrete values for the values.
template <typename IRContext>
class NaiveValueMachine {
  //! The generic types of instructions supported for this
  //!  machine, currently only binary and unary.
  enum class InstructionType { UNARY_OP, BINARY_OP };

 public:
  //! Constructor lowers all the expr IR nodes stored in precomputed_values
  //!  and stores them in the private state.
  NaiveValueMachine(PrecomputedValuesBase<IRContext>& precomputed_values);

  //! Runs all the instructions and write results to the associated
  //!  precomputed_values.
  void run();

 private:
  //! Convert an unary IR expr to an instruction
  void makeUnaryOp(UnaryOp* uop);

  //! Convert an binary IR expr to an instruction
  void makeBinaryOp(BinaryOp* bop);

  //! Create an empty instruction with all default values
  //!  and place it at the end of the instruction buffer.
  int makeInstructionEntry();

  //! Run a single instruction at the given index of
  //!  the instruction buffer. Decodes and dispatches
  //!  to the corresponding instruction handle functions.
  void runInstruction(int index);

  //! Runs a unary operation at given index of instruction buffer
  void runUnaryOp(int index);

  //! Runs a binary operation at given index of instruction buffer
  void runBinaryOp(int index);

 private:
  friend PrecomputedValuesBase<IRContext>;

  //! Reference to the PrecomputedValues workspace associated with
  //!   this runtime. All the instructions will read and write the
  //!   values in this workspace.
  PrecomputedValuesBase<IRContext>& precomputed_values_;

  //! Instruction buffer. All states are in separate vectors and
  //!  the entry of each vector at the same index correspond to
  //!  the same instruction.

  //! Total number of instructions
  int num_of_instructions_ = 0;

  //! Machine instruction type for each instruction i.e.
  //!  unary or binary
  std::vector<InstructionType> inst_type_;

  //! Unary operator type if applicable, contains a default
  //!  value at each index corresponding to a binary op.
  std::vector<UnaryOpType> uop_type_;

  //! Data type for unary op of type UnaryOpType::Cast, contains a default
  //!  value at each index corresponding other ops.
  std::vector<DataType> data_type_;

  //! Unary operator type if applicable, contains a default
  //!  value at each index corresponding to a unary op.
  std::vector<BinaryOpType> bop_type_;

  //! Indexes of operands and destination of each instruction.
  //!  The indexes corresponds to positions in the workspace
  //!  where concrete values are hosted.

  //! Operand 0 of each instruction.
  std::vector<int> src0_;

  //! Operand 1 of each instruction, a default value at
  //!  each index corresponding to a unary op.
  std::vector<int> src1_;

  //! Destination of each instruction.
  std::vector<int> dest_;
};

//! PrecomputedValuesBase:
//!  A class to support optimized evaluation of values
//!  at runtime.
//!    At compile time all necessary values are collected
//!  from given IR nodes and a runtime and a workspace containing
//!  the concrete values is created and pre-allocated.
//!    At runtime the value vm is used to evaluate all the
//!  values and store them in the workspace ahead of time.
template <typename IRContext>
class PrecomputedValuesBase {
  using VALUE_MACHINE = NaiveValueMachine<IRContext>;

 public:
  explicit PrecomputedValuesBase() = default;

  //! Returns if the workspace contains evaluated results.
  bool ready() {
    return has_valid_values_;
  }

  //! Runs the internal value machine that will compute
  //!  the values allocated in the workspace.
  void evaluate();

  //! Returns value for the given IR node if it's stored
  //!  in the workspace and has been evaluated.
  c10::optional<IntOrDouble> getMaybeValueFor(const Val* val);

  //! Debugging helper, prints all the currently known values
  void print() const;

 protected:
  //! Initialize the workspace before first use.
  //!  Assume the given value list IR nodes have
  //!  been topologically sorted.
  void initializeValueList(
      typename IRContext::EVALUATOR_TYPE& evaluator,
      const std::vector<Val*>& sorted_value_list);

  //! Bind concrete value to the given index
  //!  if the index is valid.
  void bindValue(int index, IntOrDouble value) {
    if (index < 0 || is_constant_[index]) {
      return;
    }
    defined_[index] = true;
    values_[index] = value;
    binding_log_.emplace_back(index, value);
  }

  //! Invalidate all computed values in the workspace.
  void invalidate();

  //! Interface for subclasses to access symbols_
  void loadSymbols(std::vector<Val*> symbols) {
    symbols_ = std::move(symbols);
  }

  //! Interface for subclasses to access symbols_
  std::vector<Val*>& symbols() {
    return symbols_;
  }

  //! Initialize the value runtime that will
  //!  infer instructions from the workspace.
  void initializeIntegerMachine() {
    value_machine_ = std::make_unique<VALUE_MACHINE>(*this);
  }

  bool hasValidValues() {
    return has_valid_values_;
  }

 private:
  //! Post evaluation check, throws if any computed value
  //!  is inconsistent with its bound value
  void validate();

  //! Returns true if workspace has a computed or constant
  //!  value for given index.
  bool hasValue(int index) {
    TORCH_INTERNAL_ASSERT(index > 0);
    return defined_[index] || is_constant_[index];
  }

 private:
  friend VALUE_MACHINE;

  //! Marks if an evaluation has finished
  bool has_valid_values_ = false;

  //! The size of workspace
  int num_of_values_ = -1;

  //! Marks if a value has been bound or
  //!  computed at each index.
  std::vector<bool> defined_;

  //! Marks if a value is compile-time constant
  //!  at each index.
  std::vector<bool> is_constant_;

  //! Stores the concrete values at each index.
  std::vector<IntOrDouble> values_;

  //! Stores the IR nodes corresponding to each index.
  std::vector<Val*> symbols_;

  //! An internal log to keep track of all the bindings
  //!  used in each evaluation cycle. To be used for
  //!  consistency check.
  std::vector<std::pair<int, IntOrDouble>> binding_log_;

  //! Integer runtime for realizing the values computations.
  std::unique_ptr<VALUE_MACHINE> value_machine_;
};

//! PrecomputedValues workspace in Fusion IR context,
//!  defines the set of values to be collected in each
//!  fusion graph and the input value binding given each
//!  fusion runtime input.
class FusionPrecomputedValues : public PrecomputedValuesBase<FusionIRContext> {
  using precomputedValuesBaseType = PrecomputedValuesBase<FusionIRContext>;

 public:
  FusionPrecomputedValues(Fusion* fusion);

  //! Bind concrete values from fusion runtime inputs
  void bindFusionInputs(const KernelArgumentHolder& args);

 private:
  void bindTensorMetaData(
      TensorView* tv,
      const TensorArgAbstract* tensor_arg_abstract);

 private:
  Fusion* fusion_ = nullptr;
};
//! PrecomputedValues workspace in Fusion IR context,
//!  defines the set of values to be collected in each
//!  kernel IR sequence and the input value binding given each
//!  fusion runtime input and launch constraints.
class KernelPrecomputedValues : public PrecomputedValuesBase<KernelIRContext> {
  using precomputedValuesBaseType = PrecomputedValuesBase<KernelIRContext>;

 public:
  using ParallelExtentMap =
      std::unordered_map<ParallelType, std::vector<const Val*>, TypeHash>;

  KernelPrecomputedValues(kir::Kernel* kernel);

  //! Bind concrete values from fusion runtime inputs
  void bindKernelInputs(kir::Kernel* kernel, const KernelArgumentHolder& args);

  //! Bind concrete values from launch constraints
  void bindParallelExtents(
      const ParallelExtentMap& parallel_extents,
      const LaunchParams& launch_constraint);

  //! Bind the NamedScalars corresponding to the
  //!  concrete parallel dimension sizes after the
  //!  actual value has been resolved.
  void bindConcreteParallelTypeValue(ParallelType pt, int64_t value);

 private:
  void bindTensorMetaData(
      TensorView* tv,
      const TensorArgAbstract* tensor_arg_abstract);

  //! Iterate through all the named scalars corresponding
  //!  to thread sizes and pre-group them by their parallel
  //!  types.
  void initializeNamedScalars();

 private:
  //! Contains all the named scalars correspond
  //!  to thread size of each parallel type.
  std::unordered_map<ParallelType, std::unique_ptr<std::vector<int>>, TypeHash>
      thread_dim_value_indices_;
};

} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch