File: WaveAllJointReduction.cpp

package info (click to toggle)
intel-graphics-compiler2 2.16.0-2
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 106,644 kB
  • sloc: cpp: 805,640; lisp: 287,672; ansic: 16,414; python: 3,952; yacc: 2,588; lex: 1,666; pascal: 313; sh: 186; makefile: 35
file content (146 lines) | stat: -rw-r--r-- 5,596 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
/*========================== begin_copyright_notice ============================

Copyright (C) 2024 Intel Corporation

SPDX-License-Identifier: MIT

============================= end_copyright_notice ===========================*/

#include <GenISAIntrinsics/GenIntrinsicInst.h>
#include "WaveAllJointReduction.hpp"
#include "Compiler/IGCPassSupport.h"
#include "common/LLVMWarningsPush.hpp"
#include <llvm/IR/InstVisitor.h>
#include <llvm/ADT/SmallVector.h>
#include "common/LLVMWarningsPop.hpp"

#define DEBUG_TYPE "igc-wave-all-joint-reduction"

using namespace IGC;
using namespace llvm;

namespace IGC {
class WaveAllJointReductionImpl : public InstVisitor<WaveAllJointReductionImpl> {
public:
  WaveAllJointReductionImpl(Function &F) : F(F) {}
  bool run();
  void visitCallInst(CallInst &callInst);

private:
  Value *createInsertElements(SmallVector<WaveAllIntrinsic *, 16> &mergeList);
  void createExtractElements(SmallVector<WaveAllIntrinsic *, 16> &mergeList, WaveAllIntrinsic *waveAllJoint);
  Function &F;
  DenseSet<WaveAllIntrinsic *> ToDelete;
  bool Changed = false;
};

class WaveAllJointReduction : public FunctionPass {
public:
  static char ID;
  WaveAllJointReduction() : FunctionPass(ID) {}

  llvm::StringRef getPassName() const override { return "WaveAllJointReduction"; }
  bool runOnFunction(Function &F) override;
};

FunctionPass *createWaveAllJointReduction() { return new WaveAllJointReduction(); }
} // namespace IGC

Value *WaveAllJointReductionImpl::createInsertElements(SmallVector<WaveAllIntrinsic *, 16> &mergeList) {
  IRBuilder<> builder(mergeList.front());
  auto *vecType = VectorType::get(mergeList.front()->getSrc()->getType(), mergeList.size(), false);
  auto *vec =
      builder.CreateInsertElement(UndefValue::get(vecType), mergeList.front()->getSrc(), (uint64_t)0, "waveAllSrc");
  for (uint64_t i = 1; i < mergeList.size(); i++) {
    vec = builder.CreateInsertElement(vec, mergeList[i]->getSrc(), i, "waveAllSrc");
  }
  return vec;
}

void WaveAllJointReductionImpl::createExtractElements(SmallVector<WaveAllIntrinsic *, 16> &mergeList,
                                                      WaveAllIntrinsic *waveAllJoint) {
  IRBuilder<> builder(mergeList.front());
  for (uint64_t i = 0; i < mergeList.size(); i++) {
    auto *res = builder.CreateExtractElement(waveAllJoint, i, "waveAllDst");
    mergeList[i]->replaceAllUsesWith(res);
  }
}

void WaveAllJointReductionImpl::visitCallInst(CallInst &callInst) {

  if (auto *waveAllInst = dyn_cast<WaveAllIntrinsic>(&callInst)) {
    // marked as delete because it was already merged with prior insts
    if (ToDelete.count(waveAllInst)) {
      return;
    }

    // Optimization already happened, first operand is already vector
    if (waveAllInst->getSrc()->getType()->isVectorTy()) {
      return;
    }

    SmallVector<WaveAllIntrinsic *, 16> mergeList{waveAllInst};

    // For locality, only look at consecutive instructions since non-consecutive instructions may require sinking the
    // final vector WaveAll instruction to where the last joined WaveAll is to satisfy proper domination of each
    // WaveAll's Src
    // TODO: If needed, a complicated analysis could find non-consecutive WaveAll instructions that are able to
    // participate in WaveAll joint reduction, but seems like an edge case for now
    Instruction *I = waveAllInst->getNextNode();
    while (I != waveAllInst->getParent()->getTerminator()) {
      auto *nextWaveAllInst = dyn_cast<WaveAllIntrinsic>(I);
      // TODO: Can check helper lane mode here if necessary, unsure whether that changes anything
      if (!nextWaveAllInst || nextWaveAllInst->getSrc()->getType()->isVectorTy() ||
          nextWaveAllInst->getSrc()->getType() != waveAllInst->getSrc()->getType() ||
          nextWaveAllInst->getOpKind() != waveAllInst->getOpKind()) {
        break;
      }

      mergeList.push_back(nextWaveAllInst);

      I = I->getNextNode();
    }

    if (mergeList.size() > 1) {
      // Multiple WaveAll operations eligible to participate in joint operation
      auto *arg0 = createInsertElements(mergeList);
      IRBuilder<> builder(mergeList.front());
      Type *funcType[] = {arg0->getType(), Type::getInt8Ty(builder.getContext()),
                          Type::getInt32Ty(builder.getContext())};
      Function *waveAllJointFunc =
          GenISAIntrinsic::getDeclaration(mergeList.front()->getModule(), GenISAIntrinsic::GenISA_WaveAll, funcType);

      auto *waveAllJoint = builder.CreateCall(
          waveAllJointFunc, {arg0, waveAllInst->getOperand(1), waveAllInst->getOperand(2)}, "waveAllJoint");
      createExtractElements(mergeList, cast<WaveAllIntrinsic>(waveAllJoint));

      // Mark merged WaveAll ops participating in joint operation for deletion
      for (auto *mergedInst : mergeList) {
        ToDelete.insert(mergedInst);
      }
      Changed = true;
    }
  }
}

bool WaveAllJointReductionImpl::run() {
  visit(F);
  for (auto *mergedWaveAllInst : ToDelete) {
    mergedWaveAllInst->eraseFromParent();
  }
  return Changed;
}

bool WaveAllJointReduction::runOnFunction(Function &F) {
  WaveAllJointReductionImpl WorkerInstance(F);
  return WorkerInstance.run();
}

char WaveAllJointReduction::ID = 0;

#define PASS_FLAG "igc-wave-all-joint-reduction"
#define PASS_DESCRIPTION "WaveAllJointReduction"
#define PASS_CFG_ONLY false
#define PASS_ANALYSIS false
IGC_INITIALIZE_PASS_BEGIN(WaveAllJointReduction, PASS_FLAG, PASS_DESCRIPTION, PASS_CFG_ONLY, PASS_ANALYSIS)
IGC_INITIALIZE_PASS_END(WaveAllJointReduction, PASS_FLAG, PASS_DESCRIPTION, PASS_CFG_ONLY, PASS_ANALYSIS)