1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
|
//===- EmptyOpPatterns.cpp - Patterns related to tensor.empty folding ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/Support/Debug.h"
using namespace mlir;
using namespace mlir::tensor;
namespace {
template <typename ReshapeOp>
struct FoldEmptyTensorWithReshapeOp : public OpRewritePattern<ReshapeOp> {
FoldEmptyTensorWithReshapeOp(MLIRContext *ctx, PatternBenefit benefit = 1,
bool foldSingleUseOnly = false)
: OpRewritePattern<ReshapeOp>(ctx, benefit),
foldSingleUseOnly(foldSingleUseOnly) {}
LogicalResult matchAndRewrite(ReshapeOp reshapeOp,
PatternRewriter &rewriter) const override {
// Check for tensor.empty source.
auto emptyOp = reshapeOp.getSrc().template getDefiningOp<EmptyOp>();
if (!emptyOp)
return failure();
// Check for single use.
if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses()))
return failure();
// Reify result shape.
Location loc = reshapeOp.getLoc();
ReifiedRankedShapedTypeDims resultShapes;
if (failed(reifyResultShapes(rewriter, reshapeOp, resultShapes)) ||
!llvm::hasSingleElement(resultShapes))
return failure();
// Create new tensor.empty op.
// TODO: Do not drop tensor type encoding.
Value emptyTensor = rewriter.create<EmptyOp>(
loc, resultShapes[0], reshapeOp.getResultType().getElementType());
if (emptyTensor.getType() != reshapeOp.getResultType()) {
rewriter.replaceOpWithNewOp<tensor::CastOp>(
reshapeOp, reshapeOp.getResultType(), emptyTensor);
} else {
rewriter.replaceOp(reshapeOp, emptyTensor);
}
return success();
}
private:
bool foldSingleUseOnly = false;
};
/// tensor.empty does not define any tensor contents, so a slice of a
/// tensor.empty can be folded to a smaller tensor.empty.
struct FoldEmptyTensorWithExtractSliceOp
: public OpRewritePattern<ExtractSliceOp> {
FoldEmptyTensorWithExtractSliceOp(MLIRContext *ctx,
PatternBenefit benefit = 1,
bool foldSingleUseOnly = false)
: OpRewritePattern<ExtractSliceOp>(ctx, benefit),
foldSingleUseOnly(foldSingleUseOnly) {}
LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
PatternRewriter &rewriter) const override {
// Check for tensor.empty source.
auto emptyOp = sliceOp.getSource().template getDefiningOp<EmptyOp>();
if (!emptyOp)
return failure();
// Check for single use.
if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses()))
return failure();
// Create new tensor.empty op. tensor.extract_slice may be rank-reducing;
// its dynamic sizes must be preserved as well as its result type.
auto tensorType = RankedTensorType::get(sliceOp.getType().getShape(),
sliceOp.getType().getElementType(),
sliceOp.getType().getEncoding());
rewriter.replaceOpWithNewOp<EmptyOp>(sliceOp, tensorType,
sliceOp.getSizes());
return success();
}
private:
bool foldSingleUseOnly = false;
};
/// tensor.empty does not define any tensor contents, so an unpadded pack
/// can be folded away.
struct FoldEmptyTensorWithPackOp : public OpRewritePattern<PackOp> {
using OpRewritePattern<PackOp>::OpRewritePattern;
LogicalResult matchAndRewrite(PackOp packOp,
PatternRewriter &rewriter) const override {
// Check for tensor.empty source.
auto emptyOp = packOp.getSource().getDefiningOp<EmptyOp>();
if (!emptyOp)
return failure();
// Check for padding.
// Packing with padding cannot be simply removed.
if (packOp.getPaddingValue())
return rewriter.notifyMatchFailure(packOp, "expects no padding value");
// Replace the pack directly with its destination.
rewriter.replaceOp(packOp, packOp.getDest());
return success();
}
};
/// tensor.empty does not define any tensor contents, so an unpack
/// can be folded away.
struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern<UnPackOp> {
using OpRewritePattern<UnPackOp>::OpRewritePattern;
LogicalResult matchAndRewrite(UnPackOp unPackOp,
PatternRewriter &rewriter) const override {
// Check for tensor.empty source.
auto emptyOp = unPackOp.getSource().getDefiningOp<EmptyOp>();
if (!emptyOp)
return failure();
// Replace the unpack directly with its destination.
rewriter.replaceOp(unPackOp, unPackOp.getDest());
return success();
}
};
// Fold concat operation where all the operands are empty.
struct FoldConcatsOfEmpty : public OpRewritePattern<ConcatOp> {
using OpRewritePattern<ConcatOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
PatternRewriter &rewriter) const override {
auto concatOperands = concatOp.getInputs();
if (concatOperands.empty()) {
return failure();
}
auto firstEmptyOp = concatOperands.front().getDefiningOp<tensor::EmptyOp>();
if (!firstEmptyOp) {
return failure();
}
auto isDefinedByEmptyOp = [](Value v) -> bool {
return v.getDefiningOp<tensor::EmptyOp>();
};
if (!llvm::all_of(concatOperands.drop_front(), isDefinedByEmptyOp)) {
return rewriter.notifyMatchFailure(
concatOp, "not all operands are defined by an empty op");
}
SmallVector<SmallVector<OpFoldResult>> resultShape;
if (failed(concatOp.reifyResultShapes(rewriter, resultShape))) {
return rewriter.notifyMatchFailure(concatOp,
"failed to get result shape");
}
rewriter.replaceOpWithNewOp<tensor::EmptyOp>(
concatOp, resultShape[0], concatOp.getResultType().getElementType());
return success();
}
};
} // namespace
void mlir::tensor::populateFoldTensorEmptyPatterns(RewritePatternSet &patterns,
bool foldSingleUseOnly) {
patterns.add<FoldEmptyTensorWithExtractSliceOp,
FoldEmptyTensorWithReshapeOp<tensor::ExpandShapeOp>,
FoldEmptyTensorWithReshapeOp<tensor::CollapseShapeOp>>(
patterns.getContext(), /*benefit=*/1, foldSingleUseOnly);
patterns.add<FoldConcatsOfEmpty, FoldEmptyTensorWithPackOp,
FoldEmptyTensorWithUnPackOp>(patterns.getContext(),
/*benefit=*/1);
}
|