1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321
|
//===-- WebAssemblyFixFunctionBitcasts.cpp - Fix function bitcasts --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// \file
/// Fix bitcasted functions.
///
/// WebAssembly requires caller and callee signatures to match, however in LLVM,
/// some amount of slop is vaguely permitted. Detect mismatch by looking for
/// bitcasts of functions and rewrite them to use wrapper functions instead.
///
/// This doesn't catch all cases, such as when a function's address is taken in
/// one place and casted in another, but it works for many common cases.
///
/// Note that LLVM already optimizes away function bitcasts in common cases by
/// dropping arguments as needed, so this pass only ends up getting used in less
/// common cases.
///
//===----------------------------------------------------------------------===//
#include "WebAssembly.h"
#include "llvm/IR/CallSite.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Operator.h"
#include "llvm/Pass.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
using namespace llvm;
#define DEBUG_TYPE "wasm-fix-function-bitcasts"
namespace {
class FixFunctionBitcasts final : public ModulePass {
StringRef getPassName() const override {
return "WebAssembly Fix Function Bitcasts";
}
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.setPreservesCFG();
ModulePass::getAnalysisUsage(AU);
}
bool runOnModule(Module &M) override;
public:
static char ID;
FixFunctionBitcasts() : ModulePass(ID) {}
};
} // End anonymous namespace
char FixFunctionBitcasts::ID = 0;
INITIALIZE_PASS(FixFunctionBitcasts, DEBUG_TYPE,
"Fix mismatching bitcasts for WebAssembly", false, false)
ModulePass *llvm::createWebAssemblyFixFunctionBitcasts() {
return new FixFunctionBitcasts();
}
// Recursively descend the def-use lists from V to find non-bitcast users of
// bitcasts of V.
static void findUses(Value *V, Function &F,
SmallVectorImpl<std::pair<Use *, Function *>> &Uses,
SmallPtrSetImpl<Constant *> &ConstantBCs) {
for (Use &U : V->uses()) {
if (auto *BC = dyn_cast<BitCastOperator>(U.getUser()))
findUses(BC, F, Uses, ConstantBCs);
else if (U.get()->getType() != F.getType()) {
CallSite CS(U.getUser());
if (!CS)
// Skip uses that aren't immediately called
continue;
Value *Callee = CS.getCalledValue();
if (Callee != V)
// Skip calls where the function isn't the callee
continue;
if (isa<Constant>(U.get())) {
// Only add constant bitcasts to the list once; they get RAUW'd
auto C = ConstantBCs.insert(cast<Constant>(U.get()));
if (!C.second)
continue;
}
Uses.push_back(std::make_pair(&U, &F));
}
}
}
// Create a wrapper function with type Ty that calls F (which may have a
// different type). Attempt to support common bitcasted function idioms:
// - Call with more arguments than needed: arguments are dropped
// - Call with fewer arguments than needed: arguments are filled in with undef
// - Return value is not needed: drop it
// - Return value needed but not present: supply an undef
//
// If the all the argument types of trivially castable to one another (i.e.
// I32 vs pointer type) then we don't create a wrapper at all (return nullptr
// instead).
//
// If there is a type mismatch that we know would result in an invalid wasm
// module then generate wrapper that contains unreachable (i.e. abort at
// runtime). Such programs are deep into undefined behaviour territory,
// but we choose to fail at runtime rather than generate and invalid module
// or fail at compiler time. The reason we delay the error is that we want
// to support the CMake which expects to be able to compile and link programs
// that refer to functions with entirely incorrect signatures (this is how
// CMake detects the existence of a function in a toolchain).
//
// For bitcasts that involve struct types we don't know at this stage if they
// would be equivalent at the wasm level and so we can't know if we need to
// generate a wrapper.
static Function *createWrapper(Function *F, FunctionType *Ty) {
Module *M = F->getParent();
Function *Wrapper = Function::Create(Ty, Function::PrivateLinkage,
F->getName() + "_bitcast", M);
BasicBlock *BB = BasicBlock::Create(M->getContext(), "body", Wrapper);
const DataLayout &DL = BB->getModule()->getDataLayout();
// Determine what arguments to pass.
SmallVector<Value *, 4> Args;
Function::arg_iterator AI = Wrapper->arg_begin();
Function::arg_iterator AE = Wrapper->arg_end();
FunctionType::param_iterator PI = F->getFunctionType()->param_begin();
FunctionType::param_iterator PE = F->getFunctionType()->param_end();
bool TypeMismatch = false;
bool WrapperNeeded = false;
Type *ExpectedRtnType = F->getFunctionType()->getReturnType();
Type *RtnType = Ty->getReturnType();
if ((F->getFunctionType()->getNumParams() != Ty->getNumParams()) ||
(F->getFunctionType()->isVarArg() != Ty->isVarArg()) ||
(ExpectedRtnType != RtnType))
WrapperNeeded = true;
for (; AI != AE && PI != PE; ++AI, ++PI) {
Type *ArgType = AI->getType();
Type *ParamType = *PI;
if (ArgType == ParamType) {
Args.push_back(&*AI);
} else {
if (CastInst::isBitOrNoopPointerCastable(ArgType, ParamType, DL)) {
Instruction *PtrCast =
CastInst::CreateBitOrPointerCast(AI, ParamType, "cast");
BB->getInstList().push_back(PtrCast);
Args.push_back(PtrCast);
} else if (ArgType->isStructTy() || ParamType->isStructTy()) {
LLVM_DEBUG(dbgs() << "createWrapper: struct param type in bitcast: "
<< F->getName() << "\n");
WrapperNeeded = false;
} else {
LLVM_DEBUG(dbgs() << "createWrapper: arg type mismatch calling: "
<< F->getName() << "\n");
LLVM_DEBUG(dbgs() << "Arg[" << Args.size() << "] Expected: "
<< *ParamType << " Got: " << *ArgType << "\n");
TypeMismatch = true;
break;
}
}
}
if (WrapperNeeded && !TypeMismatch) {
for (; PI != PE; ++PI)
Args.push_back(UndefValue::get(*PI));
if (F->isVarArg())
for (; AI != AE; ++AI)
Args.push_back(&*AI);
CallInst *Call = CallInst::Create(F, Args, "", BB);
Type *ExpectedRtnType = F->getFunctionType()->getReturnType();
Type *RtnType = Ty->getReturnType();
// Determine what value to return.
if (RtnType->isVoidTy()) {
ReturnInst::Create(M->getContext(), BB);
} else if (ExpectedRtnType->isVoidTy()) {
LLVM_DEBUG(dbgs() << "Creating dummy return: " << *RtnType << "\n");
ReturnInst::Create(M->getContext(), UndefValue::get(RtnType), BB);
} else if (RtnType == ExpectedRtnType) {
ReturnInst::Create(M->getContext(), Call, BB);
} else if (CastInst::isBitOrNoopPointerCastable(ExpectedRtnType, RtnType,
DL)) {
Instruction *Cast =
CastInst::CreateBitOrPointerCast(Call, RtnType, "cast");
BB->getInstList().push_back(Cast);
ReturnInst::Create(M->getContext(), Cast, BB);
} else if (RtnType->isStructTy() || ExpectedRtnType->isStructTy()) {
LLVM_DEBUG(dbgs() << "createWrapper: struct return type in bitcast: "
<< F->getName() << "\n");
WrapperNeeded = false;
} else {
LLVM_DEBUG(dbgs() << "createWrapper: return type mismatch calling: "
<< F->getName() << "\n");
LLVM_DEBUG(dbgs() << "Expected: " << *ExpectedRtnType
<< " Got: " << *RtnType << "\n");
TypeMismatch = true;
}
}
if (TypeMismatch) {
// Create a new wrapper that simply contains `unreachable`.
Wrapper->eraseFromParent();
Wrapper = Function::Create(Ty, Function::PrivateLinkage,
F->getName() + "_bitcast_invalid", M);
BasicBlock *BB = BasicBlock::Create(M->getContext(), "body", Wrapper);
new UnreachableInst(M->getContext(), BB);
Wrapper->setName(F->getName() + "_bitcast_invalid");
} else if (!WrapperNeeded) {
LLVM_DEBUG(dbgs() << "createWrapper: no wrapper needed: " << F->getName()
<< "\n");
Wrapper->eraseFromParent();
return nullptr;
}
LLVM_DEBUG(dbgs() << "createWrapper: " << F->getName() << "\n");
return Wrapper;
}
// Test whether a main function with type FuncTy should be rewritten to have
// type MainTy.
static bool shouldFixMainFunction(FunctionType *FuncTy, FunctionType *MainTy) {
// Only fix the main function if it's the standard zero-arg form. That way,
// the standard cases will work as expected, and users will see signature
// mismatches from the linker for non-standard cases.
return FuncTy->getReturnType() == MainTy->getReturnType() &&
FuncTy->getNumParams() == 0 &&
!FuncTy->isVarArg();
}
bool FixFunctionBitcasts::runOnModule(Module &M) {
LLVM_DEBUG(dbgs() << "********** Fix Function Bitcasts **********\n");
Function *Main = nullptr;
CallInst *CallMain = nullptr;
SmallVector<std::pair<Use *, Function *>, 0> Uses;
SmallPtrSet<Constant *, 2> ConstantBCs;
// Collect all the places that need wrappers.
for (Function &F : M) {
findUses(&F, F, Uses, ConstantBCs);
// If we have a "main" function, and its type isn't
// "int main(int argc, char *argv[])", create an artificial call with it
// bitcasted to that type so that we generate a wrapper for it, so that
// the C runtime can call it.
if (F.getName() == "main") {
Main = &F;
LLVMContext &C = M.getContext();
Type *MainArgTys[] = {Type::getInt32Ty(C),
PointerType::get(Type::getInt8PtrTy(C), 0)};
FunctionType *MainTy = FunctionType::get(Type::getInt32Ty(C), MainArgTys,
/*isVarArg=*/false);
if (shouldFixMainFunction(F.getFunctionType(), MainTy)) {
LLVM_DEBUG(dbgs() << "Found `main` function with incorrect type: "
<< *F.getFunctionType() << "\n");
Value *Args[] = {UndefValue::get(MainArgTys[0]),
UndefValue::get(MainArgTys[1])};
Value *Casted =
ConstantExpr::getBitCast(Main, PointerType::get(MainTy, 0));
CallMain = CallInst::Create(MainTy, Casted, Args, "call_main");
Use *UseMain = &CallMain->getOperandUse(2);
Uses.push_back(std::make_pair(UseMain, &F));
}
}
}
DenseMap<std::pair<Function *, FunctionType *>, Function *> Wrappers;
for (auto &UseFunc : Uses) {
Use *U = UseFunc.first;
Function *F = UseFunc.second;
auto *PTy = cast<PointerType>(U->get()->getType());
auto *Ty = dyn_cast<FunctionType>(PTy->getElementType());
// If the function is casted to something like i8* as a "generic pointer"
// to be later casted to something else, we can't generate a wrapper for it.
// Just ignore such casts for now.
if (!Ty)
continue;
auto Pair = Wrappers.insert(std::make_pair(std::make_pair(F, Ty), nullptr));
if (Pair.second)
Pair.first->second = createWrapper(F, Ty);
Function *Wrapper = Pair.first->second;
if (!Wrapper)
continue;
if (isa<Constant>(U->get()))
U->get()->replaceAllUsesWith(Wrapper);
else
U->set(Wrapper);
}
// If we created a wrapper for main, rename the wrapper so that it's the
// one that gets called from startup.
if (CallMain) {
Main->setName("__original_main");
auto *MainWrapper =
cast<Function>(CallMain->getCalledValue()->stripPointerCasts());
delete CallMain;
if (Main->isDeclaration()) {
// The wrapper is not needed in this case as we don't need to export
// it to anyone else.
MainWrapper->eraseFromParent();
} else {
// Otherwise give the wrapper the same linkage as the original main
// function, so that it can be called from the same places.
MainWrapper->setName("main");
MainWrapper->setLinkage(Main->getLinkage());
MainWrapper->setVisibility(Main->getVisibility());
}
}
return true;
}
|