File: GenericCastToPtrOpt.cpp

package info (click to toggle)
intel-graphics-compiler2 2.22.3-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 107,676 kB
  • sloc: cpp: 809,645; lisp: 288,070; ansic: 16,397; python: 4,010; yacc: 2,588; lex: 1,666; pascal: 314; sh: 186; makefile: 38
file content (124 lines) | stat: -rw-r--r-- 4,364 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
/*========================== begin_copyright_notice ============================

Copyright (C) 2025 Intel Corporation

SPDX-License-Identifier: MIT

============================= end_copyright_notice ===========================*/

#include "Compiler/Optimizer/OpenCLPasses/GenericCastToPtrOpt/GenericCastToPtrOpt.hpp"

#include "common/LLVMWarningsPush.hpp"
#include <llvm/ADT/SmallVector.h>
#include <llvm/ADT/StringRef.h>
#include <llvm/Analysis/CallGraph.h>
#include <llvm/IR/Function.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/Pass.h>
#include "common/LLVMWarningsPop.hpp"

#include "CISACodeGen/CastToGASAnalysis.h"
#include "Probe/Assertion.h"
#include "Compiler/IGCPassSupport.h"
#include "llvmWrapper/ADT/Optional.h"

using namespace llvm;
using namespace IGC;

#define PASS_FLAG "igc-generic-cast-to-ptr-opt"
#define PASS_DESCRIPTION "Optimize GenericCastToPtrExplicit casts"
#define PASS_CFG_ONLY false
#define PASS_ANALYSIS false
IGC_INITIALIZE_PASS_BEGIN(GenericCastToPtrOpt, PASS_FLAG, PASS_DESCRIPTION,
                          PASS_CFG_ONLY, PASS_ANALYSIS)
IGC_INITIALIZE_PASS_DEPENDENCY(CastToGASAnalysis)
IGC_INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
IGC_INITIALIZE_PASS_END(GenericCastToPtrOpt, PASS_FLAG, PASS_DESCRIPTION,
                        PASS_CFG_ONLY, PASS_ANALYSIS)

char GenericCastToPtrOpt::ID = 0;

constexpr std::string_view GENERIC_CAST_TO_PTR_FN_NAME =
    "spirv_GenericCastToPtrExplicit_ToGlobal";

static void replaceGenericCastToPtrCall(CallInst *TargetFnCall) {
  IRBuilder<> Builder(TargetFnCall->getParent());
  Builder.SetInsertPoint(TargetFnCall);
  auto *AddrSpaceCast = Builder.CreateAddrSpaceCast(
      TargetFnCall->getArgOperand(0),
      TargetFnCall->getCalledFunction()->getReturnType(),
      "generic_cast_to_ptr");
  IGC_ASSERT(TargetFnCall->getCalledFunction()->getReturnType()->isPointerTy());
  IGC_ASSERT(TargetFnCall->getArgOperand(0)->getType()->isPointerTy());
  TargetFnCall->replaceAllUsesWith(AddrSpaceCast);
  TargetFnCall->eraseFromParent();

  // Clean up the users of the address space cast
  auto *NewAddrSpaceCast = dyn_cast<AddrSpaceCastInst>(AddrSpaceCast);
  SmallVector<AddrSpaceCastInst *, 32> UsersToRemove;
  for (auto *User : NewAddrSpaceCast->users()) {
    if (auto *AddrSpaceCastUser = dyn_cast<AddrSpaceCastInst>(User)) {
      if (NewAddrSpaceCast->getSrcTy() == AddrSpaceCastUser->getDestTy()) {
        AddrSpaceCastUser->replaceAllUsesWith(NewAddrSpaceCast->getOperand(0));
        UsersToRemove.push_back(AddrSpaceCastUser);
      }
    }
  }
  for (auto *User : UsersToRemove) {
    User->eraseFromParent();
  }
}

GenericCastToPtrOpt::GenericCastToPtrOpt() : ModulePass(ID) {
  initializeGenericCastToPtrOptPass(*PassRegistry::getPassRegistry());
}

void GenericCastToPtrOpt::getAnalysisUsage(llvm::AnalysisUsage &AU) const {
  AU.addRequired<CastToGASAnalysis>();
  AU.addRequired<CallGraphWrapperPass>();

  AU.setPreservesCFG();
}

bool GenericCastToPtrOpt::runOnModule(Module &M) {
  if (skipModule(M)) {
    return false;
  }

  CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
  GASInfo &GI = getAnalysis<CastToGASAnalysis>().getGASInfo();
  const bool noGenericPtToLocalOrPrivate =
      GI.isNoLocalToGenericOptionEnabled() &&
      GI.isPrivateAllocatedInGlobalMemory();

  bool modified = false;

  for (auto &[Fn, FnCallGraph] : CG) {
    if (!Fn || Fn->hasOptNone()) {
      continue;
    }

    const bool noGenericPtToLocalOrPrivateFn =
        !GI.canGenericPointToLocal(*FnCallGraph->getFunction()) &&
        !GI.canGenericPointToPrivate(*FnCallGraph->getFunction());
    if (!noGenericPtToLocalOrPrivate && !noGenericPtToLocalOrPrivateFn) {
      continue;
    }

    for (auto &[CallInstVH, CallRecordNode] : *FnCallGraph.get()) {
      // We check only direct calls
      auto CallInstVHOptional = IGCLLVM::makeOptional(CallInstVH);
      if (!CallInstVHOptional.has_value()) {
        continue;
      }
      auto *CallRecordFn = CallRecordNode->getFunction();
      if (CallRecordFn && CallRecordFn->getName().contains(llvm::StringRef(
                              GENERIC_CAST_TO_PTR_FN_NAME.data(),
                              GENERIC_CAST_TO_PTR_FN_NAME.size()))) {
        replaceGenericCastToPtrCall(cast<CallInst>(CallInstVHOptional.value()));
        modified = true;
      }
    }
  }
  return modified;
}