File: compile_time_info.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (257 lines) | stat: -rw-r--r-- 9,566 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
#pragma once

#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>

namespace torch {
namespace jit {
namespace fuser {
namespace cuda {

//! namespace for hosting catalog of possible compile time
//!  info that can be cached. Each possible entry type has
//!  a value in `CompileTimeEntryType` and an entry type class
//!  definition like `VectorizableInputsAndOutputs`. The corresponnding
//!  classes contain their entry type, data type and maybe more
//!  later depending on use cases.
namespace HeuristicCompileTime {

//! Each entry type under this category represent some information
//!  that can be inferred compile-time, i.e. without any runtime input
//!  meta data. They will be stored in `HeuristicSummary` and will
//!  be re-used each time the same fusion is visited.

//! Enum for all possible types of cached entries of compile-time info.
enum class CompileTimeEntryType {
  DOMAIN_MAP,
  REFERENCE_TENSORS,
  VECTORIZABLE_INPUTS_AND_OUTPUTS,
  INPUTS_AND_OUTPUTS_INNER_DIM_GROUPS,
  UNROLLABLE_INPUTS_AND_OUTPUTS,
  REDUCTION_TVS,
  PERSISTENT_BUFFER_INFO,
  SCOPE_PERSISTENT_FACTOR_INFO,
  BROADCAST_BYTE_MULTIPLES
};

//! Entry type definition class for `DOMAIN_MAP`,
//!  stores the domain map of a fusion.
class DomainMap {
 public:
  using DataType = pointwise_utils::DomainMap;
  static const CompileTimeEntryType EntryType =
      CompileTimeEntryType::DOMAIN_MAP;
};

//! Entry type definition class for `REFERENCE_TENSORS`,
//!  stores the the reference TensorViews used to schedule a fusion.
class ReferenceTensors {
 public:
  using DataType = std::vector<TensorView*>;
  static const CompileTimeEntryType EntryType =
      CompileTimeEntryType::REFERENCE_TENSORS;
};

//! Entry type definition class for `VECTORIZABLE_INPUTS_AND_OUTPUTS`,
//!  stores the vectorizable TensorViews on a fusion's inputs and outputs.
class VectorizableInputsAndOutputs {
 public:
  using DataType = std::vector<TensorView*>;
  static const CompileTimeEntryType EntryType =
      CompileTimeEntryType::VECTORIZABLE_INPUTS_AND_OUTPUTS;
};

//! Entry type definition class for `INPUTS_AND_OUTPUTS_INNER_DIM_GROUPS`,
//!  stores the fusion's inputs and outputs grouped by inner most dimension.
class InputsOutputsInnerDimGroups {
 public:
  using DataType = std::vector<std::vector<TensorView*>>;
  static const CompileTimeEntryType EntryType =
      CompileTimeEntryType::INPUTS_AND_OUTPUTS_INNER_DIM_GROUPS;
};

//! Entry type definition class for `UNROLLABLE_INPUTS_AND_OUTPUTS`,
//!  stores the unrollable TensorViews on a fusion's inputs and outputs.
class UnrollableInputsAndOutputs {
 public:
  using DataType = std::vector<TensorView*>;
  static const CompileTimeEntryType EntryType =
      CompileTimeEntryType::UNROLLABLE_INPUTS_AND_OUTPUTS;
};

//! Entry type definition class for `REDUCTION_TVS`,
//!  stores the all tvs with non-trivial reduction axes in a fusion.
class ReductionTVs {
 public:
  using DataType = std::vector<TensorView*>;
  static const CompileTimeEntryType EntryType =
      CompileTimeEntryType::REDUCTION_TVS;
};

//! Entry type definition class for `PERSISTENT_BUFFER_INFO`,
//!  stores persistent buffers inferred from topology and scheduling of fusion.
class PersistentBufferInfo {
 public:
  using DataType = scheduler_utils::PersistentBufferInfo;
  static const CompileTimeEntryType EntryType =
      CompileTimeEntryType::PERSISTENT_BUFFER_INFO;
};

//! Auxiliary data types for `SCOPE_PERSISTENT_FACTOR_INFO` entry type.
using ScopedPersistenceBufferMap = std::unordered_map<Val*, std::vector<bool>>;

//! Entry type definition class for `SCOPE_PERSISTENT_FACTOR_INFO`,
// Tracks which buffers are active at a given Val*, order of bool vector is
// based on persistence buffer order from persistence buffer info, this is then
// appended by the projectable persistent buffers' inputs. True in the bool
// vector means the persistent buffer is active at the generation of the key.
class ScopePersistentFactorInfo {
 public:
  using DataType = ScopedPersistenceBufferMap;
  static const CompileTimeEntryType EntryType =
      CompileTimeEntryType::SCOPE_PERSISTENT_FACTOR_INFO;
};

//! Entry type definition class for `BROADCAST_BYTE_MULTIPLES`,
//!  stores "byte multiples" information. This information can be used to figure
//!  out if using a 2D scheduler how many bytes have to be transferred with
//!  varying split locations. See BroadcastMultiple definition for more
//!  information.
class BroadcastMultiples {
 public:
  using DataType = std::vector<scheduler_utils::BroadcastMultiple>;
  static const CompileTimeEntryType EntryType =
      CompileTimeEntryType::BROADCAST_BYTE_MULTIPLES;
};

//! Base abstract class for unified storage in `HeuristicSummary`,
//!  each entry in `HeuristicSummary` will be a subclass.
class CompileTimeInfoBase : public PolymorphicBase {
 public:
  CompileTimeInfoBase(CompileTimeEntryType entry_type)
      : entry_type_(entry_type) {}
  CompileTimeEntryType type() {
    return entry_type_;
  }

