File: WaveShuffleIndexSinking.cpp

package info (click to toggle)
intel-graphics-compiler2 2.16.0-2
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 106,644 kB
  • sloc: cpp: 805,640; lisp: 287,672; ansic: 16,414; python: 3,952; yacc: 2,588; lex: 1,666; pascal: 313; sh: 186; makefile: 35
file content (701 lines) | stat: -rw-r--r-- 30,104 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
/*========================== begin_copyright_notice ============================

Copyright (C) 2024 Intel Corporation

SPDX-License-Identifier: MIT

============================= end_copyright_notice ===========================*/

#include <GenISAIntrinsics/GenIntrinsicInst.h>
#include "WaveShuffleIndexSinking.hpp"
#include "Compiler/IGCPassSupport.h"
#include <igc_regkeys.hpp>
#include "common/LLVMWarningsPush.hpp"
#include <llvm/ADT/DenseMap.h>
#include <llvm/ADT/SmallVector.h>
#include <llvm/IR/Dominators.h>
#include "common/LLVMWarningsPop.hpp"

#define DEBUG_TYPE "igc-wave-shuffle-index-sinking"

using namespace IGC;
using namespace llvm;

namespace IGC {
class WaveShuffleIndexSinkingImpl {
  class ShuffleGroup {
    // Group of WaveShuffleIndex instructions with constant lane indexes that have one or more identical instructions
    // after ex. %0 = ... %1 = ... %2 = call i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32 %0, i32 0, i32 0) %3 = add
    // i32 %2, %1 %4 = shl i32 %3, 2 %use_4 = call @f(%4) %5 = call i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32 %0,
    // i32 1, i32 0) %6 = add i32 %5, %1 %7 = shl i32 %6, 2 %use_7 = call @f(%7) %8 = call i32
    // @llvm.genx.GenISA.WaveShuffleIndex.i32(i32 %0, i32 2, i32 0) %9 = add i32 %8, %1 %10 = shl i32 %9, 2 %use_10 =
    // call @f(%10)
    //
    // This can be transformed to the following since each WaveShuffleIndex is essentially a broadcast operation (from
    // having a constant lane index) Uniform (Constant) operands in operations following a WaveShuffleIndex can be
    // hoisted to the source Depending on the distributive properties of instructions, the shl in this example can be
    // hoisted above the add, and afterwards hoisted before the WaveShuffleIndex %0 = ... %1 = ... %2 = shl i32 %0, 2 %3
    // = shl i32 %1, 2 %4 = call i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32 %2, i32 0, i32 0) %5 = add i32 %3, %4
    // %use_4 = call @f(%5)
    // %6 = call i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32 %2, i32 1, i32 0)
    // %7 = add i32 %3, %6
    // %use_7 = call @f(%7)
    // %8 = call i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32 %2, i32 2, i32 0)
    // %9 = add i32 %3, %8
    // %use_10 = call @f(%9)
    //
    // This reduces the number of instructions in this particular ShuffleGroup from 3 * # of WaveShuffleIndex to 2 + 2 *
    // # of WaveShuffleIndex The more WaveShuffleIndex instructions in a ShuffleGroup, the more effective this
    // transformation is
  public:
    ShuffleGroup(WaveShuffleIndexIntrinsic *shuffleInst) { ShuffleOps.push_back(shuffleInst); }

