File: LoopFusion.cpp

package info (click to toggle)
swiftlang 6.0.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,519,992 kB
  • sloc: cpp: 9,107,863; ansic: 2,040,022; asm: 1,135,751; python: 296,500; objc: 82,456; f90: 60,502; lisp: 34,951; pascal: 19,946; sh: 18,133; perl: 7,482; ml: 4,937; javascript: 4,117; makefile: 3,840; awk: 3,535; xml: 914; fortran: 619; cs: 573; ruby: 573
file content (1512 lines) | stat: -rw-r--r-- 65,438 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
//===- LoopFusion.cpp - Code to perform loop fusion -----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements affine fusion.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/Passes.h"

#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/Utils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/LoopFusionUtils.h"
#include "mlir/Dialect/Affine/LoopUtils.h"
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <iomanip>
#include <optional>
#include <sstream>

namespace mlir {
namespace affine {
#define GEN_PASS_DEF_AFFINELOOPFUSION
#include "mlir/Dialect/Affine/Passes.h.inc"
} // namespace affine
} // namespace mlir

#define DEBUG_TYPE "affine-loop-fusion"

using namespace mlir;
using namespace mlir::affine;

namespace {
/// Loop fusion pass. This pass currently supports a greedy fusion policy,
/// which fuses loop nests with single-writer/single-reader memref dependences
/// with the goal of improving locality.

// TODO: Support fusion of source loop nests which write to multiple
// memrefs, where each memref can have multiple users (if profitable).
// TODO: Extend this pass to check for fusion preventing dependences,
// and add support for more general loop fusion algorithms.

struct LoopFusion : public affine::impl::AffineLoopFusionBase<LoopFusion> {
  LoopFusion() = default;
  LoopFusion(unsigned fastMemorySpace, uint64_t localBufSizeThresholdBytes,
             bool maximalFusion, enum FusionMode affineFusionMode) {
    this->fastMemorySpace = fastMemorySpace;
    this->localBufSizeThreshold = localBufSizeThresholdBytes / 1024;
    this->maximalFusion = maximalFusion;
    this->affineFusionMode = affineFusionMode;
  }

  void runOnBlock(Block *block);
  void runOnOperation() override;
};

} // namespace

/// Returns true if node 'srcId' can be removed after fusing it with node
/// 'dstId'. The node can be removed if any of the following conditions are met:
///   1. 'srcId' has no output dependences after fusion and no escaping memrefs.
///   2. 'srcId' has no output dependences after fusion, has escaping memrefs
///       and the fusion slice is maximal.
///   3. 'srcId' has output dependences after fusion, the fusion slice is
///      maximal and the fusion insertion point dominates all the dependences.
static bool canRemoveSrcNodeAfterFusion(
    unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice,
    Operation *fusedLoopInsPoint, const DenseSet<Value> &escapingMemRefs,
    MemRefDependenceGraph *mdg) {

  Operation *dstNodeOp = mdg->getNode(dstId)->op;
  bool hasOutDepsAfterFusion = false;

  for (auto &outEdge : mdg->outEdges[srcId]) {
    Operation *depNodeOp = mdg->getNode(outEdge.id)->op;
    // Skip dependence with dstOp since it will be removed after fusion.
    if (depNodeOp == dstNodeOp)
      continue;

    // Only fusion within the same block is supported. Use domination analysis
    // when needed.
    if (depNodeOp->getBlock() != dstNodeOp->getBlock())
      return false;

    // Check if the insertion point of the fused loop dominates the dependence.
    // Otherwise, the src loop can't be removed.
    if (fusedLoopInsPoint != depNodeOp &&
        !fusedLoopInsPoint->isBeforeInBlock(depNodeOp)) {
      LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: dst loop doesn't "
                                 "dominate dependence\n");
      return false;
    }

    hasOutDepsAfterFusion = true;
  }

  // If src loop has dependences after fusion or it writes to an live-out or
  // escaping memref, we can only remove it if the fusion slice is maximal so
  // that all the dependences are preserved.
  if (hasOutDepsAfterFusion || !escapingMemRefs.empty()) {
    std::optional<bool> isMaximal = fusionSlice.isMaximal();
    if (!isMaximal) {
      LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: can't determine "
                                 "if fusion is maximal\n");
      return false;
    }

    if (!*isMaximal) {
      LLVM_DEBUG(llvm::dbgs()
                 << "Src loop can't be removed: fusion is not maximal\n");
      return false;
    }
  }

  return true;
}

/// Returns in 'srcIdCandidates' the producer fusion candidates for consumer
/// 'dstId'. Candidates are sorted by node id order. This order corresponds to
/// the program order when the 'mdg' is created. However, program order is not
/// guaranteed and must not be required by the client. Program order won't be
/// held if the 'mdg' is reused from a previous fusion step or if the node
/// creation order changes in the future to support more advance cases.
// TODO: Move this to a loop fusion utility once 'mdg' is also moved.
static void getProducerCandidates(unsigned dstId, MemRefDependenceGraph *mdg,
                                  SmallVectorImpl<unsigned> &srcIdCandidates) {
  // Skip if no input edges along which to fuse.
  if (mdg->inEdges.count(dstId) == 0)
    return;

  // Gather memrefs from loads in 'dstId'.
  auto *dstNode = mdg->getNode(dstId);
  DenseSet<Value> consumedMemrefs;
  for (Operation *load : dstNode->loads)
    consumedMemrefs.insert(cast<AffineReadOpInterface>(load).getMemRef());

  // Traverse 'dstId' incoming edges and gather the nodes that contain a store
  // to one of the consumed memrefs.
  for (auto &srcEdge : mdg->inEdges[dstId]) {
    auto *srcNode = mdg->getNode(srcEdge.id);
    // Skip if 'srcNode' is not a loop nest.
    if (!isa<AffineForOp>(srcNode->op))
      continue;

    if (any_of(srcNode->stores, [&](Operation *op) {
          auto storeOp = cast<AffineWriteOpInterface>(op);
          return consumedMemrefs.count(storeOp.getMemRef()) > 0;
        }))
      srcIdCandidates.push_back(srcNode->id);
  }

  llvm::sort(srcIdCandidates);
  srcIdCandidates.erase(
      std::unique(srcIdCandidates.begin(), srcIdCandidates.end()),
      srcIdCandidates.end());
}

/// Returns in 'producerConsumerMemrefs' the memrefs involved in a
/// producer-consumer dependence between 'srcId' and 'dstId'.
static void
gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId,
                              MemRefDependenceGraph *mdg,
                              DenseSet<Value> &producerConsumerMemrefs) {
  auto *dstNode = mdg->getNode(dstId);
  auto *srcNode = mdg->getNode(srcId);
  gatherProducerConsumerMemrefs(srcNode->stores, dstNode->loads,
                                producerConsumerMemrefs);
}

/// A memref escapes in the context of the fusion pass if either:
///   1. it (or its alias) is a block argument, or
///   2. created by an op not known to guarantee alias freedom,
///   3. it (or its alias) are used by ops other than affine dereferencing ops
///   (e.g., by call op, memref load/store ops, alias creating ops, unknown ops,
///   terminator ops, etc.); such ops do not deference the memref in an affine
///   way.
static bool isEscapingMemref(Value memref, Block *block) {
  Operation *defOp = memref.getDefiningOp();
  // Check if 'memref' is a block argument.
  if (!defOp)
    return true;

  // Check if this is defined to be an alias of another memref.
  if (auto viewOp = dyn_cast<mlir::ViewLikeOpInterface>(defOp))
    if (isEscapingMemref(viewOp.getViewSource(), block))
      return true;

  // Any op besides allocating ops wouldn't guarantee alias freedom
  if (!hasSingleEffect<mlir::MemoryEffects::Allocate>(defOp, memref))
    return true;

  // Check if 'memref' is used by a non-deferencing op (including unknown ones)
  // (e.g., call ops, alias creating ops, etc.).
  return llvm::any_of(memref.getUsers(), [&](Operation *user) {
    // Ignore users outside of `block`.
    if (block->getParent()->findAncestorOpInRegion(*user)->getBlock() != block)
      return false;
    return !isa<AffineMapAccessInterface>(*user);
  });
}

/// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id'
/// that escape the block or are accessed in a non-affine way.
static void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg,
                                  DenseSet<Value> &escapingMemRefs) {
  auto *node = mdg->getNode(id);
  for (Operation *storeOp : node->stores) {
    auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
    if (escapingMemRefs.count(memref))
      continue;
    if (isEscapingMemref(memref, &mdg->block))
      escapingMemRefs.insert(memref);
  }
}

// Initializes the data dependence graph by walking operations in `block`.
// Assigns each node in the graph a node id based on program order in 'f'.
bool MemRefDependenceGraph::init() {
  LLVM_DEBUG(llvm::dbgs() << "--- Initializing MDG ---\n");
  // Map from a memref to the set of ids of the nodes that have ops accessing
  // the memref.
  DenseMap<Value, SetVector<unsigned>> memrefAccesses;

  DenseMap<Operation *, unsigned> forToNodeMap;
  for (Operation &op : block) {
    if (auto forOp = dyn_cast<AffineForOp>(op)) {
      // Create graph node 'id' to represent top-level 'forOp' and record
      // all loads and store accesses it contains.
      LoopNestStateCollector collector;
      collector.collect(&op);
      // Return false if a region holding op other than 'affine.for' and
      // 'affine.if' was found (not currently supported).
      if (collector.hasNonAffineRegionOp)
        return false;
      Node node(nextNodeId++, &op);
      for (auto *opInst : collector.loadOpInsts) {
        node.loads.push_back(opInst);
        auto memref = cast<AffineReadOpInterface>(opInst).getMemRef();
        memrefAccesses[memref].insert(node.id);
      }
      for (auto *opInst : collector.storeOpInsts) {
        node.stores.push_back(opInst);
        auto memref = cast<AffineWriteOpInterface>(opInst).getMemRef();
        memrefAccesses[memref].insert(node.id);
      }
      forToNodeMap[&op] = node.id;
      nodes.insert({node.id, node});
    } else if (auto loadOp = dyn_cast<AffineReadOpInterface>(op)) {
      // Create graph node for top-level load op.
      Node node(nextNodeId++, &op);
      node.loads.push_back(&op);
      auto memref = cast<AffineReadOpInterface>(op).getMemRef();
      memrefAccesses[memref].insert(node.id);
      nodes.insert({node.id, node});
    } else if (auto storeOp = dyn_cast<AffineWriteOpInterface>(op)) {
      // Create graph node for top-level store op.
      Node node(nextNodeId++, &op);
      node.stores.push_back(&op);
      auto memref = cast<AffineWriteOpInterface>(op).getMemRef();
      memrefAccesses[memref].insert(node.id);
      nodes.insert({node.id, node});
    } else if (op.getNumRegions() != 0) {
      // Return false if another region is found (not currently supported).
      return false;
    } else if (op.getNumResults() > 0 && !op.use_empty()) {
      // Create graph node for top-level producer of SSA values, which
      // could be used by loop nest nodes.
      Node node(nextNodeId++, &op);
      nodes.insert({node.id, node});
    } else if (isa<CallOpInterface>(op)) {
      // Create graph node for top-level Call Op that takes any argument of
      // memref type. Call Op that returns one or more memref type results
      // is already taken care of, by the previous conditions.
      if (llvm::any_of(op.getOperandTypes(),
                       [&](Type t) { return isa<MemRefType>(t); })) {
        Node node(nextNodeId++, &op);
        nodes.insert({node.id, node});
      }
    } else if (hasEffect<MemoryEffects::Write, MemoryEffects::Free>(&op)) {
      // Create graph node for top-level op, which could have a memory write
      // side effect.
      Node node(nextNodeId++, &op);
      nodes.insert({node.id, node});
    }
  }

  for (auto &idAndNode : nodes) {
    LLVM_DEBUG(llvm::dbgs() << "Create node " << idAndNode.first << " for:\n"
                            << *(idAndNode.second.op) << "\n");
    (void)idAndNode;
  }

  // Add dependence edges between nodes which produce SSA values and their
  // users. Load ops can be considered as the ones producing SSA values.
  for (auto &idAndNode : nodes) {
    const Node &node = idAndNode.second;
    // Stores don't define SSA values, skip them.
    if (!node.stores.empty())
      continue;
    Operation *opInst = node.op;
    for (Value value : opInst->getResults()) {
      for (Operation *user : value.getUsers()) {
        // Ignore users outside of the block.
        if (block.getParent()->findAncestorOpInRegion(*user)->getBlock() !=
            &block)
          continue;
        SmallVector<AffineForOp, 4> loops;
        getAffineForIVs(*user, &loops);
        if (loops.empty())
          continue;
        assert(forToNodeMap.count(loops[0]) > 0 && "missing mapping");
        unsigned userLoopNestId = forToNodeMap[loops[0]];
        addEdge(node.id, userLoopNestId, value);
      }
    }
  }

  // Walk memref access lists and add graph edges between dependent nodes.
  for (auto &memrefAndList : memrefAccesses) {
    unsigned n = memrefAndList.second.size();
    for (unsigned i = 0; i < n; ++i) {
      unsigned srcId = memrefAndList.second[i];
      bool srcHasStore =
          getNode(srcId)->getStoreOpCount(memrefAndList.first) > 0;
      for (unsigned j = i + 1; j < n; ++j) {
        unsigned dstId = memrefAndList.second[j];
        bool dstHasStore =
            getNode(dstId)->getStoreOpCount(memrefAndList.first) > 0;
        if (srcHasStore || dstHasStore)
          addEdge(srcId, dstId, memrefAndList.first);
      }
    }
  }
  return true;
}

// Sinks all sequential loops to the innermost levels (while preserving
// relative order among them) and moves all parallel loops to the
// outermost (while again preserving relative order among them).
// This can increase the loop depth at which we can fuse a slice, since we are
// pushing loop carried dependence to a greater depth in the loop nest.
static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) {
  assert(isa<AffineForOp>(node->op));
  AffineForOp newRootForOp = sinkSequentialLoops(cast<AffineForOp>(node->op));
  node->op = newRootForOp;
}

