1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316
|
#pragma once
#include <ATen/core/ivalue.h>
#include <c10/core/DeviceType.h>
#include <c10/util/Exception.h>
#include <cuda.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/codegen/cuda/executor_kernel_arg.h>
#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/kernel.h>
#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <string>
#include <vector>
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
namespace executor_utils {
// Include all the functions we might need in generated code
std::string kernelPreamble();
void validateKernelInputs(
Fusion* fusion,
const KernelArgumentHolder& args,
const c10::Device& device);
void validateKernelOutputs(
Fusion* fusion,
const std::vector<at::Tensor>& outputs,
const c10::Device& device);
//! Bind kernel input values to runtime values
kir::ExpressionEvaluator bindKernelInputs(
const KernelArgumentHolder& args,
kir::Kernel* kernel,
bool check_consistency = true);
//! Bind fusion input values to runtime values
TORCH_CUDA_CU_API ExpressionEvaluator
bindFusionInputs(const KernelArgumentHolder& args, Fusion* fusion);
struct NvrtcFunction {
CUmodule module = CUmodule();
CUfunction function = CUfunction();
};
void initializeCudaContext();
// Returns executable function and the ptxas log from compilation
std::pair<NvrtcFunction, std::string> nvrtcCompile(
const std::string& code,
const std::string& func_name,
int id,
c10::optional<int> opt_block_size = c10::nullopt);
namespace caching {
// TODO: Could consider putting some of
// the logic in the common space and re-use
//! List of all the possible entry types in
//! `FusionExecutor` compile-time data cache.
enum class CompileTimeEntryType {
PARALLEL_BINDING_ITERDOMAINS,
PARALLEL_ITER_EXTENT_MAP,
SIMPLIFIED_PARALLEL_ITER_EXTENT_MAP,
WARP_PADDED_PARALLEL_EXTENTS,
VECTORIZED_TENSOR_VALIDATION,
INPUT_ALIAS_INDICES,
OUTPUT_ALIAS_INDICES
};
//! Entry class definitions for each entry type:
//! each class defines the data type for each entry type
//! Compile-time info to be cached in each FusionExecutor:
//! ParallelBindingIterDomains:
//! Stores all the iterdomains that are parallelized
//! on the scheduled Fusion graph. They will be used
//! in launch param iteration and their extents may
//! come from launch constraints.
class ParallelBindingIterDomains {
public:
using DataType = std::vector<IterDomain*>;
static const CompileTimeEntryType EntryType =
CompileTimeEntryType::PARALLEL_BINDING_ITERDOMAINS;
};
//! Compile-time info to be cached in each FusionExecutor:
//! ParallelIterExtentMap
//! Stores the symbolic extents of all the parallelized
//! iterdomains corresponding to each used parallel type.
class ParallelIterExtentMap {
public:
using DataType =
std::unordered_map<ParallelType, std::vector<const Val*>, TypeHash>;
static const CompileTimeEntryType EntryType =
CompileTimeEntryType::PARALLEL_ITER_EXTENT_MAP;
};
//! Compile-time info to be cached in each FusionExecutor:
//! SimplifiedParallelIterExtentMap
//! This entry type is a simplified version of ParallelIterExtentMap.
//!
//! For launch parameter binding we only need the most concrete iterdomain
//! in each disjoint set stored in CaParallelMap. This entry stores the
//! remaining list of extents for binding after this simplification.
//!
//! We still need ParallelIterExtentMap since we want to bind the concrete
//! values to the extents of all parallelized iterdomains. We would be
//! able to save these bindings if the integer machine has a notion of
//! equality and could be configured compile time. But that'd be a longer
//! term target.
class SimplifiedParallelIterExtentMap {
public:
using DataType =
std::unordered_map<ParallelType, std::vector<const Val*>, TypeHash>;
static const CompileTimeEntryType EntryType =
CompileTimeEntryType::SIMPLIFIED_PARALLEL_ITER_EXTENT_MAP;
};
//! WarpPaddedExtentsInfo:
//! Auxiliary data type for entry class WarpPaddedParallelExtents
struct WarpPaddedExtentsInfo {
std::unordered_set<const Val*> warp_padded_extent_set;
std::unordered_map<const Val*, int64_t> warp_padded_constant;
};
//! Compile-time info to be cached in each FusionExecutor:
//! WarpPaddedParallelExtents
//! Stores the symbolic and constant extents of warp
//! padded parallel iterdomains.
class WarpPaddedParallelExtents {
public:
using DataType = WarpPaddedExtentsInfo;
static const CompileTimeEntryType EntryType =
CompileTimeEntryType::WARP_PADDED_PARALLEL_EXTENTS;
};
//! VectorizedTensorInfo:
//! Auxiliary data type for entry class VectorizedTensorValidation
struct VectorizedTensorInfo {
//! Aligned vectorized fusion inputs
std::vector<int> aligned_vectorized_inp_tensor_pos;
//! Aligned vectorized fusion outputs
std::vector<int> aligned_vectorized_out_tensor_pos;
//! Misaligned vectorized input tensors
std::unordered_set<TensorView*> global_inp_misaligned_tv;
//! Misaligned vectorized output tensors
std::unordered_set<TensorView*> global_out_misaligned_tv;
//! Positions of misaligned input tensors
std::vector<int> inp_misaligned_tensors_pos;
//! Positions of misaligned output tensors
std::vector<int> out_misaligned_tensors_pos;
};
//! Compile-time info to be cached in each FusionExecutor:
//! VectorizedTensorValidation
//! Stores position info and vector word sizes of
//! vectorized input/output tensors, to be used
//! in misaligned vectorization validation.
class VectorizedTensorValidation {
public:
using DataType = VectorizedTensorInfo;
static const CompileTimeEntryType EntryType =
CompileTimeEntryType::VECTORIZED_TENSOR_VALIDATION;
};
//! Compile-time info to be cached in each FusionExecutor:
//! InputAliasIndices
//! Stores position info of aliased input tensors
class InputAliasIndices {
public:
using DataType = std::vector<std::pair<int, int>>;
static const CompileTimeEntryType EntryType =
CompileTimeEntryType::INPUT_ALIAS_INDICES;
};
//! Compile-time info to be cached in each FusionExecutor:
//! OutputAliasIndices
//! Stores position info of aliased output tensors
class OutputAliasIndices {
public:
using DataType = std::unordered_set<int>;
static const CompileTimeEntryType EntryType =
CompileTimeEntryType::OUTPUT_ALIAS_INDICES;
};
//! Base abstract class for unified storage in `ExecutorCompileTimeInfoCache`,
//! each entry in `ExecutorCompileTimeInfoCache` will be a subclass.
class CompileTimeInfoBase : public PolymorphicBase {
public:
CompileTimeInfoBase(CompileTimeEntryType entry_type)
: entry_type_(entry_type) {}
CompileTimeEntryType type() {
return entry_type_;
}
private:
CompileTimeEntryType entry_type_;
};
// Note: Do NOT export this class. MSVC issue with exported class that contains
// std::vector<unique_ptr<xxx>>: https://godbolt.org/z/3E4e8T1P1
//! Compile-time information cache
class ExecutorCompileTimeInfoCache {
using Entry = CompileTimeInfoBase;
using EntryOwningPtr = std::unique_ptr<Entry>;
using EntryPtr = Entry*;
using EntryType = CompileTimeEntryType;
public:
void insert(EntryOwningPtr new_entry);
EntryPtr at(EntryType entry_type) {
return entry_type_map_.at(entry_type);
}
bool has(EntryType entry_type) {
return entry_type_map_.count(entry_type);
}
private:
std::vector<EntryOwningPtr> entries_;
std::unordered_map<EntryType, EntryPtr> entry_type_map_;
};
//! A utility class to facilitate accessing ExecutorCompileTimeInfoCache.
template <typename EntryClass>
class ExecutorCompileTimeEntry {
using EntryDataType = typename EntryClass::DataType;
using EntryDataTypeOwnPtr = std::unique_ptr<EntryDataType>;
using MakerFnType = std::function<EntryDataTypeOwnPtr()>;
public:
//! Creates a data entry with type defined in EntryClass,
//! eg. EntryClass = VectorizableInputsAndOutputs;
//!
//! @param data_cache, a pointer to an instantiated compile-time
//! info cache. The info data will be
//! 1. read from data cache if data cache has the corresponding entry.
//! 2. written into data cache if data cache doesn't have the entry.
//! 3. managed by owned_data_ if data cache is nullptr
//! @param fn:
//! The factory function that needs to return a owning pointer
//! i.e. std::unique_ptr<EntryClass::DataType>. It will only
//! be called either when data cache is missing an entry or when no data
//! cache is given.
ExecutorCompileTimeEntry(
ExecutorCompileTimeInfoCache* data_cache,
MakerFnType fn);
//! Unified interface to get actual data, either from cache
//! or from factory function.
EntryDataType& get() {
return *data_ptr_;
}
private:
//! Internal data owing pointer that will manage the computed
//! data where there is no data cache.
EntryDataTypeOwnPtr owned_data_ = nullptr;
//! Pointer to the valid data entry that could be accessed.
EntryDataType* data_ptr_ = nullptr;
};
} // namespace caching
//! Returns the vector of tensorviews that will be used to bind parallel
//! dimensions.
std::vector<IterDomain*> getParallelBindingsIterDomains(
GpuLower* lower,
const std::vector<TensorView*>& used_tvs);
using ParallelExtentMap =
std::unordered_map<ParallelType, std::vector<const Val*>, TypeHash>;
//! Returns the extents of all parallel binding iterdomains corresponding
//! to each parallel type.
std::unique_ptr<ParallelExtentMap> getParallelIterExtents(
std::vector<IterDomain*>& parallel_binding_ids);
//! Returns the simplified set of extents necessary for launch parameter
//! binding.
std::unique_ptr<ParallelExtentMap> getSimplifiedParallelIterExtents(
GpuLower* lower,
std::vector<IterDomain*>& parallel_binding_ids);
//! Returns the symbolic or constant extetns of warp padded parallel
//! iterdomains in the given vector.
std::unique_ptr<caching::WarpPaddedExtentsInfo> getWarpPaddedExtentsInfo(
kir::Kernel* lower,
std::vector<IterDomain*>& parallel_binding_ids);
void validateVectorizedTensors(
kir::Kernel* kernel,
const KernelArgumentHolder& args,
const std::vector<at::Tensor>& outputs,
caching::ExecutorCompileTimeInfoCache* data_cache,
kir::ExpressionEvaluator& expr_eval);
} // namespace executor_utils
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch
|