1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
|
#include <torch/csrc/lazy/core/helpers.h>
#include <c10/util/Half.h>
#include <c10/util/irange.h>
#include <torch/csrc/lazy/core/tensor_util.h>
#include <limits>
namespace torch {
namespace lazy {
std::vector<int64_t> DropDimensions(
c10::ArrayRef<int64_t> sizes,
c10::ArrayRef<int64_t> drop_dims) {
std::vector<int64_t> new_dims;
size_t drop_index = 0;
for (const auto i : c10::irange(sizes.size())) {
if (drop_index < drop_dims.size() && i == drop_dims[drop_index]) {
++drop_index;
} else {
new_dims.push_back(sizes[i]);
}
}
TORCH_CHECK(drop_index == drop_dims.size());
return new_dims;
}
int64_t GetCanonicalDimensionIndex(int64_t dim, int64_t rank) {
int64_t min_shape_dim = -rank;
int64_t max_shape_dim = rank - 1;
TORCH_CHECK(
min_shape_dim <= dim && dim <= max_shape_dim,
"Value out of range (expected to be in range of [",
min_shape_dim,
", ",
max_shape_dim,
"], but got ",
dim,
")");
int64_t dim_index = dim < 0 ? rank + dim : dim;
TORCH_CHECK(dim_index >= 0);
TORCH_CHECK(dim_index < rank);
return dim_index;
}
std::vector<int64_t> GetCanonicalDimensionIndices(
c10::ArrayRef<int64_t> dimensions,
int64_t rank) {
std::vector<int64_t> canonical_dim_indices;
for (int64_t dim : dimensions) {
canonical_dim_indices.push_back(GetCanonicalDimensionIndex(dim, rank));
}
return canonical_dim_indices;
}
int64_t GetCanonicalPosition(
c10::ArrayRef<int64_t> dimensions,
int64_t dim,
int64_t pos) {
dim = GetCanonicalDimensionIndex(dim, dimensions.size());
if (pos < 0) {
pos = GetCanonicalDimensionIndex(pos, dimensions[dim]);
} else {
pos = std::min<int64_t>(pos, dimensions[dim]);
}
return pos;
}
std::vector<int64_t> MakeTransposePermutation(
int64_t dim0,
int64_t dim1,
int64_t rank) {
int64_t canonical_dim0 = GetCanonicalDimensionIndex(dim0, rank);
int64_t canonical_dim1 = GetCanonicalDimensionIndex(dim1, rank);
auto permute_dims = Iota<int64_t>(rank);
std::swap(permute_dims[canonical_dim0], permute_dims[canonical_dim1]);
return permute_dims;
}
std::vector<int64_t> GetPromotedShape(
c10::ArrayRef<int64_t> shape1_dims,
c10::ArrayRef<int64_t> shape2_dims) {
std::vector<int64_t> dimensions;
// If the rank of a shape is bigger than then other, fill up the first
// dimensions with the ones of the bigger.
// Example:
// shape1 = [9, 7, 6, 5, 2]
// shape2 = [6, 1, 2]
// Insert [9, 7] into the dimensions vector.
if (shape1_dims.size() > shape2_dims.size()) {
dimensions.insert(
dimensions.end(),
shape1_dims.begin(),
shape1_dims.begin() + (shape1_dims.size() - shape2_dims.size()));
} else if (shape2_dims.size() > shape1_dims.size()) {
dimensions.insert(
dimensions.end(),
shape2_dims.begin(),
shape2_dims.begin() + (shape2_dims.size() - shape1_dims.size()));
}
// For the common dimensions, they must match, or one of them be 1.
size_t min_size = std::min(shape1_dims.size(), shape2_dims.size());
for (const auto i : c10::irange(min_size)) {
int64_t dim1 = shape1_dims[shape1_dims.size() - min_size + i];
int64_t dim2 = shape2_dims[shape2_dims.size() - min_size + i];
TORCH_CHECK(
dim1 == dim2 || dim1 == 1 || dim2 == 1,
"(",
c10::Join(", ", shape1_dims),
") and (",
c10::Join(", ", shape1_dims),
")");
if (dim1 == 0 || dim2 == 0) {
dimensions.push_back(0);
} else {
dimensions.push_back(std::max<int64_t>(dim1, dim2));
}
}
return dimensions;
}
Shape GetPromotedBinaryOpShape(const Shape& shape1, const Shape& shape2) {
return Shape(
promoteTypes(shape1.scalar_type(), shape2.scalar_type()),
GetPromotedShape(shape1.sizes(), shape2.sizes()));
}
std::vector<std::string> StrSplit(c10::string_view text, char delim) {
size_t start = 0;
size_t end = 0;
std::vector<std::string> tokens;
while ((start = text.find_first_not_of(delim, end)) != std::string::npos) {
end = text.find(delim, start);
auto token = text.substr(start, end - start);
tokens.emplace_back(token.begin(), token.end());
}
return tokens;
}
} // namespace lazy
} // namespace torch
|