1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
|
//===- FuncConversions.cpp - Function conversions -------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
using namespace mlir::func;
namespace {
/// Converts the operand and result types of the CallOp, used together with the
/// FuncOpSignatureConversion.
struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
using OpConversionPattern<CallOp>::OpConversionPattern;
/// Hook for derived classes to implement combined matching and rewriting.
LogicalResult
matchAndRewrite(CallOp callOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Convert the original function results.
SmallVector<Type, 1> convertedResults;
if (failed(typeConverter->convertTypes(callOp.getResultTypes(),
convertedResults)))
return failure();
// If this isn't a one-to-one type mapping, we don't know how to aggregate
// the results.
if (callOp->getNumResults() != convertedResults.size())
return failure();
// Substitute with the new result types from the corresponding FuncType
// conversion.
rewriter.replaceOpWithNewOp<CallOp>(
callOp, callOp.getCallee(), convertedResults, adaptor.getOperands());
return success();
}
};
} // namespace
void mlir::populateCallOpTypeConversionPattern(RewritePatternSet &patterns,
TypeConverter &converter) {
patterns.add<CallOpSignatureConversion>(converter, patterns.getContext());
}
namespace {
/// Only needed to support partial conversion of functions where this pattern
/// ensures that the branch operation arguments matches up with the succesor
/// block arguments.
class BranchOpInterfaceTypeConversion
: public OpInterfaceConversionPattern<BranchOpInterface> {
public:
using OpInterfaceConversionPattern<
BranchOpInterface>::OpInterfaceConversionPattern;
BranchOpInterfaceTypeConversion(
TypeConverter &typeConverter, MLIRContext *ctx,
function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand)
: OpInterfaceConversionPattern(typeConverter, ctx, /*benefit=*/1),
shouldConvertBranchOperand(shouldConvertBranchOperand) {}
LogicalResult
matchAndRewrite(BranchOpInterface op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// For a branch operation, only some operands go to the target blocks, so
// only rewrite those.
SmallVector<Value, 4> newOperands(op->operand_begin(), op->operand_end());
for (int succIdx = 0, succEnd = op->getBlock()->getNumSuccessors();
succIdx < succEnd; ++succIdx) {
OperandRange forwardedOperands =
op.getSuccessorOperands(succIdx).getForwardedOperands();
if (forwardedOperands.empty())
continue;
for (int idx = forwardedOperands.getBeginOperandIndex(),
eidx = idx + forwardedOperands.size();
idx < eidx; ++idx) {
if (!shouldConvertBranchOperand || shouldConvertBranchOperand(op, idx))
newOperands[idx] = operands[idx];
}
}
rewriter.updateRootInPlace(
op, [newOperands, op]() { op->setOperands(newOperands); });
return success();
}
private:
function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand;
};
} // namespace
namespace {
/// Only needed to support partial conversion of functions where this pattern
/// ensures that the branch operation arguments matches up with the succesor
/// block arguments.
class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> {
public:
using OpConversionPattern<ReturnOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
// For a return, all operands go to the results of the parent, so
// rewrite them all.
rewriter.updateRootInPlace(op,
[&] { op->setOperands(adaptor.getOperands()); });
return success();
}
};
} // namespace
void mlir::populateBranchOpInterfaceTypeConversionPattern(
RewritePatternSet &patterns, TypeConverter &typeConverter,
function_ref<bool(BranchOpInterface, int)> shouldConvertBranchOperand) {
patterns.add<BranchOpInterfaceTypeConversion>(
typeConverter, patterns.getContext(), shouldConvertBranchOperand);
}
bool mlir::isLegalForBranchOpInterfaceTypeConversionPattern(
Operation *op, TypeConverter &converter) {
// All successor operands of branch like operations must be rewritten.
if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
for (int p = 0, e = op->getBlock()->getNumSuccessors(); p < e; ++p) {
auto successorOperands = branchOp.getSuccessorOperands(p);
if (!converter.isLegal(
successorOperands.getForwardedOperands().getTypes()))
return false;
}
return true;
}
return false;
}
void mlir::populateReturnOpTypeConversionPattern(RewritePatternSet &patterns,
TypeConverter &typeConverter) {
patterns.add<ReturnOpTypeConversion>(typeConverter, patterns.getContext());
}
bool mlir::isLegalForReturnOpTypeConversionPattern(Operation *op,
TypeConverter &converter,
bool returnOpAlwaysLegal) {
// If this is a `return` and the user pass wants to convert/transform across
// function boundaries, then `converter` is invoked to check whether the the
// `return` op is legal.
if (isa<ReturnOp>(op) && !returnOpAlwaysLegal)
return converter.isLegal(op);
// ReturnLike operations have to be legalized with their parent. For
// return this is handled, for other ops they remain as is.
return op->hasTrait<OpTrait::ReturnLike>();
}
bool mlir::isNotBranchOpInterfaceOrReturnLikeOp(Operation *op) {
// If it is not a terminator, ignore it.
if (!op->mightHaveTrait<OpTrait::IsTerminator>())
return true;
// If it is not the last operation in the block, also ignore it. We do
// this to handle unknown operations, as well.
Block *block = op->getBlock();
if (!block || &block->back() != op)
return true;
// We don't want to handle terminators in nested regions, assume they are
// always legal.
if (!isa_and_nonnull<FuncOp>(op->getParentOp()))
return true;
return false;
}
|