1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
|
#include <torch/csrc/jit/tensorexpr/tensor.h>
#include <c10/util/Logging.h>
#include <torch/csrc/jit/tensorexpr/dim_arg.h>
#include <torch/csrc/jit/tensorexpr/reduction.h>
namespace torch {
namespace jit {
namespace tensorexpr {
Tensor* Compute(
const std::string& func_name,
const std::vector<DimArg>& dim_args,
const std::function<ExprHandle(const std::vector<VarHandle>&)>& body_func) {
std::vector<const Expr*> dims;
std::vector<const Var*> args;
unpack_dim_args(dim_args, &dims, &args);
const Expr* body = body_func(VarVectorToVarHandleVector(args)).node();
Function* func = new Function(func_name, dims, args, body);
return new Tensor(func, 0);
}
Tensor* Compute(
const std::string& func_name,
const std::vector<DimArg>& dim_args,
const std::function<ExprHandle(const VarHandle&)>& body_func) {
if (dim_args.size() != 1) {
throw malformed_input("mismatch between body and arg size (1)");
}
std::vector<const Expr*> dims;
std::vector<const Var*> args;
unpack_dim_args(dim_args, &dims, &args);
const Expr* body = body_func(VarHandle(args[0])).node();
Function* func = new Function(func_name, dims, args, body);
return new Tensor(func, 0);
}
Tensor* Compute(
const std::string& func_name,
const std::vector<DimArg>& dim_args,
const std::function<ExprHandle(const VarHandle&, const VarHandle&)>&
body_func) {
if (dim_args.size() != 2) {
throw malformed_input("mismatch between body and arg size (2)");
}
std::vector<const Expr*> dims;
std::vector<const Var*> args;
unpack_dim_args(dim_args, &dims, &args);
const Expr* body = body_func(VarHandle(args[0]), VarHandle(args[1])).node();
Function* func = new Function(func_name, dims, args, body);
return new Tensor(func, 0);
}
Tensor* Compute(
const std::string& func_name,
const std::vector<DimArg>& dim_args,
const std::function<
ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&)>&
body_func) {
if (dim_args.size() != 3) {
throw malformed_input("mismatch between body and arg size (3)");
}
std::vector<const Expr*> dims;
std::vector<const Var*> args;
unpack_dim_args(dim_args, &dims, &args);
const Expr* body =
body_func(VarHandle(args[0]), VarHandle(args[1]), VarHandle(args[2]))
.node();
Function* func = new Function(func_name, dims, args, body);
return new Tensor(func, 0);
}
Tensor* Compute(
const std::string& func_name,
const std::vector<DimArg>& dim_args,
const std::function<ExprHandle(
const VarHandle&,
const VarHandle&,
const VarHandle&,
const VarHandle&)>& body_func) {
if (dim_args.size() != 4) {
throw malformed_input("mismatch between body and arg size (4)");
}
std::vector<const Expr*> dims;
std::vector<const Var*> args_nodes;
unpack_dim_args(dim_args, &dims, &args_nodes);
auto args = VarVectorToVarHandleVector(args_nodes);
const Expr* body = body_func(args[0], args[1], args[2], args[3]).node();
Function* func = new Function(func_name, dims, args_nodes, body);
return new Tensor(func, 0);
}
Stmt* Function::ElementStmt(size_t index) {
const Buf* buf = func_var(index);
std::vector<const Expr*> indices;
for (size_t i = 0; i < buf->ndim(); i++) {
indices.push_back(this->args_[i]);
}
const Expr* mask = new IntImm(1);
Stmt* update_stmt = new Store(buf, indices, body(index), mask);
return update_stmt;
}
Tensor* Reduce(
const std::string& func_name,
const std::vector<DimArg>& dim_args,
const Reducer& reducer,
const Placeholder& buffer,
const std::vector<DimArg>& reduce_args) {
return Reduce(
func_name,
dim_args,
reducer,
[&](ParameterList& p) { return buffer.load(p); },
reduce_args);
}
Tensor* Reduce(
const std::string& func_name,
const std::vector<DimArg>& dim_args,
const Reducer& reducer,
Tensor* tensor,
const std::vector<DimArg>& reduce_args) {
return Reduce(
func_name,
dim_args,
reducer,
[&](ParameterList& p) { return tensor->call(p); },
reduce_args);
}
} // namespace tensorexpr
} // namespace jit
} // namespace torch
|