File: NVPTXLowerUnreachable.cpp

package info (click to toggle)
llvm-toolchain-18 1%3A18.1.8-18
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,908,340 kB
  • sloc: cpp: 6,667,937; ansic: 1,440,452; asm: 883,619; python: 230,549; objc: 76,880; f90: 74,238; lisp: 35,989; pascal: 16,571; sh: 10,229; perl: 7,459; ml: 5,047; awk: 3,523; makefile: 2,987; javascript: 2,149; xml: 892; fortran: 649; cs: 573
file content (156 lines) | stat: -rw-r--r-- 5,190 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
//===-- NVPTXLowerUnreachable.cpp - Lower unreachables to exit =====--===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// PTX does not have a notion of `unreachable`, which results in emitted basic
// blocks having an edge to the next block:
//
//   block1:
//     call @does_not_return();
//     // unreachable
//   block2:
//     // ptxas will create a CFG edge from block1 to block2
//
// This may result in significant changes to the control flow graph, e.g., when
// LLVM moves unreachable blocks to the end of the function. That's a problem
// in the context of divergent control flow, as `ptxas` uses the CFG to
// determine divergent regions, and some intructions may not be executed
// divergently.
//
// For example, `bar.sync` is not allowed to be executed divergently on Pascal
// or earlier. If we start with the following:
//
//   entry:
//     // start of divergent region
//     @%p0 bra cont;
//     @%p1 bra unlikely;
//     ...
//     bra.uni cont;
//   unlikely:
//     ...
//     // unreachable
//   cont:
//     // end of divergent region
//     bar.sync 0;
//     bra.uni exit;
//   exit:
//     ret;
//
// it is transformed by the branch-folder and block-placement passes to:
//
//   entry:
//     // start of divergent region
//     @%p0 bra cont;
//     @%p1 bra unlikely;
//     ...
//     bra.uni cont;
//   cont:
//     bar.sync 0;
//     bra.uni exit;
//   unlikely:
//     ...
//     // unreachable
//   exit:
//     // end of divergent region
//     ret;
//
// After moving the `unlikely` block to the end of the function, it has an edge
// to the `exit` block, which widens the divergent region and makes the
// `bar.sync` instruction happen divergently.
//
// To work around this, we add an `exit` instruction before every `unreachable`,
// as `ptxas` understands that exit terminates the CFG. We do only do this if
// `unreachable` is not lowered to `trap`, which has the same effect (although
// with current versions of `ptxas` only because it is emited as `trap; exit;`).
//
//===----------------------------------------------------------------------===//

#include "NVPTX.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/InlineAsm.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Type.h"
#include "llvm/Pass.h"

using namespace llvm;

namespace llvm {
void initializeNVPTXLowerUnreachablePass(PassRegistry &);
}

namespace {
class NVPTXLowerUnreachable : public FunctionPass {
  StringRef getPassName() const override;
  bool runOnFunction(Function &F) override;
  bool isLoweredToTrap(const UnreachableInst &I) const;

public:
  static char ID; // Pass identification, replacement for typeid
  NVPTXLowerUnreachable(bool TrapUnreachable, bool NoTrapAfterNoreturn)
      : FunctionPass(ID), TrapUnreachable(TrapUnreachable),
        NoTrapAfterNoreturn(NoTrapAfterNoreturn) {}

private:
  bool TrapUnreachable;
  bool NoTrapAfterNoreturn;
};
} // namespace

char NVPTXLowerUnreachable::ID = 1;

INITIALIZE_PASS(NVPTXLowerUnreachable, "nvptx-lower-unreachable",
                "Lower Unreachable", false, false)

StringRef NVPTXLowerUnreachable::getPassName() const {
  return "add an exit instruction before every unreachable";
}

// =============================================================================
// Returns whether a `trap` intrinsic should be emitted before I.
//
// This is a copy of the logic in SelectionDAGBuilder::visitUnreachable().
// =============================================================================
bool NVPTXLowerUnreachable::isLoweredToTrap(const UnreachableInst &I) const {
  if (!TrapUnreachable)
    return false;
  if (!NoTrapAfterNoreturn)
    return true;
  const CallInst *Call = dyn_cast_or_null<CallInst>(I.getPrevNode());
  return Call && Call->doesNotReturn();
}

// =============================================================================
// Main function for this pass.
// =============================================================================
bool NVPTXLowerUnreachable::runOnFunction(Function &F) {
  if (skipFunction(F))
    return false;
  // Early out iff isLoweredToTrap() always returns true.
  if (TrapUnreachable && !NoTrapAfterNoreturn)
    return false;

  LLVMContext &C = F.getContext();
  FunctionType *ExitFTy = FunctionType::get(Type::getVoidTy(C), false);
  InlineAsm *Exit = InlineAsm::get(ExitFTy, "exit;", "", true);

  bool Changed = false;
  for (auto &BB : F)
    for (auto &I : BB) {
      if (auto unreachableInst = dyn_cast<UnreachableInst>(&I)) {
        if (isLoweredToTrap(*unreachableInst))
          continue; // trap is emitted as `trap; exit;`.
        CallInst::Create(ExitFTy, Exit, "", unreachableInst);
        Changed = true;
      }
    }
  return Changed;
}

FunctionPass *llvm::createNVPTXLowerUnreachablePass(bool TrapUnreachable,
                                                    bool NoTrapAfterNoreturn) {
  return new NVPTXLowerUnreachable(TrapUnreachable, NoTrapAfterNoreturn);
}