1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191
|
//===- LoopCanonicalization.cpp - Cross-dialect canonicalization patterns -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains cross-dialect canonicalization patterns that cannot be
// actual canonicalization patterns due to undesired additional dependencies.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"
namespace mlir {
#define GEN_PASS_DEF_SCFFORLOOPCANONICALIZATION
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
} // namespace mlir
using namespace mlir;
using namespace mlir::scf;
/// A simple, conservative analysis to determine if the loop is shape
/// conserving. I.e., the type of the arg-th yielded value is the same as the
/// type of the corresponding basic block argument of the loop.
/// Note: This function handles only simple cases. Expand as needed.
static bool isShapePreserving(ForOp forOp, int64_t arg) {
auto yieldOp = cast<YieldOp>(forOp.getBody()->getTerminator());
assert(arg < static_cast<int64_t>(yieldOp.getResults().size()) &&
"arg is out of bounds");
Value value = yieldOp.getResults()[arg];
while (value) {
if (value == forOp.getRegionIterArgs()[arg])
return true;
OpResult opResult = dyn_cast<OpResult>(value);
if (!opResult)
return false;
using tensor::InsertSliceOp;
value =
llvm::TypeSwitch<Operation *, Value>(opResult.getOwner())
.template Case<InsertSliceOp>(
[&](InsertSliceOp op) { return op.getDest(); })
.template Case<ForOp>([&](ForOp forOp) {
return isShapePreserving(forOp, opResult.getResultNumber())
? forOp.getIterOperands()[opResult.getResultNumber()]
: Value();
})
.Default([&](auto op) { return Value(); });
}
return false;
}
namespace {
/// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
///
/// ```
/// %0 = ... : tensor<?x?xf32>
/// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
/// %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
/// ...
/// }
/// ```
///
/// is folded to:
///
/// ```
/// %0 = ... : tensor<?x?xf32>
/// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
/// %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
/// ...
/// }
/// ```
///
/// Note: Dim ops are folded only if it can be proven that the runtime type of
/// the iter arg does not change with loop iterations.
template <typename OpTy>
struct DimOfIterArgFolder : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy dimOp,
PatternRewriter &rewriter) const override {
auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
if (!blockArg)
return failure();
auto forOp = dyn_cast<ForOp>(blockArg.getParentBlock()->getParentOp());
if (!forOp)
return failure();
if (!isShapePreserving(forOp, blockArg.getArgNumber() - 1))
return failure();
Value initArg = forOp.getOpOperandForRegionIterArg(blockArg).get();
rewriter.updateRootInPlace(
dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
return success();
};
};
/// Fold dim ops of loop results to dim ops of their respective init args. E.g.:
///
/// ```
/// %0 = ... : tensor<?x?xf32>
/// %r = scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
/// ...
/// }
/// %1 = tensor.dim %r, %c0 : tensor<?x?xf32>
/// ```
///
/// is folded to:
///
/// ```
/// %0 = ... : tensor<?x?xf32>
/// %r = scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
/// ...
/// }
/// %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
/// ```
///
/// Note: Dim ops are folded only if it can be proven that the runtime type of
/// the iter arg does not change with loop iterations.
template <typename OpTy>
struct DimOfLoopResultFolder : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy dimOp,
PatternRewriter &rewriter) const override {
auto forOp = dimOp.getSource().template getDefiningOp<scf::ForOp>();
if (!forOp)
return failure();
auto opResult = cast<OpResult>(dimOp.getSource());
unsigned resultNumber = opResult.getResultNumber();
if (!isShapePreserving(forOp, resultNumber))
return failure();
rewriter.updateRootInPlace(dimOp, [&]() {
dimOp.getSourceMutable().assign(forOp.getIterOperands()[resultNumber]);
});
return success();
}
};
/// Canonicalize AffineMinOp/AffineMaxOp operations in the context of scf.for
/// and scf.parallel loops with a known range.
template <typename OpTy>
struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
return scf::canonicalizeMinMaxOpInLoop(rewriter, op, scf::matchForLikeLoop);
}
};
struct SCFForLoopCanonicalization
: public impl::SCFForLoopCanonicalizationBase<SCFForLoopCanonicalization> {
void runOnOperation() override {
auto *parentOp = getOperation();
MLIRContext *ctx = parentOp->getContext();
RewritePatternSet patterns(ctx);
scf::populateSCFForLoopCanonicalizationPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(parentOp, std::move(patterns))))
signalPassFailure();
}
};
} // namespace
void mlir::scf::populateSCFForLoopCanonicalizationPatterns(
RewritePatternSet &patterns) {
MLIRContext *ctx = patterns.getContext();
patterns
.add<AffineOpSCFCanonicalizationPattern<affine::AffineMinOp>,
AffineOpSCFCanonicalizationPattern<affine::AffineMaxOp>,
DimOfIterArgFolder<tensor::DimOp>, DimOfIterArgFolder<memref::DimOp>,
DimOfLoopResultFolder<tensor::DimOp>,
DimOfLoopResultFolder<memref::DimOp>>(ctx);
}
std::unique_ptr<Pass> mlir::createSCFForLoopCanonicalizationPass() {
return std::make_unique<SCFForLoopCanonicalization>();
}
|