 private:
  CompileTimeEntryType entry_type_;
};

} // namespace HeuristicCompileTime

// Note: Do NOT export this class. MSVC issue with exported class that contains
// std::vector<unique_ptr<xxx>>: https://godbolt.org/z/3E4e8T1P1
//! Compile-time information cache for `canSchedule` and
//!  `getHeuristics` interfaces. Each cache instance
//!  stores information that could be inferred at compile
//!  time in a fusion and therefore corresponds to an
//!   instance of FusionExecutor.
//!  Since each instance of FusionExecutor has a unique
//!   heuristic type, this cache also has a heuristic
//!   type to simplify data validation.
//!  HeuristicSummary has two modes of operation:
//!  - when in `recording` mode, the information is not available
//!     in the cache and entries can be written and stored.
//!  - when not in `recording` mode, compiled-time data has
//!     been stored in this cache and the entries can be accessed
//!!    but new entries can no longer be inserted.
class HeuristicSummary {
  using Entry = HeuristicCompileTime::CompileTimeInfoBase;
  using EntryOwningPtr = std::unique_ptr<Entry>;
  using EntryPtr = Entry*;
  using EntryType = HeuristicCompileTime::CompileTimeEntryType;

 public:
  HeuristicSummary(
      Fusion* fusion,
      ScheduleHeuristic heuristic,
      SchedulerRuntimeInfo& runtime_info);

  bool isRecording() {
    return recording_;
  }

  void insert(EntryOwningPtr new_entry);

  EntryPtr at(EntryType entry_type) {
    return entry_type_map_.at(entry_type);
  }

 private:
  void validate() const;

 private:
  std::vector<EntryOwningPtr> entries_;
  std::unordered_map<EntryType, EntryPtr> entry_type_map_;
  ScheduleHeuristic heuristic_;
  bool recording_ = true;
};

//! A utility class to facilitate accessing HeuristicSummary.
//!  This utility is needed because the information to be stored
//!    in HeuristicSummary is used in several different scenarios
//!    and we want to support all these use cases in one interface.
//!  The current use examples are:
//!   1. During fusion segmentation process, all the fusions
//!     given to canSchedule are temporary and therefore the
//!     compile time info do not need to be cached, and in fact
//!     a cache wouldn't be instantiated by that time.
//!
//!   2. When the compiled kernel is launched the first time, the
//!     cache will be in `recording` phase and all the computed information
//!     should be captured and written into the cache.
//!
//!   3. When we check a compiled fusion for heuristic hit,
//!      we want to use the cached info to save runtime latency.
//!
//! The designed interface is used as:
//!   auto entry = HeuristicSummaryEntry<EntryClass>(data_cache, maker_fn);
//!   auto& data = entry.get();
//!
//!  `maker_fn` will be called to compute the information when no cached data
//!   exists and `entry` will own the computed data when no data cache is
//!   supplied.
template <typename EntryClass>
class HeuristicSummaryEntry {
  using EntryDataType = typename EntryClass::DataType;
  using EntryDataTypeOwnPtr = std::unique_ptr<EntryDataType>;
  using MakerFnType = std::function<EntryDataTypeOwnPtr()>;

 public:
  //! Creates a data entry with type defined in EntryClass,
  //!  eg. EntryClass = VectorizableInputsAndOutputs;
  //!
  //! @param data_cache, a pointer to an instantiated compile-time
  //!  info cache. The info data will be
  //!    1. read from data cache if data cache is not recording.
  //!    2. written into  data cache if data cache is recording.
  //!    3. managed by owned_data_ if data cache is nullptr
  //! @param fn:
  //!   The factory function that needs to return a owning pointer
  //!  i.e. std::unique_ptr<EntryClass::DataType>. It will only
  //!  be called either when data cache is recording or when no data
  //!  cache is given.
  HeuristicSummaryEntry(HeuristicSummary* data_cache, MakerFnType fn);

  //! Unified interface to get actual data, either from cache
  //!  or from factory function.
  EntryDataType& get() {
    return *data_ptr_;
  }

 private:
  //! Internal data owing pointer that will manage the computed
  //!  data where there is no data cache.
  EntryDataTypeOwnPtr owned_data_ = nullptr;

  //! Pointer to the valid data entry that could be accessed.
  EntryDataType* data_ptr_ = nullptr;
};

} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch