
|
/*========================== begin_copyright_notice ============================
Copyright (C) 2017-2021 Intel Corporation
SPDX-License-Identifier: MIT
============================= end_copyright_notice ===========================*/
#include "AdaptorCommon/ImplicitArgs.hpp"
#include "Compiler/Optimizer/OpenCLPasses/OpenCLPrintf/OpenCLPrintfAnalysis.hpp"
#include "Compiler/IGCPassSupport.h"
#include "common/LLVMWarningsPush.hpp"
#include <llvm/IR/Module.h>
#include <llvm/IR/Function.h>
#include <llvm/ADT/StringRef.h>
#include <llvm/Demangle/Demangle.h>
#include "common/LLVMWarningsPop.hpp"
#include <set>
using namespace llvm;
using namespace IGC;
using namespace IGC::IGCMD;
// Register pass to igc-opt
#define PASS_FLAG "igc-opencl-printf-analysis"
#define PASS_DESCRIPTION "Analyzes OpenCL printf calls"
#define PASS_CFG_ONLY false
#define PASS_ANALYSIS false
IGC_INITIALIZE_PASS_BEGIN(OpenCLPrintfAnalysis, PASS_FLAG, PASS_DESCRIPTION, PASS_CFG_ONLY, PASS_ANALYSIS)
IGC_INITIALIZE_PASS_DEPENDENCY(MetaDataUtilsWrapper)
IGC_INITIALIZE_PASS_END(OpenCLPrintfAnalysis, PASS_FLAG, PASS_DESCRIPTION, PASS_CFG_ONLY, PASS_ANALYSIS)
char OpenCLPrintfAnalysis::ID = 0;
OpenCLPrintfAnalysis::OpenCLPrintfAnalysis() : ModulePass(ID) {
initializeOpenCLPrintfAnalysisPass(*PassRegistry::getPassRegistry());
}
// TODO: move to a common place
const StringRef OpenCLPrintfAnalysis::OPENCL_PRINTF_FUNCTION_NAME = "printf";
const StringRef OpenCLPrintfAnalysis::ONEAPI_PRINTF_FUNCTION_NAME = "ext::oneapi::experimental::printf";
const StringRef OpenCLPrintfAnalysis::BUILTIN_PRINTF_FUNCTION_NAME = "__builtin_IB_printf_to_buffer";
bool OpenCLPrintfAnalysis::isOpenCLPrintf(const llvm::Function *F) {
return F->getName() == OPENCL_PRINTF_FUNCTION_NAME;
}
bool OpenCLPrintfAnalysis::isOneAPIPrintf(const llvm::Function *F) {
std::string demangledName = llvm::demangle(F->getName().str());
return demangledName.find(ONEAPI_PRINTF_FUNCTION_NAME.data()) != std::string::npos;
}
bool OpenCLPrintfAnalysis::isBuiltinPrintf(const llvm::Function *F) {
return F->getName() == BUILTIN_PRINTF_FUNCTION_NAME;
}
bool OpenCLPrintfAnalysis::runOnModule(Module &M) {
m_pMDUtils = getAnalysis<MetaDataUtilsWrapper>().getMetaDataUtils();
visit(M);
bool changed = false;
if (m_hasPrintfs.size()) {
for (Function &func : M.getFunctionList()) {
if (!func.isDeclaration() && m_hasPrintfs.find(&func) != m_hasPrintfs.end()) {
addPrintfBufferArgs(func);
changed = true;
}
}
}
// Update LLVM metadata based on IGC MetadataUtils
if (changed)
m_pMDUtils->save(M.getContext());
return m_hasPrintfs.size();
}
void OpenCLPrintfAnalysis::visitCallInst(llvm::CallInst &callInst) {
Function *pF = callInst.getParent()->getParent();
if (!callInst.getCalledFunction() || m_hasPrintfs.find(pF) != m_hasPrintfs.end()) {
return;
}
StringRef funcName = callInst.getCalledFunction()->getName();
bool hasPrintf = (funcName == OpenCLPrintfAnalysis::OPENCL_PRINTF_FUNCTION_NAME);
if (hasPrintf) {
m_hasPrintfs.insert(pF);
}
}
void OpenCLPrintfAnalysis::addPrintfBufferArgs(Function &F) {
SmallVector<ImplicitArg::ArgType, 1> implicitArgs;
implicitArgs.push_back(ImplicitArg::PRINTF_BUFFER);
ImplicitArgs::addImplicitArgs(F, implicitArgs, m_pMDUtils);
}
bool isPrintfOnlyStringConstantImpl(const llvm::Value *v, std::set<const llvm::User *> &visited) {
// Recursively check the users of the value until reaching the top level
// user or a call.
// Base case: Return false when use list is empty.
if (v->use_empty()) {
return false;
}
// Check users recursively with a list of permitted in-between uses. Here we
// follow OpenCLPrintfResolution::argIsString() to check if they are one of
// CastInst, GEP with all-zero indices, SelectInst, and PHINode.
for (auto &use : v->uses()) {
auto user = use.getUser();
// Skip if the user is visited.
if (visited.count(user))
continue;
visited.insert(user);
bool res = false;
if (const llvm::CallInst *call = llvm::dyn_cast<llvm::CallInst>(user)) {
// Stop when reaching a call and check if it is an opencl/oneapi
// printf call.
const Function *target = call->getCalledFunction();
bool isStringLiteral = OpenCLPrintfAnalysis::isOpenCLPrintf(target) ||
OpenCLPrintfAnalysis::isOneAPIPrintf(target) ||
OpenCLPrintfAnalysis::isBuiltinPrintf(target);
if (isStringLiteral) {
res = true;
} else {
unsigned int opIndex = call->getDataOperandNo(&use);
res = isPrintfOnlyStringConstantImpl(target->arg_begin() + opIndex, visited);
}
} else if (llvm::dyn_cast<llvm::CastInst>(user) || llvm::dyn_cast<llvm::SelectInst>(user) ||
llvm::dyn_cast<llvm::PHINode>(user)) {
res = isPrintfOnlyStringConstantImpl(user, visited);
} else if (const llvm::GetElementPtrInst *gep = llvm::dyn_cast<llvm::GetElementPtrInst>(user)) {
if (gep->hasAllZeroIndices())
res = isPrintfOnlyStringConstantImpl(user, visited);
}
if (!res)
return false;
}
// Return true as every top level user is a printf call.
return true;
}
// Check paths from a string literal to printf calls and return true if every
// path lead to a printf call.
bool OpenCLPrintfAnalysis::isPrintfOnlyStringConstant(const llvm::GlobalVariable *GV) {
const llvm::Constant *Initializer = GV->getInitializer();
if (!Initializer) {
return false;
}
bool IsNullTerminatedString = false;
if (const auto *cds = llvm::dyn_cast<llvm::ConstantDataSequential>(Initializer)) {
if (cds->isString()) {
StringRef Str = cds->getAsString();
IsNullTerminatedString = Str.contains(0);
}
}
bool IsZeroInitCharArray = Initializer->isZeroValue() && isa<ArrayType>(Initializer->getType()) &&
Initializer->getType()->getArrayElementType()->isIntegerTy(8);
if (IsNullTerminatedString || IsZeroInitCharArray) {
std::set<const llvm::User *> Visited;
return isPrintfOnlyStringConstantImpl(GV, Visited);
}
return false;
}
|