1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254
|
//===- SCCP.cpp - Sparse Conditional Constant Propagation -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This transformation pass performs a sparse conditional constant propagation
// in MLIR. It identifies values known to be constant, propagates that
// information throughout the IR, and replaces them. This is done with an
// optimistic dataflow analysis that assumes that all values are constant until
// proven otherwise.
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Analysis/DataFlowAnalysis.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/Passes.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// SCCP Analysis
//===----------------------------------------------------------------------===//
namespace {
struct SCCPLatticeValue {
SCCPLatticeValue(Attribute constant = {}, Dialect *dialect = nullptr)
: constant(constant), constantDialect(dialect) {}
/// The pessimistic state of SCCP is non-constant.
static SCCPLatticeValue getPessimisticValueState(MLIRContext *context) {
return SCCPLatticeValue();
}
static SCCPLatticeValue getPessimisticValueState(Value value) {
return SCCPLatticeValue();
}
/// Equivalence for SCCP only accounts for the constant, not the originating
/// dialect.
bool operator==(const SCCPLatticeValue &rhs) const {
return constant == rhs.constant;
}
/// To join the state of two values, we simply check for equivalence.
static SCCPLatticeValue join(const SCCPLatticeValue &lhs,
const SCCPLatticeValue &rhs) {
return lhs == rhs ? lhs : SCCPLatticeValue();
}
/// The constant attribute value.
Attribute constant;
/// The dialect the constant originated from. This is not used as part of the
/// key, and is only needed to materialize the held constant if necessary.
Dialect *constantDialect;
};
struct SCCPAnalysis : public ForwardDataFlowAnalysis<SCCPLatticeValue> {
using ForwardDataFlowAnalysis<SCCPLatticeValue>::ForwardDataFlowAnalysis;
~SCCPAnalysis() override = default;
ChangeResult
visitOperation(Operation *op,
ArrayRef<LatticeElement<SCCPLatticeValue> *> operands) final {
// Don't try to simulate the results of a region operation as we can't
// guarantee that folding will be out-of-place. We don't allow in-place
// folds as the desire here is for simulated execution, and not general
// folding.
if (op->getNumRegions())
return markAllPessimisticFixpoint(op->getResults());
SmallVector<Attribute> constantOperands(
llvm::map_range(operands, [](LatticeElement<SCCPLatticeValue> *value) {
return value->getValue().constant;
}));
// Save the original operands and attributes just in case the operation
// folds in-place. The constant passed in may not correspond to the real
// runtime value, so in-place updates are not allowed.
SmallVector<Value, 8> originalOperands(op->getOperands());
DictionaryAttr originalAttrs = op->getAttrDictionary();
// Simulate the result of folding this operation to a constant. If folding
// fails or was an in-place fold, mark the results as overdefined.
SmallVector<OpFoldResult, 8> foldResults;
foldResults.reserve(op->getNumResults());
if (failed(op->fold(constantOperands, foldResults)))
return markAllPessimisticFixpoint(op->getResults());
// If the folding was in-place, mark the results as overdefined and reset
// the operation. We don't allow in-place folds as the desire here is for
// simulated execution, and not general folding.
if (foldResults.empty()) {
op->setOperands(originalOperands);
op->setAttrs(originalAttrs);
return markAllPessimisticFixpoint(op->getResults());
}
// Merge the fold results into the lattice for this operation.
assert(foldResults.size() == op->getNumResults() && "invalid result size");
Dialect *dialect = op->getDialect();
ChangeResult result = ChangeResult::NoChange;
for (unsigned i = 0, e = foldResults.size(); i != e; ++i) {
LatticeElement<SCCPLatticeValue> &lattice =
getLatticeElement(op->getResult(i));
// Merge in the result of the fold, either a constant or a value.
OpFoldResult foldResult = foldResults[i];
if (Attribute attr = foldResult.dyn_cast<Attribute>())
result |= lattice.join(SCCPLatticeValue(attr, dialect));
else
result |= lattice.join(getLatticeElement(foldResult.get<Value>()));
}
return result;
}
/// Implementation of `getSuccessorsForOperands` that uses constant operands
/// to potentially remove dead successors.
LogicalResult getSuccessorsForOperands(
BranchOpInterface branch,
ArrayRef<LatticeElement<SCCPLatticeValue> *> operands,
SmallVectorImpl<Block *> &successors) final {
SmallVector<Attribute> constantOperands(
llvm::map_range(operands, [](LatticeElement<SCCPLatticeValue> *value) {
return value->getValue().constant;
}));
if (Block *singleSucc = branch.getSuccessorForOperands(constantOperands)) {
successors.push_back(singleSucc);
return success();
}
return failure();
}
/// Implementation of `getSuccessorsForOperands` that uses constant operands
/// to potentially remove dead region successors.
void getSuccessorsForOperands(
RegionBranchOpInterface branch, Optional<unsigned> sourceIndex,
ArrayRef<LatticeElement<SCCPLatticeValue> *> operands,
SmallVectorImpl<RegionSuccessor> &successors) final {
SmallVector<Attribute> constantOperands(
llvm::map_range(operands, [](LatticeElement<SCCPLatticeValue> *value) {
return value->getValue().constant;
}));
branch.getSuccessorRegions(sourceIndex, constantOperands, successors);
}
};
} // namespace
//===----------------------------------------------------------------------===//
// SCCP Rewrites
//===----------------------------------------------------------------------===//
/// Replace the given value with a constant if the corresponding lattice
/// represents a constant. Returns success if the value was replaced, failure
/// otherwise.
static LogicalResult replaceWithConstant(SCCPAnalysis &analysis,
OpBuilder &builder,
OperationFolder &folder, Value value) {
LatticeElement<SCCPLatticeValue> *lattice =
analysis.lookupLatticeElement(value);
if (!lattice)
return failure();
SCCPLatticeValue &latticeValue = lattice->getValue();
if (!latticeValue.constant)
return failure();
// Attempt to materialize a constant for the given value.
Dialect *dialect = latticeValue.constantDialect;
Value constant = folder.getOrCreateConstant(
builder, dialect, latticeValue.constant, value.getType(), value.getLoc());
if (!constant)
return failure();
value.replaceAllUsesWith(constant);
return success();
}
/// Rewrite the given regions using the computing analysis. This replaces the
/// uses of all values that have been computed to be constant, and erases as
/// many newly dead operations.
static void rewrite(SCCPAnalysis &analysis, MLIRContext *context,
MutableArrayRef<Region> initialRegions) {
SmallVector<Block *> worklist;
auto addToWorklist = [&](MutableArrayRef<Region> regions) {
for (Region ®ion : regions)
for (Block &block : llvm::reverse(region))
worklist.push_back(&block);
};
// An operation folder used to create and unique constants.
OperationFolder folder(context);
OpBuilder builder(context);
addToWorklist(initialRegions);
while (!worklist.empty()) {
Block *block = worklist.pop_back_val();
for (Operation &op : llvm::make_early_inc_range(*block)) {
builder.setInsertionPoint(&op);
// Replace any result with constants.
bool replacedAll = op.getNumResults() != 0;
for (Value res : op.getResults())
replacedAll &=
succeeded(replaceWithConstant(analysis, builder, folder, res));
// If all of the results of the operation were replaced, try to erase
// the operation completely.
if (replacedAll && wouldOpBeTriviallyDead(&op)) {
assert(op.use_empty() && "expected all uses to be replaced");
op.erase();
continue;
}
// Add any the regions of this operation to the worklist.
addToWorklist(op.getRegions());
}
// Replace any block arguments with constants.
builder.setInsertionPointToStart(block);
for (BlockArgument arg : block->getArguments())
(void)replaceWithConstant(analysis, builder, folder, arg);
}
}
//===----------------------------------------------------------------------===//
// SCCP Pass
//===----------------------------------------------------------------------===//
namespace {
struct SCCP : public SCCPBase<SCCP> {
void runOnOperation() override;
};
} // end anonymous namespace
void SCCP::runOnOperation() {
Operation *op = getOperation();
SCCPAnalysis analysis(op->getContext());
analysis.run(op);
rewrite(analysis, op->getContext(), op->getRegions());
}
std::unique_ptr<Pass> mlir::createSCCPPass() {
return std::make_unique<SCCP>();
}
|