1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499
|
#pragma once
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/maxinfo_propagator.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/reduction_heuristic.h>
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
class SchedulerRuntimeInfo;
class ExpressionEvaluator;
class HeuristicSummary;
namespace scheduler_utils {
// Assume any only half of the register file is available to spend on buffers,
// this is because when we allocate a buffer in register is has to be accesed
// with a compile time coonstant index. Unfortunately nvcc seems to be using
// many registers for indexing. This is a bad estimation of extra register use,
// but it's hard to get a better one.
constexpr int64_t register_file_size = 256 * 1024 / 2;
constexpr int64_t x_grid_limit = ((int64_t)1 << (int64_t)31) - (int64_t)1;
constexpr int64_t y_grid_limit = 65535;
constexpr int64_t z_grid_limit = 65535;
constexpr int64_t z_block_limit = 64;
// Largest Power of 2 less-than n
constexpr int64_t lastPow2(int64_t n) {
TORCH_INTERNAL_ASSERT(n >= 0);
n |= (n >> 1);
n |= (n >> 2);
n |= (n >> 4);
n |= (n >> 8); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
n |= (n >> 16); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
n |= (n >> 32); // NOLINT(cppcoreguidelines-avoid-magic-numbers)
return std::max((int64_t)1, n - (n >> 1));
}
// Div x by y, but min at 1
inline int64_t safeDiv(const int64_t x, const int64_t y) {
return std::max(x / y, (int64_t)1);
}
// Split the given dimensions in `to_split`. Also update the dimensions in
// `to_update` to the positions in the splitted tensor. Splitting one dimension
// multiple times is supported, and if this is the case, then the order of
// `to_split` matters. All given dimensions are numbers before any split.
TORCH_CUDA_CU_API void splitDims(
TensorView* tv,
std::vector<std::pair<size_t, size_t>> to_split, // (dim, size)
std::vector<size_t>& to_update);
TORCH_CUDA_CU_API inline void splitDims(
TensorView* tv,
std::vector<std::pair<size_t, size_t>> to_split) { // (dim, size)
std::vector<size_t> unused;
splitDims(tv, std::move(to_split), unused);
}
// Merge all the given dimensions in `to_merge` into a single dimension. Also
// update the dimensions in `to_update` to the positions in the merged tensor.
// Returns the merged dimension. All given dimensions are numbers before any
// merge.
TORCH_CUDA_CU_API c10::optional<size_t> mergeDims(
TensorView* tv,
std::vector<size_t> to_merge,
std::vector<size_t>& to_update);
TORCH_CUDA_CU_API inline c10::optional<size_t> mergeDims(
TensorView* tv,
std::vector<size_t> to_merge) {
std::vector<size_t> unused;
return mergeDims(tv, std::move(to_merge), unused);
}
// Merge all reduction to the right side and returns total number of
// reduction axes. Don't merge is typically used for trivial reductions.
size_t mergeReduction(
TensorView* tv,
const std::unordered_set<IterDomain*>& dont_merge = {});
// merge all non-reduction axes to the left side and returns total number of
// iteration axes. Don't merge is typically used for trivial reductions.
size_t mergeNonReduction(
TensorView* tv,
const std::unordered_set<IterDomain*>& dont_merge = {});
// Propagate the parallelization from the selected dimensions of the reference
// tensor to their corresponding dimensions in all selected tensors in the DAG.
// Position `pos` means selecting all the dimensions [0, 1, ..., pos - 1]. pos =
// -1 means selecting all dimensions. `selected_tvs` are selected tensors in the
// DAG. Empty `selected_tvs` means selecting all tensors in the fusion of
// `reference_tv`. `selected_parallel_types` are the selected parallel types.
// Empty `selected_parallel_types` means selecting all parallel types.
TORCH_CUDA_CU_API void parallelizeAllLike(
TensorView* reference_tv,
int64_t pos = -1,
std::vector<TensorView*> selected_tvs = {},
const std::unordered_set<ParallelType>& selected_parallel_types = {},
bool propagate_padding = true);
TORCH_CUDA_CU_API inline void parallelizeAllLike(
TensorView* reference_tv,
std::vector<TensorView*> selected_tvs,
const std::unordered_set<ParallelType>& selected_parallel_types = {},
bool propagate_padding = true) {
parallelizeAllLike(
reference_tv,
-1,
std::move(selected_tvs),
selected_parallel_types,
propagate_padding);
}
TORCH_CUDA_CU_API void computeAtInputs(
TensorView* consumer,
int pos,
ComputeAtMode mode = ComputeAtMode::Standard);
TORCH_CUDA_CU_API void computeWithOutputs(
TensorView* producer,
int pos,
ComputeAtMode mode = ComputeAtMode::Standard);
struct PersistentBufferInfo {
std::vector<TensorView*> persistent_buffers;
std::unordered_set<IterDomain*> unmappable_dims;
// Persistent buffers are needed until the path through the reduction -
// broadcast chain is resolved by any other chain using the persistent buffer
// that is not going through a reduction. This assumes all reduction paths
// have the same reduction pattern. Order is the same as persistent_buffers
std::vector<std::vector<TensorView*>> persistent_buffer_resolution_points;
// Not all persistent buffers can be projected to inputs, if a buffer can be
// projected to the inputs which may reduce the persistent buffer size (BN
// Backwards specifically) then keep track of it here. Persistent buffers that
// have a persistent buffer/reduction before them should not be projected
// through that.
std::vector<TensorView*> projectable_persistent_buffers;
// Track inputs of input projectable buffers
std::vector<TensorView*> projectable_buffer_inputs;
// Map unmappable dims to projectable_buffer_inputs
std::unordered_set<IterDomain*> unamppable_dims_projected_to_inputs;
};
// Buffers whos roots can't map to all producer roots based on compute at. These
// are the buffers we would make persistent in a persistent kerenl or would have
// to recompute if we can't make a persistent kernel. This function will also
// return inputs as being marked persistent if they follow this pattern. It is
// important to note however inputs don't strictly have to be persistent as they
// can simply be read multiple times from GMEM in the same kernel.
TORCH_CUDA_CU_API PersistentBufferInfo persistentBuffers(Fusion* fusion);
struct TvProperties {
// How many elements in tensor view are there to reduce.
int64_t total_reduction_numel = 1;
// How many reductions do we need to perform, i.e. how many iter dimension.
// elements are there
int64_t total_iteration_numel = 1;
// Is the inner most dimension a reduction, if no reductions mark true.
bool fastest_dim_reduction = true;
// How many elements in the inner most dimension merging surrounding domains
// that match in type. This is used for 3D schedulers in
// reduction/normalization.
int64_t inner_most_dimension_numel = 1;
// Same thing as above, but the number of dimensions instead of the numel.
int64_t inner_most_dimension_ndims = 1;
// Merging neighboring iteration domains, and reduction domains, what's the
// resulting dimensionality of the problem.
int64_t dimensionality = 1;
};
// Fill TvProperties structure about tv
TvProperties getProperties(
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
TensorView* tv);
// Struct to store persistent buffer sizes. also holds the persistent buffer
// size of the buffers are projected to the inputs.
struct PersistentBufferSizeReturn {
int64_t persistent_buffer_size = 0;
int64_t projected_persistent_buffer_size = 0;
};
// Compute the amount of register space would be needed to perform this kernel
// persistently, only based on buffers that must be persistent, and based on the
// maximum of all minimum size requirement. i.e. if must be persistent, only
// hold persistent dimension.
TORCH_CUDA_CU_API PersistentBufferSizeReturn persistentBufferSize(
Fusion* fusion,
SchedulerRuntimeInfo& runtime_info,
PersistentBufferInfo& persistent_buffers,
HeuristicSummary* data_cache = nullptr);
// Returns a set of all iteration domains (in roots of tensors) that map to a
// trivial reduction
std::unordered_set<IterDomain*> getTrivialReductionMap(Fusion* fusion);
// Merges tensor view to the form:
// [IterationDomain, ReductionDomain, TrivialReductionDim0,
// TrivialReductionDim1, ...] Returns if <iteration dimensions, reduction
// dimensions>
std::pair<bool, bool> canonicalDimReduction(
Fusion* fusion,
TensorView* tv,
bool schedule_3D = false);
// Return a list of tensor views that are outputs of reduction operations. If
// multiple outputs of an expression are found, only include one in the list
TORCH_CUDA_CU_API std::vector<TensorView*> getReductionTvs(
Fusion* fusion,
bool ignore_trivial = true);
// Returns a list of TensorViews that are the consumer tv for a view operation.
std::vector<TensorView*> getViewTVs(Fusion* fusion);
// Reset inputs and outputs to global memory, everything else to local.
void clearMemorySpace(Fusion* fusion);
// Returns cached after tensors of the fusion inputs if unrolled. Otherwise
// return empty vector.
TORCH_CUDA_CU_API std::vector<TensorView*> cacheInputs(
Fusion* fusion,
bool unroll);
// Returns the pairs of <cache of each fusion output, corresponding output> for
// all outputs.
TORCH_CUDA_CU_API std::vector<std::pair<TensorView*, TensorView*>>
cacheAndForkOutputs(Fusion* fusion, bool unroll);
// Ignores broadcast and reduction, returns iter domain in root domain that's
// "inner most". If this is an rfactored reduction domain, actually check the
// root domain, this is because the rfactored reduction tensorview has the
// vectorized dimension, but that means the rfactor domain could have reordered
// what we consider the "inner most" allocated position on it if we consider the
// rfactor dimension.
//
// If reduction tv and has rfactor return root domain, otherwise return rfactor
// domain.
IterDomain* innerMostRootDim(TensorView* tv);
// Looks through fusion and finds all dims that match to the one provided in
// the tensorview provided. Iter domain must be a root domain. If inner_only,
// will only map dimensions if they're the inner most position. This is
// important when projecting a dimension between an rfactor position and its
// root position when mapping from consumer to producer. If inner_only=true,
// takes the rfactor/root dimensions that maps, projects it to the root/rfactor
// domain, but only following the inner most pass when encounting split/merge.
// When propagating backward, for split it will only propagate backwards if the
// mapped dimension is the inner portion of the split. For merge, inner_only
// doesn't make a dimension and will propagate through the inner portion of the
// merge. When propagating forward, the logic is symmetric with the backward
// case.
class FindAllMappedDims : public MaxInfoSpanningTree::Propagator {
std::unordered_map<TensorView*, IterDomain*> mapped_root_ids_;
std::unordered_map<TensorView*, IterDomain*> mapped_rfactor_ids_;
TensorView* starting_tv_ = nullptr;
IterDomain* starting_id_ = nullptr;
bool inner_only_;
public:
FindAllMappedDims(TensorView* from, IterDomain* starting_id, bool inner_only);
virtual void setUp() override;
virtual void propagateC2P(TensorView* from, TensorView* to) override;
virtual void propagateP2C(TensorView* from, TensorView* to) override;
virtual void propagateSibling(TensorView* from, TensorView* to) override;
std::unordered_set<IterDomain*> get() const;
};
// Checks if tensor view has an iteration domain in vector dims in its inner
// most root position (excluding broadcast and reduction), and checks if it is a
// contiguous dimension
bool hasInnerDim(
TensorView* tv,
std::unordered_set<IterDomain*> vector_dims,
bool should_vectorize);
// Returns all inputs and outputs that share the inner most dimension of the
// provided reference. If reference is an input it ignores reduction axes, will
// ignore all broadcast axes. If inner_only, will require inner->inner mapping
// in view, otherwise, it allows all inner->any mapping. If vectorize_pass, will
// check contiguity for vectorization, otherwise it just checks it has that
// inner dim.
std::vector<TensorView*> getInputsOutputsWithInnerDim(
TensorView* reference_tv,
bool inner_only,
bool vectorize_pass);
// Structure to hold byte multiples for break points. I.e. if we have the
// tensors:
// T0[I0, I1] float
// T1[I0, I1] bool
// T2[I0] half
// T3 [I1] double
// and a break point of 1 the multiples would be:
// lhs_multiple = 4 + 1 + 2 = 7
// rhs_multiple = 4 + 1 + 8 = 13
struct BroadcastMultiple {
int64_t rhs_multiple = 0;
int64_t lhs_multiple = 0;
};
// Returns a vector of counts, size = reference_tv->getRootDomain().size(), each
// entry [i] is the number of inputs/outputs that have a non-broadcast dimension
// mapped to the corresponding dimension in reference_tv. Count includes
// reference_tv if reference_tv is an input or output. Count is multiplied by
// data type size.
std::vector<BroadcastMultiple> getBroadcastMultiples(
TensorView* reference_tv,
DataType index_type);
//! Collect maximum vectorization word size of a tensor whose
//! innermost domain is leaf_merged_domain. Contig merging is taken
//! into account to expand vectorization if possible.
size_t collectMaxVectorizeSizeWithContigMerge(
TensorView* tv,
IterDomain* leaf_merged_domain,
size_t max_word_size_in_byte,
ExpressionEvaluator& expression_evaluator,
DataType index_type);
namespace matmul_utils {
//! Utilities in this namespace facilitates scheduling matmul kernels with
//! hierarchichal tiling specified in MatMulTileOptions.
//! Schedule utility for matmul prolog:
//! Use all the threads on a CTA tile to load matmul operands
//! into shared memory with the given vectorization word.
//! TODO:
//! will need to add bank conflict removal swizzle in a follow up.
TORCH_CUDA_CU_API void scheduleContiguousVectorLoad(
TensorView* tv,
MatMulTileOptions tile,
int vector_word,
bool vectorize = true);
//! Schedule utility for mma output in matmul main loop:
//! Realize the hierarchical tiling based on the given tiling options.
//! TODO: rewrite this one with makeTile
TORCH_CUDA_CU_API void scheduleWarpTileWithReduction(
TensorView* tv,
MatMulTileOptions tile);
//! Schedule utility for mma output in matmul main loop:
//! Realize the hierarchical tiling based on the given tiling options
//! on consumers of mma ops in epilog.
//! TODO: remove this one eventually.
TORCH_CUDA_CU_API void scheduleWarpTileWithNoReduction(
TensorView* tv,
MatMulTileOptions tile);
//! Lower level primitive spliting inner iterdomains into tiles:
//! Eg.
//! A[B,I0,I1,I2] -> makeTile({1,2,3})
//! Gives A[B, I0o, I1o, I2o, I0i(1), I1i(2), I2i(3)]
TORCH_CUDA_CU_API void makeTile(TensorView* tv, std::vector<int> tile_sizes);
//! Order the inner tile dimensions as the original order in
//! root domain. Also putting broadcast domains on the left.
//! Eg. A[I0o,I1o,B2o,I0i,I1i,B2i] (root domain: I1,B,I0)
//! -> A[I0o, I1o, B2o, B2i, I1i, I0i]
//! This is used to facilitate data layout swizzling and
//! defining vectorized loads.
TORCH_CUDA_CU_API void orderTiledConcreteIdAsRoot(TensorView* tv);
//! Orders the root id ordering of the given tv as
//! [Batch, Previous Reduction, M, N, K]
//! for easier processing of later scheduling steps.
//!
//! This matching works on root domain only, and
//! will throw if the tv has a leaf iterdomain that is
//! not a root id.
TORCH_CUDA_CU_API void canonicalizeMmaTvOrdering(TensorView* tv);
} // namespace matmul_utils
//! Propagate current transformations on from_tv up to the given
//! position, to all tensorviews on the owning fusion that has
//! a connection with `from_tv` on the fusion graph.
TORCH_CUDA_CU_API void transformPropagateToAllFrom(
TensorView* from_tv,
int pos);
//! A type of custom transform propagator that propagates iterdomain
//! transforms from a source tv to all tvs that are selected
//! using a "direction" and a "boundary".
//!
//! The propagation model always assumes a `from_tv`, a `direction` and a
//! `boundary`.
//!
//! This propagator will only transform producers and consumers
//! of `from_tv`, and all propagation modes **require** a boundary to be
//! specified to signify where the propagation should stop.
//!
//! There are currently three modes of propagation: forward, backward and
//! both-way, see comment on the interface functions for details.
struct TORCH_CUDA_CU_API BoundedDirectionalTransformPropagator {
//! Custom option container for configuring
//! the transform propagation actions.
//! All option values default to false unless
//! the corresponding setter is called.
struct Options {
//! If true, the transform propagator will
//! also propagate parallel types from
//! `from_tv` to all selected tvs.
bool propagate_parallel_type = false;
//! If true, the specified boundary tvs
//! will also be replayed as `from_tv`.
//! If false, they will not be affected
//! by the propagation pass.
bool transform_boundary = false;
//! Sets the position boundary in parallel
//! type propagation, see comment on
//! scheduler_utils::parallelizeAllLike.
//! Only used if propagate_parallel_type==true.
int parallel_propagation_pos = -1;
//! Setter for enabling parallel type
//! propagation. see comment on the variable.
//!
//! \param up_to_pos, sets the parallel type
//! propagation boundary. see comment on
//! scheduler_utils::parallelizeAllLike.
Options propagateParallelType(int up_to_pos = -1) {
propagate_parallel_type = true;
parallel_propagation_pos = up_to_pos;
return *this;
}
//! Setter for enabling propagation to
//! boundary tvs. see comment on the variable
Options propagateToBoundary() {
transform_boundary = true;
return *this;
}
};
//! Replay transforms from tensorview `from`
//! to the tensorviews that are consumers
//! of boundary tensorviews in `to` and producers of `from`.
static void backward(
TensorView* from,
int pos,
std::vector<TensorView*> to,
c10::optional<Options> options = c10::nullopt);
//! Replay transforms from tensorview `from`
//! to the tensorviews that are producers
//! of boundary tensorviews in `to` and consumers of `from`.
static void forward(
TensorView* from,
int pos,
std::vector<TensorView*> to,
c10::optional<Options> options = c10::nullopt);
//! Replay transforms from tensorview `from`
//! to all the tensorviews that are consumers
//! of tensorviews in `backward_to` and producers
//! of tensorviews in `forward_to` while being
//! either a producer or a consumer of tensorview `from`.
static void bothWays(
TensorView* from,
int pos,
std::vector<TensorView*> backward_to,
std::vector<TensorView*> forward_to,
c10::optional<Options> options = c10::nullopt);
private:
//! Utility function:
//! Will realize the transform propagation to the
//! tensorview's in `included_tvs`.
//! Assumes that all tvs in included_tvs are either
//! a producer or a consumer of from_tv.
static void propagate(
TensorView* from_tv,
int pos,
std::unordered_set<TensorView*> included_tvs,
Options options);
};
} // namespace scheduler_utils
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch
|