1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
|
#pragma once
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <bitset>
#include <map>
// Provides utilities for dealing with nested ForLoop and IfThenElse scopes
namespace torch {
namespace jit {
namespace fuser {
class ThreadPredicateMap;
namespace scope_utils {
// Grab the ForLoop starting from scope working out
std::vector<kir::ForLoop*> getLoops(Expr* scope);
// Track how far our for loop scope is
unsigned int computeForDepth(Expr* scope);
// Push back an expr to scope
void pushBack(Expr* scope, Expr* expr);
// Insert expr in scope before ref
void insertBefore(Expr* scope, Expr* ref, Expr* expr);
// Returns if expr is in scope, does not check nested scopes
bool exprInScope(Expr* scope, Expr* expr);
// Return the parent of the active scope
Expr* getParent(Expr* scope);
// Open a new inner most for loop
kir::ForLoop* openFor(Expr* scope, IterDomain*);
// Provide a new for loop matching the one provided, sets parent_scope as
// parent_scope, but does not insert into parent scope.
kir::ForLoop* cloneLoopNest(kir::ForLoop* to_clone, Expr* parent_scope);
// Run through a scope and replace expressions inside with replacement_map
void replaceExprsInScope(
Expr* scope,
std::unordered_map<Expr*, Expr*> replacement_map);
Expr* firstInnerMostScope(Expr* scope);
} // namespace scope_utils
namespace ir_utils {
// Somtimes we want to temporarily view a tensorview with another tensordomain.
// This isn't a permanent transformation, but in indexing we want to index
// producers with a consumer set of indices, so we need to view the producer
// transformed like consumer while we index. This will set the tv with td for
// the life of this context guard.
class TVDomainGuard {
private:
TensorView* tv_;
TensorDomain* prev_domain;
public:
explicit TVDomainGuard(TensorView* _tv, TensorDomain* td);
~TVDomainGuard();
};
// Return inputs of provided IterDomains that are IterDomains
std::vector<IterDomain*> iterDomainInputsOf(const std::vector<IterDomain*>&);
// Return inputs of provided IterDomains that are IterDomains, order as the
// second provided vector.
std::vector<IterDomain*> iterDomainInputsOfOrderedAs(
const std::vector<IterDomain*>& of,
const std::vector<IterDomain*>& order);
std::vector<Val*> indices(std::vector<kir::ForLoop*>);
bool isTV(const Val* const);
bool isTVOp(const Expr*);
TensorView* getTVOutput(const Expr*);
bool isScalarOp(const Expr*);
void ASSERT_EXPR(Statement*);
bool isScope(const Expr*);
Expr* asExpr(Statement*);
// TODO: Remove in favor of ->as<TensorView>()
TensorView* asTV(Val*);
// TODO: Remove in favor of ->as<ForLoop>()
kir::ForLoop* asForLoop(Statement*);
// TODO: Remove in favor of ->as<TensorView>()
const TensorView* asConstTV(const Val*);
bool isUnrolledFor(const Expr*);
// Represents mapping to bool from BIDx, BIDy, BIDz, TIDx, TIDy and TIDz.
class ParallelTypeBitmap {
public:
static constexpr int num_p_type = 6;
ParallelTypeBitmap() = default;
bool get(ParallelType pt) const;
bool set(ParallelType pt, bool);
ParallelTypeBitmap operator&=(const ParallelTypeBitmap& other);
ParallelTypeBitmap operator|=(const ParallelTypeBitmap& other);
ParallelTypeBitmap operator^=(const ParallelTypeBitmap& other);
ParallelTypeBitmap operator~() const;
bool none() const;
bool any() const;
bool all() const;
bool operator[](size_t pos) const;
std::map<ParallelType, bool> getMap() const;
private:
ParallelTypeBitmap(const std::bitset<num_p_type>& bs) : bitset_(bs) {}
std::bitset<num_p_type> bitset_;
const static std::unordered_map<ParallelType, int, TypeHash> pt_to_offset_;
const static std::unordered_map<int, ParallelType> offset_to_pt_;
};
ParallelTypeBitmap operator&(
const ParallelTypeBitmap& lhs,
const ParallelTypeBitmap& rhs);
ParallelTypeBitmap operator|(
const ParallelTypeBitmap& lhs,
const ParallelTypeBitmap& rhs);
ParallelTypeBitmap operator^(
const ParallelTypeBitmap& lhs,
const ParallelTypeBitmap& rhs);
// Returns a ParallelTypeBitmap representing which domain needs
// blockBroadcast.
// Even when a domain is broadcast and parallelized, it does not need
// blockBroadcast unless it is predicated.
ParallelTypeBitmap getParallelBroadcastDomains(
const Val* bop_out,
const ThreadPredicateMap& preds);
} // namespace ir_utils
namespace loop_utils {
// I wanted to make the tv's in these util functions constant, but that started
// a long const-ness project going into TensorView (making functions const
// there) then into lower_loops where we sort exprs.
// TODO: We should fix this when we have some time.
// Figure out which loop the allocation needs to be in. Returns nullptr if
// outside the first loop in loops. Also find out which index in tv the
// first dimension that needs to be allocated is. Meaning we need to allocate
// that local axis and above.
std::pair<kir::ForLoop*, int64_t> getAllocPoint(
TensorView* tv,
const std::vector<kir::ForLoop*>& loops);
// Go through exprs mapping root domains from producer to consumer. Provides a
// ground truth for how root domains map through our expressions. Needed for
// unrolling.
std::unordered_map<IterDomain*, IterDomain*> p2cRootMap(
const std::vector<Expr*>& exprs);
// Given a root IterationDomain and a p2c_root_map find the root IterationDomain
// furthest down in the sorted expr list it maps to. Needed for unrolling.
IterDomain* getTermIDInMap(
IterDomain* root_id,
std::unordered_map<IterDomain*, IterDomain*> p2c_root_map);
} // namespace loop_utils
} // namespace fuser
} // namespace jit
} // namespace torch
|