1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418
|
#pragma once
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_base_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_internal_nodes.h>
#include <torch/csrc/jit/ir/ir.h>
/*
* Nodes in here are intended to be "user facing" users in this sense being
* those that want to be able to generate CUDA code.
*/
namespace torch {
namespace jit {
namespace fuser {
/*
* A Bool value.
* This value can be a symbolic value (defined after the kernel
* is compiled) or a constant value (inlined into the kernel definition).
*/
class TORCH_CUDA_API Bool : public Val {
public:
~Bool() = default;
Bool() : Val(ValType::Scalar, DataType::Bool), maybe_value_{c10::nullopt} {}
explicit Bool(bool _value)
: Val(ValType::Scalar, DataType::Bool), maybe_value_{_value} {}
Bool(const Bool* src, IrCloner* ir_cloner);
Bool(const Bool& other) = delete;
Bool& operator=(const Bool& other) = delete;
Bool(Bool&& other) = delete;
Bool& operator=(Bool&& other) = delete;
bool isSymbolic() const {
return !(maybe_value_.has_value());
}
bool isConst() const {
return maybe_value_.has_value();
}
c10::optional<bool> value() const {
return maybe_value_;
}
bool sameAs(const Bool* const other) const;
private:
const c10::optional<bool> maybe_value_;
};
/*
* A Float32 value. For now we don't have any other type besides
* Float32. This value can be a symbolic value (defined after the kernel
* is compiled) or a constant value (inlined into the kernel definition).
*/
class TORCH_CUDA_API Float : public Val {
public:
using ScalarType = double;
~Float() = default;
Float() : Val(ValType::Scalar, DataType::Float), maybe_value_{c10::nullopt} {}
explicit Float(ScalarType _value)
: Val(ValType::Scalar, DataType::Float), maybe_value_{_value} {}
Float(const Float* src, IrCloner* ir_cloner);
Float(const Float& other) = delete;
Float& operator=(const Float& other) = delete;
Float(Float&& other) = delete;
Float& operator=(Float&& other) = delete;
bool isSymbolic() const {
return !(maybe_value_.has_value());
}
bool isConst() const {
return maybe_value_.has_value();
}
c10::optional<ScalarType> value() const {
return maybe_value_;
}
bool sameAs(const Float* const other) const;
private:
const c10::optional<ScalarType> maybe_value_;
};
/*
* An IEEE 754 Float16 value.
* This value can be a symbolic value (defined after the kernel
* is compiled) or a constant value (inlined into the kernel definition).
*/
class TORCH_CUDA_API Half : public Val {
public:
~Half() = default;
Half() : Val(ValType::Scalar, DataType::Half), maybe_value_{c10::nullopt} {}
explicit Half(float _value)
: Val(ValType::Scalar, DataType::Half), maybe_value_{_value} {}
Half(const Half* src, IrCloner* ir_cloner);
Half(const Half& other) = delete;
Half& operator=(const Half& other) = delete;
Half(Half&& other) = delete;
Half& operator=(Half&& other) = delete;
bool isSymbolic() const {
return !(maybe_value_.has_value());
}
bool isConst() const {
return maybe_value_.has_value();
}
c10::optional<float> value() const {
return maybe_value_;
}
bool sameAs(const Half* const other) const;
private:
const c10::optional<float> maybe_value_;
};
// An Int64 value. If used for indexing it's set as size_t. Otherwise it's an
// inlined literal in the kernel.
class TORCH_CUDA_API Int : public Val {
public:
using ScalarType = int64_t;
~Int() = default;
Int() : Val(ValType::Scalar, DataType::Int), maybe_value_{c10::nullopt} {}
explicit Int(ScalarType _value)
: Val(ValType::Scalar, DataType::Int), maybe_value_{_value} {}
Int(const Int* src, IrCloner* ir_cloner);
Int(const Int& other) = delete;
Int& operator=(const Int& other) = delete;
Int(Int&& other) = delete;
Int& operator=(Int&& other) = delete;
bool isSymbolic() const {
return !(maybe_value_.has_value());
}
bool isConst() const {
return maybe_value_.has_value();
}
c10::optional<ScalarType> value() const {
return maybe_value_;
}
bool sameAs(const Int* const other) const;
private:
const c10::optional<ScalarType> maybe_value_;
};
class ComputeAt;
class TransformReplay;
class TransformIter;
class OptOutMutator;
class LoopNestGenerator;
namespace ir_utils {
class TVDomainGuard;
}
// TensorView is our primitive Tensor Type used in code generation. It can be
// thought of as representing physical memory, however, its dimensionality is
// modifed as split/merge/computeAt functions are called. The history of
// these transformations are kept and used for generating actual code referncing
// physical memory. Generally when users are thinking of code generation in
// reference to a Tensor, this is the class they should be interacting with.
//
// The reason we need both TensorView and TensorDomain is that we need to have a
// record of both what is being computed and how it is being computed. For
// example we may have the operation: TV3[I, J, K] = TV2[I, J, K] + TV1[I, J, K]
// The mathematical operations here are on the tensor views TV1, TV2, and TV3.
// This operation is a pointwise operation. To compute this pointwise operation
// we iterate over the 3D TensorDomain [I, J, K], where K is the fastest
// changing dimension.
//
// TODO: Need to work on the const model for TensorView, making all functions
// that should be const, const. Gave this a try but expanded really quickly.
// getComputeAtAxis not being const because it can return a TV that some expect
// to be non-const is the biggest headache.
class TORCH_CUDA_API TensorView : public Val {
public:
~TensorView() = default;
TensorView(const TensorView& other) = delete;
TensorView& operator=(const TensorView& other) = delete;
TensorView(TensorView&& other) = delete;
TensorView& operator=(TensorView&& other) = delete;
TensorView(
TensorDomain* _domain,
DataType dtype,
MemoryType mtype = MemoryType::Local);
TensorView(const std::shared_ptr<c10::TensorType>& tensor_type);
TensorView(const std::shared_ptr<Value>& jit_value)
: TensorView(jit_value->type()->cast<c10::TensorType>()) {}
TensorView(const TensorView* src, IrCloner* ir_cloner);
TensorDomain* domain() const {
return domain_;
}
bool hasReduction() const;
bool hasBlockReduction() const;
bool hasGridReduction() const;
bool hasBlockBroadcast() const;
bool hasBroadcast() const;
bool hasRFactor() const;
c10::optional<unsigned int> getReductionAxis() const;
const std::vector<IterDomain*>& getRootDomain() const;
const std::vector<IterDomain*>& getRFactorDomain() const;
// If rfactor domain exists in domain() return it, otherwise return root
// domain.
const std::vector<IterDomain*>& getMaybeRFactorDomain() const;
IterDomain* axis(int pos) const;
// Is there an active computeAt TensorView/Axis
bool hasComputeAt() const {
return compute_at_view_ != nullptr;
}
// Return the TensorView we're computing at
TensorView* getComputeAtView() const {
return compute_at_view_;
}
size_t nDims() const;
// Return compute at axis relative to this domain
unsigned int getThisComputeAtAxis() const {
return this_compute_at_axis_;
}
// Return compute at axis relative to compute at view
unsigned int getRelativeComputeAtAxis() const {
return relative_compute_at_axis_;
}
// Return position in compute_at_view that lines up with this->axis(pos)?
int getComputeAtRelPos(int pos);
// Will check if an axis is inside computeAtAxis and will fetch the reference
// to be used in code generation.
std::pair<int, TensorView*> getComputeAtPos(int pos) {
pos = normalizeAxisPos(pos);
TORCH_INTERNAL_ASSERT(
nDims() > 0, "Tried to access a computeAt axis in a 0-dim TensorView");
if (!hasComputeAt() || getThisComputeAtAxis() <= (unsigned int)pos)
return std::make_pair(pos, this);
return compute_at_view_->getComputeAtPos(getComputeAtRelPos(pos));
}
std::pair<IterDomain*, TensorView*> getComputeAtAxis(int pos) {
const auto computeAtPos = getComputeAtPos(pos);
return std::make_pair(
computeAtPos.second->axis(computeAtPos.first), computeAtPos.second);
}
// Compute this TensorView relative to another tensor at axis
TensorView* computeAt(TensorView* consumer, int axis);
void clearComputeAt() {
this_compute_at_axis_ = 0;
relative_compute_at_axis_ = 0;
compute_at_view_ = nullptr;
}
// Split "axis" into 2 axes where the inner axes is size of "factor"
// and outer axis is size axis.size() / factor
TensorView* split(int axis, unsigned int factor);
// Split "axis" into 2 axes where the inner axes is size of "factor"
// and outer axis is size axis.size() / factor. Factor can be a symbolic
// value instead of constant. This requires setting the symbolic value as an
// input, or using a parallel dim from NamedScalar::getParallelDim
TensorView* split(int axis, Val* factor);
// Merge axis_o and axis_i into 1 IterDomain
TensorView* merge(int axis_o, int axis_i);
// Merge axis and axis+1 into 1 IterDomain
TensorView* merge(int axis) {
return merge(axis, axis + 1);
}
// Reorder axes according to old2new[old_pos] = new_pos
TensorView* reorder(const std::unordered_map<int, int>& old2new);
// WARNING: rFactor does not return this TensorView, ir returns a new
// tensorview consumed by this!
//
// Take reduction axes out of this domain, and create a new
// domain. New domain will be used to create this domain.
//
// For example:
// TV1[I0, R1, R2, I3] = TV0[I0, I1, I2, I3]
//
// After:
// TV1->rfactor({1}), TV1 is transformed to -> TV1[I0, R2, I3]
//
// The TensorView returned is: TV2[I0, R1, I2, I3]
//
// The reduction will now beset as:
// TV2[I0, R1, I2, I3] = TV0[I0, I1, I2, I3]
// TV1[I0, R2, I3] = TV2[I0, R1, I2, I3]
//
TensorView* rFactor(const std::vector<int>& axes);
// Create a TensorView before the original tensor. A common use case is to
// write results into shared memory or registers before moving to global
// memory. Analogous to TVM Cache_Write
TensorView* cache_before();
// Create a TensorView after the original tensor. A common use case is to
// read tensor into shared memory or registers. Analogous to TVM Cache_Read
TensorView* cache_after();
MemoryType getMemoryType() const {
return memory_type_;
}
void setMemoryType(MemoryType mt);
friend TORCH_CUDA_API TransformReplay;
friend TORCH_CUDA_API OptOutMutator;
friend TORCH_CUDA_API LoopNestGenerator;
friend ComputeAt;
friend void IrFixComputeAt(Fusion*);
friend void adjustMemoryTypes(Fusion* fusion);
friend class ir_utils::TVDomainGuard;
protected:
// Make an exact copy of this tensor (similar to clone()), however, also grabs
// the same name. Current use of this is for initialization of reductions.
// This will break our dependency chain as it is a literal clone of a
// TensorView but it has a different dependency chain. We need to improve our
// dependency model to allow for initailziation of reduction buffers. The only
// reason we can get away with this for now is because we don't use dependency
// analysis for the IR after we call this.
TensorView* unsafeClone() const;
void setDomain(TensorDomain* td) {
domain_ = td;
}
void setComputeAt(TensorView* computeAtView, int axis);
// Set all computeAt members without checking any correctness. Useful for
// computeAt with outputs relative to eachother
void setComputeAt(TensorView* computeAtView, int thisPos, int relPos);
private:
int normalizeAxisPos(int pos) const {
if (pos < 0) {
pos += nDims();
}
return pos;
}
// In Cache Before, for the origin expr of the original tensor,
// we create a new operation where the original tensor is replaced
// with the new cache tensor. This function creates a new expr
// given the consumer, the output of the expression.
void createExprConsumer(Expr* expr, TensorView* consumer);
// In Cache After, for all the uses of the original tensor, we create
// a new operation where the original tensor is replaced with the new
// cache tensor. This function creates a new expr given a producer,
// an input for the expression.
void createExprProducer(
Expr* expr,
TensorView* current,
TensorView* producer);
void setThisComputeAtAxis();
private:
TensorDomain* domain_ = nullptr;
TensorView* compute_at_view_ = nullptr;
// compute at axis in compute at view
unsigned int relative_compute_at_axis_ = 0;
unsigned int this_compute_at_axis_ = 0;
MemoryType memory_type_ = MemoryType::Local;
};
} // namespace fuser
} // namespace jit
} // namespace torch
|