1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284
|
#pragma once
#include <ATen/core/ivalue.h>
#include <c10/macros/Export.h>
#include <c10/util/Exception.h>
#include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_container.h>
#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
//! Usage: FusionGuard and Fusion are required user interfaces for any operation
//! underlying the code generator. In order to create values, expressions, and
//! generate code a Fusion instance must be active. It is the responsibility of
//! the user to create a Fusion instance and register it with the fusion guard.
//! The simplest example of this is:
//!
//! Fusion fusion;
//! FusionGuard fg(&fusion);
//!
//! Once a fusion is active all values and operations will be registered with
//! it.
//!
//! FusionGuard and Fusion are critical to the lifetime model of the IR system.
//! FusionGuard is a convenient way to set what base container instance holds
//! the defined IR. Statements that are defined are registered through the
//! FusionGuard with a particular Fusion. FusionGuard provides convenient
//! methods to access the active fusion so it doesn't need to be passed around
//! constantly. Any IR node derived classes from Statement must register with
//! Fusion to avoid memory leaks.
//!
//! Fusion is generally thought of as a translated fusion group from the JIT. It
//! is likely a single kernel, although, we don't have to stick to this in the
//! future and could in theory generate multiple kernels with an executor to run
//! them.
//!
//! Fusion also allows users to set input/output values that will allow us to
//! figure out how to hook up runtime data to and from the JIT as well as
//! provide us mechanisms for dependency analysis and DCE including safety
//! checks.
class Fusion;
class TensorView;
class WelfordResult;
class SegmentCandidateFinder;
class SegmentedFusion;
class KernelArgumentHolder;
//! Fusion Guard is our "context manager". It holds the actrive fusion and
//! allows it to be accessed anywhere through FusionGuard::getCurFusion()
class TORCH_CUDA_CU_API FusionGuard {
public:
Fusion* prev_fusion;
//! Set the active fusion so it can be manipulated.
explicit FusionGuard(Fusion* fusion);
~FusionGuard();
static Fusion* getCurFusion();
static void setCurFusion(Fusion* fusion);
};
//! Fusion is mutable but unique. Nodes cannot be copied in any way from one
//! Fusion to another. If anything like that is desired, it would require
//! duplicating all associated values and exprs. Fusion is considered to be SSA,
//! though this could also change in the future if there is a good reason to do
//! so.
//!
//! The Fusion owns the whole IR graph (Vals and Exprs)
//!
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
class TORCH_CUDA_CU_API Fusion : public IrContainer {
typedef std::unordered_map<int, std::vector<int64_t>> PermutationMap;
public:
Fusion() = default;
Fusion(const Fusion& other);
Fusion(Fusion&& other) noexcept;
Fusion& operator=(const Fusion& other);
Fusion& operator=(Fusion&& other) noexcept;
~Fusion();
friend void swap(Fusion& a, Fusion& b) noexcept;
void clear() noexcept;
//! Break dependency chains associated with Expr, remove references to expr
//! delete expr
void removeExpr(Expr* expr) override;
//! Completely remove val from the fusion, break all dependencies associated
//! with it
void removeVal(Val* val) override;
//! Register input as an input of the fusion
void addInput(Val* input);
//! Register output as an output of the fusion
void addOutput(Val* output);
//! Deregister input as an input of the fusion
void removeInput(Val* input);
//! Deregister output as an output of the fusion
void removeOutput(Val* output);
//! Replace output with another value
void replaceOutput(Val* output, Val* replacement);
//! Assert that all leaves found from outputs are registered as an input
void validateInputs();
//! Print this fusion to the console
void print();
//! Print Arith exprs
//! \param from_outputs_only Only print exprs reachable from outputs
void printMath(bool from_outputs_only = true);
//! Print transformations used in fusion (can be very verbose)
void printTransforms();
//! Lower the fusion and print a kernel
void printKernel(DataType index_type = DataType::Int);
//! Return a list of topologically sorted expressions. This only includes
//! exprs required to genereate registered outputs.
std::vector<Expr*> exprs();
//! Return a vector of fusion inputs that feed this Val
std::vector<Val*> inputsOf(Val* val);
//! Return all Vals in math expressions that cannot be eliminated.
//!
//! It is generally equivalent to vals that are used to generate
//! outputs, however, when a multi-output expression exists, and only
//! some of the outputs are used, the remaining unused outputs are
//! also included as they must show up in the final code.
std::vector<Val*> usedMathVals();
//! Returns all vals that are produced by used math expressions and
//! also do not have further consumers.
//!
//! In the case of an active multi-output expressions, the returned vector
//! will include the expression outputs that did not lead to an fusion
//! output.
std::vector<Val*> terminatingMathVals();
//! Return all Exprs that use val
std::unordered_set<Expr*> unordered_uses(const Val* val) const;
//! Return the Expr that produces val
Expr* definition(const Val* val) const;
//! Indicate to kernel to set itself up to generate random numbers
bool isStochastic();
//! Run fusion segmentation algorithm to create a segmented fusion
std::unique_ptr<SegmentedFusion> segment(const KernelArgumentHolder& args);
const auto& inputs() const {
return inputs_;
}
std::vector<Val*> inputsAndCreated();
const auto& outputs() const {
return outputs_;
}
std::vector<Val*> getTerminatingOutputs() const;
// Aliasing output to input value, this is a WAR to allow inplace update on
// input tensor.
// Note: this is not always safe and should be used with extra caution.
// Currently the only place it's used is in the running stats update for batch
// normalization.
// TODO: alias should be made aware to segmentation, so we'll always include
// the input tensor to the section where output is produced.
void aliasOutputToInput(Val* output, Val* input);
Val* getOutputAlias(Val* output);
std::unordered_set<int> getOutputAliasIndices() const;
std::vector<std::pair<int, int>> getInputAliasIndices() const;
// mark input at index to be permuted by permutation
void setPermutationOnInput(int index, std::vector<int64_t> permutation) {
permuted_input_map_.insert({index, permutation});
}
// mark output at index to be restored by permutation
void setPermutationOnOutput(int index, std::vector<int64_t> permutation) {
permuted_output_map_.insert({index, permutation});
}
// return a map of indices to permutation, which indicates all input tensors
// that needs to be permuted
const PermutationMap& getPermutationInputMap() const {
return permuted_input_map_;
}
// return a map of indices to permutation, which indicates all output tensors
// that needs to be permuted
const PermutationMap& getPermutationOutputMap() const {
return permuted_output_map_;
}
bool isTVUseInfoValid() {
return all_tv_uses_valid_;
}
bool isUpdatingTVUseInfo() {
return is_during_update_uses_;
}
const auto& ioAlias() const {
return io_alias_;
}
protected:
friend SegmentCandidateFinder;
friend SegmentedFusion;
friend class TranslateApplicableWelford;
friend Val;
static IrCloner copy(const Fusion* from, Fusion* to);
//! Register the Val with this fusion
virtual void registerVal(Val* val) override;
//! Register expr with this fusion.
//! When we register an expression, we want to update the dependency tracking
//! of Vals. If this container is a not a Kernel, it will remove previous
//! definitions of outputs and register this Expr as the definition. Otherwise
//! will update definition if not previously set, but will not remove old
//! definitions.
virtual void registerExpr(Expr* expr) override;
//! Clear Expr's from TV uses that are not required to produce outputs from
//! inputs. Only other place this is used (other than Fusion) is in
//! Val::uses()
void resetTvUses();
private:
// Determine if the two values are compatible for aliasing
// Same DataType, ValType, and number of dimensions
bool isAliasCompatible(Val* left, Val* right);
private:
// Fusion inputs and outputs
std::vector<Val*> inputs_;
std::vector<Val*> outputs_;
// io alias pointing from output to input
std::unordered_map<Val*, Val*> io_alias_;
// See Note [ Permutation support in nvfuser ]
// map from indices of input tensor to permutation
PermutationMap permuted_input_map_;
// map from indices of output tensor to permutation
PermutationMap permuted_output_map_;
// Records if the current use data in the IR nodes are valid
// the states are either all valid or all invalid
bool all_tv_uses_valid_ = false;
bool is_during_update_uses_ = false;
};
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch
|