// Creates and returns a private (single-user) memref for fused loop rooted
// at 'forOp', with (potentially reduced) memref size based on the
// MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'.
// TODO: consider refactoring the common code from generateDma and
// this one.
static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
                                 unsigned dstLoopDepth,
                                 std::optional<unsigned> fastMemorySpace,
                                 uint64_t localBufSizeThreshold) {
  Operation *forInst = forOp.getOperation();

  // Create builder to insert alloc op just before 'forOp'.
  OpBuilder b(forInst);
  // Builder to create constants at the top level.
  OpBuilder top(forInst->getParentRegion());
  // Create new memref type based on slice bounds.
  auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOpInst).getMemRef();
  auto oldMemRefType = cast<MemRefType>(oldMemRef.getType());
  unsigned rank = oldMemRefType.getRank();

  // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'.
  MemRefRegion region(srcStoreOpInst->getLoc());
  bool validRegion = succeeded(region.compute(srcStoreOpInst, dstLoopDepth));
  (void)validRegion;
  assert(validRegion && "unexpected memref region failure");
  SmallVector<int64_t, 4> newShape;
  std::vector<SmallVector<int64_t, 4>> lbs;
  SmallVector<int64_t, 8> lbDivisors;
  lbs.reserve(rank);
  // Query 'region' for 'newShape' and lower bounds of MemRefRegion accessed
  // by 'srcStoreOpInst' at depth 'dstLoopDepth'.
  std::optional<int64_t> numElements =
      region.getConstantBoundingSizeAndShape(&newShape, &lbs, &lbDivisors);
  assert(numElements && "non-constant number of elts in local buffer");

  const FlatAffineValueConstraints *cst = region.getConstraints();
  // 'outerIVs' holds the values that this memory region is symbolic/parametric
  // on; this would correspond to loop IVs surrounding the level at which the
  // slice is being materialized.
  SmallVector<Value, 8> outerIVs;
  cst->getValues(rank, cst->getNumVars(), &outerIVs);

  // Build 'rank' AffineExprs from MemRefRegion 'lbs'
  SmallVector<AffineExpr, 4> offsets;
  offsets.reserve(rank);
  for (unsigned d = 0; d < rank; ++d) {
    assert(lbs[d].size() == cst->getNumCols() - rank && "incorrect bound size");

    AffineExpr offset = top.getAffineConstantExpr(0);
    for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++) {
      offset = offset + lbs[d][j] * top.getAffineDimExpr(j);
    }
    assert(lbDivisors[d] > 0);
    offset =
        (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]);
    offsets.push_back(offset);
  }

  // Create 'newMemRefType' using 'newShape' from MemRefRegion accessed
  // by 'srcStoreOpInst'.
  auto eltSize = getMemRefIntOrFloatEltSizeInBytes(oldMemRefType);
  assert(eltSize && "memrefs with size elt types expected");
  uint64_t bufSize = *eltSize * *numElements;
  unsigned newMemSpace;
  if (bufSize <= localBufSizeThreshold && fastMemorySpace.has_value()) {
    newMemSpace = *fastMemorySpace;
  } else {
    newMemSpace = oldMemRefType.getMemorySpaceAsInt();
  }
  auto newMemRefType = MemRefType::get(newShape, oldMemRefType.getElementType(),
                                       {}, newMemSpace);

  // Create new private memref for fused loop 'forOp'. 'newShape' is always
  // a constant shape.
  // TODO: Create/move alloc ops for private memrefs closer to their
  // consumer loop nests to reduce their live range. Currently they are added
  // at the beginning of the block, because loop nests can be reordered
  // during the fusion pass.
  Value newMemRef = top.create<memref::AllocOp>(forOp.getLoc(), newMemRefType);

  // Build an AffineMap to remap access functions based on lower bound offsets.
  SmallVector<AffineExpr, 4> remapExprs;
  remapExprs.reserve(rank);
  for (unsigned i = 0; i < rank; i++) {
    auto dimExpr = b.getAffineDimExpr(outerIVs.size() + i);

    auto remapExpr =
        simplifyAffineExpr(dimExpr - offsets[i], outerIVs.size() + rank, 0);
    remapExprs.push_back(remapExpr);
  }

  auto indexRemap =
      AffineMap::get(outerIVs.size() + rank, 0, remapExprs, forOp.getContext());

  // Replace all users of 'oldMemRef' with 'newMemRef'.
  LogicalResult res =
      replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap,
                               /*extraOperands=*/outerIVs,
                               /*symbolOperands=*/{},
                               /*domOpFilter=*/&*forOp.getBody()->begin());
  assert(succeeded(res) &&
         "replaceAllMemrefUsesWith should always succeed here");
  (void)res;
  return newMemRef;
}

/// Walking from node 'srcId' to node 'dstId' (exclusive of 'srcId' and
/// 'dstId'), if there is any non-affine operation accessing 'memref', return
/// true. Otherwise, return false.
static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId,
                                       Value memref,
                                       MemRefDependenceGraph *mdg) {
  auto *srcNode = mdg->getNode(srcId);
  auto *dstNode = mdg->getNode(dstId);
  Value::user_range users = memref.getUsers();
  // For each MemRefDependenceGraph's node that is between 'srcNode' and
  // 'dstNode' (exclusive of 'srcNodes' and 'dstNode'), check whether any
  // non-affine operation in the node accesses the 'memref'.
  for (auto &idAndNode : mdg->nodes) {
    Operation *op = idAndNode.second.op;
    // Take care of operations between 'srcNode' and 'dstNode'.
    if (srcNode->op->isBeforeInBlock(op) && op->isBeforeInBlock(dstNode->op)) {
      // Walk inside the operation to find any use of the memref.
      // Interrupt the walk if found.
      auto walkResult = op->walk([&](Operation *user) {
        // Skip affine ops.
        if (isa<AffineMapAccessInterface>(*user))
          return WalkResult::advance();
        // Find a non-affine op that uses the memref.
        if (llvm::is_contained(users, user))
          return WalkResult::interrupt();
        return WalkResult::advance();
      });
      if (walkResult.wasInterrupted())
        return true;
    }
  }
  return false;
}

/// Check whether a memref value in node 'srcId' has a non-affine that
/// is between node 'srcId' and node 'dstId' (exclusive of 'srcNode' and
/// 'dstNode').
static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId,
                                       MemRefDependenceGraph *mdg) {
  // Collect memref values in node 'srcId'.
  auto *srcNode = mdg->getNode(srcId);
  llvm::SmallDenseSet<Value, 2> memRefValues;
  srcNode->op->walk([&](Operation *op) {
    // Skip affine ops.
    if (isa<AffineForOp>(op))
      return WalkResult::advance();
    for (Value v : op->getOperands())
      // Collect memref values only.
      if (isa<MemRefType>(v.getType()))
        memRefValues.insert(v);
    return WalkResult::advance();
  });
  // Looking for users between node 'srcId' and node 'dstId'.
  return llvm::any_of(memRefValues, [&](Value memref) {
    return hasNonAffineUsersOnThePath(srcId, dstId, memref, mdg);
  });
}