    // Attempt to match a new WaveShuffleIndex instruction to this ShuffleGroup
    bool match(WaveShuffleIndexIntrinsic *shuffleInst) {
      if (ShuffleOps.size() == 1) {
        // Attempting to match with fresh ShuffleGroup, match the maximal number of instructions
        SmallVector<BinaryOperator *> InstChainA;
        SmallVector<BinaryOperator *> InstChainB;
        SmallVector<bool> NewHoistOrAnchorInstsIdx;
        unsigned numHoistable = compareWaveShuffleIndexes(ShuffleOps.front(), shuffleInst, InstChainA, InstChainB,
                                                          NewHoistOrAnchorInstsIdx);

        if (numHoistable == 0) {
          // Only match new shuffleInst with current ShuffleGroup if hoistable targets were found
          return false;
        }

        // Update ShuffleGroup members
        HoistOrAnchorInstsIdx = std::move(NewHoistOrAnchorInstsIdx);
        InstChains.push_back(InstChainA);
        InstChains.push_back(InstChainB);
        ShuffleOps.push_back(shuffleInst);
        return true;
      } else {
        // Use the first chain in the existing ShuffleGroup to check if the new shuffleInst can fit into the
        // ShuffleGroup
        SmallVector<BinaryOperator *> NewInstChain;
        SmallVector<bool> NewHoistOrAnchorInstsIdx;
        unsigned numHoistable = compareWaveShuffleIndexes(ShuffleOps.front(), shuffleInst, InstChains.front(),
                                                          NewInstChain, NewHoistOrAnchorInstsIdx);
        if (numHoistable == 0) {
          // Only match new shuffleInst with current ShuffleGroup if hoistable targets were found
          return false;
        }

        // New shuffleInst fits, but NewInstChain.size() may be lesser than the existing instChains
        // Reduce the hoistable instructions in the current group for now
        // Truncated hoistable instructions will be processed in the next iteration when matching to a smaller
        // ShuffleGroup (comprising of the existing InstChains and excluding NewInstChain)
        for (auto &instChain : InstChains) {
          instChain.truncate(NewInstChain.size());
        }

        // Update ShuffleGroup members
        HoistOrAnchorInstsIdx = std::move(NewHoistOrAnchorInstsIdx); // this should be the same size as NewInstChain
        InstChains.push_back(NewInstChain);
        ShuffleOps.push_back(shuffleInst);
        return true;
      }
    }

    // Once entire ShuffleGroup is gathered, pre-process and find the instructions that are actually profitable to hoist
    // ex. #WS = 2, Inst Cost = #WS * 6 (WS inst + mul + add + shl + add + shl) = 12
    // %0 = ...
    // %1 = ...
    // %2 = ...
    // %3 = call i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32 %0, i32 0, i32 0)
    // %4 = mul i32 %3, 5 <- Hoistable
    // %5 = add i32 %4, %1 <- Anchor
    // %6 = shl i32 %5, 2 <- Hoistable, profitable to hoist past Anchor
    // %7 = add i32 %6, %2 <- Anchor
    // %8 = shl i32 %7, 3  <- Hoistable, not profitable to hoist past Anchor, demoted to Anchor
    // %9 = call i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32 %0, i32 1, i32 0)
    // %10 = mul i32 %9, 5 <- Hoistable
    // %11 = add i32 %10, %1 <- Anchor
    // %12 = shl i32 %11, 2 <- Hoistable past Anchor
    // %13 = add i32 %12, %2 <- Anchor
    // %14 = shl i32 %13, 3 <- Hoistable, not profitable to hoist past Anchor, demoted to Anchor
    //
    // Result: Inst Cost = 3 (mul + shl + shl) + #WS * 4 (WS inst + add + add + shl) = 11
    // %0 = ...
    // %1 = ...
    // %2 = ...
    // %3 = mul i32 %0, 5
    // %4 = shl i32 %3, 2
    // %5 = shl i32 %1, 2
    // %6 = call i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32 %4, i32 0, i32 0)
    // %7 = add i32 %6, %5 <- Anchor
    // %8 = add i32 %7, %2 <- Anchor
    // %9 = shl i32 %8, 3 <- Demoted Anchor
    // %10 = call i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32 %4, i32 1, i32 0)
    // %11 = add i32 %10, %5 <- Anchor
    // %12 = add i32 %11, %2 <- Anchor
    // %13 = shl i32 %12, 3 <- Demoted Anchor

