File: ir.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (295 lines) | stat: -rw-r--r-- 8,009 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
#include <torch/csrc/jit/tensorexpr/ir.h>

#include <torch/csrc/jit/tensorexpr/tensor.h>

#include <c10/util/irange.h>

namespace torch {
namespace jit {
namespace tensorexpr {

static Dtype ChooseDtype(const Dtype& buffer_dtype, const Dtype& index_dtype) {
  return Dtype(buffer_dtype, index_dtype.lanes());
}

static Dtype dtypeOfIndices(const std::vector<ExprPtr>& indices) {
  if (!indices.size()) {
    // Return something so we can handle scalar buffers.
    return kInt;
  }
  return indices.at(0)->dtype();
}

void castIndicesToInts(std::vector<ExprPtr>& indices) {
  // Cast all indices to either Int or Long
  auto index_dtype = ScalarType::Int;
  for (auto& index : indices) {
    if (index->dtype().scalar_type() == ScalarType::Long) {
      // If any of the indexes is Long, cast all of them to Long
      index_dtype = ScalarType::Long;
      break;
    }
  }

  for (auto& index : indices) {
    const Dtype& dt = index->dtype();
    if (c10::isIntegralType(dt.scalar_type(), true) &&
        dt.scalar_type() != index_dtype) {
      index = alloc<Cast>(Dtype(index_dtype, dt.lanes()), index);
    }
  }
}

Load::Load(Dtype dtype, BufPtr buf, std::vector<ExprPtr> indices)
    : ExprNodeBase(dtype), buf_(buf), indices_(std::move(indices)) {
  castIndicesToInts(indices_);
}

Load::Load(BufPtr buf, const std::vector<ExprPtr>& indices)
    : Load(ChooseDtype(buf->dtype(), dtypeOfIndices(indices)), buf, indices) {}

ExprHandle Load::make(
    Dtype dtype,
    const BufHandle& buf,
    const std::vector<ExprHandle>& indices) {
  return ExprHandle(
      alloc<Load>(dtype, buf.node(), ExprHandleVectorToExprVector(indices)));
}

ExprHandle Load::make(
    const BufHandle& buf,
    const std::vector<ExprHandle>& indices) {
  return Load::make(buf.dtype(), buf, indices);
}

Store::Store(BufPtr buf, std::vector<ExprPtr> indices, ExprPtr value)
    : buf_(buf), indices_(std::move(indices)), value_(value) {
  castIndicesToInts(indices_);
}

StorePtr Store::make(
    const BufHandle& buf,
    const std::vector<ExprHandle>& indices,
    const ExprHandle& value) {
  return alloc<Store>(
      buf.node(), ExprHandleVectorToExprVector(indices), value.node());
}

StorePtr BufHandle::store(
    const std::vector<ExprHandle>& args,
    const ExprHandle& value) const {
  return Store::make(*this, args, value);
}

ExprPtr flatten_index(
    const std::vector<ExprPtr>& dims,
    const std::vector<ExprPtr>& indices,
    const std::vector<ExprPtr>& strides) {
  // Handle already flattened indices first
  if (indices.size() == 1) {
    return indices[0];
  }

  size_t ndim = dims.size();
  if (ndim != indices.size()) {
    throw malformed_input("dimensions mismatch in flatten_index");
  }
  if (ndim != strides.size()) {
    throw malformed_input("strides mismatch in flatten_index");
  }
  if (ndim == 0) {
    return alloc<LongImm>(0);
  }
  ExprPtr total_index = immLike(indices[0], 0);
  for (const auto i : c10::irange(ndim)) {
    total_index = alloc<Add>(total_index, alloc<Mul>(indices[i], strides[i]));
  }
  return total_index;
}

Dtype Intrinsics::IntrinsicsDtype(IntrinsicsOp op_type, Dtype dt1) {
  if (op_type == kIsNan) {
    return dt1.cloneWithScalarType(ScalarType::Int);
  }
  // TODO: check the op_type and make a real decision
  return dt1;
}

Dtype Intrinsics::IntrinsicsDtype(IntrinsicsOp op_type, Dtype dt1, Dtype dt2) {
  // TODO: check the op_type and make a real decision
  return dt1;
}

Dtype Intrinsics::IntrinsicsDtype(
    IntrinsicsOp op_type,
    const std::vector<ExprPtr>& params) {
  // TODO: check the op_type and make a real decision
  // Doesnt this fail with kRand?
  if (params.size() == 0) {
    throw malformed_input("invalid params in Intrinsics");
  } else if (params.size() == 1) {
    return IntrinsicsDtype(op_type, params[0]->dtype());
  } else if (params.size() == 2) {
    return IntrinsicsDtype(op_type, params[0]->dtype(), params[1]->dtype());
  }
  return params[0]->dtype();
}

int Intrinsics::OpArgCount(IntrinsicsOp op_type) {
  switch (op_type) {
    case kSin:
    case kCos:
    case kTan:
    case kAsin:
    case kAcos:
    case kAtan:
    case kSinh:
    case kCosh:
    case kTanh:
    case kSigmoid:
    case kExp:
    case kExpm1:
    case kAbs:
    case kLog:
    case kLog2:
    case kLog10:
    case kLog1p:
    case kErf:
    case kErfc:
    case kSqrt:
    case kRsqrt:
    case kCeil:
    case kFloor:
    case kRound:
    case kTrunc:
    case kFrac:
    case kLgamma:
    case kIsNan:
      return 1;
    case kRand:
      return 0;
    case kAtan2:
    case kFmod:
    case kPow:
    case kRemainder:
      return 2;
    default:
      throw std::runtime_error("invalid op_type: " + c10::to_string(op_type));
  }
}

ExternalCallPtr ExternalCall::make(
    BufHandle buf,
    const std::string& func_name,
    const std::vector<BufHandle>& buf_args,
    const std::vector<ExprHandle>& args) {
  std::vector<BufPtr> buf_arg_nodes;
  buf_arg_nodes.reserve(buf_args.size());
  for (const BufHandle& buf_arg : buf_args) {
    buf_arg_nodes.push_back(buf_arg.node());
  }
  return alloc<ExternalCall>(
      buf.node(), func_name, buf_arg_nodes, ExprHandleVectorToExprVector(args));
}

ExternalCallWithAllocPtr ExternalCallWithAlloc::make(
    const std::string& func_name,
    const std::vector<BufHandle>& buf_out_args,
    const std::vector<BufHandle>& buf_args,
    const std::vector<ExprHandle>& args) {
  std::vector<BufPtr> buf_out_arg_nodes;
  buf_out_arg_nodes.reserve(buf_out_args.size());
  for (const BufHandle& buf_out_arg : buf_out_args) {
    buf_out_arg_nodes.push_back(buf_out_arg.node());
  }

  std::vector<BufPtr> buf_arg_nodes;
  buf_arg_nodes.reserve(buf_args.size());
  for (const BufHandle& buf_arg : buf_args) {
    buf_arg_nodes.push_back(buf_arg.node());
  }
  return alloc<ExternalCallWithAlloc>(
      func_name,
      buf_out_arg_nodes,
      buf_arg_nodes,
      ExprHandleVectorToExprVector(args));
}

FreeExtPtr FreeExt::make(const std::vector<BufHandle>& bufs) {
  std::vector<BufPtr> buf_nodes;
  buf_nodes.reserve(bufs.size());
  for (const BufHandle& buf : bufs) {
    buf_nodes.push_back(buf.node());
  }
  return alloc<FreeExt>(buf_nodes);
}

std::vector<ExprPtr> ExprHandleVectorToExprVector(
    const std::vector<ExprHandle>& v) {
  std::vector<ExprPtr> result(v.size());
  for (const auto i : c10::irange(v.size())) {
    result[i] = v[i].node();
  }
  return result;
}

std::vector<ExprHandle> ExprVectorToExprHandleVector(
    const std::vector<ExprPtr>& v) {
  std::vector<ExprHandle> result(v.size());
  for (const auto i : c10::irange(v.size())) {
    result[i] = ExprHandle(v[i]);
  }
  return result;
}

std::vector<VarPtr> VarHandleVectorToVarVector(
    const std::vector<VarHandle>& v) {
  std::vector<VarPtr> result(v.size());
  for (const auto i : c10::irange(v.size())) {
    result[i] = v[i].node();
  }
  return result;
}

std::vector<VarHandle> VarVectorToVarHandleVector(
    const std::vector<VarPtr>& v) {
  std::vector<VarHandle> result(v.size());
  for (const auto i : c10::irange(v.size())) {
    result[i] = VarHandle(v[i]);
  }
  return result;
}

bool immediateIsNegative(ExprPtr e) {
#define TYPE_CASE(Type, Name)                \
  if (Name##ImmPtr imm = to<Name##Imm>(e)) { \
    return imm->value() < 0;                 \
  }
  AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, TYPE_CASE);
#undef TYPE_CASE
  return false;
}

bool immediateIsPositive(ExprPtr e) {
#define TYPE_CASE(Type, Name)                \
  if (Name##ImmPtr imm = to<Name##Imm>(e)) { \
    return imm->value() > 0;                 \
  }
  AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
#undef TYPE_CASE
  return false;
}

bool immediateIsZero(ExprPtr e) {
#define TYPE_CASE(Type, Name)                \
  if (Name##ImmPtr imm = to<Name##Imm>(e)) { \
    return imm->value() == 0;                \
  }
  AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
#undef TYPE_CASE
  return false;
}

} // namespace tensorexpr
} // namespace jit
} // namespace torch