1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
|
//===- BubbleUpExtractSlice.cpp - bubble up tensor.extract_slice ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements patterns that transforms linalg.<op> +
// tensor.extract_slice into tensor.extract_slice + linalg.<op> to reduce
// the computation for the linalg op.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
using namespace mlir::linalg;
namespace {
/// Bubble up extract_slice above Linalg operation.
///
/// A sequence of operations
///
/// ```mlir
/// %0 = linalg.<op> ... arg0, arg1, ...
/// %1 = tensor.extract_slice %0 ...
/// ```
///
/// can be replaced with
///
/// ```mlir
/// %0 = tensor.extract_slice %arg0
/// %1 = tensor.extract_slice %arg1
/// %2 = linalg.<op> ... %0, %1, ...
/// ```
///
/// This results in the reduce computation of the linalg operation.
///
struct BubbleUpExtractSliceOpPattern
: OpRewritePattern<tensor::ExtractSliceOp> {
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
PatternRewriter &rewriter) const final {
Value source = sliceOp.getSource();
auto linalgOp = source.getDefiningOp<LinalgOp>();
if (!linalgOp) {
return rewriter.notifyMatchFailure(sliceOp,
"expected source to be linalg op");
}
// TODO: we might relax this if we want heuristics to detect that all uses
// are small portion of the output.
if (!linalgOp->hasOneUse()) {
return rewriter.notifyMatchFailure(sliceOp,
"expected single use of linalg op");
}
if (linalgOp.getNumDpsInits() != 1) {
return rewriter.notifyMatchFailure(sliceOp,
"expected single output of linalg op");
}
if (!linalgOp.hasTensorSemantics()) {
return rewriter.notifyMatchFailure(sliceOp,
"expected tensor of linalg op");
}
if (!sliceOp.hasUnitStride())
return rewriter.notifyMatchFailure(sliceOp, "expected unit stride");
if (sliceOp.getType().getRank() != sliceOp.getSourceType().getRank()) {
return rewriter.notifyMatchFailure(sliceOp, "expected no rank reduction");
}
OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(outOperand);
if (!indexingMap.isProjectedPermutation()) {
return rewriter.notifyMatchFailure(
sliceOp, "expected a projected permutation for output");
}
auto linalgLoc = linalgOp.getLoc();
SmallVector<OpFoldResult> allShapeSizes =
linalgOp.createFlatListOfOperandDims(rewriter, linalgLoc);
AffineMap shapeSizesToLoopsMap = linalgOp.getShapesToLoopsMap();
if (!shapeSizesToLoopsMap) {
return rewriter.notifyMatchFailure(
linalgOp, "failed to get loops map from shape sizes");
}
SmallVector<OpFoldResult> sizeBounds =
affine::makeComposedFoldedMultiResultAffineApply(
rewriter, linalgLoc, shapeSizesToLoopsMap, allShapeSizes);
// The offsets and sizes from the slice operation only give you the tile
// size of the output. Use that compute the tile sizes and offsets of the
// loops. For loops not used to access the output, set the tile sizes to
// loop bounds and set the offset to 0.
SmallVector<OpFoldResult> tileOffsets(sizeBounds.size(),
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> tileSizes = sizeBounds;
for (auto const &result : enumerate(indexingMap.getResults())) {
unsigned position = result.value().cast<AffineDimExpr>().getPosition();
tileOffsets[position] = sliceOp.getMixedOffsets()[result.index()];
tileSizes[position] = sliceOp.getMixedSizes()[result.index()];
}
SmallVector<Value> valuesToTile = linalgOp->getOperands();
SmallVector<Value> tiledOperands =
makeTiledShapes(rewriter, linalgLoc, linalgOp, valuesToTile,
tileOffsets, tileSizes, sizeBounds,
/*omitPartialTileCheck=*/true);
SmallVector<Type, 4> resultTensorTypes;
for (OpOperand *opOperand : linalgOp.getDpsInitOperands())
resultTensorTypes.push_back(
tiledOperands[opOperand->getOperandNumber()].getType());
Operation *newOp =
clone(rewriter, linalgOp, resultTensorTypes, tiledOperands);
rewriter.replaceOp(sliceOp, newOp->getResults());
return success();
}
};
} // namespace
void mlir::linalg::populateBubbleUpExtractSliceOpPatterns(
RewritePatternSet &patterns) {
auto *context = patterns.getContext();
patterns.add<BubbleUpExtractSliceOpPattern>(context);
}
|