    // Terms:
    // - Anchor: BinaryOperator of which one operand is the preceding value in the InstChain
    //           (or the WaveShuffle, if the Anchor is the first inst), and the other operand is a non-constant
    // - Hoistable: BinaryOperator of which one operand is the preceding value in the InstChain
    //              (or the WaveShuffle, if the Hoistable is the first inst), and the other operand is a constant
    // - Hoistable past Anchor: Through distributive properties, a Hoistable further along the InstChain
    //                          can have its operation distributed to the operands of an Anchor
    // - Profitable to Hoist past Anchor: A "Hoistable past Anchor" instruction that when hoisted, does not result
    //                                    in more overall instructions than pre-hoist
    unsigned preprocess() {
      // Nothing to check
      if (InstChains.empty() || InstChains.front().empty())
        return false;

      // Profitability:
      // when an instruction is fully hoisted,
      // - one new instruction added to the second operand of each preceding anchor instruction
      // - one new instruction added to the singular source of all the WaveShuffleIndex instructions in the ShuffleGroup
      // - one less instruction per WaveShuffleIndex instruction in the ShuffleGroup
      // so if more instructions needed to be added for each anchor than removed for each shuffle op, optimization is no
      // longer profitable
      //
      // Based on the metric above, some instructions that are currently marked as hoist may need to be demoted to
      // anchor
      unsigned previousAnchorCount = 0;
      unsigned numProfitableHoistable = 0;
      unsigned idx = 0;
      while (idx < HoistOrAnchorInstsIdx.size()) {

        if (previousAnchorCount >= ShuffleOps.size()) {
          // not profitable anymore, demote all subsequent instructions to anchor regardless
          HoistOrAnchorInstsIdx[idx] = false;
          // no need to increment previousAnchorCount, all remaining iterations will enter this if block
        } else {
          if (!HoistOrAnchorInstsIdx[idx]) {
            previousAnchorCount++;
          } else {
            numProfitableHoistable++;
          }
        }
        idx++;
      }

      return numProfitableHoistable;
    }

