1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
|
//===-- CUFDeviceGlobal.cpp -----------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "flang/Common/Fortran.h"
#include "flang/Optimizer/Builder/CUFCommon.h"
#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/Support/InternalNames.h"
#include "flang/Runtime/CUDA/common.h"
#include "flang/Runtime/allocatable.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/DenseSet.h"
namespace fir {
#define GEN_PASS_DEF_CUFDEVICEGLOBAL
#include "flang/Optimizer/Transforms/Passes.h.inc"
} // namespace fir
namespace {
static void processAddrOfOp(fir::AddrOfOp addrOfOp,
mlir::SymbolTable &symbolTable,
llvm::DenseSet<fir::GlobalOp> &candidates,
bool recurseInGlobal) {
if (auto globalOp = symbolTable.lookup<fir::GlobalOp>(
addrOfOp.getSymbol().getRootReference().getValue())) {
// TO DO: limit candidates to non-scalars. Scalars appear to have been
// folded in already.
if (recurseInGlobal)
globalOp.walk([&](fir::AddrOfOp op) {
processAddrOfOp(op, symbolTable, candidates, recurseInGlobal);
});
candidates.insert(globalOp);
}
}
static void processEmboxOp(fir::EmboxOp emboxOp, mlir::SymbolTable &symbolTable,
llvm::DenseSet<fir::GlobalOp> &candidates) {
if (auto recTy = mlir::dyn_cast<fir::RecordType>(
fir::unwrapRefType(emboxOp.getMemref().getType()))) {
if (auto globalOp = symbolTable.lookup<fir::GlobalOp>(
fir::NameUniquer::getTypeDescriptorName(recTy.getName()))) {
if (!candidates.contains(globalOp)) {
globalOp.walk([&](fir::AddrOfOp op) {
processAddrOfOp(op, symbolTable, candidates,
/*recurseInGlobal=*/true);
});
candidates.insert(globalOp);
}
}
}
}
static void
prepareImplicitDeviceGlobals(mlir::func::FuncOp funcOp,
mlir::SymbolTable &symbolTable,
llvm::DenseSet<fir::GlobalOp> &candidates) {
auto cudaProcAttr{
funcOp->getAttrOfType<cuf::ProcAttributeAttr>(cuf::getProcAttrName())};
if (cudaProcAttr && cudaProcAttr.getValue() != cuf::ProcAttribute::Host) {
funcOp.walk([&](fir::AddrOfOp op) {
processAddrOfOp(op, symbolTable, candidates, /*recurseInGlobal=*/false);
});
funcOp.walk(
[&](fir::EmboxOp op) { processEmboxOp(op, symbolTable, candidates); });
}
}
class CUFDeviceGlobal : public fir::impl::CUFDeviceGlobalBase<CUFDeviceGlobal> {
public:
void runOnOperation() override {
mlir::Operation *op = getOperation();
mlir::ModuleOp mod = mlir::dyn_cast<mlir::ModuleOp>(op);
if (!mod)
return signalPassFailure();
llvm::DenseSet<fir::GlobalOp> candidates;
mlir::SymbolTable symTable(mod);
mod.walk([&](mlir::func::FuncOp funcOp) {
prepareImplicitDeviceGlobals(funcOp, symTable, candidates);
return mlir::WalkResult::advance();
});
mod.walk([&](cuf::KernelOp kernelOp) {
kernelOp.walk([&](fir::AddrOfOp addrOfOp) {
processAddrOfOp(addrOfOp, symTable, candidates,
/*recurseInGlobal=*/false);
});
});
// Copying the device global variable into the gpu module
mlir::SymbolTable parentSymTable(mod);
auto gpuMod = cuf::getOrCreateGPUModule(mod, parentSymTable);
if (!gpuMod)
return signalPassFailure();
mlir::SymbolTable gpuSymTable(gpuMod);
for (auto globalOp : mod.getOps<fir::GlobalOp>()) {
if (cuf::isRegisteredDeviceGlobal(globalOp))
candidates.insert(globalOp);
}
for (auto globalOp : candidates) {
auto globalName{globalOp.getSymbol().getValue()};
if (gpuSymTable.lookup<fir::GlobalOp>(globalName)) {
break;
}
gpuSymTable.insert(globalOp->clone());
}
}
};
} // namespace
|