1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
|
//===- UpliftWhileToFor.cpp - scf.while to scf.for loop uplifting ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Transforms SCF.WhileOp's into SCF.ForOp's.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"
using namespace mlir;
namespace {
struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(scf::WhileOp loop,
PatternRewriter &rewriter) const override {
return upliftWhileToForLoop(rewriter, loop);
}
};
} // namespace
FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
scf::WhileOp loop) {
Block *beforeBody = loop.getBeforeBody();
if (!llvm::hasSingleElement(beforeBody->without_terminator()))
return rewriter.notifyMatchFailure(loop, "Loop body must have single op");
auto cmp = dyn_cast<arith::CmpIOp>(beforeBody->front());
if (!cmp)
return rewriter.notifyMatchFailure(loop,
"Loop body must have single cmp op");
scf::ConditionOp beforeTerm = loop.getConditionOp();
if (!cmp->hasOneUse() || beforeTerm.getCondition() != cmp.getResult())
return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
diag << "Expected single condition use: " << *cmp;
});
// All `before` block args must be directly forwarded to ConditionOp.
// They will be converted to `scf.for` `iter_vars` except induction var.
if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs())
return rewriter.notifyMatchFailure(loop, "Invalid args order");
using Pred = arith::CmpIPredicate;
Pred predicate = cmp.getPredicate();
if (predicate != Pred::slt && predicate != Pred::sgt)
return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
diag << "Expected 'slt' or 'sgt' predicate: " << *cmp;
});
BlockArgument inductionVar;
Value ub;
DominanceInfo dom;
// Check if cmp has a suitable form. One of the arguments must be a `before`
// block arg, other must be defined outside `scf.while` and will be treated
// as upper bound.
for (bool reverse : {false, true}) {
auto expectedPred = reverse ? Pred::sgt : Pred::slt;
if (cmp.getPredicate() != expectedPred)
continue;
auto arg1 = reverse ? cmp.getRhs() : cmp.getLhs();
auto arg2 = reverse ? cmp.getLhs() : cmp.getRhs();
auto blockArg = dyn_cast<BlockArgument>(arg1);
if (!blockArg || blockArg.getOwner() != beforeBody)
continue;
if (!dom.properlyDominates(arg2, loop))
continue;
inductionVar = blockArg;
ub = arg2;
break;
}
if (!inductionVar)
return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
diag << "Unrecognized cmp form: " << *cmp;
});
// inductionVar must have 2 uses: one is in `cmp` and other is `condition`
// arg.
if (!llvm::hasNItems(inductionVar.getUses(), 2))
return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
diag << "Unrecognized induction var: " << inductionVar;
});
Block *afterBody = loop.getAfterBody();
scf::YieldOp afterTerm = loop.getYieldOp();
unsigned argNumber = inductionVar.getArgNumber();
Value afterTermIndArg = afterTerm.getResults()[argNumber];
Value inductionVarAfter = afterBody->getArgument(argNumber);
// Find suitable `addi` op inside `after` block, one of the args must be an
// Induction var passed from `before` block and second arg must be defined
// outside of the loop and will be considered step value.
// TODO: Add `subi` support?
auto addOp = afterTermIndArg.getDefiningOp<arith::AddIOp>();
if (!addOp)
return rewriter.notifyMatchFailure(loop, "Didn't found suitable 'addi' op");
Value step;
if (addOp.getLhs() == inductionVarAfter) {
step = addOp.getRhs();
} else if (addOp.getRhs() == inductionVarAfter) {
step = addOp.getLhs();
}
if (!step || !dom.properlyDominates(step, loop))
return rewriter.notifyMatchFailure(loop, "Invalid 'addi' form");
Value lb = loop.getInits()[argNumber];
assert(lb.getType().isIntOrIndex());
assert(lb.getType() == ub.getType());
assert(lb.getType() == step.getType());
llvm::SmallVector<Value> newArgs;
// Populate inits for new `scf.for`, skip induction var.
newArgs.reserve(loop.getInits().size());
for (auto &&[i, init] : llvm::enumerate(loop.getInits())) {
if (i == argNumber)
continue;
newArgs.emplace_back(init);
}
Location loc = loop.getLoc();
// With `builder == nullptr`, ForOp::build will try to insert terminator at
// the end of newly created block and we don't want it. Provide empty
// dummy builder instead.
auto emptyBuilder = [](OpBuilder &, Location, Value, ValueRange) {};
auto newLoop =
rewriter.create<scf::ForOp>(loc, lb, ub, step, newArgs, emptyBuilder);
Block *newBody = newLoop.getBody();
// Populate block args for `scf.for` body, move induction var to the front.
newArgs.clear();
ValueRange newBodyArgs = newBody->getArguments();
for (auto i : llvm::seq<size_t>(0, newBodyArgs.size())) {
if (i < argNumber) {
newArgs.emplace_back(newBodyArgs[i + 1]);
} else if (i == argNumber) {
newArgs.emplace_back(newBodyArgs.front());
} else {
newArgs.emplace_back(newBodyArgs[i]);
}
}
rewriter.inlineBlockBefore(loop.getAfterBody(), newBody, newBody->end(),
newArgs);
auto term = cast<scf::YieldOp>(newBody->getTerminator());
// Populate new yield args, skipping the induction var.
newArgs.clear();
for (auto &&[i, arg] : llvm::enumerate(term.getResults())) {
if (i == argNumber)
continue;
newArgs.emplace_back(arg);
}
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(term);
rewriter.replaceOpWithNewOp<scf::YieldOp>(term, newArgs);
// Compute induction var value after loop execution.
rewriter.setInsertionPointAfter(newLoop);
Value one;
if (isa<IndexType>(step.getType())) {
one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
} else {
one = rewriter.create<arith::ConstantIntOp>(loc, 1, step.getType());
}
Value stepDec = rewriter.create<arith::SubIOp>(loc, step, one);
Value len = rewriter.create<arith::SubIOp>(loc, ub, lb);
len = rewriter.create<arith::AddIOp>(loc, len, stepDec);
len = rewriter.create<arith::DivSIOp>(loc, len, step);
len = rewriter.create<arith::SubIOp>(loc, len, one);
Value res = rewriter.create<arith::MulIOp>(loc, len, step);
res = rewriter.create<arith::AddIOp>(loc, lb, res);
// Reconstruct `scf.while` results, inserting final induction var value
// into proper place.
newArgs.clear();
llvm::append_range(newArgs, newLoop.getResults());
newArgs.insert(newArgs.begin() + argNumber, res);
rewriter.replaceOp(loop, newArgs);
return newLoop;
}
void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns) {
patterns.add<UpliftWhileOp>(patterns.getContext());
}
|