1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
|
//===- DecomposeAffineOps.cpp - Decompose affine ops into finer-grained ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements functionality to progressively decompose coarse-grained
// affine ops into finer-grained ops.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"
using namespace mlir;
using namespace mlir::affine;
#define DEBUG_TYPE "decompose-affine-ops"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define DBGSNL() (llvm::dbgs() << "\n")
/// Count the number of loops surrounding `operand` such that operand could be
/// hoisted above.
/// Stop counting at the first loop over which the operand cannot be hoisted.
static int64_t numEnclosingInvariantLoops(OpOperand &operand) {
int64_t count = 0;
Operation *currentOp = operand.getOwner();
while (auto loopOp = currentOp->getParentOfType<LoopLikeOpInterface>()) {
if (!loopOp.isDefinedOutsideOfLoop(operand.get()))
break;
currentOp = loopOp;
count++;
}
return count;
}
void mlir::affine::reorderOperandsByHoistability(RewriterBase &rewriter,
AffineApplyOp op) {
SmallVector<int64_t> numInvariant = llvm::to_vector(
llvm::map_range(op->getOpOperands(), [&](OpOperand &operand) {
return numEnclosingInvariantLoops(operand);
}));
int64_t numOperands = op.getNumOperands();
SmallVector<int64_t> operandPositions =
llvm::to_vector(llvm::seq<int64_t>(0, numOperands));
llvm::stable_sort(operandPositions, [&numInvariant](size_t i1, size_t i2) {
return numInvariant[i1] > numInvariant[i2];
});
SmallVector<AffineExpr> replacements(numOperands);
SmallVector<Value> operands(numOperands);
for (int64_t i = 0; i < numOperands; ++i) {
operands[i] = op.getOperand(operandPositions[i]);
replacements[operandPositions[i]] = getAffineSymbolExpr(i, op.getContext());
}
AffineMap map = op.getAffineMap();
ArrayRef<AffineExpr> repls{replacements};
map = map.replaceDimsAndSymbols(repls.take_front(map.getNumDims()),
repls.drop_front(map.getNumDims()),
/*numResultDims=*/0,
/*numResultSyms=*/numOperands);
map = AffineMap::get(0, numOperands,
simplifyAffineExpr(map.getResult(0), 0, numOperands),
op->getContext());
canonicalizeMapAndOperands(&map, &operands);
rewriter.startRootUpdate(op);
op.setMap(map);
op->setOperands(operands);
rewriter.finalizeRootUpdate(op);
}
/// Build an affine.apply that is a subexpression `expr` of `originalOp`s affine
/// map and with the same operands.
/// Canonicalize the map and operands to deduplicate and drop dead operands
/// before returning but do not perform maximal composition of AffineApplyOp
/// which would defeat the purpose.
static AffineApplyOp createSubApply(RewriterBase &rewriter,
AffineApplyOp originalOp, AffineExpr expr) {
MLIRContext *ctx = originalOp->getContext();
AffineMap m = originalOp.getAffineMap();
auto rhsMap = AffineMap::get(m.getNumDims(), m.getNumSymbols(), expr, ctx);
SmallVector<Value> rhsOperands = originalOp->getOperands();
canonicalizeMapAndOperands(&rhsMap, &rhsOperands);
return rewriter.create<AffineApplyOp>(originalOp.getLoc(), rhsMap,
rhsOperands);
}
FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter,
AffineApplyOp op) {
// 1. Preconditions: only handle dimensionless AffineApplyOp maps with a
// top-level binary expression that we can reassociate (i.e. add or mul).
AffineMap m = op.getAffineMap();
if (m.getNumDims() > 0)
return rewriter.notifyMatchFailure(op, "expected no dims");
AffineExpr remainingExp = m.getResult(0);
auto binExpr = remainingExp.dyn_cast<AffineBinaryOpExpr>();
if (!binExpr)
return rewriter.notifyMatchFailure(op, "terminal affine.apply");
if (!binExpr.getLHS().isa<AffineBinaryOpExpr>() &&
!binExpr.getRHS().isa<AffineBinaryOpExpr>())
return rewriter.notifyMatchFailure(op, "terminal affine.apply");
bool supportedKind = ((binExpr.getKind() == AffineExprKind::Add) ||
(binExpr.getKind() == AffineExprKind::Mul));
if (!supportedKind)
return rewriter.notifyMatchFailure(
op, "only add or mul binary expr can be reassociated");
LLVM_DEBUG(DBGS() << "Start decomposeIntoFinerGrainedOps: " << op << "\n");
// 2. Iteratively extract the RHS subexpressions while the top-level binary
// expr kind remains the same.
MLIRContext *ctx = op->getContext();
SmallVector<AffineExpr> subExpressions;
while (true) {
auto currentBinExpr = remainingExp.dyn_cast<AffineBinaryOpExpr>();
if (!currentBinExpr || currentBinExpr.getKind() != binExpr.getKind()) {
subExpressions.push_back(remainingExp);
LLVM_DEBUG(DBGS() << "--terminal: " << subExpressions.back() << "\n");
break;
}
subExpressions.push_back(currentBinExpr.getRHS());
LLVM_DEBUG(DBGS() << "--subExpr: " << subExpressions.back() << "\n");
remainingExp = currentBinExpr.getLHS();
}
// 3. Reorder subExpressions by the min symbol they are a function of.
// This also takes care of properly reordering local variables.
// This however won't be able to split expression that cannot be reassociated
// such as ones that involve divs and multiple symbols.
auto getMaxSymbol = [&](AffineExpr e) -> int64_t {
for (int64_t i = m.getNumSymbols(); i >= 0; --i)
if (e.isFunctionOfSymbol(i))
return i;
return -1;
};
llvm::stable_sort(subExpressions, [&](AffineExpr e1, AffineExpr e2) {
return getMaxSymbol(e1) < getMaxSymbol(e2);
});
LLVM_DEBUG(
llvm::interleaveComma(subExpressions, DBGS() << "--sorted subexprs: ");
llvm::dbgs() << "\n");
// 4. Merge sorted subExpressions iteratively, thus achieving reassociation.
auto s0 = getAffineSymbolExpr(0, ctx);
auto s1 = getAffineSymbolExpr(1, ctx);
AffineMap binMap = AffineMap::get(
/*dimCount=*/0, /*symbolCount=*/2,
getAffineBinaryOpExpr(binExpr.getKind(), s0, s1), ctx);
auto current = createSubApply(rewriter, op, subExpressions[0]);
for (int64_t i = 1, e = subExpressions.size(); i < e; ++i) {
Value tmp = createSubApply(rewriter, op, subExpressions[i]);
current = rewriter.create<AffineApplyOp>(op.getLoc(), binMap,
ValueRange{current, tmp});
LLVM_DEBUG(DBGS() << "--reassociate into: " << current << "\n");
}
// 5. Replace original op.
rewriter.replaceOp(op, current.getResult());
return current;
}
|