File: NVPTXLowerUnreachable.cpp

package info (click to toggle)
llvm-toolchain-17 1%3A17.0.6-22
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,799,624 kB
  • sloc: cpp: 6,428,607; ansic: 1,383,196; asm: 793,408; python: 223,504; objc: 75,364; f90: 60,502; lisp: 33,869; pascal: 15,282; sh: 9,684; perl: 7,453; ml: 4,937; awk: 3,523; makefile: 2,889; javascript: 2,149; xml: 888; fortran: 619; cs: 573
file content (126 lines) | stat: -rw-r--r-- 3,860 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
//===-- NVPTXLowerUnreachable.cpp - Lower unreachables to exit =====--===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// PTX does not have a notion of `unreachable`, which results in emitted basic
// blocks having an edge to the next block:
//
//   block1:
//     call @does_not_return();
//     // unreachable
//   block2:
//     // ptxas will create a CFG edge from block1 to block2
//
// This may result in significant changes to the control flow graph, e.g., when
// LLVM moves unreachable blocks to the end of the function. That's a problem
// in the context of divergent control flow, as `ptxas` uses the CFG to
// determine divergent regions, and some intructions may not be executed
// divergently.
//
// For example, `bar.sync` is not allowed to be executed divergently on Pascal
// or earlier. If we start with the following:
//
//   entry:
//     // start of divergent region
//     @%p0 bra cont;
//     @%p1 bra unlikely;
//     ...
//     bra.uni cont;
//   unlikely:
//     ...
//     // unreachable
//   cont:
//     // end of divergent region
//     bar.sync 0;
//     bra.uni exit;
//   exit:
//     ret;
//
// it is transformed by the branch-folder and block-placement passes to:
//
//   entry:
//     // start of divergent region
//     @%p0 bra cont;
//     @%p1 bra unlikely;
//     ...
//     bra.uni cont;
//   cont:
//     bar.sync 0;
//     bra.uni exit;
//   unlikely:
//     ...
//     // unreachable
//   exit:
//     // end of divergent region
//     ret;
//
// After moving the `unlikely` block to the end of the function, it has an edge
// to the `exit` block, which widens the divergent region and makes the
// `bar.sync` instruction happen divergently.
//
// To work around this, we add an `exit` instruction before every `unreachable`,
// as `ptxas` understands that exit terminates the CFG. Note that `trap` is not
// equivalent, and only future versions of `ptxas` will model it like `exit`.
//
//===----------------------------------------------------------------------===//

#include "NVPTX.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/InlineAsm.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Type.h"
#include "llvm/Pass.h"

using namespace llvm;

namespace llvm {
void initializeNVPTXLowerUnreachablePass(PassRegistry &);
}

namespace {
class NVPTXLowerUnreachable : public FunctionPass {
  bool runOnFunction(Function &F) override;

public:
  static char ID; // Pass identification, replacement for typeid
  NVPTXLowerUnreachable() : FunctionPass(ID) {}
  StringRef getPassName() const override {
    return "add an exit instruction before every unreachable";
  }
};
} // namespace

char NVPTXLowerUnreachable::ID = 1;

INITIALIZE_PASS(NVPTXLowerUnreachable, "nvptx-lower-unreachable",
                "Lower Unreachable", false, false)

// =============================================================================
// Main function for this pass.
// =============================================================================
bool NVPTXLowerUnreachable::runOnFunction(Function &F) {
  if (skipFunction(F))
    return false;

  LLVMContext &C = F.getContext();
  FunctionType *ExitFTy = FunctionType::get(Type::getVoidTy(C), false);
  InlineAsm *Exit = InlineAsm::get(ExitFTy, "exit;", "", true);

  bool Changed = false;
  for (auto &BB : F)
    for (auto &I : BB) {
      if (auto unreachableInst = dyn_cast<UnreachableInst>(&I)) {
        Changed = true;
        CallInst::Create(ExitFTy, Exit, "", unreachableInst);
      }
    }
  return Changed;
}

FunctionPass *llvm::createNVPTXLowerUnreachablePass() {
  return new NVPTXLowerUnreachable();
}