File: GenXLCECalculation.cpp

package info (click to toggle)
intel-graphics-compiler 1.0.17791.18-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 102,312 kB
  • sloc: cpp: 935,343; lisp: 286,143; ansic: 16,196; python: 3,279; yacc: 2,487; lex: 1,642; pascal: 300; sh: 174; makefile: 27
file content (317 lines) | stat: -rw-r--r-- 9,924 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
/*========================== begin_copyright_notice ============================

Copyright (C) 2024 Intel Corporation

SPDX-License-Identifier: MIT

============================= end_copyright_notice ===========================*/

//
/// GenXLCECalculation
/// -------------------
///
/// GenXLCECalculation is a function pass that analyzes loop bounds and tries
/// to calculate loop count expression in a form of 'Factor * Symbol + Addend'
/// where symbol is an info about a kernel argument. E.g. for the loop in the
/// following kernel:
///    void foo(int N, ...) {
///      for (int i = 0; i < (N / 2 + 1); ++i) {...}
///    }
/// LoopCountExpr = 0.5 * Symbol(N) + 1, where
/// Symbol(N) = {0, 0, 4, false}: direct symbol with Num=0, Offset=0, Size=4
//
//===----------------------------------------------------------------------===//

#include "GenX.h"

#include "vc/Utils/GenX/CostInfo.h"
#include "vc/Utils/GenX/KernelInfo.h"

#include "llvmWrapper/IR/IRBuilder.h"

#include <llvm/Analysis/LoopInfo.h>
#include <llvm/Analysis/ScalarEvolution.h>
#include <llvm/IR/Constants.h>
#include <llvm/IR/GetElementPtrTypeIterator.h>
#include <llvm/IR/InstVisitor.h>
#include <llvm/IR/Module.h>
#include <llvm/InitializePasses.h>
#include <llvm/Pass.h>

using namespace llvm;
using namespace genx;

namespace {

class LoopCountExprWrapper {
  friend class LCEFinder;
  vc::LoopCountExpr Expr;

public:
  LoopCountExprWrapper() {}
  LoopCountExprWrapper(float F, vc::ArgSym S, float A) {
    Expr.Symbol = S;
    Expr.Factor = F;
    Expr.Addend = A;
    Expr.IsUndef = false;
  }

  bool isUndef() const { return Expr.IsUndef; }
  bool save(const llvm::Loop &L, llvm::Module &M) const {
    return vc::saveLCEToMetadata(L, M, Expr);
  }

  LoopCountExprWrapper operator-(LoopCountExprWrapper const &RHS) const {
    // Propagate undef.
    if (Expr.IsUndef || RHS.Expr.IsUndef)
      return LoopCountExprWrapper{};
    // If both expressions are not constant they should share
    // the same symbol.
    if (Expr.Factor != 0.0 && RHS.Expr.Factor != 0.0 &&
        Expr.Symbol != RHS.Expr.Symbol)
      return LoopCountExprWrapper{};
    return LoopCountExprWrapper(Expr.Factor - RHS.Expr.Factor,
                                Expr.Factor != 0.0 ? Expr.Symbol
                                                   : RHS.Expr.Symbol,
                                Expr.Addend - RHS.Expr.Addend);
  }
  LoopCountExprWrapper operator/(unsigned Val) const {
    auto Res =
        LoopCountExprWrapper(Expr.Factor / Val, Expr.Symbol, Expr.Addend / Val);
    Res.Expr.IsUndef = Expr.IsUndef;
    return Res;
  }
};

// This class traverses IR to find LCE for the loop bound.
class LCEFinder : public InstVisitor<LCEFinder, Value *> {
  LoopCountExprWrapper LCE;
  const DataLayout *DL;

public:
  LCEFinder(const DataLayout *DataL) : DL(DataL), LCE(1.0, {}, 0.0) {}

  LoopCountExprWrapper getLCE(Value &Start);