    bool hoist(DenseMap<BasicBlock *, SmallVector<Instruction *, 4>> &MoveToCommonDominatorInstMap, DominatorTree &DT) {
      // If there is no common dominator abort hoisting
      BasicBlock *CommonDominator = findCommonDominator(DT);
      if (!CommonDominator)
        return false;

      // Track the new source for all the ShuffleOps
      auto *prev = ShuffleOps.front()->getSrc();

      for (unsigned idx = 0; idx < HoistOrAnchorInstsIdx.size(); idx++) {
        bool moveToCommonDominator = false;
        if (HoistOrAnchorInstsIdx[idx]) {
          // clone the inst to be hoisted
          auto *hoistedInst = InstChains.front()[idx]->clone();
          hoistedInst->setName(InstChains.front()[idx]->getName() + "_hoisted");
          hoistedInst->insertBefore(ShuffleOps.front());

          if (CommonDominator != hoistedInst->getParent()) {
            moveToCommonDominator = true;
            MoveToCommonDominatorInstMap[CommonDominator].emplace_back(hoistedInst);
          }

          // Replace the correct operand
          auto *hoistedOp0 = hoistedInst->getOperand(0);
          Instruction *hoistedOpPrev = (idx == 0) ? cast<Instruction>(ShuffleOps.front()) : InstChains.front()[idx - 1];
          unsigned chainOpIdx = 0; // Record which operand is the previous inst in the InstChain
          if (hoistedOp0 == hoistedOpPrev) {
            hoistedInst->setOperand(0, prev);
          } else {
            chainOpIdx = 1;
            hoistedInst->setOperand(1, prev);
          }

          prev = hoistedInst;

          // Create copies for each anchor instruction further up the chain
          for (unsigned anchorIdx = 0; anchorIdx < idx; anchorIdx++) {
            // found anchor
            if (!HoistOrAnchorInstsIdx[anchorIdx]) {
              // clone the inst to be hoisted
              auto *anchorHoistedInst = hoistedInst->clone();
              anchorHoistedInst->setName(hoistedInst->getName() + "_for_" + InstChains.front()[anchorIdx]->getName());
              anchorHoistedInst->insertBefore(InstChains.front()[anchorIdx]);

              // Replace the correct operand
              // ex.
              // %0 = ...
              // %1 = ...
              // %2 = call i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32 %0, i32 0, i32 0)
              // %3 = add i32 %2, %1 <- Anchor
              // %4 = shl i32 %3, 2 <- Hoistable past Anchor
              // Result:
              // %0 = ...
              // %1 = ...
              // %2 = shl i32 %0, 2 <- hoistedInst (WaveShuffle path)
              // %3 = shl i32 %1, 2 <- anchorHoistedInst (Anchor path)
              // %4 = call i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32 %2, i32 0, i32 0)
              // %5 = add i32 %4, %2 <- Anchor
              // Find the operand that originates from outside the chain to use in anchorHoistedInst
              auto *anchorOp0 = InstChains.front()[anchorIdx]->getOperand(0);
              auto *anchorOp1 = InstChains.front()[anchorIdx]->getOperand(1);
              Instruction *anchorOpPrev =
                  (anchorIdx == 0) ? cast<Instruction>(ShuffleOps.front()) : InstChains.front()[anchorIdx - 1];
              if (anchorOp0 == anchorOpPrev) {
                anchorHoistedInst->setOperand(chainOpIdx, anchorOp1);
              } else {
                anchorHoistedInst->setOperand(chainOpIdx, anchorOp0);
              }

              // Properly set the anchor instructions in all chains to use the new anchorHoistedInst
              for (unsigned i = 0; i < ShuffleOps.size(); i++) {
                auto *anchorOp0 = InstChains[i][anchorIdx]->getOperand(0);
                Instruction *anchorOpPrev =
                    (anchorIdx == 0) ? cast<Instruction>(ShuffleOps[i]) : InstChains[i][anchorIdx - 1];
                if (anchorOp0 == anchorOpPrev) {
                  InstChains[i][anchorIdx]->setOperand(1, anchorHoistedInst);
                } else {
                  InstChains[i][anchorIdx]->setOperand(0, anchorHoistedInst);
                }
              }

              // If hoisted instruction is moved, it's safe to move anchor as well.
              if (moveToCommonDominator) {
                MoveToCommonDominatorInstMap[CommonDominator].emplace_back(anchorHoistedInst);
              }
            }
          }
        }
      }

      // prev is last hoisted instruction, use as new src operand for all the shuffle ops in ShuffleGroup
      for (auto *waveShuffle : ShuffleOps) {
        waveShuffle->setSrc(prev);
      }

      // Rewire all operations around the hoisted instructions
      // This means removing the hoisted instructions in the InstChains path
      // Done for all InstChains
      for (unsigned i = 0; i < InstChains.size(); i++) {
        int lastAnchorIdx = -1;
        for (unsigned rewireIdx = 0; rewireIdx < HoistOrAnchorInstsIdx.size(); rewireIdx++) {
          if (HoistOrAnchorInstsIdx[rewireIdx]) {
            // no-op for hoisted insts
            continue;
          } else if (lastAnchorIdx + 1 == rewireIdx) {
            // already wired correctly, just increment
            lastAnchorIdx++;
          } else {
            Instruction *lastAnchor =
                lastAnchorIdx == -1 ? cast<Instruction>(ShuffleOps[i]) : InstChains[i][lastAnchorIdx];
            // operand to be replaced
            Instruction *rewirePrev = InstChains[i][rewireIdx - 1];
            unsigned rewireOpIdx = InstChains[i][rewireIdx]->getOperand(0) == rewirePrev ? 0 : 1;
            InstChains[i][rewireIdx]->setOperand(rewireOpIdx, lastAnchor);
            lastAnchorIdx = rewireIdx;
          }
        }

        if (lastAnchorIdx != HoistOrAnchorInstsIdx.size() - 1) {
          // one or more hoisted insts between last anchor and end of InstChain, one last rewire
          Instruction *lastAnchor =
              lastAnchorIdx == -1 ? cast<Instruction>(ShuffleOps[i]) : InstChains[i][lastAnchorIdx];
          InstChains[i].back()->replaceAllUsesWith(lastAnchor);
        }
      }

      for (auto &instChain : InstChains) {
        for (unsigned i = 0; i < HoistOrAnchorInstsIdx.size(); i++) {
          if (HoistOrAnchorInstsIdx[i])
            instChain[i]->eraseFromParent();
        }
      }

      return true;
    }