// Checks the profitability of fusing a backwards slice of the loop nest
// surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'.
// The argument 'srcStoreOpInst' is used to calculate the storage reduction on
// the memref being produced and consumed, which is an input to the cost model.
// For producer-consumer fusion, 'srcStoreOpInst' will be the same as
// 'srcOpInst', as we are slicing w.r.t to that producer. For input-reuse
// fusion, 'srcOpInst' will be the src loop nest LoadOp which reads from the
// same memref as dst loop nest load ops, and 'srcStoreOpInst' will be the
// unique store op in the src node, which will be used to check that the write
// region is the same after input-reuse fusion. Computation slices are provided
// in 'depthSliceUnions' for each legal fusion depth. The maximal depth at which
// fusion is legal is provided in 'maxLegalFusionDepth'. Returns true if it is
// profitable to fuse the candidate loop nests. Returns false otherwise.
// `dstLoopDepth` is set to the most profitable depth at which to materialize
// the source loop nest slice.
// The profitability model executes the following steps:
// *) Computes the backward computation slice at 'srcOpInst'. This
//    computation slice of the loop nest surrounding 'srcOpInst' is
//    represented by modified src loop bounds in 'sliceState', which are
//    functions of loop IVs in the loop nest surrounding 'srcOpInst'.
// *) Computes the cost of unfused src/dst loop nests (currently the cost of a
//    loop nest is the total number of dynamic operation instances in the loop
//    nest).
// *) Computes the cost of fusing a slice of the src loop nest into the dst
//    loop nest at various values of dst loop depth, attempting to fuse
//    the largest computation slice at the maximal dst loop depth (closest to
//    the load) to minimize reuse distance and potentially enable subsequent
//    load/store forwarding.
//    NOTE: 'dstLoopDepth' refers to the loop depth within the destination loop
//    nest, at which the src computation slice is inserted/fused.
//    NOTE: We attempt to maximize the dst loop depth, but there are cases
//    where a particular setting for 'dstLoopNest' might fuse an unsliced
//    loop (within the src computation slice) at a depth which results in
//    excessive recomputation (see unit tests for examples).
// *) Compares the total cost of the unfused loop nests to the min cost fused
//    loop nest computed in the previous step, and returns true if the latter
//    is lower.
// TODO: Extend profitability analysis to support scenarios with multiple
// stores.
static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
                               AffineForOp dstForOp,
                               ArrayRef<ComputationSliceState> depthSliceUnions,
                               unsigned maxLegalFusionDepth,
                               unsigned *dstLoopDepth,
                               double computeToleranceThreshold) {
  LLVM_DEBUG({
    llvm::dbgs() << "Checking whether fusion is profitable between src op:\n";
    llvm::dbgs() << ' ' << *srcOpInst << " and destination loop:\n";
    llvm::dbgs() << dstForOp << "\n";
  });

  if (maxLegalFusionDepth == 0) {
    LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxLegalFusionDepth is 0\n");
    return false;
  }

  // Compute cost of sliced and unsliced src loop nest.
  SmallVector<AffineForOp, 4> srcLoopIVs;
  getAffineForIVs(*srcOpInst, &srcLoopIVs);

  // Walk src loop nest and collect stats.
  LoopNestStats srcLoopNestStats;
  if (!getLoopNestStats(srcLoopIVs[0], &srcLoopNestStats))
    return false;

  // Compute cost of dst loop nest.
  LoopNestStats dstLoopNestStats;
  if (!getLoopNestStats(dstForOp, &dstLoopNestStats))
    return false;

  // Search for min cost value for 'dstLoopDepth'. At each value of
  // 'dstLoopDepth' from 'maxLegalLoopDepth' to '1', compute computation slice
  // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union
  // of these bounds). Next the union slice bounds are used to calculate
  // the cost of the slice and the cost of the slice inserted into the dst
  // loop nest at 'dstLoopDepth'.
  uint64_t minFusedLoopNestComputeCost = std::numeric_limits<uint64_t>::max();
  double maxStorageReduction = 0.0;
  std::optional<uint64_t> sliceMemEstimate;

  // The best loop depth at which to materialize the slice.
  std::optional<unsigned> bestDstLoopDepth;

  // Compute op instance count for the src loop nest without iteration slicing.
  uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], srcLoopNestStats);

  // Compute src loop nest write region size.
  MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc());
  if (failed(srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0))) {
    LLVM_DEBUG(llvm::dbgs()
               << "Unable to compute MemRefRegion for source operation\n");
    return false;
  }

  std::optional<int64_t> maybeSrcWriteRegionSizeBytes =
      srcWriteRegion.getRegionSize();
  if (!maybeSrcWriteRegionSizeBytes.has_value())
    return false;
  int64_t srcWriteRegionSizeBytes = *maybeSrcWriteRegionSizeBytes;

  // Compute op instance count for the src loop nest.
  uint64_t dstLoopNestCost = getComputeCost(dstForOp, dstLoopNestStats);

  // Evaluate all depth choices for materializing the slice in the destination
  // loop nest.
  for (unsigned i = maxLegalFusionDepth; i >= 1; --i) {
    const ComputationSliceState &slice = depthSliceUnions[i - 1];
    // Skip slice union if it wasn't computed for this depth.
    if (slice.isEmpty())
      continue;

    int64_t fusedLoopNestComputeCost;
    if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstForOp,
                              dstLoopNestStats, slice,
                              &fusedLoopNestComputeCost)) {
      LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost\n");
      continue;
    }

    double additionalComputeFraction =
        fusedLoopNestComputeCost /
            (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
        1;

    // Determine what the slice write MemRefRegion would be, if the src loop
    // nest slice 'slice' were to be inserted into the dst loop nest at loop
    // depth 'i'.
    MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc());
    if (failed(sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0,
                                        &slice))) {
      LLVM_DEBUG(llvm::dbgs()
                 << "Failed to compute slice write region at loopDepth: " << i
                 << "\n");
      continue;
    }

    std::optional<int64_t> maybeSliceWriteRegionSizeBytes =
        sliceWriteRegion.getRegionSize();
    if (!maybeSliceWriteRegionSizeBytes.has_value() ||
        *maybeSliceWriteRegionSizeBytes == 0) {
      LLVM_DEBUG(llvm::dbgs()
                 << "Failed to get slice write region size at loopDepth: " << i
                 << "\n");
      continue;
    }
    int64_t sliceWriteRegionSizeBytes = *maybeSliceWriteRegionSizeBytes;

    // If we are fusing for reuse, check that write regions remain the same.
    // TODO: Write region check should check sizes and offsets in
    // each dimension, so that we are sure they are covering the same memref
    // region. Also, move this out to a isMemRefRegionSuperSet helper function.
    if (srcOpInst != srcStoreOpInst &&
        sliceWriteRegionSizeBytes != srcWriteRegionSizeBytes)
      continue;

    double storageReduction = static_cast<double>(srcWriteRegionSizeBytes) /
                              static_cast<double>(sliceWriteRegionSizeBytes);

    LLVM_DEBUG({
      std::stringstream msg;
      msg << "  evaluating fusion profitability at depth : " << i << "\n"
          << std::fixed << std::setprecision(2)
          << "   additional compute fraction: "
          << 100.0 * additionalComputeFraction << "%\n"
          << "   storage reduction factor: " << storageReduction << "x\n"
          << "   fused nest cost: " << fusedLoopNestComputeCost << "\n"
          << "   src write region size: " << srcWriteRegionSizeBytes << "\n"
          << "   slice write region size: " << sliceWriteRegionSizeBytes
          << "\n";
      llvm::dbgs() << msg.str();
    });

    // TODO: This is a placeholder cost model.
    // Among all choices that add an acceptable amount of redundant computation
    // (as per computeToleranceThreshold), we will simply pick the one that
    // reduces the intermediary size the most.
    if ((storageReduction > maxStorageReduction) &&
        (additionalComputeFraction < computeToleranceThreshold)) {
      maxStorageReduction = storageReduction;
      bestDstLoopDepth = i;
      minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
      sliceMemEstimate = sliceWriteRegionSizeBytes;
    }
  }

  // A simple cost model: fuse if it reduces the memory footprint.

  if (!bestDstLoopDepth) {
    LLVM_DEBUG(
        llvm::dbgs()
        << "All fusion choices involve more than the threshold amount of "
           "redundant computation; NOT fusing.\n");
    return false;
  }

  if (!bestDstLoopDepth) {
    LLVM_DEBUG(llvm::dbgs() << "no fusion depth could be evaluated.\n");
    return false;
  }

  // Set dstLoopDepth based on best values from search.
  *dstLoopDepth = *bestDstLoopDepth;

  LLVM_DEBUG(
      llvm::dbgs() << " LoopFusion fusion stats:"
                   << "\n  best loop depth: " << bestDstLoopDepth
                   << "\n  src loop nest compute cost: " << srcLoopNestCost
                   << "\n  dst loop nest compute cost: " << dstLoopNestCost
                   << "\n  fused loop nest compute cost: "
                   << minFusedLoopNestComputeCost << "\n");

  auto dstMemSize = getMemoryFootprintBytes(dstForOp);
  auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]);

  std::optional<double> storageReduction;

  if (!dstMemSize || !srcMemSize) {
    LLVM_DEBUG(llvm::dbgs()
               << "  fusion memory benefit cannot be evaluated; NOT fusing.\n");
    return false;
  }

  auto srcMemSizeVal = *srcMemSize;
  auto dstMemSizeVal = *dstMemSize;

  assert(sliceMemEstimate && "expected value");
  auto fusedMem = dstMemSizeVal + *sliceMemEstimate;

  LLVM_DEBUG(llvm::dbgs() << "   src mem: " << srcMemSizeVal << "\n"
                          << "   dst mem: " << dstMemSizeVal << "\n"
                          << "   fused mem: " << fusedMem << "\n"
                          << "   slice mem: " << sliceMemEstimate << "\n");

  if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) {
    LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n");
    return false;
  }
  storageReduction =
      100.0 *
      (1.0 - fusedMem / (static_cast<double>(srcMemSizeVal) + dstMemSizeVal));

  double additionalComputeFraction =
      100.0 * (minFusedLoopNestComputeCost /
                   (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
               1);
  (void)additionalComputeFraction;
  LLVM_DEBUG({
    std::stringstream msg;
    msg << " fusion is most profitable at depth " << *dstLoopDepth << " with "
        << std::setprecision(2) << additionalComputeFraction
        << "% redundant computation and a ";
    msg << (storageReduction ? std::to_string(*storageReduction) : "<unknown>");
    msg << "% storage reduction.\n";
    llvm::dbgs() << msg.str();
  });

  return true;
}