  Value *visitInstruction(Instruction &I);
  Value *visitBinaryOperator(BinaryOperator &BO);
  Value *visitCastInst(CastInst &CI);
  Value *visitGetElementPtrInst(GetElementPtrInst &GEP);
  Value *visitLoadInst(LoadInst &LI);
};

class GenXLCECalculation : public FunctionPass {
  using LoopDirection = Loop::LoopBounds::Direction;
  struct LCELoopInfo {
    // The initial value of induction variable.
    LoopCountExprWrapper Init;
    // The final value of induction variable.
    LoopCountExprWrapper Final;
    LoopCountExprWrapper TripCount;
    LoopDirection Direction = LoopDirection::Unknown;
    unsigned AbsStepValue = 0;
  };
  DenseMap<Loop *, LCELoopInfo> LoopMap;

public:
  static char ID;
  explicit GenXLCECalculation() : FunctionPass(ID) {}
  StringRef getPassName() const override {
    return "GenX loop count expression calculation";
  }
  void getAnalysisUsage(AnalysisUsage &AU) const override {
    AU.addRequired<LoopInfoWrapperPass>();
    AU.addRequired<ScalarEvolutionWrapperPass>();
    AU.addPreserved<LoopInfoWrapperPass>();
    AU.addPreserved<ScalarEvolutionWrapperPass>();
  }
  bool runOnFunction(Function &F) override;

private:
  LCELoopInfo processLoop(const Loop &L, const Function &F, ScalarEvolution &SE,
                          LoopInfo &LI) const;
};

} // end namespace

char GenXLCECalculation::ID = 0;
namespace llvm {
void initializeGenXLCECalculationPass(PassRegistry &);
} // end namespace llvm

INITIALIZE_PASS_BEGIN(GenXLCECalculation, "GenXLCECalculation",
                      "GenXLCECalculation", false, false)
INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
INITIALIZE_PASS_END(GenXLCECalculation, "GenXLCECalculation",
                    "GenXLCECalculation", false, false)

FunctionPass *llvm::createGenXLCECalculationPass() {
  initializeGenXLCECalculationPass(*PassRegistry::getPassRegistry());
  return new GenXLCECalculation;
}

// Stop traversal when instruction is unknown.
Value *LCEFinder::visitInstruction(Instruction &I) { return nullptr; }

Value *LCEFinder::visitBinaryOperator(BinaryOperator &BO) {
  // Operand of the instruction which is a constant factor/addend.
  Value *C = nullptr;
  // Operand of the instruction which will be traversed further.
  Value *V = nullptr;
  // True when constant is the first operand and value is the second.
  bool IsReversed = false;

  if (isa<ConstantData>(BO.getOperand(0))) {
    IsReversed = true;
    C = BO.getOperand(0);
    V = BO.getOperand(1);
  } else if (isa<ConstantData>(BO.getOperand(1))) {
    C = BO.getOperand(1);
    V = BO.getOperand(0);
  } else
    return nullptr;

  auto Opcode = BO.getOpcode();
  if (IsReversed && !BO.isCommutative() && Opcode != Instruction::Sub &&
      Opcode != Instruction::FSub)
    return nullptr;

  float ConstAsFP = isa<ConstantFP>(C)
                         ? cast<ConstantFP>(C)->getValue().convertToFloat()
                         : cast<ConstantInt>(C)->getSExtValue();
  switch (Opcode) {
  default:
    return nullptr;
  case Instruction::Add:
  case Instruction::FAdd:
    LCE.Expr.Addend += ConstAsFP * LCE.Expr.Factor;
    break;
  case Instruction::Sub:
  case Instruction::FSub: {
    ConstAsFP *= LCE.Expr.Factor;
    if (!IsReversed)
      // sub V, C -> add V, -C
      LCE.Expr.Addend += -ConstAsFP;
    else {
      // sub C, V -> add -V, C
      LCE.Expr.Factor = -LCE.Expr.Factor;
      LCE.Expr.Addend += ConstAsFP;
    }
    break;
  }
  case Instruction::Mul:
  case Instruction::FMul:
    LCE.Expr.Factor *= ConstAsFP;
    break;
  case Instruction::UDiv:
  case Instruction::SDiv:
  case Instruction::FDiv:
    LCE.Expr.Factor /= ConstAsFP;
    break;
  case Instruction::Shl:
    LCE.Expr.Factor *= 1 << (unsigned)ConstAsFP;
    break;
  case Instruction::LShr:
    LCE.Expr.Factor /= 1 << (unsigned)ConstAsFP;
    break;
  }
  return V;
}