    SmallVector<WaveShuffleIndexIntrinsic *> ShuffleOps; // all the WaveShuffleIndex instructions in the group
  private:
    BasicBlock *findCommonDominator(DominatorTree &DT) {
      BasicBlock *DomBB = ShuffleOps.front()->getParent();
      for (auto &inst : ShuffleOps) {
        BasicBlock *UseBB = inst->getParent();
        DomBB = DT.findNearestCommonDominator(DomBB, UseBB);
      }

      return DomBB;
    }

    SmallVector<SmallVector<BinaryOperator *>>
        InstChains;                          // all common instructions shared by the shuffle ops, some can be hoisted
    SmallVector<bool> HoistOrAnchorInstsIdx; // Type of each Binary Operator in each InstChain: true -
                                             // Hoistable/Hoistable past previous Anchors, false - Anchor
  }; // ShuffleGroup

public:
  WaveShuffleIndexSinkingImpl(Function &F) : F(F) {}
  bool run();

private:
  bool splitWaveShuffleIndexes();
  bool mergeWaveShuffleIndexes();
  bool moveToCommonDominator();
  void gatherShuffleGroups();
  bool sinkShuffleGroups();
  static unsigned compareWaveShuffleIndexes(WaveShuffleIndexIntrinsic *waveShuffleIndex,
                                            WaveShuffleIndexIntrinsic *newWaveShuffleIndex,
                                            SmallVector<BinaryOperator *> &InstChain,
                                            SmallVector<BinaryOperator *> &newInstChain,
                                            SmallVector<bool> &hoistOrAnchor);
  static bool isHoistable(BinaryOperator *inst);
  static bool isHoistableOverAnchor(BinaryOperator *instToHoist, BinaryOperator *anchorInst);
  Function &F;
  DominatorTree DT;
  DenseMap<BasicBlock *, SmallVector<Instruction *, 4>> MoveToCommonDominatorInstMap;
  DenseMap<Value *, SmallVector<ShuffleGroup, 4>> ShuffleGroupMap;
  DenseSet<WaveShuffleIndexIntrinsic *> Visited;
};

class WaveShuffleIndexSinking : public FunctionPass {
public:
  static char ID;
  WaveShuffleIndexSinking() : FunctionPass(ID) {}

  StringRef getPassName() const override { return "WaveShuffleIndexSinking"; }

  bool runOnFunction(Function &F) override;
};

FunctionPass *createWaveShuffleIndexSinking() { return new WaveShuffleIndexSinking(); }
} // namespace IGC

// Split any WaveShuffleIndex instructions that have more than one user
// This may uncover more hoisting opportunities
// If none of the instructions were able to be hoisted, the split instructions will be merged back together at the end
bool WaveShuffleIndexSinkingImpl::splitWaveShuffleIndexes() {
  bool Changed = false;
  SmallVector<WaveShuffleIndexIntrinsic *> InstsToSplit;
  for (auto &BB : F) {
    for (auto &I : BB) {
      if (auto *waveShuffleInst = dyn_cast<WaveShuffleIndexIntrinsic>(&I)) {
        if (auto *constantChannel = dyn_cast<ConstantInt>(waveShuffleInst->getChannel())) {
          // Do not split WaveShuffleIndex insts that do not have a constant index since they cannot be optimized by
          // this pass anyways
          if (!waveShuffleInst->getUniqueUndroppableUser()) {
            // More than one user, split to potentially uncover more chances sink each individual WaveShuffleIndex
            Changed = true;
            InstsToSplit.push_back(waveShuffleInst);
          }
        }
      }
    }
  }

  for (auto *instToSplit : InstsToSplit) {
    SmallVector<std::pair<Instruction *, Instruction *>> ReplacementPairs;
    // Multiple users, split instruction
    for (auto *user : instToSplit->users()) {
      auto *userInst = cast<Instruction>(user);
      auto *clonedWaveShuffleInst = instToSplit->clone();
      clonedWaveShuffleInst->setName(instToSplit->getName() + "_clone");
      clonedWaveShuffleInst->insertBefore(instToSplit);
      // Track replacement to perform after loop since iterators will be messed up if performed mid loop
      ReplacementPairs.emplace_back(userInst, clonedWaveShuffleInst);
    }

    for (auto &p : ReplacementPairs) {
      p.first->replaceUsesOfWith(instToSplit, p.second);
    }

    // Each user is now using a cloned instruction, original should be safe to remove
    if (instToSplit->isSafeToRemove()) {
      instToSplit->eraseFromParent();
    }
  }

  return Changed;
}

