1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289
|
//===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to convert MLIR standard and builtin dialects
// into the LLVM IR dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/StringRef.h"
#include <functional>
namespace mlir {
#define GEN_PASS_DEF_CONVERTCONTROLFLOWTOLLVMPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir
using namespace mlir;
#define PASS_NAME "convert-cf-to-llvm"
static std::string generateGlobalMsgSymbolName(ModuleOp moduleOp) {
std::string prefix = "assert_msg_";
int counter = 0;
while (moduleOp.lookupSymbol(prefix + std::to_string(counter)))
++counter;
return prefix + std::to_string(counter);
}
/// Generate IR that prints the given string to stderr.
static void createPrintMsg(OpBuilder &builder, Location loc, ModuleOp moduleOp,
StringRef msg, LLVMTypeConverter &typeConverter) {
auto ip = builder.saveInsertionPoint();
builder.setInsertionPointToStart(moduleOp.getBody());
MLIRContext *ctx = builder.getContext();
// Create a zero-terminated byte representation and allocate global symbol.
SmallVector<uint8_t> elementVals;
elementVals.append(msg.begin(), msg.end());
elementVals.push_back(0);
auto dataAttrType = RankedTensorType::get(
{static_cast<int64_t>(elementVals.size())}, builder.getI8Type());
auto dataAttr =
DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals));
auto arrayTy =
LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
std::string symbolName = generateGlobalMsgSymbolName(moduleOp);
auto globalOp = builder.create<LLVM::GlobalOp>(
loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private, symbolName,
dataAttr);
// Emit call to `printStr` in runtime library.
builder.restoreInsertionPoint(ip);
auto msgAddr = builder.create<LLVM::AddressOfOp>(
loc, typeConverter.getPointerType(arrayTy), globalOp.getName());
SmallVector<LLVM::GEPArg> indices(1, 0);
Value gep = builder.create<LLVM::GEPOp>(
loc, typeConverter.getPointerType(builder.getI8Type()), arrayTy, msgAddr,
indices);
Operation *printer = LLVM::lookupOrCreatePrintStrFn(
moduleOp, typeConverter.useOpaquePointers());
builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
gep);
}
namespace {
/// Lower `cf.assert`. The default lowering calls the `abort` function if the
/// assertion is violated and has no effect otherwise. The failure message is
/// ignored by the default lowering but should be propagated by any custom
/// lowering.
struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
explicit AssertOpLowering(LLVMTypeConverter &typeConverter,
bool abortOnFailedAssert = true)
: ConvertOpToLLVMPattern<cf::AssertOp>(typeConverter, /*benefit=*/1),
abortOnFailedAssert(abortOnFailedAssert) {}
LogicalResult
matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto module = op->getParentOfType<ModuleOp>();
// Split block at `assert` operation.
Block *opBlock = rewriter.getInsertionBlock();
auto opPosition = rewriter.getInsertionPoint();
Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition);
// Failed block: Generate IR to print the message and call `abort`.
Block *failureBlock = rewriter.createBlock(opBlock->getParent());
createPrintMsg(rewriter, loc, module, op.getMsg(), *getTypeConverter());
if (abortOnFailedAssert) {
// Insert the `abort` declaration if necessary.
auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
if (!abortFunc) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
"abort", abortFuncTy);
}
rewriter.create<LLVM::CallOp>(loc, abortFunc, std::nullopt);
rewriter.create<LLVM::UnreachableOp>(loc);
} else {
rewriter.create<LLVM::BrOp>(loc, ValueRange(), continuationBlock);
}
// Generate assertion test.
rewriter.setInsertionPointToEnd(opBlock);
rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
op, adaptor.getArg(), continuationBlock, failureBlock);
return success();
}
private:
/// If set to `false`, messages are printed but program execution continues.
/// This is useful for testing asserts.
bool abortOnFailedAssert = true;
};
/// The cf->LLVM lowerings for branching ops require that the blocks they jump
/// to first have updated types which should be handled by a pattern operating
/// on the parent op.
static LogicalResult verifyMatchingValues(ConversionPatternRewriter &rewriter,
ValueRange operands,
ValueRange blockArgs, Location loc,
llvm::StringRef messagePrefix) {
for (const auto &idxAndTypes :
llvm::enumerate(llvm::zip(blockArgs, operands))) {
int64_t i = idxAndTypes.index();
Value argValue =
rewriter.getRemappedValue(std::get<0>(idxAndTypes.value()));
Type operandType = std::get<1>(idxAndTypes.value()).getType();
// In the case of an invalid jump, the block argument will have been
// remapped to an UnrealizedConversionCast. In the case of a valid jump,
// there might still be a no-op conversion cast with both types being equal.
// Consider both of these details to see if the jump would be invalid.
if (auto op = dyn_cast_or_null<UnrealizedConversionCastOp>(
argValue.getDefiningOp())) {
if (op.getOperandTypes().front() != operandType) {
return rewriter.notifyMatchFailure(loc, [&](Diagnostic &diag) {
diag << messagePrefix;
diag << "mismatched types from operand # " << i << " ";
diag << operandType;
diag << " not compatible with destination block argument type ";
diag << op.getOperandTypes().front();
diag << " which should be converted with the parent op.";
});
}
}
}
return success();
}
/// Ensure that all block types were updated and then create an LLVM::BrOp
struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyMatchingValues(rewriter, adaptor.getDestOperands(),
op.getSuccessor()->getArguments(),
op.getLoc(),
/*messagePrefix=*/"")))
return failure();
rewriter.replaceOpWithNewOp<LLVM::BrOp>(
op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
return success();
}
};
/// Ensure that all block types were updated and then create an LLVM::CondBrOp
struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(cf::CondBranchOp op,
typename cf::CondBranchOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyMatchingValues(rewriter, adaptor.getFalseDestOperands(),
op.getFalseDest()->getArguments(),
op.getLoc(), "in false case branch ")))
return failure();
if (failed(verifyMatchingValues(rewriter, adaptor.getTrueDestOperands(),
op.getTrueDest()->getArguments(),
op.getLoc(), "in true case branch ")))
return failure();
rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
return success();
}
};
/// Ensure that all block types were updated and then create an LLVM::SwitchOp
struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> {
using ConvertOpToLLVMPattern<cf::SwitchOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyMatchingValues(rewriter, adaptor.getDefaultOperands(),
op.getDefaultDestination()->getArguments(),
op.getLoc(), "in switch default case ")))
return failure();
for (const auto &i : llvm::enumerate(
llvm::zip(adaptor.getCaseOperands(), op.getCaseDestinations()))) {
if (failed(verifyMatchingValues(
rewriter, std::get<0>(i.value()),
std::get<1>(i.value())->getArguments(), op.getLoc(),
"in switch case " + std::to_string(i.index()) + " "))) {
return failure();
}
}
rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
return success();
}
};
} // namespace
void mlir::cf::populateControlFlowToLLVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
// clang-format off
patterns.add<
AssertOpLowering,
BranchOpLowering,
CondBranchOpLowering,
SwitchOpLowering>(converter);
// clang-format on
}
void mlir::cf::populateAssertToLLVMConversionPattern(
LLVMTypeConverter &converter, RewritePatternSet &patterns,
bool abortOnFailure) {
patterns.add<AssertOpLowering>(converter, abortOnFailure);
}
//===----------------------------------------------------------------------===//
// Pass Definition
//===----------------------------------------------------------------------===//
namespace {
/// A pass converting MLIR operations into the LLVM IR dialect.
struct ConvertControlFlowToLLVM
: public impl::ConvertControlFlowToLLVMPassBase<ConvertControlFlowToLLVM> {
using Base::Base;
/// Run the dialect converter on the module.
void runOnOperation() override {
LLVMConversionTarget target(getContext());
RewritePatternSet patterns(&getContext());
LowerToLLVMOptions options(&getContext());
if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
options.overrideIndexBitwidth(indexBitwidth);
options.useOpaquePointers = useOpaquePointers;
LLVMTypeConverter converter(&getContext(), options);
mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};
} // namespace
|