1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216
|
//===- RewriteAsConstant.cpp - Patterns to rewrite tensor ops as constants ===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace mlir::tensor;
namespace {
/// Rewrite tensor.generate with arith.constant if the yielded value is a
/// constant and the tensor type is static.
struct GenerateToConstant : public OpRewritePattern<GenerateOp> {
using OpRewritePattern<GenerateOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GenerateOp generateOp,
PatternRewriter &rewriter) const override {
auto tensorType =
llvm::cast<RankedTensorType>(generateOp.getResult().getType());
if (!tensorType.hasStaticShape())
return failure();
auto terminatorOp =
cast<tensor::YieldOp>(generateOp.getBody().front().getTerminator());
Attribute attr;
if (!matchPattern(terminatorOp.getValue(), m_Constant(&attr)))
return failure();
Operation *constantOp =
rewriter.getContext()
->getLoadedDialect<TensorDialect>()
->materializeConstant(rewriter,
DenseElementsAttr::get(tensorType, attr),
tensorType, generateOp->getLoc());
if (!constantOp)
return failure();
rewriter.replaceOp(generateOp, constantOp->getResults());
return success();
}
};
/// Transform a linear index from one indexing space to another given:
///
/// - the shape of the source indexing space,
/// - the strides of the target indexing space,
/// - a linear index into the source indexing space.
///
/// This function is logically a sequence of linearize/delinearize over
/// different bases but avoids allocating intermediate SmallVectors.
int64_t transformIndexSpace(ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> outputStrides,
int64_t srcLinearIndex) {
assert(inputShape.size() == outputStrides.size());
int64_t dstLinearIndex = 0;
for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) {
// Compute the index into the current dimension of the source tensor.
// `quotient` is the remaining linear index after accounting for the
// current dimension.
//
// `remainder` is the index into the source tensor for the current
// dimension.
auto [quotient, remainder] = std::div(srcLinearIndex, inputShape[dim]);
srcLinearIndex = quotient;
// Add the contribution of the current dimension to the output using the
// permutation map.
dstLinearIndex += outputStrides[dim] * remainder;
}
return dstLinearIndex;
}
template <typename ElemType, typename AttrType>
Value constantFoldPadOp(PatternRewriter &rewriter, Location loc,
DenseElementsAttr input, AttrType padValue,
ArrayRef<int64_t> padLow, ArrayRef<int64_t> padHigh) {
auto inputValues = input.tryGetValues<ElemType>();
if (failed(inputValues))
return nullptr;
auto oldShape = input.getType().getShape();
// Compute the output shape of the new value.
auto newShape =
llvm::map_to_vector(llvm::zip(oldShape, padLow, padHigh),
[](std::tuple<int64_t, int64_t, int64_t> pack) {
auto [old, low, high] = pack;
return old + low + high;
});
int64_t outputSize = computeProduct(newShape);
// Fully initialize the vector with the padding value.
// The non-padded area will then be copied.
SmallVector<ElemType> values(outputSize, padValue.getValue());
// Strides for input and output are used to transform between the indexing
// space of the input and output tensors.
SmallVector<int64_t> outputStrides = computeStrides(newShape);
// The contribution of the low padding to the offset in the output tensor.
// This is the starting position of the source tensor within the padding
// tensor.
int64_t startingOffset = linearize(padLow, outputStrides);
// Copy values from the input tensor to the corresponding sub-region
// of the output tensor.
for (auto [inputIndex, inputValue] : llvm::enumerate(*inputValues)) {
auto outputIndex = transformIndexSpace(oldShape, outputStrides, inputIndex);
values[outputIndex + startingOffset] = inputValue;
}
// Create an attribute for the folded value.
auto newType = input.getType().clone(newShape);
auto newAttr = DenseElementsAttr::get(newType, values);
Operation *constantOp =
rewriter.getContext()
->getLoadedDialect<TensorDialect>()
->materializeConstant(rewriter, newAttr, newType, loc);
return constantOp ? constantOp->getResult(0) : nullptr;
}
struct PadOpToConstant final : public OpRewritePattern<PadOp> {
PadOpToConstant(MLIRContext *context, const ControlFoldFn &controlFn,
PatternBenefit benefit = 1)
: OpRewritePattern<PadOp>(context, benefit), controlFn{controlFn} {}
LogicalResult matchAndRewrite(PadOp padTensorOp,
PatternRewriter &rewriter) const override {
if (padTensorOp.getNofold())
return rewriter.notifyMatchFailure(
padTensorOp, "refusing to fold nofold pad operation");
TypedValue<RankedTensorType> input = padTensorOp.getSource();
RankedTensorType resultType = padTensorOp.getResult().getType();
DenseElementsAttr inputAttr = nullptr;
if (!matchPattern(input, m_Constant(&inputAttr)))
return failure();
Value paddingValue = padTensorOp.getConstantPaddingValue();
// Extract the constant value used for padding or bail out.
Attribute paddingAttr = nullptr;
if (!paddingValue || !matchPattern(paddingValue, m_Constant(&paddingAttr)))
return rewriter.notifyMatchFailure(padTensorOp,
"unable to get constant value");
// Try to extract the constant values of the low and high padding.
auto lowPad = getConstantIntValues(padTensorOp.getMixedLowPad());
auto highPad = getConstantIntValues(padTensorOp.getMixedHighPad());
// If the padding cannot be extracted, bail out.
if (!lowPad || !highPad)
return rewriter.notifyMatchFailure(padTensorOp,
"unable to extract constant padding");
// We have a potential candidate, consult the control function to
// determine if the op should fold.
if (!controlFn(&padTensorOp.getSourceMutable()))
return rewriter.notifyMatchFailure(padTensorOp,
"not folding due to cost function");
Location loc = padTensorOp.getLoc();
// Try constant folding the supported cases of integer and float values.
Value newOp =
llvm::TypeSwitch<Attribute, Value>(paddingAttr)
.Case([&](FloatAttr floatAttr) {
return constantFoldPadOp<llvm::APFloat>(
rewriter, loc, inputAttr, floatAttr, *lowPad, *highPad);
})
.Case([&](IntegerAttr integerAttr) {
return constantFoldPadOp<llvm::APInt>(
rewriter, loc, inputAttr, integerAttr, *lowPad, *highPad);
})
.Default(Value());
if (!newOp)
return rewriter.notifyMatchFailure(padTensorOp,
"tensor type not supported");
if (newOp.getType() != resultType)
newOp = rewriter.create<tensor::CastOp>(loc, resultType, newOp);
rewriter.replaceOp(padTensorOp, newOp);
return success();
}
private:
ControlFoldFn controlFn;
};
} // namespace
void mlir::tensor::populateRewriteAsConstantPatterns(
RewritePatternSet &patterns, const ControlFoldFn &controlFn) {
patterns.add<GenerateToConstant>(patterns.getContext());
patterns.add<PadOpToConstant>(patterns.getContext(), controlFn);
}
|