bool WaveShuffleIndexSinkingImpl::moveToCommonDominator() {
  // hoisted intruction needs to be moved to common dominator BB.
  // If instructions in shuffle group are from different basic blocks
  // there is a risk of non-dominating all users.
  bool Changed = false;
  for (auto &bb : MoveToCommonDominatorInstMap) {
    auto instrInsertPtr = (&*bb.first->getFirstInsertionPt());
    for (auto &inst : bb.second) {
      inst->moveBefore(instrInsertPtr);
      Changed = true;
    }
  }

  return Changed;
}

// Merge WaveShuffleIndex instructions that have the same source operand and the same constant lane/channel operand
bool WaveShuffleIndexSinkingImpl::mergeWaveShuffleIndexes() {
  // Map from Source to (Map from Lane to list of duplicate instructions)
  DenseMap<Value *, DenseMap<ConstantInt *, SmallVector<WaveShuffleIndexIntrinsic *>>> mergeMap;
  for (auto &BB : F) {
    for (auto &I : BB) {
      if (auto *waveShuffleInst = dyn_cast<WaveShuffleIndexIntrinsic>(&I)) {
        if (auto *constantChannel = dyn_cast<ConstantInt>(waveShuffleInst->getChannel())) {
          mergeMap[waveShuffleInst->getSrc()][constantChannel].push_back(waveShuffleInst);
        }
      }
    }
  }

  bool Changed = false;
  for (auto &srcToLaneAndInstsMap : mergeMap) {
    for (auto &laneToInstsMap : srcToLaneAndInstsMap.second) {
      // Only 1 WaveShuffleIndex using the same src with the same constant channel index, nothing to merge
      auto &duplicateInsts = laneToInstsMap.second;
      if (duplicateInsts.size() < 2)
        continue;
      Changed = true;
      auto *mainShuffleIndex = duplicateInsts.front();

      // Find common dominator for main WaveShuffleIndex
      bool moveToCommonDominator = false;
      BasicBlock *DomBB = mainShuffleIndex->getParent();

      for (unsigned i = 1; i < duplicateInsts.size(); i++) {
        BasicBlock *UseBB = duplicateInsts[i]->getParent();
        DomBB = DT.findNearestCommonDominator(DomBB, UseBB);
      }

      if (!DomBB) {
        // Do not merge if Common Dominator is not found
        Changed = false;
        continue;
      }

      moveToCommonDominator = DomBB != mainShuffleIndex->getParent() ? true : false;

      // replace uses of other WaveShuffleIndex with the first one
      for (unsigned i = 1; i < duplicateInsts.size(); i++) {
        duplicateInsts[i]->replaceAllUsesWith(mainShuffleIndex);
        duplicateInsts[i]->eraseFromParent();
      }

      if (moveToCommonDominator) {
        MoveToCommonDominatorInstMap[DomBB].emplace_back(mainShuffleIndex);
      }
    }
  }

  return Changed;
}

// Find WaveShuffleIndex instructions and group them together based on common successor instructions
void WaveShuffleIndexSinkingImpl::gatherShuffleGroups() {
  for (auto &BB : F) {
    for (auto &I : BB) {
      if (auto *waveShuffleInst = dyn_cast<WaveShuffleIndexIntrinsic>(&I)) {
        if (Visited.count(waveShuffleInst) || !isa<ConstantInt>(waveShuffleInst->getChannel())) {
          // Processed in prior iteration and nothing changed or does not have a constant channel
          // Save compute and do not re-process/ create a new ShuffleGroup
          continue;
        }
        if (ShuffleGroupMap.count(waveShuffleInst->getSrc())) {
          // Found existing group(s) with the same source, try to match with one of the groups
          bool match = false;
          for (auto &shuffleGroup : ShuffleGroupMap[waveShuffleInst->getSrc()]) {
            if (shuffleGroup.match(waveShuffleInst)) {
              match = true;
              break;
            }
          }

          // create new ShuffleGroup since no suitable match was found
          if (!match) {
            ShuffleGroupMap[waveShuffleInst->getSrc()].emplace_back(waveShuffleInst);
          }
        } else {
          // create new ShuffleGroup for broadcast operations
          ShuffleGroupMap[waveShuffleInst->getSrc()].emplace_back(waveShuffleInst);
        }
      }
    }
  }
}

