1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
|
//===- ElementwiseToLinalg.cpp - conversion of elementwise to linalg ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTELEMENTWISETOLINALG
#include "mlir/Dialect/Linalg/Passes.h.inc"
} // namespace mlir
using namespace mlir;
static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
if (!OpTrait::hasElementwiseMappableTraits(op))
return false;
// TODO: The conversion pattern can be made to work for `any_of` here, but
// it's more complex as it requires tracking which operands are scalars.
return llvm::all_of(op->getOperandTypes(),
[](Type type) { return isa<RankedTensorType>(type); });
}
/// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over
/// the result types and return a list of values such that, for each result type
/// `t` and value `v` at the same index `idx`:
/// 1. `v.getType() == t`
/// 2. If an operand of `op` has type `t`, let `operand_first` be the first
/// such operand. Then`v == operand_first`.
/// 3. Otherwise, v is a newly created `tensor::EmptyOp` with:
/// a. Static and dynamic dims extracted from the first operand of `op`.
/// b. Elemental type equal to the elemental type of `t`.
///
/// This is sufficient because ElementwiseMappable guarantees that "The static
/// types of all vector (resp. tensor) operands and results must have the same
/// shape".
static SmallVector<Value, 4>
getOrCreateOperandsMatchingResultTypes(OpBuilder &b, Operation *op) {
assert(isElementwiseMappableOpOnRankedTensors(op));
Location loc = op->getLoc();
ValueRange operands = op->getOperands();
TypeRange rankedTensorTypes = op->getResultTypes();
SmallVector<Value, 4> res;
res.reserve(rankedTensorTypes.size());
for (Type t : rankedTensorTypes) {
// Try to find an operand with type matching the result tensor.
bool found = false;
for (Value v : operands) {
if (v.getType() == t) {
found = true;
res.push_back(v);
break;
}
}
if (found)
continue;
// Extract static / dynamic shape mix from the first operand.
res.push_back(b.create<tensor::EmptyOp>(
loc, tensor::getMixedSizes(b, loc, operands.front()),
cast<RankedTensorType>(t).getElementType()));
}
return res;
}
namespace {
struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
ConvertAnyElementwiseMappableOpOnRankedTensors(MLIRContext *context)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const final {
if (!isElementwiseMappableOpOnRankedTensors(op))
return rewriter.notifyMatchFailure(
op, "requires elementwise op on ranked tensors");
auto rank = cast<RankedTensorType>(op->getResult(0).getType()).getRank();
SmallVector<AffineMap, 3> indexingMaps(
op->getNumResults() + op->getNumOperands(),
rewriter.getMultiDimIdentityMap(rank));
SmallVector<utils::IteratorType, 6> iteratorTypes(
rank, utils::IteratorType::parallel);
auto outputs = getOrCreateOperandsMatchingResultTypes(rewriter, op);
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
op, /*resultTensorTypes=*/op->getResultTypes(),
/*inputs=*/op->getOperands(),
/*outputs=*/outputs,
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/iteratorTypes,
/*bodyBuilder=*/
[&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
auto resultTypes = llvm::to_vector<6>(
llvm::map_range(op->getResultTypes(), [](Type type) {
return cast<TensorType>(type).getElementType();
}));
auto *scalarOp =
builder.create(loc, op->getName().getIdentifier(),
regionArgs.take_front(op->getNumOperands()),
resultTypes, op->getAttrs());
builder.create<linalg::YieldOp>(loc, scalarOp->getResults());
});
return success();
}
};
} // namespace
void mlir::linalg::populateElementwiseToLinalgConversionPatterns(
RewritePatternSet &patterns) {
patterns.add<ConvertAnyElementwiseMappableOpOnRankedTensors>(
patterns.getContext());
}
namespace {
class ConvertElementwiseToLinalgPass
: public impl::ConvertElementwiseToLinalgBase<
ConvertElementwiseToLinalgPass> {
void runOnOperation() final {
auto *func = getOperation();
auto *context = &getContext();
ConversionTarget target(*context);
RewritePatternSet patterns(context);
mlir::linalg::populateElementwiseToLinalgConversionPatterns(patterns);
target.markUnknownOpDynamicallyLegal([](Operation *op) {
return !isElementwiseMappableOpOnRankedTensors(op);
});
if (failed(applyPartialConversion(func, target, std::move(patterns))))
signalPassFailure();
}
};
} // namespace
std::unique_ptr<Pass> mlir::createConvertElementwiseToLinalgPass() {
return std::make_unique<ConvertElementwiseToLinalgPass>();
}
|