1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
|
//===- TestMakeIsolatedFromAbove.cpp - Test makeIsolatedFromAbove method -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/RegionUtils.h"
using namespace mlir;
/// Helper function to call the `makeRegionIsolatedFromAbove` to convert
/// `test.one_region_op` to `test.isolated_one_region_op`.
static LogicalResult
makeIsolatedFromAboveImpl(RewriterBase &rewriter,
test::OneRegionWithOperandsOp regionOp,
llvm::function_ref<bool(Operation *)> callBack) {
Region ®ion = regionOp.getRegion();
SmallVector<Value> capturedValues =
makeRegionIsolatedFromAbove(rewriter, region, callBack);
SmallVector<Value> operands = regionOp.getOperands();
operands.append(capturedValues);
auto isolatedRegionOp =
rewriter.create<test::IsolatedOneRegionOp>(regionOp.getLoc(), operands);
rewriter.inlineRegionBefore(region, isolatedRegionOp.getRegion(),
isolatedRegionOp.getRegion().begin());
rewriter.eraseOp(regionOp);
return success();
}
namespace {
/// Simple test for making region isolated from above without cloning any
/// operations.
struct SimpleMakeIsolatedFromAbove
: OpRewritePattern<test::OneRegionWithOperandsOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(test::OneRegionWithOperandsOp regionOp,
PatternRewriter &rewriter) const override {
return makeIsolatedFromAboveImpl(rewriter, regionOp,
[](Operation *) { return false; });
}
};
/// Test for making region isolated from above while clong operations
/// with no operands.
struct MakeIsolatedFromAboveAndCloneOpsWithNoOperands
: OpRewritePattern<test::OneRegionWithOperandsOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(test::OneRegionWithOperandsOp regionOp,
PatternRewriter &rewriter) const override {
return makeIsolatedFromAboveImpl(rewriter, regionOp, [](Operation *op) {
return op->getNumOperands() == 0;
});
}
};
/// Test for making region isolated from above while clong operations
/// with no operands.
struct MakeIsolatedFromAboveAndCloneOpsWithOperands
: OpRewritePattern<test::OneRegionWithOperandsOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(test::OneRegionWithOperandsOp regionOp,
PatternRewriter &rewriter) const override {
return makeIsolatedFromAboveImpl(rewriter, regionOp,
[](Operation *op) { return true; });
}
};
/// Test pass for testing the `makeIsolatedFromAbove` function.
struct TestMakeIsolatedFromAbovePass
: public PassWrapper<TestMakeIsolatedFromAbovePass,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMakeIsolatedFromAbovePass)
TestMakeIsolatedFromAbovePass() = default;
TestMakeIsolatedFromAbovePass(const TestMakeIsolatedFromAbovePass &pass)
: PassWrapper(pass) {}
StringRef getArgument() const final {
return "test-make-isolated-from-above";
}
StringRef getDescription() const final {
return "Test making a region isolated from above";
}
Option<bool> simple{
*this, "simple",
llvm::cl::desc("Test simple case with no cloning of operations"),
llvm::cl::init(false)};
Option<bool> cloneOpsWithNoOperands{
*this, "clone-ops-with-no-operands",
llvm::cl::desc("Test case with cloning of operations with no operands"),
llvm::cl::init(false)};
Option<bool> cloneOpsWithOperands{
*this, "clone-ops-with-operands",
llvm::cl::desc("Test case with cloning of operations with no operands"),
llvm::cl::init(false)};
void runOnOperation() override;
};
} // namespace
void TestMakeIsolatedFromAbovePass::runOnOperation() {
MLIRContext *context = &getContext();
func::FuncOp funcOp = getOperation();
if (simple) {
RewritePatternSet patterns(context);
patterns.insert<SimpleMakeIsolatedFromAbove>(context);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
return;
}
if (cloneOpsWithNoOperands) {
RewritePatternSet patterns(context);
patterns.insert<MakeIsolatedFromAboveAndCloneOpsWithNoOperands>(context);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
return;
}
if (cloneOpsWithOperands) {
RewritePatternSet patterns(context);
patterns.insert<MakeIsolatedFromAboveAndCloneOpsWithOperands>(context);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
return;
}
}
namespace mlir {
namespace test {
void registerTestMakeIsolatedFromAbovePass() {
PassRegistration<TestMakeIsolatedFromAbovePass>();
}
} // namespace test
} // namespace mlir
|