namespace {

// GreedyFusion greedily fuses loop nests which have a producer/consumer or
// input-reuse relationship on a memref, with the goal of improving locality.
//
// The steps of the producer-consumer fusion algorithm are as follows:
//
// *) A worklist is initialized with node ids from the dependence graph.
// *) For each node id in the worklist:
//   *) Pop an AffineForOp of the worklist. This 'dstAffineForOp' will be a
//      candidate destination AffineForOp into which fusion will be attempted.
//   *) Add each LoadOp currently in 'dstAffineForOp' into list 'dstLoadOps'.
//   *) For each LoadOp in 'dstLoadOps' do:
//      *) Look up dependent loop nests which have a single store op to the same
//         memref.
//      *) Check if dependences would be violated by the fusion.
//      *) Get a computation slice of 'srcLoopNest', which adjusts its loop
//         bounds to be functions of 'dstLoopNest' IVs and symbols.
//      *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest',
//         at a loop depth determined by the cost model in 'isFusionProfitable'.
//      *) Add the newly fused load/store operations to the state,
//         and also add newly fused load ops to 'dstLoopOps' to be considered
//         as fusion dst load ops in another iteration.
//      *) Remove old src loop nest and its associated state.
//
// The steps of the input-reuse fusion algorithm are as follows:
//
// *) Initialize 'worklist' with node ids from the dependence graph.
// *) For each 'dstNode' in the worklist:
//   *) Find a candidate sibling node 'sibNode' to fuse with 'dstNode' which
//      loads from the same memref, but which has no dependence paths to/from.
//   *) Get a computation slice of 'sibLoopNest', which adjusts its loop
//      bounds to be functions of 'dstLoopNest' IVs and symbols.
//   *) Fuse the 'sibLoopNest' computation slice into the 'dstLoopNest',
//      at a loop depth determined by the cost model in 'isFusionProfitable'.
//      This function also checks that the memref write region of 'sibLoopNest',
//      is preserved in the fused loop nest.
//   *) Update graph state to reflect the fusion of 'sibNode' into 'dstNode'.
//
// Given a graph where top-level operations are vertices in the set 'V' and
// edges in the set 'E' are dependences between vertices, this algorithm
// takes O(V) time for initialization, and has runtime O(V + E).
//
// This greedy algorithm is not 'maximal' due to the current restriction of
// fusing along single producer consumer edges, but there is a TODO: to fix
// this.
//
// TODO: Experiment with other fusion policies.
struct GreedyFusion {
public:
  // The data dependence graph to traverse during fusion.
  MemRefDependenceGraph *mdg;
  // Worklist of graph nodes visited during the fusion pass.
  SmallVector<unsigned, 8> worklist;
  // Parameter for local buffer size threshold.
  unsigned localBufSizeThreshold;
  // Parameter for fast memory space.
  std::optional<unsigned> fastMemorySpace;
  // If true, ignore any additional (redundant) computation tolerance threshold
  // that would have prevented fusion.
  bool maximalFusion;
  // The amount of additional computation that is tolerated while fusing
  // pair-wise as a fraction of the total computation.
  double computeToleranceThreshold;

  using Node = MemRefDependenceGraph::Node;

  GreedyFusion(MemRefDependenceGraph *mdg, unsigned localBufSizeThreshold,
               std::optional<unsigned> fastMemorySpace, bool maximalFusion,
               double computeToleranceThreshold)
      : mdg(mdg), localBufSizeThreshold(localBufSizeThreshold),
        fastMemorySpace(fastMemorySpace), maximalFusion(maximalFusion),
        computeToleranceThreshold(computeToleranceThreshold) {}

  /// Initializes 'worklist' with nodes from 'mdg'.
  void init() {
    // TODO: Add a priority queue for prioritizing nodes by different
    // metrics (e.g. arithmetic intensity/flops-to-bytes ratio).
    worklist.clear();
    for (auto &idAndNode : mdg->nodes) {
      const Node &node = idAndNode.second;
      worklist.push_back(node.id);
    }
  }
  /// Run only sibling fusion on the `mdg`.
  void runSiblingFusionOnly() {
    fuseSiblingNodes();
    eraseUnusedMemRefAllocations();
  }

  /// Run only producer/consumer fusion on the `mdg`.
  void runProducerConsumerFusionOnly() {
    fuseProducerConsumerNodes(
        /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max());
    eraseUnusedMemRefAllocations();
  }

  // Run the GreedyFusion pass.
  // *) First pass through the nodes fuses single-use producer nodes into their
  //    unique consumer.
  // *) Second pass fuses sibling nodes which share no dependence edges.
  // *) Third pass fuses any remaining producer nodes into their users.
  void runGreedyFusion() {
    // TODO: Run this repeatedly until a fixed-point is reached.
    fuseProducerConsumerNodes(/*maxSrcUserCount=*/1);
    fuseSiblingNodes();
    fuseProducerConsumerNodes(
        /*maxSrcUserCount=*/std::numeric_limits<unsigned>::max());
    eraseUnusedMemRefAllocations();
  }

  /// Returns true if a private memref can be created for `memref` given
  /// the fusion scenario reflected by the other arguments.
  bool canCreatePrivateMemRef(Value memref,
                              const DenseSet<Value> &srcEscapingMemRefs,
                              unsigned producerId, unsigned consumerId,
                              bool removeSrcNode) {
    const Node *consumerNode = mdg->getNode(consumerId);
    // If `memref` is an escaping one, do not create a private memref
    // for the below scenarios, since doing so will leave the escaping
    // memref unmodified as all the writes originally meant for the
    // escaping memref would be performed on the private memref:
    // 1. The source is to be removed after fusion,
    // OR
    // 2. The destination writes to `memref`.
    if (srcEscapingMemRefs.count(memref) > 0 &&
        (removeSrcNode || consumerNode->getStoreOpCount(memref) > 0))
      return false;

    // Don't create a private memref if 'srcNode' has in edges on
    // 'memref' or 'dstNode' has out edges on 'memref'.
    if (mdg->getIncomingMemRefAccesses(producerId, memref) > 0 ||
        mdg->getOutEdgeCount(consumerId, memref) > 0)
      return false;

    // If 'srcNode' will be removed but it has out edges on 'memref' to
    // nodes other than 'dstNode', we have to preserve dependences and
    // cannot create a private memref.
    if (removeSrcNode &&
        any_of(mdg->outEdges[producerId], [&](const auto &edge) {
          return edge.value == memref && edge.id != consumerId;
        }))
      return false;

    return true;
  }