Value *LCEFinder::visitCastInst(CastInst &CI) { return CI.getOperand(0); }

Value *LCEFinder::visitGetElementPtrInst(GetElementPtrInst &GEP) {
  if (!LCE.Expr.Symbol.IsIndirect)
    return nullptr;
  auto GTI = gep_type_begin(GEP);
  for (auto OI = GEP.op_begin() + 1, E = GEP.op_end(); OI != E; ++OI, ++GTI) {
    auto *Idx = dyn_cast<ConstantInt>(*OI);
    if (!Idx)
      return nullptr;
    // TODO: Should struct types be handled?
    if (GTI.getStructTypeOrNull())
      return nullptr;
    LCE.Expr.Symbol.Offset +=
        DL->getTypeAllocSize(GTI.getIndexedType()) * Idx->getSExtValue();
  }
  return GEP.getPointerOperand();
}

Value *LCEFinder::visitLoadInst(LoadInst &LI) {
  if (LCE.Expr.Symbol.IsIndirect)
    return nullptr;
  LCE.Expr.Symbol.IsIndirect = true;
  return LI.getOperand(0);
}

LoopCountExprWrapper LCEFinder::getLCE(Value &Start) {
  if (auto *CI = dyn_cast<ConstantInt>(&Start)) {
    LCE.Expr.Factor = 0.0f;
    LCE.Expr.Addend = CI->getSExtValue();
    return LCE;
  }

  auto *Prev = &Start;
  Instruction *NextInst = nullptr;
  while (NextInst = dyn_cast_or_null<Instruction>(Prev))
    Prev = visit(NextInst);

  auto *Arg = dyn_cast_or_null<Argument>(Prev);
  if (!Arg)
    return LoopCountExprWrapper{};
  auto *ArgTy = Arg->getType();
  // Implicit arguments must come at the end.
  // So we don't take them into an account.
  LCE.Expr.Symbol.Num = Arg->getArgNo();
  LCE.Expr.Symbol.Size = ArgTy->isPointerTy()
                             ? DL->getPointerTypeSize(ArgTy)
                             : DL->getTypeSizeInBits(ArgTy) / 8;
  return LCE;
}

GenXLCECalculation::LCELoopInfo
GenXLCECalculation::processLoop(const Loop &L, const Function &F,
                                ScalarEvolution &SE, LoopInfo &LI) const {
  auto LBOptional = L.getBounds(SE);
  if (!LBOptional)
    return LCELoopInfo{};

  auto LB = *LBOptional;
  auto *StepValueCI = dyn_cast_or_null<ConstantInt>(LB.getStepValue());
  // We analyze only loops with a constant StepValue.
  if (!StepValueCI)
    return LCELoopInfo{};

  auto *DL = &F.getParent()->getDataLayout();
  LCELoopInfo Result;
  Result.AbsStepValue = std::abs(StepValueCI->getSExtValue());
  Result.Direction = LB.getDirection();

  auto &IVFinalV = LB.getFinalIVValue();
  Result.Final = LCEFinder{DL}.getLCE(IVFinalV);
  if (!Result.Final.isUndef()) {
    auto &IVInitCI = LB.getInitialIVValue();
    Result.Init = LCEFinder{DL}.getLCE(IVInitCI);
  }
  return Result;
}

bool GenXLCECalculation::runOnFunction(Function &F) {
  if (!vc::isKernel(&F))
    return false;

  bool Changed = false;
  auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
  auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();

  SmallVector<Loop *, 4> Loops = LI.getLoopsInPreorder();
  for (auto *L : Loops) {
    auto LCEInfo = processLoop(*L, F, SE, LI);
    auto Res = LCEInfo.Direction == LoopDirection::Increasing
                   ? (LCEInfo.Final - LCEInfo.Init) / LCEInfo.AbsStepValue
                   : (LCEInfo.Init - LCEInfo.Final) / LCEInfo.AbsStepValue;
    Changed |= Res.save(*L, *F.getParent());
  }
  return Changed;
}