1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
|
//===- InlineElementals.cpp - Inline chained hlfir.elemental ops ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// Chained elemental operations like a + b + c can inline the first elemental
// at the hlfir.apply in the body of the second one (as described in
// docs/HighLevelFIR.md). This has to be done in a pass rather than in lowering
// so that it happens after the HLFIR intrinsic simplification pass.
//===----------------------------------------------------------------------===//
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/HLFIRTools.h"
#include "flang/Optimizer/Dialect/Support/FIRContext.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/HLFIR/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"
#include <iterator>
namespace hlfir {
#define GEN_PASS_DEF_INLINEELEMENTALS
#include "flang/Optimizer/HLFIR/Passes.h.inc"
} // namespace hlfir
/// If the elemental has only two uses and those two are an apply operation and
/// a destory operation, return those two, otherwise return {}
static std::optional<std::pair<hlfir::ApplyOp, hlfir::DestroyOp>>
getTwoUses(hlfir::ElementalOp elemental) {
mlir::Operation::user_range users = elemental->getUsers();
// don't inline anything with more than one use (plus hfir.destroy)
if (std::distance(users.begin(), users.end()) != 2) {
return std::nullopt;
}
// If the ElementalOp must produce a temporary (e.g. for
// finalization purposes), then we cannot inline it.
if (hlfir::elementalOpMustProduceTemp(elemental))
return std::nullopt;
hlfir::ApplyOp apply;
hlfir::DestroyOp destroy;
for (mlir::Operation *user : users)
mlir::TypeSwitch<mlir::Operation *, void>(user)
.Case([&](hlfir::ApplyOp op) { apply = op; })
.Case([&](hlfir::DestroyOp op) { destroy = op; });
if (!apply || !destroy)
return std::nullopt;
// we can't inline if the return type of the yield doesn't match the return
// type of the apply
auto yield = mlir::dyn_cast_or_null<hlfir::YieldElementOp>(
elemental.getRegion().back().back());
assert(yield && "hlfir.elemental should always end with a yield");
if (apply.getResult().getType() != yield.getElementValue().getType())
return std::nullopt;
return std::pair{apply, destroy};
}
namespace {
class InlineElementalConversion
: public mlir::OpRewritePattern<hlfir::ElementalOp> {
public:
using mlir::OpRewritePattern<hlfir::ElementalOp>::OpRewritePattern;
mlir::LogicalResult
matchAndRewrite(hlfir::ElementalOp elemental,
mlir::PatternRewriter &rewriter) const override {
std::optional<std::pair<hlfir::ApplyOp, hlfir::DestroyOp>> maybeTuple =
getTwoUses(elemental);
if (!maybeTuple)
return rewriter.notifyMatchFailure(
elemental, "hlfir.elemental does not have two uses");
if (elemental.isOrdered()) {
// We can only inline the ordered elemental into a loop-like
// construct that processes the indices in-order and does not
// have the side effects itself. Adhere to conservative behavior
// for the time being.
return rewriter.notifyMatchFailure(elemental,
"hlfir.elemental is ordered");
}
auto [apply, destroy] = *maybeTuple;
assert(elemental.getRegion().hasOneBlock() &&
"expect elemental region to have one block");
fir::FirOpBuilder builder{rewriter, elemental.getOperation()};
builder.setInsertionPointAfter(apply);
hlfir::YieldElementOp yield = hlfir::inlineElementalOp(
elemental.getLoc(), builder, elemental, apply.getIndices());
// remove the old elemental and all of the bookkeeping
rewriter.replaceAllUsesWith(apply.getResult(), yield.getElementValue());
rewriter.eraseOp(yield);
rewriter.eraseOp(apply);
rewriter.eraseOp(destroy);
rewriter.eraseOp(elemental);
return mlir::success();
}
};
class InlineElementalsPass
: public hlfir::impl::InlineElementalsBase<InlineElementalsPass> {
public:
void runOnOperation() override {
mlir::func::FuncOp func = getOperation();
mlir::MLIRContext *context = &getContext();
mlir::GreedyRewriteConfig config;
// Prevent the pattern driver from merging blocks.
config.enableRegionSimplification = false;
mlir::RewritePatternSet patterns(context);
patterns.insert<InlineElementalConversion>(context);
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
func, std::move(patterns), config))) {
mlir::emitError(func->getLoc(), "failure in HLFIR elemental inlining");
signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<mlir::Pass> hlfir::createInlineElementalsPass() {
return std::make_unique<InlineElementalsPass>();
}
|