1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284
|
//===- Traits.cpp - Common op traits shared by dialects -------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/Support/FormatVariadic.h"
#include <optional>
using namespace mlir;
bool OpTrait::util::staticallyKnownBroadcastable(ArrayRef<int64_t> shape1,
ArrayRef<int64_t> shape2) {
SmallVector<SmallVector<int64_t, 6>, 2> extents;
extents.emplace_back(shape1.begin(), shape1.end());
extents.emplace_back(shape2.begin(), shape2.end());
return staticallyKnownBroadcastable(extents);
}
bool OpTrait::util::staticallyKnownBroadcastable(
ArrayRef<SmallVector<int64_t, 6>> shapes) {
assert(!shapes.empty() && "Expected at least one shape");
size_t maxRank = shapes[0].size();
for (size_t i = 1; i != shapes.size(); ++i)
maxRank = std::max(maxRank, shapes[i].size());
// We look backwards through every column of `shapes`.
for (size_t i = 0; i != maxRank; ++i) {
bool seenDynamic = false;
std::optional<int64_t> nonOneDim;
for (ArrayRef<int64_t> extent : shapes) {
int64_t dim = i >= extent.size() ? 1 : extent[extent.size() - i - 1];
if (dim == 1)
continue;
// Dimensions are compatible when
//. 1. One is dynamic, the rest are 1
if (ShapedType::isDynamic(dim)) {
if (seenDynamic || nonOneDim)
return false;
seenDynamic = true;
}
// 2. All are 1 or a specific constant.
if (nonOneDim && dim != *nonOneDim)
return false;
nonOneDim = dim;
}
}
return true;
}
bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
ArrayRef<int64_t> shape2,
SmallVectorImpl<int64_t> &resultShape) {
// To compute the result broadcasted shape, we compare operand shapes
// element-wise: starting with the trailing dimensions, and working the
// way backward. Two dimensions are compatible when
// 1. they are equal, or
// 2. one of them is 1
// The result shape has the maximum among the two inputs at every
// dimension index.
resultShape.clear();
if (shape1.size() > shape2.size()) {
std::copy(shape1.begin(), shape1.end(), std::back_inserter(resultShape));
} else {
std::copy(shape2.begin(), shape2.end(), std::back_inserter(resultShape));
}
auto i1 = shape1.rbegin(), e1 = shape1.rend();
auto i2 = shape2.rbegin(), e2 = shape2.rend();
auto iR = resultShape.rbegin();
// Check each dimension is consistent.
for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) {
if (ShapedType::isDynamic(*i1) || ShapedType::isDynamic(*i2)) {
// One or both dimensions is unknown. Follow TensorFlow behavior:
// - If either dimension is greater than 1, we assume that the program is
// correct, and the other dimension will be broadcast to match it.
// - If either dimension is 1, the other dimension is the output.
if (*i1 > 1) {
*iR = *i1;
} else if (*i2 > 1) {
*iR = *i2;
} else if (*i1 == 1) {
*iR = *i2;
} else if (*i2 == 1) {
*iR = *i1;
} else {
*iR = ShapedType::kDynamic;
}
} else {
if (*i1 == *i2 || *i2 == 1) {
*iR = *i1;
} else if (*i1 == 1) {
*iR = *i2;
} else {
// This dimension of the two operand types is incompatible.
resultShape.clear();
return false;
}
}
}
return true;
}
/// Returns the shape of the given type. Scalars will be considered as having a
/// shape with zero dimensions.
static ArrayRef<int64_t> getShape(Type type) {
if (auto sType = dyn_cast<ShapedType>(type))
return sType.getShape();
return {};
}
/// Returns the result broadcast composition type from the two given types by
/// following NumPy broadcast semantics. Returned type may have dynamic shape if
/// either of the input types has dynamic shape. Returns null type if the two
/// given types are not broadcast-compatible.
///
/// elementType, if specified, will be used as the element type of the
/// broadcasted result type. Otherwise it is required that the element type of
/// type1 and type2 is the same and this element type will be used as the
/// resultant element type.
Type OpTrait::util::getBroadcastedType(Type type1, Type type2,
Type elementType) {
// If the elementType is not specified, then the use the common element type
// of the inputs or fail if there is no common element type.
if (!elementType) {
elementType = getElementTypeOrSelf(type1);
if (elementType != getElementTypeOrSelf(type2))
return {};
}
// If one of the types is unranked tensor, then the other type shouldn't be
// vector and the result should have unranked tensor type.
if (isa<UnrankedTensorType>(type1) || isa<UnrankedTensorType>(type2)) {
if (isa<VectorType>(type1) || isa<VectorType>(type2))
return {};
return UnrankedTensorType::get(elementType);
}
// Returns the type kind if the given type is a vector or ranked tensor type.
// Returns std::nullopt otherwise.
auto getCompositeTypeKind = [](Type type) -> std::optional<TypeID> {
if (isa<VectorType, RankedTensorType>(type))
return type.getTypeID();
return std::nullopt;
};
// Make sure the composite type, if has, is consistent.
std::optional<TypeID> compositeKind1 = getCompositeTypeKind(type1);
std::optional<TypeID> compositeKind2 = getCompositeTypeKind(type2);
std::optional<TypeID> resultCompositeKind;
if (compositeKind1 && compositeKind2) {
// Disallow mixing vector and tensor.
if (compositeKind1 != compositeKind2)
return {};
resultCompositeKind = compositeKind1;
} else if (compositeKind1) {
resultCompositeKind = compositeKind1;
} else if (compositeKind2) {
resultCompositeKind = compositeKind2;
}
// Get the shape of each type.
SmallVector<int64_t, 4> resultShape;
if (!getBroadcastedShape(getShape(type1), getShape(type2), resultShape))
return {};
// Compose the final broadcasted type
if (resultCompositeKind == VectorType::getTypeID())
return VectorType::get(resultShape, elementType);
if (resultCompositeKind == RankedTensorType::getTypeID())
return RankedTensorType::get(resultShape, elementType);
return elementType;
}
/// Returns a tuple corresponding to whether range has tensor or vector type.
template <typename iterator_range>
static std::tuple<bool, bool> hasTensorOrVectorType(iterator_range types) {
return std::make_tuple(
llvm::any_of(types, [](Type t) { return isa<TensorType>(t); }),
llvm::any_of(types, [](Type t) { return isa<VectorType>(t); }));
}
static bool isCompatibleInferredReturnShape(ArrayRef<int64_t> inferred,
ArrayRef<int64_t> existing) {
auto isCompatible = [](int64_t inferredDim, int64_t existingDim) {
// The following criterion is used to determine the validity of an existing
// dimension:
//
// inferredDim existingDim Behavior
// ----------- ----------- --------
// dynamic dynamic OK
// dynamic static Error
// static dynamic OK
// static static OK if equal
return ShapedType::isDynamic(existingDim) || inferredDim == existingDim;
};
if (inferred.size() != existing.size())
return false;
for (auto [inferredDim, existingDim] : llvm::zip(inferred, existing))
if (!isCompatible(inferredDim, existingDim))
return false;
return true;
}
static std::string getShapeString(ArrayRef<int64_t> shape) {
// TODO: should replace with printing shape more uniformly across here and
// when in type.
std::string ret;
llvm::raw_string_ostream ss(ret);
ss << '\'';
llvm::interleave(
shape, ss,
[&](int64_t dim) {
if (ShapedType::isDynamic(dim))
ss << '?';
else
ss << dim;
},
"x");
ss << '\'';
return ss.str();
}
LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
// Ensure broadcasting only tensor or only vector types.
auto operandsHasTensorVectorType =
hasTensorOrVectorType(op->getOperandTypes());
auto resultsHasTensorVectorType = hasTensorOrVectorType(op->getResultTypes());
if ((std::get<0>(operandsHasTensorVectorType) ||
std::get<0>(resultsHasTensorVectorType)) &&
(std::get<1>(operandsHasTensorVectorType) ||
std::get<1>(resultsHasTensorVectorType)))
return op->emitError("cannot broadcast vector with tensor");
auto rankedOperands = make_filter_range(
op->getOperandTypes(), [](Type t) { return isa<RankedTensorType>(t); });
// If all operands are unranked, then all result shapes are possible.
if (rankedOperands.empty())
return success();
// Compute broadcasted shape of operands (which requires that operands are
// broadcast compatible). The results need to be broadcast compatible with
// this result shape.
SmallVector<int64_t, 4> resultShape;
(void)util::getBroadcastedShape(getShape(*rankedOperands.begin()), {},
resultShape);
for (auto other : make_early_inc_range(rankedOperands)) {
SmallVector<int64_t, 4> temp = resultShape;
if (!util::getBroadcastedShape(temp, getShape(other), resultShape))
return op->emitOpError("operands don't have broadcast-compatible shapes");
}
auto rankedResults = make_filter_range(
op->getResultTypes(), [](Type t) { return isa<RankedTensorType>(t); });
// If all of the results are unranked then no further verification.
if (rankedResults.empty())
return success();
for (auto type : rankedResults) {
ArrayRef<int64_t> actualSuffix =
getShape(type).take_back(resultShape.size());
if (!isCompatibleInferredReturnShape(resultShape, actualSuffix))
return op->emitOpError()
<< "result type " << getShapeString(getShape(type))
<< " not broadcast compatible with broadcasted operands's shapes "
<< getShapeString(resultShape);
}
return success();
}
|