  /// Perform fusions with node `dstId` as the destination of fusion, with
  /// No fusion is performed when producers with a user count greater than
  /// `maxSrcUserCount` for any of the memrefs involved.
  void performFusionsIntoDest(unsigned dstId, unsigned maxSrcUserCount) {
    LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n");
    // Skip if this node was removed (fused into another node).
    if (mdg->nodes.count(dstId) == 0)
      return;
    // Get 'dstNode' into which to attempt fusion.
    auto *dstNode = mdg->getNode(dstId);
    // Skip if 'dstNode' is not a loop nest.
    if (!isa<AffineForOp>(dstNode->op))
      return;
    // Skip if 'dstNode' is a loop nest returning values.
    // TODO: support loop nests that return values.
    if (dstNode->op->getNumResults() > 0)
      return;

    LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n");

    // Sink sequential loops in 'dstNode' (and thus raise parallel loops)
    // while preserving relative order. This can increase the maximum loop
    // depth at which we can fuse a slice of a producer loop nest into a
    // consumer loop nest.
    sinkSequentialLoops(dstNode);
    auto dstAffineForOp = cast<AffineForOp>(dstNode->op);

    // Try to fuse 'dstNode' with candidate producer loops until a fixed point
    // is reached. Fusing two loops may expose new fusion opportunities.
    bool dstNodeChanged;
    do {
      // Gather src loop candidates for 'dstNode' and visit them in "quasi"
      // reverse program order to minimize the number of iterations needed to
      // reach the fixed point. Note that this is a best effort approach since
      // 'getProducerCandidates' does not always guarantee that program order
      // in 'srcIdCandidates'.
      dstNodeChanged = false;
      SmallVector<unsigned, 16> srcIdCandidates;
      getProducerCandidates(dstId, mdg, srcIdCandidates);

      for (unsigned srcId : llvm::reverse(srcIdCandidates)) {
        // Get 'srcNode' from which to attempt fusion into 'dstNode'.
        auto *srcNode = mdg->getNode(srcId);
        auto srcAffineForOp = cast<AffineForOp>(srcNode->op);
        LLVM_DEBUG(llvm::dbgs() << "Evaluating src loop " << srcId
                                << " for dst loop " << dstId << "\n");

        // Skip if 'srcNode' is a loop nest returning values.
        // TODO: support loop nests that return values.
        if (isa<AffineForOp>(srcNode->op) && srcNode->op->getNumResults() > 0)
          continue;

        DenseSet<Value> producerConsumerMemrefs;
        gatherProducerConsumerMemrefs(srcId, dstId, mdg,
                                      producerConsumerMemrefs);

        // Skip if 'srcNode' out edge count on any memref is greater than
        // 'maxSrcUserCount'.
        if (any_of(producerConsumerMemrefs, [&](Value memref) {
              return mdg->getOutEdgeCount(srcNode->id, memref) >
                     maxSrcUserCount;
            }))
          continue;

        // Gather memrefs in 'srcNode' that are written and escape out of the
        // block (e.g., memref block arguments, returned memrefs,
        // memrefs passed to function calls, etc.).
        DenseSet<Value> srcEscapingMemRefs;
        gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs);

        // Skip if there are non-affine operations in between the 'srcNode'
        // and 'dstNode' using their memrefs. If so, we wouldn't be able to
        // compute a legal insertion point for now. 'srcNode' and 'dstNode'
        // memrefs with non-affine operation users would be considered
        // escaping memrefs so we can limit this check to only scenarios with
        // escaping memrefs.
        if (!srcEscapingMemRefs.empty() &&
            hasNonAffineUsersOnThePath(srcId, dstId, mdg)) {
          LLVM_DEBUG(llvm::dbgs()
                     << "Can't fuse: non-affine users in between the loops\n");
          continue;
        }

        // Compute an operation list insertion point for the fused loop
        // nest which preserves dependences.
        Operation *fusedLoopInsPoint =
            mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id);
        if (fusedLoopInsPoint == nullptr)
          continue;

        // Compute the innermost common loop depth for dstNode
        // producer-consumer loads/stores.
        SmallVector<Operation *, 2> dstMemrefOps;
        for (Operation *op : dstNode->loads)
          if (producerConsumerMemrefs.count(
                  cast<AffineReadOpInterface>(op).getMemRef()) > 0)
            dstMemrefOps.push_back(op);
        for (Operation *op : dstNode->stores)
          if (producerConsumerMemrefs.count(
                  cast<AffineWriteOpInterface>(op).getMemRef()))
            dstMemrefOps.push_back(op);
        unsigned dstLoopDepthTest = getInnermostCommonLoopDepth(dstMemrefOps);

        // Check the feasibility of fusing src loop nest into dst loop nest
        // at loop depths in range [1, dstLoopDepthTest].
        unsigned maxLegalFusionDepth = 0;
        SmallVector<ComputationSliceState, 8> depthSliceUnions;
        depthSliceUnions.resize(dstLoopDepthTest);
        FusionStrategy strategy(FusionStrategy::ProducerConsumer);
        for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
          FusionResult result = affine::canFuseLoops(
              srcAffineForOp, dstAffineForOp,
              /*dstLoopDepth=*/i, &depthSliceUnions[i - 1], strategy);

          if (result.value == FusionResult::Success)
            maxLegalFusionDepth = i;
        }

        if (maxLegalFusionDepth == 0) {
          LLVM_DEBUG(llvm::dbgs()
                     << "Can't fuse: fusion is not legal at any depth\n");
          continue;
        }

        // Check if fusion would be profitable. We skip profitability analysis
        // for maximal fusion since we already know the maximal legal depth to
        // fuse.
        unsigned bestDstLoopDepth = maxLegalFusionDepth;
        if (!maximalFusion) {
          // Retrieve producer stores from the src loop.
          SmallVector<Operation *, 2> producerStores;
          for (Operation *op : srcNode->stores)
            if (producerConsumerMemrefs.count(
                    cast<AffineWriteOpInterface>(op).getMemRef()))
              producerStores.push_back(op);

          // TODO: Suppport multiple producer stores in profitability
          // analysis. We limit profitability analysis to only scenarios with
          // a single producer store for now. Note that some multi-store
          // producer scenarios will still go through profitability analysis
          // if only one of the stores is involved the producer-consumer
          // relationship of the candidate loops.
          assert(!producerStores.empty() && "Expected producer store");
          if (producerStores.size() > 1)
            LLVM_DEBUG(llvm::dbgs() << "Skipping profitability analysis. Not "
                                       "supported for this case\n");
          else if (!isFusionProfitable(producerStores[0], producerStores[0],
                                       dstAffineForOp, depthSliceUnions,
                                       maxLegalFusionDepth, &bestDstLoopDepth,
                                       computeToleranceThreshold))
            continue;
        }

        assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
        ComputationSliceState &bestSlice =
            depthSliceUnions[bestDstLoopDepth - 1];
        assert(!bestSlice.isEmpty() && "Missing slice union for depth");

        // Determine if 'srcId' can be removed after fusion, taking into
        // account remaining dependences, escaping memrefs and the fusion
        // insertion point.
        bool removeSrcNode = canRemoveSrcNodeAfterFusion(
            srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs,
            mdg);

        DenseSet<Value> privateMemrefs;
        for (Value memref : producerConsumerMemrefs) {
          if (canCreatePrivateMemRef(memref, srcEscapingMemRefs, srcId, dstId,
                                     removeSrcNode)) {
            // Create a private version of this memref.
            LLVM_DEBUG(llvm::dbgs()
                       << "Creating private memref for " << memref << '\n');
            // Create a private version of this memref.
            privateMemrefs.insert(memref);
          }
        }

        // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
        fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice);
        dstNodeChanged = true;

        LLVM_DEBUG(llvm::dbgs()
                   << "Fused src loop " << srcId << " into dst loop " << dstId
                   << " at depth " << bestDstLoopDepth << ":\n"
                   << dstAffineForOp << "\n");

        // Move 'dstAffineForOp' before 'insertPointInst' if needed.
        if (fusedLoopInsPoint != dstAffineForOp)
          dstAffineForOp->moveBefore(fusedLoopInsPoint);

        // Update edges between 'srcNode' and 'dstNode'.
        mdg->updateEdges(srcNode->id, dstNode->id, privateMemrefs,
                         removeSrcNode);

