1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
|
#pragma once
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_container.h>
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
namespace kir {
class Kernel;
}
class IrCloner;
// Passkey for builder to register properties with statements, and to call
// functions in IrContainer
class TORCH_CUDA_CU_API IrBuilderPasskey {
friend class IrBuilder;
public:
// TODO: Collapse ir_container and Kernel once Kernel inherits from
// IrContainer
IrContainer* const ir_container_ = nullptr;
private:
explicit IrBuilderPasskey(IrContainer* ir_container);
};
//! IR builder interface
class TORCH_CUDA_CU_API IrBuilder {
public:
//! Allocate a new IR node, forwarding the arguments to the appropriate
//! constructor and registering with the container
template <class T, class... Args>
static T* create(Args&&... args) {
auto container = FusionGuard::getCurFusion();
// return create<T>(container, std::forward<Args>(args)...);
TORCH_INTERNAL_ASSERT(
container != nullptr, "Need an active container to build IR.");
T* node = new T(IrBuilderPasskey(container), std::forward<Args>(args)...);
container->registerStmt(IrBuilderPasskey(container), node);
return node;
}
//! Allocate a new IR node, forwarding the arguments to the appropriate
//! constructor and registering with the container
template <class T, class... Args>
static T* create(IrContainer* container, Args&&... args) {
TORCH_INTERNAL_ASSERT(
container != nullptr, "Need an active container to build IR.");
T* node = new T(IrBuilderPasskey(container), std::forward<Args>(args)...);
container->registerStmt(IrBuilderPasskey(container), node);
return node;
}
//! Clone an IR node, forwarding the arguments to the IrCloner constructor.
//! Register clones with IrCloner's target container.
template <class T>
static T* clone(const T* src, IrCloner* ir_cloner);
// Unary operations
static Val* negExpr(Val* val);
static Val* notExpr(Val* val);
static Val* setExpr(Val* val);
static Val* setExprNamedScalar(const std::string& name, Val* val);
static Val* addressExprNamedScalar(const std::string& name, Val* val);
// Binary operations
static Val* andExpr(Val* lhs, Val* rhs);
static Val* eqExpr(Val* lhs, Val* rhs);
static Val* gtExpr(Val* lhs, Val* rhs);
static Val* ltExpr(Val* lhs, Val* rhs);
static Val* leExpr(Val* lhs, Val* rhs);
static Val* geExpr(Val* lhs, Val* rhs);
static Val* addExpr(Val* lhs, Val* rhs);
static Val* subExpr(Val* lhs, Val* rhs);
static Val* mulExpr(Val* lhs, Val* rhs);
static Val* divExpr(Val* lhs, Val* rhs);
static Val* ceilDivExpr(Val* lhs, Val* rhs);
static Val* modExpr(Val* lhs, Val* rhs);
static Val* maxExpr(Val* lhs, Val* rhs);
static Val* minExpr(Val* lhs, Val* rhs);
// Ternary operations
static Val* whereExpr(Val* pred, Val* lhs, Val* rhs);
// Swizzle operations
static Val* swizzle2DIntExpr(
Val* x,
Val* y,
Val* extent_x,
Val* extent_y,
Swizzle2DType swizzle_type);
static Val* pairSelectExpr(Val* in, kir::PairSelect::Selection sel);
private:
static Val* newResult(DataType dtype);
static Val* newArithmeticExpr(BinaryOpType op_type, Val* lhs, Val* rhs);
static Val* newLogicExpr(BinaryOpType op_type, Val* lhs, Val* rhs);
};
//! A wrapper builder with static expression simplification
//!
//! Example:
//! - addExpr(new Int(1), new Int(2)) -> Int(3)
//! - addExpr(new Int(0), new NamedScalar("foo")) -> NamedScalar("foo")
//!
//! Designed to be used to simplify predicate and index expressions in
//! generated code. Also, the shift validation may fail without
//! this simplification.
class TORCH_CUDA_CU_API SimplifyingIrBuilder : public IrBuilder {
public:
static Val* negExpr(Val* val);
static Val* notExpr(Val* val);
static Val* addExpr(Int* lhs, Int::ScalarType rhs);
static Val* addExpr(Val* lhs, Int::ScalarType rhs);
static Val* addExpr(Int* lhs, Int* rhs);
static Val* addExpr(Val* lhs, Val* rhs);
static Val* subExpr(Val* lhs, Val* rhs);
static Val* mulExpr(Int* lhs, Int::ScalarType rhs);
static Val* mulExpr(Val* lhs, Int::ScalarType rhs);
static Val* mulExpr(Int* lhs, Int* rhs);
static Val* mulExpr(Val* lhs, Val* rhs);
static Val* andExpr(Val* lhs, Val* rhs);
static Val* maxExpr(Val* lhs, Val* rhs);
static Val* minExpr(Val* lhs, Val* rhs);
};
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch
|