1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
|
#include <ATen/core/dynamic_type.h>
#include <ATen/core/type_factory.h>
#include <torch/csrc/jit/mobile/register_ops_common_utils.h>
namespace torch::jit {
int64_t normalizeIndex(int64_t idx, int64_t list_size) {
if (idx < 0) {
// Handle negative indexing
idx = list_size + idx;
}
return idx;
}
IValue tensorToListRecursive(
char* data,
int64_t cur_dim,
int64_t num_tensor_dims,
at::TypePtr ty,
at::ScalarType scalar_ty,
at::IntArrayRef sizes,
at::IntArrayRef strides,
size_t element_size) {
// If ty is a ListType, get the element type.
if (auto list_type = ty->cast<at::ListType>()) {
ty = list_type->getElementType();
} else {
// If the output type is a scalar, read and push one scalar of
// the right type onto the stack.
if (ty == at::IntType::get()) {
int64_t scalar = *(int64_t*)data;
return IValue(scalar);
} else if (ty == at::FloatType::get()) {
TORCH_INTERNAL_ASSERT(
scalar_ty == at::ScalarType::Float ||
scalar_ty == at::ScalarType::Double,
"Unexpected scalar type for Tensor");
double scalar =
scalar_ty == at::ScalarType::Float ? *(float*)data : *(double*)data;
return IValue(scalar);
} else if (ty == at::ComplexType::get()) {
TORCH_INTERNAL_ASSERT(
scalar_ty == at::ScalarType::ComplexFloat ||
scalar_ty == at::ScalarType::ComplexDouble,
"Unexpected scalar type for Tensor");
c10::complex<double> scalar = scalar_ty == at::ScalarType::ComplexFloat
? *(c10::complex<float>*)data
: *(c10::complex<double>*)data;
return IValue(scalar);
} else if (ty == at::BoolType::get()) {
bool scalar = *(bool*)data;
return IValue(scalar);
} else {
TORCH_CHECK(
false,
ty->repr_str(),
" is not one of the supported types for tolist: int, float, bool");
}
}
// Make the result list consisting of elements of type ty. Since this
// invocation is processing dimension cur_dim, there will be sizes[cur_dim]
// output elements.
auto result = c10::impl::GenericList(ty);
result.reserve(sizes[cur_dim]);
// Since ty was a list type, tensorToListRecursive needs to be called
// recursively on each slice of the tensor in the current dimension.
for (int64_t i = 0, e = sizes[cur_dim]; i < e; ++i) {
auto inner_result = tensorToListRecursive(
data,
cur_dim + 1,
num_tensor_dims,
ty,
scalar_ty,
sizes,
strides,
element_size);
if (inner_result.isList()) {
result.emplace_back(inner_result.toList());
} else if (inner_result.isComplexDouble()) {
result.emplace_back(inner_result.toComplexDouble());
} else if (inner_result.isDouble()) {
result.emplace_back(inner_result.toDouble());
} else if (inner_result.isInt()) {
result.emplace_back(inner_result.toInt());
} else if (inner_result.isBool()) {
result.emplace_back(inner_result.toBool());
} else {
TORCH_INTERNAL_ASSERT(
false && "Unknown return type for tensorToListRecursive");
}
data += strides[cur_dim] * element_size;
}
return result;
}
} // namespace torch::jit
|