        // Create private memrefs.
        if (!privateMemrefs.empty()) {
          // Gather stores for all the private-to-be memrefs.
          DenseMap<Value, SmallVector<Operation *, 4>> privateMemRefToStores;
          dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) {
            Value storeMemRef = storeOp.getMemRef();
            if (privateMemrefs.count(storeMemRef) > 0)
              privateMemRefToStores[storeMemRef].push_back(storeOp);
          });

          // Replace original memrefs with private memrefs. Note that all the
          // loads and stores on these memrefs will be replaced with a new
          // loads and stores. Any reference to the original ones becomes
          // invalid after this point.
          for (auto &memrefToStoresPair : privateMemRefToStores) {
            // TODO: Use union of memref write regions to compute
            // private memref footprint.
            SmallVector<Operation *, 4> &storesForMemref =
                memrefToStoresPair.second;
            Value newMemRef = createPrivateMemRef(
                dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
                fastMemorySpace, localBufSizeThreshold);
            // Create new node in dependence graph for 'newMemRef' alloc op.
            unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp());
            // Add edge from 'newMemRef' node to dstNode.
            mdg->addEdge(newMemRefNodeId, dstId, newMemRef);
          }
          // One or more entries for 'newMemRef' alloc op are inserted into
          // the DenseMap mdg->nodes. Since an insertion may cause DenseMap to
          // reallocate, update dstNode.
          dstNode = mdg->getNode(dstId);
        }

        // Collect dst loop stats after memref privatization transformation.
        LoopNestStateCollector dstLoopCollector;
        dstLoopCollector.collect(dstAffineForOp);

        // Clear and add back loads and stores.
        mdg->clearNodeLoadAndStores(dstNode->id);
        mdg->addToNode(dstId, dstLoopCollector.loadOpInsts,
                       dstLoopCollector.storeOpInsts);

        if (removeSrcNode) {
          LLVM_DEBUG(llvm::dbgs()
                     << "Removing src loop " << srcId << " after fusion\n");
          // srcNode is no longer valid after it is removed from mdg.
          srcAffineForOp.erase();
          mdg->removeNode(srcId);
          srcNode = nullptr;
        }
      }
    } while (dstNodeChanged);
  }

  /// Visit each node in the graph, and for each node, attempt to fuse it with
  /// producer-consumer candidates. No fusion is performed when producers with a
  /// user count greater than `maxSrcUserCount` for any of the memrefs involved
  /// are encountered.
  void fuseProducerConsumerNodes(unsigned maxSrcUserCount) {
    LLVM_DEBUG(llvm::dbgs() << "--- Producer/Consumer Fusion ---\n");
    init();
    while (!worklist.empty()) {
      unsigned dstId = worklist.back();
      worklist.pop_back();
      performFusionsIntoDest(dstId, maxSrcUserCount);
    }
  }

  // Visits each node in the graph, and for each node, attempts to fuse it with
  // its sibling nodes (nodes which share a parent, but no dependence edges).
  void fuseSiblingNodes() {
    LLVM_DEBUG(llvm::dbgs() << "--- Sibling Fusion ---\n");
    init();
    while (!worklist.empty()) {
      unsigned dstId = worklist.back();
      worklist.pop_back();

      // Skip if this node was removed (fused into another node).
      if (mdg->nodes.count(dstId) == 0)
        continue;
      // Get 'dstNode' into which to attempt fusion.
      auto *dstNode = mdg->getNode(dstId);
      // Skip if 'dstNode' is not a loop nest.
      if (!isa<AffineForOp>(dstNode->op))
        continue;
      // Attempt to fuse 'dstNode' with its sibling nodes in the graph.
      fuseWithSiblingNodes(dstNode);
    }
  }

  // Attempt to fuse 'dstNode' with sibling nodes in the graph.
  void fuseWithSiblingNodes(Node *dstNode) {
    DenseSet<unsigned> visitedSibNodeIds;
    std::pair<unsigned, Value> idAndMemref;
    auto dstAffineForOp = cast<AffineForOp>(dstNode->op);

    while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) {
      unsigned sibId = idAndMemref.first;
      Value memref = idAndMemref.second;
      // TODO: Check that 'sibStoreOpInst' post-dominates all other
      // stores to the same memref in 'sibNode' loop nest.
      auto *sibNode = mdg->getNode(sibId);
      // Compute an operation list insertion point for the fused loop
      // nest which preserves dependences.
      assert(sibNode->op->getBlock() == dstNode->op->getBlock());
      Operation *insertPointInst =
          sibNode->op->isBeforeInBlock(dstNode->op)
              ? mdg->getFusedLoopNestInsertionPoint(sibNode->id, dstNode->id)
              : mdg->getFusedLoopNestInsertionPoint(dstNode->id, sibNode->id);
      if (insertPointInst == nullptr)
        continue;

      // Check if fusion would be profitable and at what depth.

      // Get unique 'sibNode' load op to 'memref'.
      SmallVector<Operation *, 2> sibLoadOpInsts;
      sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts);
      // Currently findSiblingNodeToFuse searches for siblings with one load.
      assert(sibLoadOpInsts.size() == 1);
      Operation *sibLoadOpInst = sibLoadOpInsts[0];

      // Gather 'dstNode' load ops to 'memref'.
      SmallVector<Operation *, 2> dstLoadOpInsts;
      dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts);

      SmallVector<AffineForOp, 4> dstLoopIVs;
      getAffineForIVs(*dstLoadOpInsts[0], &dstLoopIVs);
      unsigned dstLoopDepthTest = dstLoopIVs.size();
      auto sibAffineForOp = cast<AffineForOp>(sibNode->op);

      // Compute loop depth and slice union for fusion.
      SmallVector<ComputationSliceState, 8> depthSliceUnions;
      depthSliceUnions.resize(dstLoopDepthTest);
      unsigned maxLegalFusionDepth = 0;
      FusionStrategy strategy(memref);
      for (unsigned i = 1; i <= dstLoopDepthTest; ++i) {
        FusionResult result = affine::canFuseLoops(
            sibAffineForOp, dstAffineForOp,
            /*dstLoopDepth=*/i, &depthSliceUnions[i - 1], strategy);

        if (result.value == FusionResult::Success)
          maxLegalFusionDepth = i;
      }

      // Skip if fusion is not feasible at any loop depths.
      if (maxLegalFusionDepth == 0)
        continue;

      unsigned bestDstLoopDepth = maxLegalFusionDepth;
      if (!maximalFusion) {
        // Check if fusion would be profitable. For sibling fusion, the sibling
        // load op is treated as the src "store" op for fusion profitability
        // purposes. The footprint of the load in the slice relative to the
        // unfused source's determines reuse.
        if (!isFusionProfitable(sibLoadOpInst, sibLoadOpInst, dstAffineForOp,
                                depthSliceUnions, maxLegalFusionDepth,
                                &bestDstLoopDepth, computeToleranceThreshold))
          continue;
      }

      assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
      assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() &&
             "Fusion depth has no computed slice union");
      // Check if source loop is being inserted in the innermost
      // destination loop. Based on this, the fused loop may be optimized
      // further inside `fuseLoops`.
      bool isInnermostInsertion = (bestDstLoopDepth == dstLoopDepthTest);
      // Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'.
      affine::fuseLoops(sibAffineForOp, dstAffineForOp,
                        depthSliceUnions[bestDstLoopDepth - 1],
                        isInnermostInsertion);

      auto dstForInst = cast<AffineForOp>(dstNode->op);
      // Update operation position of fused loop nest (if needed).
      if (insertPointInst != dstForInst) {
        dstForInst->moveBefore(insertPointInst);
      }
      // Update data dependence graph state post fusion.
      updateStateAfterSiblingFusion(sibNode, dstNode);
    }
  }

  // Searches block argument uses and the graph from 'dstNode' looking for a
  // fusion candidate sibling node which shares no dependences with 'dstNode'
  // but which loads from the same memref. Returns true and sets
  // 'idAndMemrefToFuse' on success. Returns false otherwise.
  bool findSiblingNodeToFuse(Node *dstNode,
                             DenseSet<unsigned> *visitedSibNodeIds,
                             std::pair<unsigned, Value> *idAndMemrefToFuse) {
    // Returns true if 'sibNode' can be fused with 'dstNode' for input reuse
    // on 'memref'.
    auto canFuseWithSibNode = [&](Node *sibNode, Value memref) {
      // Skip if 'outEdge' is not a read-after-write dependence.
      // TODO: Remove restrict to single load op restriction.
      if (sibNode->getLoadOpCount(memref) != 1)
        return false;
      // Skip if there exists a path of dependent edges between
      // 'sibNode' and 'dstNode'.
      if (mdg->hasDependencePath(sibNode->id, dstNode->id) ||
          mdg->hasDependencePath(dstNode->id, sibNode->id))
        return false;
      // Skip sib node if it loads to (and stores from) the same memref on
      // which it also has an input dependence edge.
      DenseSet<Value> loadAndStoreMemrefSet;
      sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet);
      if (llvm::any_of(loadAndStoreMemrefSet, [=](Value memref) {
            return mdg->getIncomingMemRefAccesses(sibNode->id, memref) > 0;
          }))
        return false;

      // Check that all stores are to the same memref if any.
      DenseSet<Value> storeMemrefs;
      for (auto *storeOpInst : sibNode->stores) {
        storeMemrefs.insert(
            cast<AffineWriteOpInterface>(storeOpInst).getMemRef());
      }
      if (storeMemrefs.size() > 1)
        return false;

      // Skip if a memref value in one node is used by a non-affine memref
      // access that lies between 'dstNode' and 'sibNode'.
      if (hasNonAffineUsersOnThePath(dstNode->id, sibNode->id, mdg) ||
          hasNonAffineUsersOnThePath(sibNode->id, dstNode->id, mdg))
        return false;
      return true;
    };

    // Search for siblings which load the same memref block argument.
    Block *block = dstNode->op->getBlock();
    for (unsigned i = 0, e = block->getNumArguments(); i != e; ++i) {
      for (Operation *user : block->getArgument(i).getUsers()) {
        auto loadOp = dyn_cast<AffineReadOpInterface>(user);
        if (!loadOp)
          continue;
        // Gather loops surrounding 'use'.
        SmallVector<AffineForOp, 4> loops;
        getAffineForIVs(*user, &loops);
        // Skip 'use' if it is not within a loop nest.
        if (loops.empty())
          continue;
        Node *sibNode = mdg->getForOpNode(loops[0]);
        assert(sibNode != nullptr);
        // Skip 'use' if it not a sibling to 'dstNode'.
        if (sibNode->id == dstNode->id)
          continue;
        // Skip 'use' if it has been visited.
        if (visitedSibNodeIds->count(sibNode->id) > 0)
          continue;
        // Skip 'use' if it does not load from the same memref as 'dstNode'.
        auto memref = loadOp.getMemRef();
        if (dstNode->getLoadOpCount(memref) == 0)
          continue;
        // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'.
        if (canFuseWithSibNode(sibNode, memref)) {
          visitedSibNodeIds->insert(sibNode->id);
          idAndMemrefToFuse->first = sibNode->id;
          idAndMemrefToFuse->second = memref;
          return true;
        }
      }
    }

    // Search for siblings by following edges through an intermediate src node.
    // Collect candidate 'dstNode' input edges in 'inEdges'.
    SmallVector<MemRefDependenceGraph::Edge, 2> inEdges;
    mdg->forEachMemRefInputEdge(
        dstNode->id, [&](MemRefDependenceGraph::Edge inEdge) {
          // Add 'inEdge' if it is a read-after-write dependence.
          if (dstNode->getLoadOpCount(inEdge.value) > 0 &&
              mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0)
            inEdges.push_back(inEdge);
        });

    // Search for sibling nodes to fuse by visiting output edges from each input
    // edge in 'inEdges'.
    for (auto &inEdge : inEdges) {
      // Collect candidate output edges from each node 'inEdge.id' in 'inEdges'.
      SmallVector<MemRefDependenceGraph::Edge, 2> outEdges;
      mdg->forEachMemRefOutputEdge(
          inEdge.id, [&](MemRefDependenceGraph::Edge outEdge) {
            unsigned sibNodeId = outEdge.id;
            if (visitedSibNodeIds->count(sibNodeId) > 0)
              return;
            // Skip output edge if not a sibling using the same memref.
            if (outEdge.id == dstNode->id || outEdge.value != inEdge.value)
              return;
            auto *sibNode = mdg->getNode(sibNodeId);
            if (!isa<AffineForOp>(sibNode->op))
              return;
            // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'.
            if (canFuseWithSibNode(sibNode, outEdge.value)) {
              // Add candidate 'outEdge' to sibling node.
              outEdges.push_back(outEdge);
            }
          });

      // Add first candidate if any were returned.
      if (!outEdges.empty()) {
        visitedSibNodeIds->insert(outEdges[0].id);
        idAndMemrefToFuse->first = outEdges[0].id;
        idAndMemrefToFuse->second = outEdges[0].value;
        return true;
      }
    }
    return false;
  }

  /// Update data dependence graph state to reflect sibling fusion of 'sibNode'
  /// into 'dstNode'.
  void updateStateAfterSiblingFusion(Node *sibNode, Node *dstNode) {
    // Update 'sibNode' and 'dstNode' input/output edges to reflect fusion.
    mdg->updateEdges(sibNode->id, dstNode->id);

    // Collect dst loop stats after memref privatization transformation.
    auto dstForInst = cast<AffineForOp>(dstNode->op);
    LoopNestStateCollector dstLoopCollector;
    dstLoopCollector.collect(dstForInst);
    // Clear and add back loads and stores
    mdg->clearNodeLoadAndStores(dstNode->id);
    mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts,
                   dstLoopCollector.storeOpInsts);
    // Remove old sibling loop nest if it no longer has outgoing dependence
    // edges, and it does not write to a memref which escapes the block.
    if (mdg->getOutEdgeCount(sibNode->id) == 0) {
      Operation *op = sibNode->op;
      mdg->removeNode(sibNode->id);
      op->erase();
    }
  }

  // Clean up any allocs with no users.
  void eraseUnusedMemRefAllocations() {
    for (auto &pair : mdg->memrefEdgeCount) {
      if (pair.second > 0)
        continue;
      auto memref = pair.first;
      // Skip if there exist other uses (return operation or function calls).
      if (!memref.use_empty())
        continue;
      // Use list expected to match the dep graph info.
      auto *op = memref.getDefiningOp();
      if (isa_and_nonnull<memref::AllocOp>(op))
        op->erase();
    }
  }
};

} // namespace

