File: ReduceValuesToReturn.cpp

package info (click to toggle)
llvm-toolchain-21 1%3A21.1.6-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 2,245,028 kB
  • sloc: cpp: 7,619,726; ansic: 1,434,018; asm: 1,058,748; python: 252,740; f90: 94,671; objc: 70,685; lisp: 42,813; pascal: 18,401; sh: 8,601; ml: 5,111; perl: 4,720; makefile: 3,675; awk: 3,523; javascript: 2,409; xml: 892; fortran: 770
file content (267 lines) | stat: -rw-r--r-- 9,382 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Try to reduce a function by inserting new return instructions. Try to insert
// an early return for each instruction value at that point. This requires
// mutating the return type, or finding instructions with a compatible type.
//
//===----------------------------------------------------------------------===//

#define DEBUG_TYPE "llvm-reduce"

#include "ReduceValuesToReturn.h"

#include "Delta.h"
#include "Utils.h"
#include "llvm/IR/AttributeMask.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/Instructions.h"
#include "llvm/Support/Debug.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"

using namespace llvm;

/// Return true if it is legal to emit a copy of the function with a non-void
/// return type.
static bool canUseNonVoidReturnType(const Function &F) {
  // Functions with sret arguments must return void.
  return !F.hasStructRetAttr() &&
         CallingConv::supportsNonVoidReturnType(F.getCallingConv());
}

/// Return true if it's legal to replace a function return type to use \p Ty.
static bool isReallyValidReturnType(Type *Ty) {
  return FunctionType::isValidReturnType(Ty) && !Ty->isTokenTy() &&
         Ty->isFirstClassType();
}

/// Insert a ret inst after \p NewRetValue, which returns the value it produces.
static void rewriteFuncWithReturnType(Function &OldF, Value *NewRetValue) {
  Type *NewRetTy = NewRetValue->getType();
  FunctionType *OldFuncTy = OldF.getFunctionType();

  FunctionType *NewFuncTy =
      FunctionType::get(NewRetTy, OldFuncTy->params(), OldFuncTy->isVarArg());

  LLVMContext &Ctx = OldF.getContext();
  BasicBlock &EntryBB = OldF.getEntryBlock();
  Instruction *NewRetI = dyn_cast<Instruction>(NewRetValue);
  BasicBlock *NewRetBlock = NewRetI ? NewRetI->getParent() : &EntryBB;

  BasicBlock::iterator NewValIt =
      NewRetI ? std::next(NewRetI->getIterator()) : EntryBB.begin();

  Type *OldRetTy = OldFuncTy->getReturnType();

  // Hack up any return values in other blocks, we can't leave them as returning OldRetTy.
  if (OldRetTy != NewRetTy) {
    for (BasicBlock &OtherRetBB : OldF) {
      if (&OtherRetBB != NewRetBlock) {
        auto *OrigRI = dyn_cast<ReturnInst>(OtherRetBB.getTerminator());
        if (!OrigRI)
          continue;

        OrigRI->eraseFromParent();
        ReturnInst::Create(Ctx, getDefaultValue(NewRetTy), &OtherRetBB);
      }
    }
  }

  // If we're returning an instruction, split the basic block so we can let
  // simpleSimplifyCFG cleanup the successors.
  BasicBlock *TailBB = NewRetBlock->splitBasicBlock(NewValIt);

  // Replace the unconditional branch splitBasicBlock created
  NewRetBlock->getTerminator()->eraseFromParent();
  ReturnInst::Create(Ctx, NewRetValue, NewRetBlock);

  // Now prune any CFG edges we have to deal with.
  simpleSimplifyCFG(OldF, {TailBB}, /*FoldBlockIntoPredecessor=*/false);

  // Drop the incompatible attributes before we copy over to the new function.
  if (OldRetTy != NewRetTy) {
    AttributeList AL = OldF.getAttributes();
    AttributeMask IncompatibleAttrs =
        AttributeFuncs::typeIncompatible(NewRetTy, AL.getRetAttrs());
    OldF.removeRetAttrs(IncompatibleAttrs);
  }

  // Now we need to remove any returned attributes from parameters.
  for (Argument &A : OldF.args())
    OldF.removeParamAttr(A.getArgNo(), Attribute::Returned);

  Function *NewF =
      Function::Create(NewFuncTy, OldF.getLinkage(), OldF.getAddressSpace(), "",
                       OldF.getParent());

  NewF->removeFromParent();
  OldF.getParent()->getFunctionList().insertAfter(OldF.getIterator(), NewF);
  NewF->takeName(&OldF);
  NewF->copyAttributesFrom(&OldF);

  // Adjust the callsite uses to the new return type. We pre-filtered cases
  // where the original call type was incorrectly non-void.
  for (User *U : make_early_inc_range(OldF.users())) {
    if (auto *CB = dyn_cast<CallBase>(U);
        CB && CB->getCalledOperand() == &OldF) {
      if (CB->getType()->isVoidTy()) {
        FunctionType *CallType = CB->getFunctionType();

        // The callsite may not match the new function type, in an undefined
        // behavior way. Only mutate the local return type.
        FunctionType *NewCallType = FunctionType::get(
            NewRetTy, CallType->params(), CallType->isVarArg());

        CB->mutateType(NewRetTy);
        CB->setCalledFunction(NewCallType, NewF);
      } else {
        assert(CB->getType() == NewRetTy &&
               "only handle exact return type match with non-void returns");
      }
    }
  }

  NewF->splice(NewF->begin(), &OldF);
  OldF.replaceAllUsesWith(NewF);

  // Preserve the parameters of OldF.
  for (auto Z : zip_first(OldF.args(), NewF->args())) {
    Argument &OldArg = std::get<0>(Z);
    Argument &NewArg = std::get<1>(Z);

    OldArg.replaceAllUsesWith(&NewArg);
    NewArg.takeName(&OldArg);
  }

  OldF.eraseFromParent();
}