// Run profitability function and decide whether to sink ShuffleGroups or not
bool WaveShuffleIndexSinkingImpl::sinkShuffleGroups() {
  bool Changed = false;
  for (auto &kvp : ShuffleGroupMap) {
    for (auto &shuffleGroup : kvp.second) {
      unsigned numProfitableToHoist = shuffleGroup.preprocess();
      if (numProfitableToHoist > 0) {
        // Pre-process found profitable instructions left to hoist
        Changed |= shuffleGroup.hoist(MoveToCommonDominatorInstMap, DT);
      } else {
        // No-op, mark all WaveShuffleInst in the current shuffle group as visited
        for (auto *waveShuffleInst : shuffleGroup.ShuffleOps) {
          Visited.insert(waveShuffleInst);
        }
      }
    }
  }
  return Changed;
}

unsigned WaveShuffleIndexSinkingImpl::compareWaveShuffleIndexes(WaveShuffleIndexIntrinsic *waveShuffleIndex,
                                                                WaveShuffleIndexIntrinsic *newWaveShuffleIndex,
                                                                SmallVector<BinaryOperator *> &InstChain,
                                                                SmallVector<BinaryOperator *> &NewInstChain,
                                                                SmallVector<bool> &hoistOrAnchor) {
  // Only search up to the number of existing instructions in InstChain, if it is prepopulated
  // InstChain will be pre-populated if newWaveShuffleIndex is being compared to a developed ShuffleGroup (two or more
  // shuffle ops in group)
  std::optional<unsigned> limit;
  bool EmptyStartingInstChain = InstChain.empty();
  if (!EmptyStartingInstChain) {
    limit = InstChain.size();
  }
  Instruction *curInstA = waveShuffleIndex;
  Instruction *curInstB = newWaveShuffleIndex;
  unsigned idx = 0;
  unsigned numHoistable = 0;
  while (curInstA->hasOneUse() && curInstB->hasOneUse() && (!limit.has_value() || idx < limit)) {
    // Only attempt to search past BinaryOperator for now
    auto *instA = dyn_cast<BinaryOperator>(curInstA->getUniqueUndroppableUser());
    auto *instB = dyn_cast<BinaryOperator>(curInstB->getUniqueUndroppableUser());
    if (!instA || !instB)
      break;

    if (!instA->isSameOperationAs(instB))
      break;

    // Check that both operands match
    auto *opA0 = instA->getOperand(0);
    auto *opA1 = instA->getOperand(1);
    auto *opB0 = instB->getOperand(0);
    auto *opB1 = instB->getOperand(1);

    if (instA->isCommutative()) {
      // covers all four cases
      // ex.
      // add i32 %ws1, %a  | add i32 %a, %ws1
      // ...               | ...
      // add i32 %ws2, %a  | add i32 %a, %ws2
      //-------------------|-----------------
      // add i32 %ws1, %a  | add i32 %a, %ws1
      // ...               | ...
      // add i32 %a, %ws2  | add i32 %ws2, %a
      if (!(opA0 == curInstA && opB0 == curInstB && opA1 == opB1) &&
          !(opA0 == curInstA && opB1 == curInstB && opA1 == opB0) &&
          !(opA1 == curInstA && opB0 == curInstB && opA0 == opB1) &&
          !(opA1 == curInstA && opB1 == curInstB && opA0 == opB0))
        break;
    } else {
      // covers the 2 cases in row 1 above
      if (!(opA0 == curInstA && opB0 == curInstB && opA1 == opB1) &&
          !(opA1 == curInstA && opB1 == curInstB && opA0 == opB0))
        break;
    }

    if (isHoistable(instA)) {
      bool canHoistPastAnchor = true;
      // start checking from last instruction
      for (int i = hoistOrAnchor.size() - 1; i >= 0; i--) {
        // hoistOrAnchor[i] is an anchor and cannot hoist instA over an anchor
        if (!hoistOrAnchor[i] && !isHoistableOverAnchor(instA, InstChain[i])) {
          canHoistPastAnchor = false;
        }
      }

      if (canHoistPastAnchor) {
        numHoistable++;
        hoistOrAnchor.push_back(true);
      } else {
        hoistOrAnchor.push_back(false);
      }
    } else {
      hoistOrAnchor.push_back(false);
    }

    if (!limit.has_value()) {
      // Only update InstChain if it was a fresh vector
      InstChain.push_back(instA);
    }
    NewInstChain.push_back(instB);
    curInstA = instA;
    curInstB = instB;
    idx++;
  }

  return numHoistable;
}

