1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
|
//===- ConstantArgumentGlobalisation.cpp ----------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace fir {
#define GEN_PASS_DEF_CONSTANTARGUMENTGLOBALISATIONOPT
#include "flang/Optimizer/Transforms/Passes.h.inc"
} // namespace fir
#define DEBUG_TYPE "flang-constant-argument-globalisation-opt"
namespace {
unsigned uniqueLitId = 1;
class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
protected:
const mlir::DominanceInfo &di;
public:
using OpRewritePattern::OpRewritePattern;
CallOpRewriter(mlir::MLIRContext *ctx, const mlir::DominanceInfo &_di)
: OpRewritePattern(ctx), di(_di) {}
llvm::LogicalResult
matchAndRewrite(fir::CallOp callOp,
mlir::PatternRewriter &rewriter) const override {
LLVM_DEBUG(llvm::dbgs() << "Processing call op: " << callOp << "\n");
auto module = callOp->getParentOfType<mlir::ModuleOp>();
bool needUpdate = false;
fir::FirOpBuilder builder(rewriter, module);
llvm::SmallVector<mlir::Value> newOperands;
llvm::SmallVector<std::pair<mlir::Operation *, mlir::Operation *>> allocas;
for (const mlir::Value &a : callOp.getArgs()) {
auto alloca = mlir::dyn_cast_or_null<fir::AllocaOp>(a.getDefiningOp());
// We can convert arguments that are alloca, and that has
// the value by reference attribute. All else is just added
// to the argument list.
if (!alloca || !alloca->hasAttr(fir::getAdaptToByRefAttrName())) {
newOperands.push_back(a);
continue;
}
mlir::Type varTy = alloca.getInType();
assert(!fir::hasDynamicSize(varTy) &&
"only expect statically sized scalars to be by value");
// Find immediate store with const argument
mlir::Operation *store = nullptr;
for (mlir::Operation *s : alloca->getUsers()) {
if (mlir::isa<fir::StoreOp>(s) && di.dominates(s, callOp)) {
// We can only deal with ONE store - if already found one,
// set to nullptr and exit the loop.
if (store) {
store = nullptr;
break;
}
store = s;
}
}
// If we didn't find any store, or multiple stores, add argument as is
// and move on.
if (!store) {
newOperands.push_back(a);
continue;
}
LLVM_DEBUG(llvm::dbgs() << " found store " << *store << "\n");
mlir::Operation *definingOp = store->getOperand(0).getDefiningOp();
// If not a constant, add to operands and move on.
if (!mlir::isa<mlir::arith::ConstantOp>(definingOp)) {
// Unable to remove alloca arg
newOperands.push_back(a);
continue;
}
LLVM_DEBUG(llvm::dbgs() << " found define " << *definingOp << "\n");
std::string globalName =
"_global_const_." + std::to_string(uniqueLitId++);
assert(!builder.getNamedGlobal(globalName) &&
"We should have a unique name here");
if (llvm::none_of(allocas,
[alloca](auto x) { return x.first == alloca; })) {
allocas.push_back(std::make_pair(alloca, store));
}
auto loc = callOp.getLoc();
fir::GlobalOp global = builder.createGlobalConstant(
loc, varTy, globalName,
[&](fir::FirOpBuilder &builder) {
mlir::Operation *cln = definingOp->clone();
builder.insert(cln);
mlir::Value val =
builder.createConvert(loc, varTy, cln->getResult(0));
builder.create<fir::HasValueOp>(loc, val);
},
builder.createInternalLinkage());
mlir::Value addr = builder.create<fir::AddrOfOp>(loc, global.resultType(),
global.getSymbol());
newOperands.push_back(addr);
needUpdate = true;
}
if (needUpdate) {
auto loc = callOp.getLoc();
llvm::SmallVector<mlir::Type> newResultTypes;
newResultTypes.append(callOp.getResultTypes().begin(),
callOp.getResultTypes().end());
fir::CallOp newOp = builder.create<fir::CallOp>(
loc,
callOp.getCallee().has_value() ? callOp.getCallee().value()
: mlir::SymbolRefAttr{},
newResultTypes, newOperands);
// Copy all the attributes from the old to new op.
newOp->setAttrs(callOp->getAttrs());
rewriter.replaceOp(callOp, newOp);
for (auto a : allocas) {
if (a.first->hasOneUse()) {
// If the alloca is only used for a store and the call operand, the
// store is no longer required.
rewriter.eraseOp(a.second);
rewriter.eraseOp(a.first);
}
}
LLVM_DEBUG(llvm::dbgs() << "global constant for " << callOp << " as "
<< newOp << '\n');
return mlir::success();
}
// Failure here just means "we couldn't do the conversion", which is
// perfectly acceptable to the upper layers of this function.
return mlir::failure();
}
};
// this pass attempts to convert immediate scalar literals in function calls
// to global constants to allow transformations such as Dead Argument
// Elimination
class ConstantArgumentGlobalisationOpt
: public fir::impl::ConstantArgumentGlobalisationOptBase<
ConstantArgumentGlobalisationOpt> {
public:
ConstantArgumentGlobalisationOpt() = default;
void runOnOperation() override {
mlir::ModuleOp mod = getOperation();
mlir::DominanceInfo *di = &getAnalysis<mlir::DominanceInfo>();
auto *context = &getContext();
mlir::RewritePatternSet patterns(context);
mlir::GreedyRewriteConfig config;
config.enableRegionSimplification =
mlir::GreedySimplifyRegionLevel::Disabled;
config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;
patterns.insert<CallOpRewriter>(context, *di);
if (mlir::failed(
mlir::applyPatternsGreedily(mod, std::move(patterns), config))) {
mlir::emitError(mod.getLoc(),
"error in constant globalisation optimization\n");
signalPassFailure();
}
}
};
} // namespace
|