1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298
|
//===- LLVMIRConversionGen.cpp - MLIR LLVM IR builder generator -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file uses tablegen definitions of the LLVM IR Dialect operations to
// generate the code building the LLVM IR from it.
//
//===----------------------------------------------------------------------===//
#include "mlir/Support/LogicalResult.h"
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
using namespace llvm;
using namespace mlir;
static bool emitError(const Twine &message) {
llvm::errs() << message << "\n";
return false;
}
namespace {
// Helper structure to return a position of the substring in a string.
struct StringLoc {
size_t pos;
size_t length;
// Take a substring identified by this location in the given string.
StringRef in(StringRef str) const { return str.substr(pos, length); }
// A location is invalid if its position is outside the string.
explicit operator bool() { return pos != std::string::npos; }
};
} // namespace
// Find the next TableGen variable in the given pattern. These variables start
// with a `$` character and can contain alphanumeric characters or underscores.
// Return the position of the variable in the pattern and its length, including
// the `$` character. The escape syntax `$$` is also detected and returned.
static StringLoc findNextVariable(StringRef str) {
size_t startPos = str.find('$');
if (startPos == std::string::npos)
return {startPos, 0};
// If we see "$$", return immediately.
if (startPos != str.size() - 1 && str[startPos + 1] == '$')
return {startPos, 2};
// Otherwise, the symbol spans until the first character that is not
// alphanumeric or '_'.
size_t endPos = str.find_if_not([](char c) { return isAlnum(c) || c == '_'; },
startPos + 1);
if (endPos == std::string::npos)
endPos = str.size();
return {startPos, endPos - startPos};
}
// Check if `name` is the name of the variadic operand of `op`. The variadic
// operand can only appear at the last position in the list of operands.
static bool isVariadicOperandName(const tblgen::Operator &op, StringRef name) {
unsigned numOperands = op.getNumOperands();
if (numOperands == 0)
return false;
const auto &operand = op.getOperand(numOperands - 1);
return operand.isVariableLength() && operand.name == name;
}
// Check if `result` is a known name of a result of `op`.
static bool isResultName(const tblgen::Operator &op, StringRef name) {
for (int i = 0, e = op.getNumResults(); i < e; ++i)
if (op.getResultName(i) == name)
return true;
return false;
}
// Check if `name` is a known name of an attribute of `op`.
static bool isAttributeName(const tblgen::Operator &op, StringRef name) {
return llvm::any_of(
op.getAttributes(),
[name](const tblgen::NamedAttribute &attr) { return attr.name == name; });
}
// Check if `name` is a known name of an operand of `op`.
static bool isOperandName(const tblgen::Operator &op, StringRef name) {
for (int i = 0, e = op.getNumOperands(); i < e; ++i)
if (op.getOperand(i).name == name)
return true;
return false;
}
// Emit to `os` the operator-name driven check and the call to LLVM IRBuilder
// for one definition of a LLVM IR Dialect operation. Return true on success.
static bool emitOneBuilder(const Record &record, raw_ostream &os) {
auto op = tblgen::Operator(record);
if (!record.getValue("llvmBuilder"))
return emitError("no 'llvmBuilder' field for op " + op.getOperationName());
// Return early if there is no builder specified.
auto builderStrRef = record.getValueAsString("llvmBuilder");
if (builderStrRef.empty())
return true;
// Progressively create the builder string by replacing $-variables with
// value lookups. Keep only the not-yet-traversed part of the builder pattern
// to avoid re-traversing the string multiple times.
std::string builder;
llvm::raw_string_ostream bs(builder);
while (auto loc = findNextVariable(builderStrRef)) {
auto name = loc.in(builderStrRef).drop_front();
// First, insert the non-matched part as is.
bs << builderStrRef.substr(0, loc.pos);
// Then, rewrite the name based on its kind.
bool isVariadicOperand = isVariadicOperandName(op, name);
if (isOperandName(op, name)) {
auto result = isVariadicOperand
? formatv("lookupValues(op.{0}())", name)
: formatv("valueMapping.lookup(op.{0}())", name);
bs << result;
} else if (isAttributeName(op, name)) {
bs << formatv("op.{0}()", name);
} else if (isResultName(op, name)) {
bs << formatv("valueMapping[op.{0}()]", name);
} else if (name == "_resultType") {
bs << "op.getResult().getType().cast<LLVM::LLVMType>()."
"getUnderlyingType()";
} else if (name == "_hasResult") {
bs << "opInst.getNumResults() == 1";
} else if (name == "_location") {
bs << "opInst.getLoc()";
} else if (name == "_numOperands") {
bs << "opInst.getNumOperands()";
} else if (name == "$") {
bs << '$';
} else {
return emitError(name + " is neither an argument nor a result of " +
op.getOperationName());
}
// Finally, only keep the untraversed part of the string.
builderStrRef = builderStrRef.substr(loc.pos + loc.length);
}
// Output the check and the rewritten builder string.
os << "if (auto op = dyn_cast<" << op.getQualCppClassName()
<< ">(opInst)) {\n";
os << bs.str() << builderStrRef << "\n";
os << " return success();\n";
os << "}\n";
return true;
}
// Emit all builders. Returns false on success because of the generator
// registration requirements.
static bool emitBuilders(const RecordKeeper &recordKeeper, raw_ostream &os) {
for (const auto *def : recordKeeper.getAllDerivedDefinitions("LLVM_OpBase")) {
if (!emitOneBuilder(*def, os))
return true;
}
return false;
}
namespace {
// Wrapper class around a Tablegen definition of an LLVM enum attribute case.
class LLVMEnumAttrCase : public tblgen::EnumAttrCase {
public:
using tblgen::EnumAttrCase::EnumAttrCase;
// Constructs a case from a non LLVM-specific enum attribute case.
explicit LLVMEnumAttrCase(const tblgen::EnumAttrCase &other)
: tblgen::EnumAttrCase(&other.getDef()) {}
// Returns the C++ enumerant for the LLVM API.
StringRef getLLVMEnumerant() const {
return def->getValueAsString("llvmEnumerant");
}
};
// Wraper class around a Tablegen definition of an LLVM enum attribute.
class LLVMEnumAttr : public tblgen::EnumAttr {
public:
using tblgen::EnumAttr::EnumAttr;
// Returns the C++ enum name for the LLVM API.
StringRef getLLVMClassName() const {
return def->getValueAsString("llvmClassName");
}
// Returns all associated cases viewed as LLVM-specific enum cases.
std::vector<LLVMEnumAttrCase> getAllCases() const {
std::vector<LLVMEnumAttrCase> cases;
for (auto &c : tblgen::EnumAttr::getAllCases())
cases.push_back(LLVMEnumAttrCase(c));
return cases;
}
};
} // namespace
// Emits conversion function "LLVMClass convertEnumToLLVM(Enum)" and containing
// switch-based logic to convert from the MLIR LLVM dialect enum attribute case
// (Enum) to the corresponding LLVM API enumerant
static void emitOneEnumToConversion(const llvm::Record *record,
raw_ostream &os) {
LLVMEnumAttr enumAttr(record);
StringRef llvmClass = enumAttr.getLLVMClassName();
StringRef cppClassName = enumAttr.getEnumClassName();
StringRef cppNamespace = enumAttr.getCppNamespace();
// Emit the function converting the enum attribute to its LLVM counterpart.
os << formatv("static {0} convert{1}ToLLVM({2}::{1} value) {{\n", llvmClass,
cppClassName, cppNamespace);
os << " switch (value) {\n";
for (const auto &enumerant : enumAttr.getAllCases()) {
StringRef llvmEnumerant = enumerant.getLLVMEnumerant();
StringRef cppEnumerant = enumerant.getSymbol();
os << formatv(" case {0}::{1}::{2}:\n", cppNamespace, cppClassName,
cppEnumerant);
os << formatv(" return {0}::{1};\n", llvmClass, llvmEnumerant);
}
os << " }\n";
os << formatv(" llvm_unreachable(\"unknown {0} type\");\n",
enumAttr.getEnumClassName());
os << "}\n\n";
}
// Emits conversion function "Enum convertEnumFromLLVM(LLVMClass)" and
// containing switch-based logic to convert from the LLVM API enumerant to MLIR
// LLVM dialect enum attribute (Enum).
static void emitOneEnumFromConversion(const llvm::Record *record,
raw_ostream &os) {
LLVMEnumAttr enumAttr(record);
StringRef llvmClass = enumAttr.getLLVMClassName();
StringRef cppClassName = enumAttr.getEnumClassName();
StringRef cppNamespace = enumAttr.getCppNamespace();
// Emit the function converting the enum attribute from its LLVM counterpart.
os << formatv("static {0}::{1} convert{1}FromLLVM({2} value) {{\n",
cppNamespace, cppClassName, llvmClass);
os << " switch (value) {\n";
for (const auto &enumerant : enumAttr.getAllCases()) {
StringRef llvmEnumerant = enumerant.getLLVMEnumerant();
StringRef cppEnumerant = enumerant.getSymbol();
os << formatv(" case {0}::{1}:\n", llvmClass, llvmEnumerant);
os << formatv(" return {0}::{1}::{2};\n", cppNamespace, cppClassName,
cppEnumerant);
}
os << " }\n";
os << formatv(" llvm_unreachable(\"unknown {0} type\");",
enumAttr.getLLVMClassName());
os << "}\n\n";
}
// Emits conversion functions between MLIR enum attribute case and corresponding
// LLVM API enumerants for all registered LLVM dialect enum attributes.
template <bool ConvertTo>
static bool emitEnumConversionDefs(const RecordKeeper &recordKeeper,
raw_ostream &os) {
for (const auto *def : recordKeeper.getAllDerivedDefinitions("LLVM_EnumAttr"))
if (ConvertTo)
emitOneEnumToConversion(def, os);
else
emitOneEnumFromConversion(def, os);
return false;
}
static mlir::GenRegistration
genLLVMIRConversions("gen-llvmir-conversions",
"Generate LLVM IR conversions", emitBuilders);
static mlir::GenRegistration
genEnumToLLVMConversion("gen-enum-to-llvmir-conversions",
"Generate conversions of EnumAttrs to LLVM IR",
emitEnumConversionDefs</*ConvertTo=*/true>);
static mlir::GenRegistration
genEnumFromLLVMConversion("gen-enum-from-llvmir-conversions",
"Generate conversions of EnumAttrs from LLVM IR",
emitEnumConversionDefs</*ConvertTo=*/false>);
|