1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
|
//===- BasicPtxBuilderInterface.td - PTX builder interface -*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Defines the interface to build PTX (Parallel Thread Execution) from NVVM Ops
// automatically. It is used by NVVM to LLVM pass.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
#include "mlir/Support/LogicalResult.h"
#define DEBUG_TYPE "ptx-builder"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define DBGSNL() (llvm::dbgs() << "\n")
//===----------------------------------------------------------------------===//
// BasicPtxBuilderInterface
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.cpp.inc"
using namespace mlir;
using namespace NVVM;
static constexpr int64_t kSharedMemorySpace = 3;
static char getRegisterType(Type type) {
if (type.isInteger(1))
return 'b';
if (type.isInteger(16))
return 'h';
if (type.isInteger(32))
return 'r';
if (type.isInteger(64))
return 'l';
if (type.isF32())
return 'f';
if (type.isF64())
return 'd';
if (auto ptr = type.dyn_cast<LLVM::LLVMPointerType>()) {
// Shared address spaces is addressed with 32-bit pointers.
if (ptr.getAddressSpace() == kSharedMemorySpace) {
return 'r';
}
return 'l';
}
// register type for struct is not supported.
llvm_unreachable("The register type could not deduced from MLIR type");
return '?';
}
static char getRegisterType(Value v) {
if (v.getDefiningOp<LLVM::ConstantOp>())
return 'n';
return getRegisterType(v.getType());
}
void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
LLVM_DEBUG(DBGS() << v << "\t Modifier : " << &itype << "\n");
auto getModifier = [&]() -> const char * {
if (itype == PTXRegisterMod::ReadWrite) {
assert(false && "Read-Write modifier is not supported. Try setting the "
"same value as Write and Read seperately.");
return "+";
}
if (itype == PTXRegisterMod::Write) {
return "=";
}
return "";
};
auto addValue = [&](Value v) {
if (itype == PTXRegisterMod::Read) {
ptxOperands.push_back(v);
return;
}
if (itype == PTXRegisterMod::ReadWrite)
ptxOperands.push_back(v);
hasResult = true;
};
llvm::raw_string_ostream ss(registerConstraints);
// Handle Structs
if (auto stype = dyn_cast<LLVM::LLVMStructType>(v.getType())) {
if (itype == PTXRegisterMod::Write) {
addValue(v);
}
for (auto [idx, t] : llvm::enumerate(stype.getBody())) {
if (itype != PTXRegisterMod::Write) {
Value extractValue = rewriter.create<LLVM::ExtractValueOp>(
interfaceOp->getLoc(), v, idx);
addValue(extractValue);
}
if (itype == PTXRegisterMod::ReadWrite) {
ss << idx << ",";
} else {
ss << getModifier() << getRegisterType(t) << ",";
}
ss.flush();
}
return;
}
// Handle Scalars
addValue(v);
ss << getModifier() << getRegisterType(v) << ",";
ss.flush();
}
LLVM::InlineAsmOp PtxBuilder::build() {
auto asmDialectAttr = LLVM::AsmDialectAttr::get(interfaceOp->getContext(),
LLVM::AsmDialect::AD_ATT);
auto resultTypes = interfaceOp->getResultTypes();
// Remove the last comma from the constraints string.
if (!registerConstraints.empty() &&
registerConstraints[registerConstraints.size() - 1] == ',')
registerConstraints.pop_back();
std::string ptxInstruction = interfaceOp.getPtx();
// Add the predicate to the asm string.
if (interfaceOp.getPredicate().has_value() &&
interfaceOp.getPredicate().value()) {
std::string predicateStr = "@%";
predicateStr += std::to_string((ptxOperands.size() - 1));
ptxInstruction = predicateStr + " " + ptxInstruction;
}
// Tablegen doesn't accept $, so we use %, but inline assembly uses $.
// Replace all % with $
std::replace(ptxInstruction.begin(), ptxInstruction.end(), '%', '$');
return rewriter.create<LLVM::InlineAsmOp>(
interfaceOp->getLoc(),
/*result types=*/resultTypes,
/*operands=*/ptxOperands,
/*asm_string=*/llvm::StringRef(ptxInstruction),
/*constraints=*/registerConstraints.data(),
/*has_side_effects=*/interfaceOp.hasSideEffect(),
/*is_align_stack=*/false,
/*asm_dialect=*/asmDialectAttr,
/*operand_attrs=*/ArrayAttr());
}
void PtxBuilder::buildAndReplaceOp() {
LLVM::InlineAsmOp inlineAsmOp = build();
LLVM_DEBUG(DBGS() << "\n Generated PTX \n\t" << inlineAsmOp << "\n");
if (inlineAsmOp->getNumResults() == interfaceOp->getNumResults()) {
rewriter.replaceOp(interfaceOp, inlineAsmOp);
} else {
rewriter.eraseOp(interfaceOp);
}
}
|