// Check if all the callsites of the void function are void, or happen to
// incorrectly use the new return type.
//
// TODO: We could make better effort to handle call type mismatches.
static bool canReplaceFuncUsers(const Function &F, Type *NewRetTy) {
  for (const Use &U : F.uses()) {
    const CallBase *CB = dyn_cast<CallBase>(U.getUser());
    if (!CB)
      continue;

    // Normal pointer uses are trivially replacable.
    if (!CB->isCallee(&U))
      continue;

    // We can trivially replace the correct void call sites.
    if (CB->getType()->isVoidTy())
      continue;

    // We can trivially replace the call if the return type happened to match
    // the new return type.
    if (CB->getType() == NewRetTy)
      continue;

    // TODO: If all callsites have no uses, we could mutate the type of all the
    // callsites. This will complicate the visit and rewrite ordering though.
    LLVM_DEBUG(dbgs() << "Cannot replace used callsite with wrong type: " << *CB
                      << '\n');
    return false;
  }

  return true;
}

/// Return true if it's worthwhile replacing the non-void return value of \p BB
/// with \p Replacement
static bool shouldReplaceNonVoidReturnValue(const BasicBlock &BB,
                                            const Value *Replacement) {
  if (const auto *RI = dyn_cast<ReturnInst>(BB.getTerminator()))
    return RI->getReturnValue() != Replacement;
  return true;
}

static bool shouldForwardValueToReturn(const BasicBlock &BB, const Value *V,
                                       Type *RetTy) {
  if (!isReallyValidReturnType(V->getType()))
    return false;

  return (RetTy->isVoidTy() || shouldReplaceNonVoidReturnValue(BB, V)) &&
         canReplaceFuncUsers(*BB.getParent(), V->getType());
}

static bool tryForwardingInstructionsToReturn(
    Function &F, Oracle &O,
    std::vector<std::pair<Function *, Value *>> &FuncsToReplace) {

  // TODO: Should we try to expand returns to aggregate for function that
  // already have a return value?
  Type *RetTy = F.getReturnType();

  for (BasicBlock &BB : F) {
    // Skip the terminator, we can't insert a second terminator to return its
    // value.
    for (Instruction &I : make_range(BB.begin(), std::prev(BB.end()))) {
      if (shouldForwardValueToReturn(BB, &I, RetTy) && !O.shouldKeep()) {
        FuncsToReplace.emplace_back(&F, &I);
        return true;
      }
    }
  }

  return false;
}

static bool tryForwardingArgumentsToReturn(
    Function &F, Oracle &O,
    std::vector<std::pair<Function *, Value *>> &FuncsToReplace) {

  Type *RetTy = F.getReturnType();
  BasicBlock &EntryBB = F.getEntryBlock();

  for (Argument &A : F.args()) {
    if (shouldForwardValueToReturn(EntryBB, &A, RetTy) && !O.shouldKeep()) {
      FuncsToReplace.emplace_back(&F, &A);
      return true;
    }
  }

  return false;
}

void llvm::reduceArgumentsToReturnDeltaPass(Oracle &O,
                                            ReducerWorkItem &WorkItem) {
  Module &Program = WorkItem.getModule();

  // We're going to chaotically hack on the other users of the function in other
  // functions, so we need to collect a worklist of returns to replace.
  std::vector<std::pair<Function *, Value *>> FuncsToReplace;

  for (Function &F : Program.functions()) {
    if (!F.isDeclaration() && canUseNonVoidReturnType(F))
      tryForwardingArgumentsToReturn(F, O, FuncsToReplace);
  }

  for (auto [F, NewRetVal] : FuncsToReplace)
    rewriteFuncWithReturnType(*F, NewRetVal);
}

void llvm::reduceInstructionsToReturnDeltaPass(Oracle &O,
                                               ReducerWorkItem &WorkItem) {
  Module &Program = WorkItem.getModule();

  // We're going to chaotically hack on the other users of the function in other
  // functions, so we need to collect a worklist of returns to replace.
  std::vector<std::pair<Function *, Value *>> FuncsToReplace;

  for (Function &F : Program.functions()) {
    if (!F.isDeclaration() && canUseNonVoidReturnType(F))
      tryForwardingInstructionsToReturn(F, O, FuncsToReplace);
  }

  for (auto [F, NewRetVal] : FuncsToReplace)
    rewriteFuncWithReturnType(*F, NewRetVal);
}