bool WaveShuffleIndexSinkingImpl::isHoistable(BinaryOperator *inst) {
  // One operand has to be a constant, representing uniformity and allowing the operation to be performed on all simd
  // lanes prior to broadcast operation
  return isa<ConstantInt>(inst->getOperand(0)) || isa<ConstantFP>(inst->getOperand(0)) ||
         isa<ConstantInt>(inst->getOperand(1)) || isa<ConstantFP>(inst->getOperand(1));
}

// Combination of leftDistributesOverRight and rightDistributesOverLeft from LLVM InstCombining.cpp
bool WaveShuffleIndexSinkingImpl::isHoistableOverAnchor(BinaryOperator *instToHoist, BinaryOperator *anchorInst) {
  if (instToHoist->isCommutative()) {
    Instruction::BinaryOps FirstOp = anchorInst->getOpcode();
    Instruction::BinaryOps SecondOp = instToHoist->getOpcode();

    // X & (Y | Z) <--> (X & Y) | (X & Z)
    // X & (Y ^ Z) <--> (X & Y) ^ (X & Z)
    // In practice, FirstOp is unlikely to be And, Or, or Xor as they would themselves be hoistable and thus, never an
    // anchor inst
    if (SecondOp == Instruction::And)
      return FirstOp == Instruction::Or || FirstOp == Instruction::Xor;

    // X | (Y & Z) <--> (X | Y) & (X | Z)
    if (SecondOp == Instruction::Or)
      return FirstOp == Instruction::And;

    // X * (Y + Z) <--> (X * Y) + (X * Z)
    // X * (Y - Z) <--> (X * Y) - (X * Z)
    if (SecondOp == Instruction::Mul)
      return FirstOp == Instruction::Add || FirstOp == Instruction::Sub;

    return false;
  } else {
    return anchorInst->isBitwiseLogicOp() || instToHoist->isShift();
  }
}

bool WaveShuffleIndexSinkingImpl::run() {
  DT.recalculate(F);
  bool Changed = splitWaveShuffleIndexes();

  unsigned numIters = 0;
  while (numIters < IGC_GET_FLAG_VALUE(WaveShuffleIndexSinkingMaxIterations)) {
    gatherShuffleGroups();
    if (sinkShuffleGroups()) {
      Changed = true;
    } else {
      break;
    }

    numIters++;
    ShuffleGroupMap.clear();
  }
  Changed |= mergeWaveShuffleIndexes();
  Changed |= moveToCommonDominator();
  return Changed;
}

bool WaveShuffleIndexSinking::runOnFunction(Function &F) {
  WaveShuffleIndexSinkingImpl WorkerInstance(F);
  return WorkerInstance.run();
}

char WaveShuffleIndexSinking::ID = 0;

#define PASS_FLAG "igc-wave-shuffle-index-sinking"
#define PASS_DESCRIPTION "WaveShuffleIndexSinking"
#define PASS_CFG_ONLY false
#define PASS_ANALYSIS false
IGC_INITIALIZE_PASS_BEGIN(WaveShuffleIndexSinking, PASS_FLAG, PASS_DESCRIPTION, PASS_CFG_ONLY, PASS_ANALYSIS)
IGC_INITIALIZE_PASS_END(WaveShuffleIndexSinking, PASS_FLAG, PASS_DESCRIPTION, PASS_CFG_ONLY, PASS_ANALYSIS)