1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312
|
/**
* This file implements the core classes for Tensor Expressions.
*
* The structure of the expressions is inspired by Halide/TVM IR.
*/
#pragma once
#include <torch/csrc/jit/tensorexpr/ir_mutator.h>
#include <torch/csrc/jit/tensorexpr/ir_visitor.h>
#include <torch/csrc/jit/tensorexpr/mem_arena.h>
#include <torch/csrc/jit/tensorexpr/types.h>
namespace torch {
namespace jit {
namespace tensorexpr {
enum IRNodeType {
kPrimitive,
kAdd,
kSub,
kMul,
kDiv,
kMod,
kMax,
kMin,
kAnd,
kOr,
kLshift,
kRshift,
kXor,
kCompareSelect,
kLet,
kCast,
kBroadcast,
kRamp,
kPolynomial,
kTerm,
kRoundOff,
kMaxTerm,
kMinTerm,
kNone,
kExtra
};
// The common base between all expression node.
class Expr : public KernelScopedObject {
public:
explicit Expr(Dtype dtype, IRNodeType expr_type = kNone)
: dtype_(dtype), expr_type_(expr_type) {}
Dtype dtype() const {
return dtype_;
}
TORCH_API virtual void accept(IRVisitor* visitor) const = 0;
virtual const Expr* accept_mutator(IRMutator* mutator) const = 0;
IRNodeType expr_type() const {
return expr_type_;
}
// Is this a fixed (constant) immediate value.
virtual bool isConstant() const {
return false;
}
private:
Dtype dtype_;
IRNodeType expr_type_;
};
// A CRTP pattern to accept visitors for children class,
// and dispatch back to the children.
template <class Op, class Base = Expr>
class ExprNode : public Base {
public:
using ExprNodeBase = ExprNode<Op>;
void accept(IRVisitor* visitor) const override {
visitor->visit(static_cast<const Op*>(this));
}
const Expr* accept_mutator(IRMutator* mutator) const override;
// pass the constructor to the base class
using Base::Base;
};
// A wrapper object to the underlying ExprNode.
// Also serves the primary way to build and operate on other expressions.
class TORCH_API ExprHandle {
public:
ExprHandle() {}
explicit ExprHandle(const Expr* node)
: base_expr_node_(const_cast<Expr*>(node)) {}
Expr* node() {
return base_expr_node_;
}
const Expr* node() const {
return base_expr_node_;
}
bool empty() const {
return base_expr_node_ == nullptr;
}
#define IMM_EXPR_DECLARE(Type, Name) ExprHandle(Type v);
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_EXPR_DECLARE);
#undef IMM_EXPR_DECLARE
template <class Op>
Op* AsNode() {
return dynamic_cast<Op*>(this->node());
}
template <class Op>
const Op* AsNode() const {
return const_cast<ExprHandle*>(this)->AsNode<Op>();
}
Dtype dtype() const {
return node()->dtype();
}
// Handling the math operators.
ExprHandle operator+(const ExprHandle& other) const;
ExprHandle operator-(const ExprHandle& other) const;
ExprHandle operator*(const ExprHandle& other) const;
ExprHandle operator/(const ExprHandle& other) const;
ExprHandle operator%(const ExprHandle& other) const;
ExprHandle operator==(const ExprHandle& other) const;
ExprHandle operator!=(const ExprHandle& other) const;
ExprHandle operator>(const ExprHandle& other) const;
ExprHandle operator>=(const ExprHandle& other) const;
ExprHandle operator<(const ExprHandle& other) const;
ExprHandle operator<=(const ExprHandle& other) const;
ExprHandle operator&(const ExprHandle& other) const;
ExprHandle operator|(const ExprHandle& other) const;
ExprHandle operator^(const ExprHandle& other) const;
ExprHandle operator<<(const ExprHandle& other) const;
ExprHandle operator>>(const ExprHandle& other) const;
private:
Expr* base_expr_node_ = nullptr;
};
// The underlying representation node to a Var.
// Currently, each Var object represents a unique variable, even though the
// names might be the same. We should consider add a unique_name as well.
class Var : public ExprNode<Var> {
public:
static ExprHandle make(const std::string& name_hint, Dtype dtype) {
return ExprHandle(new Var(name_hint, dtype));
}
static ExprHandle make(Dtype dtype) {
return ExprHandle(new Var("", dtype));
}
// TODO: unique_name
const std::string& name_hint() const {
return name_hint_;
}
Var(const std::string& name_hint, Dtype dtype)
: ExprNodeBase(dtype, kPrimitive), name_hint_(name_hint) {}
private:
std::string name_hint_;
};
class TORCH_API Buf : public ExprNode<Buf> {
public:
static ExprHandle make(
const std::string& name_hint,
const std::vector<ExprHandle>& dims,
Dtype dtype);
static ExprHandle make(const std::vector<ExprHandle>& dims, Dtype dtype);
// TODO: unique_name
const Var* base_handle() const {
return base_handle_;
}
const std::string& name_hint() const {
return base_handle_->name_hint();
}
Buf(const std::string& name_hint,
const std::vector<const Expr*>& dims,
Dtype dtype)
: Buf(new Var(name_hint, kHandle), dims, dtype) {}
Buf(const Var* var, const std::vector<const Expr*>& dims, Dtype dtype)
: ExprNodeBase(dtype, kPrimitive), base_handle_(var), dims_(dims) {
TORCH_CHECK(var);
}
size_t ndim() const {
return dims_.size();
}
const Expr* dim(size_t index) const {
return dims_[index];
}
std::vector<const Expr*> dims() const {
return dims_;
}
void set_dims(std::vector<const Expr*> dims) {
dims_ = dims;
};
private:
const Var* base_handle_;
std::vector<const Expr*> dims_;
};
class TORCH_API BufHandle : public ExprHandle {
public:
BufHandle(
const std::string& name_hint,
const std::vector<ExprHandle>& dims,
Dtype dtype)
: ExprHandle(Buf::make(name_hint, dims, dtype)) {}
explicit BufHandle(const Buf* node) : ExprHandle(node) {}
const Buf* node() const {
return static_cast<const Buf*>(ExprHandle::node());
}
bool operator==(const BufHandle& other) const {
return this->node() == other.node();
}
bool operator!=(const BufHandle& other) const {
return !(*this == other);
}
const std::string& name_hint() const {
return this->node()->name_hint();
}
bool empty() const {
return (this->node() == nullptr);
}
};
// An expression to construct the underlying variable node.
// Note: do not store any info here, since it is often possible to slice this
// object. For example: VarHandle x('x'); ExprHandle x2 = x;
class VarHandle : public ExprHandle {
public:
VarHandle() : ExprHandle(nullptr) {}
explicit VarHandle(Dtype dtype) : ExprHandle(Var::make(dtype)) {}
VarHandle(const std::string& name_hint, Dtype dtype)
: ExprHandle(Var::make(name_hint, dtype)) {}
explicit VarHandle(const Var* node) : ExprHandle(node) {}
const Var* node() const {
return static_cast<const Var*>(ExprHandle::node());
}
bool operator==(const VarHandle& other) const {
return this->node() == other.node();
}
bool operator!=(const VarHandle& other) const {
return !(*this == other);
}
const std::string& name_hint() const {
return this->node()->name_hint();
}
bool empty() const {
return (this->node() == nullptr);
}
};
template <class Op, class Base>
const Expr* ExprNode<Op, Base>::accept_mutator(IRMutator* mutator) const {
ExprNode* this_mutable = const_cast<ExprNode*>(this);
return mutator->mutate(static_cast<Op*>(this_mutable));
}
inline bool same_node(const ExprHandle& expr1, const ExprHandle& expr2) {
return expr1.AsNode<Expr>() == expr2.AsNode<Expr>();
}
TORCH_API ExprHandle sin(const ExprHandle& v);
TORCH_API ExprHandle cos(const ExprHandle& v);
TORCH_API ExprHandle tan(const ExprHandle& v);
TORCH_API ExprHandle asin(const ExprHandle& v);
TORCH_API ExprHandle acos(const ExprHandle& v);
TORCH_API ExprHandle atan(const ExprHandle& v);
TORCH_API ExprHandle sinh(const ExprHandle& v);
TORCH_API ExprHandle cosh(const ExprHandle& v);
TORCH_API ExprHandle tanh(const ExprHandle& v);
TORCH_API ExprHandle sigmoid(const ExprHandle& v);
TORCH_API ExprHandle exp(const ExprHandle& v);
TORCH_API ExprHandle expm1(const ExprHandle& v);
TORCH_API ExprHandle fabs(const ExprHandle& v);
TORCH_API ExprHandle log(const ExprHandle& v);
TORCH_API ExprHandle log2(const ExprHandle& v);
TORCH_API ExprHandle log10(const ExprHandle& v);
TORCH_API ExprHandle log1p(const ExprHandle& v);
TORCH_API ExprHandle erf(const ExprHandle& v);
TORCH_API ExprHandle erfc(const ExprHandle& v);
TORCH_API ExprHandle sqrt(const ExprHandle& v);
TORCH_API ExprHandle rsqrt(const ExprHandle& v);
TORCH_API ExprHandle ceil(const ExprHandle& v);
TORCH_API ExprHandle floor(const ExprHandle& v);
TORCH_API ExprHandle round(const ExprHandle& v);
TORCH_API ExprHandle trunc(const ExprHandle& v);
TORCH_API ExprHandle frac(const ExprHandle& v);
TORCH_API ExprHandle lgamma(const ExprHandle& v);
TORCH_API ExprHandle atan2(const ExprHandle& v1, const ExprHandle& v2);
TORCH_API ExprHandle pow(const ExprHandle& v1, const ExprHandle& v2);
TORCH_API ExprHandle fmod(const ExprHandle& v1, const ExprHandle& v2);
TORCH_API ExprHandle remainder(const ExprHandle& v1, const ExprHandle& v2);
TORCH_API ExprHandle
ifThenElse(const ExprHandle& c, const ExprHandle& t, const ExprHandle& f);
} // namespace tensorexpr
} // namespace jit
} // namespace torch
|