/// Run fusion on `block`.
void LoopFusion::runOnBlock(Block *block) {
  MemRefDependenceGraph g(*block);
  if (!g.init()) {
    LLVM_DEBUG(llvm::dbgs() << "MDG init failed\n");
    return;
  }

  std::optional<unsigned> fastMemorySpaceOpt;
  if (fastMemorySpace.hasValue())
    fastMemorySpaceOpt = fastMemorySpace;
  unsigned localBufSizeThresholdBytes = localBufSizeThreshold * 1024;
  GreedyFusion fusion(&g, localBufSizeThresholdBytes, fastMemorySpaceOpt,
                      maximalFusion, computeToleranceThreshold);

  if (affineFusionMode == FusionMode::ProducerConsumer)
    fusion.runProducerConsumerFusionOnly();
  else if (affineFusionMode == FusionMode::Sibling)
    fusion.runSiblingFusionOnly();
  else
    fusion.runGreedyFusion();
}

void LoopFusion::runOnOperation() {
  for (Region &region : getOperation()->getRegions())
    for (Block &block : region.getBlocks())
      runOnBlock(&block);
}

std::unique_ptr<Pass> mlir::affine::createLoopFusionPass(
    unsigned fastMemorySpace, uint64_t localBufSizeThreshold,
    bool maximalFusion, enum FusionMode affineFusionMode) {
  return std::make_unique<LoopFusion>(fastMemorySpace, localBufSizeThreshold,
                                      maximalFusion, affineFusionMode);
}