File: TileAllocation.cpp

package info (click to toggle)
swiftlang 6.1.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 2,791,604 kB
  • sloc: cpp: 9,901,740; ansic: 2,201,431; asm: 1,091,827; python: 308,252; objc: 82,166; f90: 80,126; lisp: 38,358; pascal: 25,559; sh: 20,429; ml: 5,058; perl: 4,745; makefile: 4,484; awk: 3,535; javascript: 3,018; xml: 918; fortran: 664; cs: 573; ruby: 396
file content (855 lines) | stat: -rw-r--r-- 34,522 bytes parent folder | download | duplicates (4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
//===- TileAllocation.cpp - Allocate SME ZA tiles -------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This transform allocates SME tiles at the 'func.func' op level for ArmSME
// operations. It roughly implements a linear scan register allocator, similar
// to the one outlined in [1], but with simplifications and assumptions made for
// our use case. Note that this is a greedy allocator (so it may not always find
// the most optimal allocation of tiles).
//
// The allocator operates at the CF dialect level. It is the responsibility of
// users to ensure the IR has been lowered to CF before invoking the tile
// allocator.
//
// The 128-bit tiles overlap with other element tiles as follows (see section
// B2.3.2 of SME spec [2]):
//
//   Tile    Overlaps
//   ---------------------------------------------------------------------------
//   ZA0.B   ZA0.Q, ZA1.Q, ZA2.Q, ZA3.Q, ZA4.Q, ZA5.Q, ZA6.Q, ZA7.Q, ZA8.Q,
//           ZA9.Q, ZA10.Q, ZA11.Q, ZA12.Q, ZA13.Q, ZA14.Q, ZA15.Q
//   ZA0.H   ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
//   ZA1.H   ZA1.Q, ZA3.Q, ZA5.Q, ZA7.Q, ZA9.Q, ZA11.Q, ZA13.Q, ZA15.Q
//   ZA0.S   ZA0.Q, ZA4.Q, ZA8.Q, ZA12.Q
//   ZA1.S   ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q
//   ZA2.S   ZA2.Q, ZA6.Q, ZA10.Q, ZA14.Q
//   ZA3.S   ZA3.Q, ZA7.Q, ZA11.Q, ZA15.Q
//   ZA0.D   ZA0.Q, ZA8.Q
//   ZA1.D   ZA1.Q, ZA9.Q
//   ZA2.D   ZA2.Q, ZA10.Q
//   ZA3.D   ZA3.Q, ZA11.Q
//   ZA4.D   ZA4.Q, ZA12.Q
//   ZA5.D   ZA5.Q, ZA13.Q
//   ZA6.D   ZA6.Q, ZA14.Q
//   ZA7.D   ZA7.Q, ZA15.Q
//
// [1] "Linear Scan Register Allocation in the Context of SSA Form and Register
//      Constraints" (Hanspeter Mössenböck and Michael Pfeiffer)
//     https://link.springer.com/content/pdf/10.1007/3-540-45937-5_17.pdf
// [2] https://developer.arm.com/documentation/ddi0616/aa
//
//===----------------------------------------------------------------------===//

#include "mlir/Analysis/Liveness.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/IntervalMap.h"
#include "llvm/ADT/TypeSwitch.h"
#include <algorithm>

namespace mlir::arm_sme {
#define GEN_PASS_DEF_TESTTILEALLOCATION
#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
} // namespace mlir::arm_sme

using namespace mlir;
using namespace mlir::arm_sme;

namespace {

enum class TileMask : unsigned {
  // clang-format off
  kZA0B  = 0xffff, // 1111 1111 1111 1111

  kZA0H  = 0xaaaa, // 1010 1010 1010 1010
  kZA1H  = 0x5555, // 0101 0101 0101 0101

  kZA0S  = 0x8888, // 1000 1000 1000 1000
  kZA1S  = 0x4444, // 0100 0100 0100 0100
  kZA2S  = 0x2222, // 0010 0010 0010 0010
  kZA3S  = 0x1111, // 0001 0001 0001 0001

  kZA0D  = 0x8080, // 1000 0000 1000 0000
  kZA1D  = 0x4040, // 0100 0000 0100 0000
  kZA2D  = 0x2020, // 0010 0000 0010 0000
  kZA3D  = 0x1010, // 0001 0000 0001 0000
  kZA4D  = 0x808,  // 0000 1000 0000 1000
  kZA5D  = 0x404,  // 0000 0100 0000 0100
  kZA6D  = 0x202,  // 0000 0010 0000 0010
  kZA7D  = 0x101,  // 0000 0001 0000 0001

  kZA0Q  = 0x8000, // 1000 0000 0000 0000
  kZA1Q  = 0x4000, // 0100 0000 0000 0000
  kZA2Q  = 0x2000, // 0010 0000 0000 0000
  kZA3Q  = 0x1000, // 0001 0000 0000 0000
  kZA4Q  = 0x800,  // 0000 1000 0000 0000
  kZA5Q  = 0x400,  // 0000 0100 0000 0000
  kZA6Q  = 0x200,  // 0000 0010 0000 0000
  kZA7Q  = 0x100,  // 0000 0001 0000 0000
  kZA8Q  = 0x80,   // 0000 0000 1000 0000
  kZA9Q  = 0x40,   // 0000 0000 0100 0000
  kZA10Q = 0x20,   // 0000 0000 0010 0000
  kZA11Q = 0x10,   // 0000 0000 0001 0000
  kZA12Q = 0x8,    // 0000 0000 0000 1000
  kZA13Q = 0x4,    // 0000 0000 0000 0100
  kZA14Q = 0x2,    // 0000 0000 0000 0010
  kZA15Q = 0x1,    // 0000 0000 0000 0001

  kNone = 0x0,     // 0000 0000 0000 0000
  // clang-format on

  LLVM_MARK_AS_BITMASK_ENUM(kZA0B)
};

/// Returns the set of masks relevant for the given type.
static ArrayRef<TileMask> getMasks(ArmSMETileType type) {
  static constexpr std::array ZA_B_MASKS = {TileMask::kZA0B};
  static constexpr std::array ZA_H_MASKS = {TileMask::kZA0H, TileMask::kZA1H};
  static constexpr std::array ZA_S_MASKS = {TileMask::kZA0S, TileMask::kZA1S,
                                            TileMask::kZA2S, TileMask::kZA3S};
  static constexpr std::array ZA_D_MASKS = {
      TileMask::kZA0D, TileMask::kZA1D, TileMask::kZA2D, TileMask::kZA3D,
      TileMask::kZA4D, TileMask::kZA5D, TileMask::kZA6D, TileMask::kZA7D};
  static constexpr std::array ZA_Q_MASKS = {
      TileMask::kZA0Q,  TileMask::kZA1Q,  TileMask::kZA2Q,  TileMask::kZA3Q,
      TileMask::kZA4Q,  TileMask::kZA5Q,  TileMask::kZA6Q,  TileMask::kZA7Q,
      TileMask::kZA8Q,  TileMask::kZA9Q,  TileMask::kZA10Q, TileMask::kZA11Q,
      TileMask::kZA12Q, TileMask::kZA13Q, TileMask::kZA14Q, TileMask::kZA15Q};
  switch (type) {
  case ArmSMETileType::ZAB:
    return ZA_B_MASKS;
  case ArmSMETileType::ZAH:
    return ZA_H_MASKS;
  case ArmSMETileType::ZAS:
    return ZA_S_MASKS;
  case ArmSMETileType::ZAD:
    return ZA_D_MASKS;
  case ArmSMETileType::ZAQ:
    return ZA_Q_MASKS;
  }
}

class TileAllocator {
public:
  /// Allocates and returns a tile ID. Fails if there are no tiles left.
  FailureOr<unsigned> allocateTileId(ArmSMETileType tileType) {
    auto masks = getMasks(tileType);
    for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
      if ((tilesInUse & tileMask) == TileMask::kNone) {
        tilesInUse |= tileMask;
        return tileId;
      }
    }
    return failure();
  }

  /// Acquires a specific tile ID. Asserts the tile is initially free.
  void acquireTileId(ArmSMETileType tileType, unsigned tileId) {
    TileMask tileMask = getMasks(tileType)[tileId];
    assert((tilesInUse & tileMask) == TileMask::kNone &&
           "cannot acquire allocated tile!");
    tilesInUse |= tileMask;
  }

  /// Releases a previously allocated tile ID.
  void releaseTileId(ArmSMETileType tileType, unsigned tileId) {
    TileMask tileMask = getMasks(tileType)[tileId];
    assert((tilesInUse & tileMask) == tileMask &&
           "cannot release unallocated tile!");
    tilesInUse ^= tileMask;
  }

  /// Allocates an in-memory tile ID.
  unsigned allocateInMemoryTileId() {
    // Note: We never release in-memory tile IDs. We could, which may allow
    // reusing an allocation, but as we _never_ want to spill an SME tile this
    // is not optimized.
    return nextInMemoryTileId++;
  }

private:
  TileMask tilesInUse = TileMask::kNone;
  unsigned nextInMemoryTileId = kInMemoryTileIdBase;
};

/// Add new intermediate blocks for the true and false destinations of
/// `cf.cond_br`s that contain tile operands. This prevents spurious liveness
/// overlaps due to copies at branches.
///
///  BEFORE:
///  ```mlir
///  cf.cond_br %cond, ^bb1(%tile: vector<[4]x[4]xf32>), ^bb2
///  ```
///
///  AFTER:
///  ```mlir
///    cf.cond_br %cond, ^bb1_copy, ^bb2_copy
///  ^bb1_copy:
///    cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
///  ^bb2_copy:
///    cf.br ^bb2
///  ```
void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
  SmallVector<cf::CondBranchOp> worklist;
  function.walk([&](cf::CondBranchOp condBranch) {
    if (llvm::any_of(condBranch->getOperands(), [&](Value value) {
          return isValidSMETileVectorType(value.getType());
        })) {
      worklist.push_back(condBranch);
    }
  });

  auto insertJump = [&](Location loc, Block *source, Block *dest, auto args) {
    rewriter.setInsertionPointToEnd(source);
    rewriter.create<cf::BranchOp>(loc, dest, args);
  };

  for (auto condBranch : worklist) {
    auto loc = condBranch.getLoc();
    Block *block = condBranch->getBlock();
    auto newTrueBranch = rewriter.splitBlock(block, block->end());
    auto newFalseBranch = rewriter.splitBlock(block, block->end());
    insertJump(loc, newTrueBranch, condBranch.getTrueDest(),
               condBranch.getTrueDestOperands());
    insertJump(loc, newFalseBranch, condBranch.getFalseDest(),
               condBranch.getFalseDestOperands());
    rewriter.modifyOpInPlace(condBranch, [&] {
      condBranch.getFalseDestOperandsMutable().clear();
      condBranch.getTrueDestOperandsMutable().clear();
      condBranch.setSuccessor(newTrueBranch, 0);
      condBranch.setSuccessor(newFalseBranch, 1);
    });
  }
}

/// Inserts tile copies at `cf.br` operations.
///
///  BEFORE:
///  ```mlir
///  cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
///  ```
///
///  AFTER:
///  ```mlir
///  %copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32>
///  cf.br ^bb1(%copy: vector<[4]x[4]xf32>)
///  ```
void insertCopiesAtBranches(IRRewriter &rewriter,
                            FunctionOpInterface function) {
  for (Block &block : function.getBlocks()) {
    Operation *terminator = block.getTerminator();
    if (!isa<cf::BranchOp>(terminator))
      continue;
    rewriter.setInsertionPoint(terminator);
    for (OpOperand &operand : terminator->getOpOperands()) {
      if (isValidSMETileVectorType(operand.get().getType())) {
        auto copy =
            rewriter.create<CopyTileOp>(terminator->getLoc(), operand.get());
        rewriter.modifyOpInPlace(terminator, [&] { operand.assign(copy); });
      }
    }
  }
}

/// Prepares the IR for tile allocation. It does this by first 'splitting'
/// conditional branches (see `splitCondBranches`), then inserting tile copies
/// at branch operations. The conditional branches are split to prevent the
/// copies needed for them overlapping between the true and false paths of the
/// branch (see `tile-allocation-copies.mlir` and
/// `tile-allocation-liveness.mlir` for examples). The copies break up live
/// ranges and ensure when moving out of SSA the semantics of the program are
/// preserved.
void preprocessForTileAllocation(IRRewriter &rewriter,
                                 FunctionOpInterface function) {
  splitCondBranches(rewriter, function);
  insertCopiesAtBranches(rewriter, function);
}

/// A live range for a (collection of) tile values. A live range is built up of
/// non-overlapping intervals [start, end) which represent parts of the program
/// where a value in the range needs to be live (i.e. in an SME virtual tile).
/// Note that as the intervals are non-overlapping all values within a live
/// range can be allocated to the same SME virtual tile.
struct LiveRange {
  using RangeSet = llvm::IntervalMap<uint64_t, uint8_t, 16,
                                     llvm::IntervalMapHalfOpenInfo<unsigned>>;
  using Allocator = RangeSet::Allocator;
  // Dummy value for the IntervalMap. Only the keys matter (the intervals).
  static constexpr uint8_t kValidLiveRange = 0xff;

  LiveRange(Allocator &allocator)
      : ranges(std::make_unique<RangeSet>(allocator)) {}

  /// Returns true if this range overlaps with `otherRange`.
  bool overlaps(LiveRange const &otherRange) const {
    return llvm::IntervalMapOverlaps<RangeSet, RangeSet>(*ranges,
                                                         *otherRange.ranges)
        .valid();
  }

  /// Returns true if this range is active at `point` in the program.
  bool overlaps(uint64_t point) const {
    return ranges->lookup(point) == kValidLiveRange;
  }

  /// Unions this live range with `otherRange`, aborts if the ranges overlap.
  void unionWith(LiveRange const &otherRange) {
    for (auto it = otherRange.ranges->begin(); it != otherRange.ranges->end();
         ++it)
      ranges->insert(it.start(), it.stop(), kValidLiveRange);
    values.set_union(otherRange.values);
  }

  /// Inserts an interval [start, end) for `value` into this range.
  void insert(Value value, unsigned start, unsigned end) {
    values.insert(value);
    if (start != end)
      ranges->insert(start, end, kValidLiveRange);
  }

  bool empty() const { return ranges->empty(); }
  unsigned start() const { return ranges->start(); }
  unsigned end() const { return ranges->stop(); }
  bool operator<(LiveRange const &other) const {
    return start() < other.start();
  }

  ArmSMETileType getTileType() const {
    return *getSMETileType(cast<VectorType>(values[0].getType()));
  }

  /// The values contained in this live range.
  SetVector<Value> values;

  /// A set of (non-overlapping) intervals that mark where any value in `values`
  /// is live.
  std::unique_ptr<RangeSet> ranges;

  /// The tile ID (or none) assigned to this live range.
  std::optional<unsigned> tileId;
};

/// Number operations within a function to allow computing live ranges.
/// Operations are numbered consecutively wihin blocks, and the blocks are
/// topologically sorted (using forward edges). This function is only correct if
/// all ArmSME have been converted to CF (which is asserted).
DenseMap<Operation *, unsigned>
generateOperationNumbering(FunctionOpInterface function) {
  unsigned index = 0;
  SetVector<Block *> blocks =
      getBlocksSortedByDominance(function.getFunctionBody());
  DenseMap<Operation *, unsigned> operationToIndexMap;
  for (Block *block : blocks) {
    index++; // We want block args to have their own number.
    for (Operation &op : block->getOperations()) {
#ifndef NDEBUG
      op.walk([&](ArmSMETileOpInterface nestedOp) {
        assert(&op == nestedOp.getOperation() &&
               "ArmSME tile allocation does not support nested regions");
      });
#endif
      operationToIndexMap.try_emplace(&op, index++);
    }
  }
  return operationToIndexMap;
}

/// Gather live ranges for SME tiles from the MLIR liveness analysis.
DenseMap<Value, LiveRange>
gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
                     LiveRange::Allocator &liveRangeAllocator,
                     Liveness &liveness, FunctionOpInterface function) {
  assert(!operationToIndexMap.empty() && "expected operation numbering");
  DenseMap<Value, LiveRange> liveRanges;
  /// Defines or updates a live range for an SME tile value. Live-ins may update
  /// an existing live range (rather than define a new one). Note: If
  /// `liveAtBlockEntry` is true then `firstUseOrDef` is the first operation in
  /// the block.
  auto defineOrUpdateValueLiveRange = [&](Value value, Operation *firstUseOrDef,
                                          LivenessBlockInfo const &livenessInfo,
                                          bool liveAtBlockEntry = false) {
    if (!isValidSMETileVectorType(value.getType()))
      return;
    // Find or create a live range for `value`.
    auto [it, _] = liveRanges.try_emplace(value, liveRangeAllocator);
    LiveRange &valueLiveRange = it->second;
    auto lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef);
    // Add the interval [firstUseOrDef, lastUseInBlock) to the live range.
    unsigned startOpIdx =
        operationToIndexMap.at(firstUseOrDef) + (liveAtBlockEntry ? -1 : 0);
    unsigned endOpIdx = operationToIndexMap.at(lastUseInBlock);
    valueLiveRange.insert(value, startOpIdx, endOpIdx);
  };

  for (Block &block : function.getBlocks()) {
    LivenessBlockInfo const *livenessInfo = liveness.getLiveness(&block);
    // Handle block arguments:
    for (Value argument : block.getArguments())
      defineOrUpdateValueLiveRange(argument, &block.front(), *livenessInfo,
                                   /*liveAtBlockEntry=*/true);
    // Handle live-ins:
    for (Value liveIn : livenessInfo->in())
      defineOrUpdateValueLiveRange(liveIn, &block.front(), *livenessInfo,
                                   /*liveAtBlockEntry=*/true);
    // Handle new definitions:
    for (Operation &op : block) {
      for (Value result : op.getResults())
        defineOrUpdateValueLiveRange(result, &op, *livenessInfo);
    }
  }

  return liveRanges;
}

/// Iterate over all predecessor tile values to a (tile) block argument.
static void forEachPredecessorTileValue(BlockArgument blockArg,
                                        function_ref<void(Value)> callback) {
  Block *block = blockArg.getOwner();
  unsigned argNumber = blockArg.getArgNumber();
  for (Block *pred : block->getPredecessors()) {
    TypeSwitch<Operation *>(pred->getTerminator())
        .Case<cf::BranchOp>([&](auto branch) {
          Value predecessorOperand = branch.getDestOperands()[argNumber];
          callback(predecessorOperand);
        })
        .Case<cf::CondBranchOp>([&](auto condBranch) {
          if (condBranch.getFalseDest() == block) {
            Value predecessorOperand =
                condBranch.getFalseDestOperands()[argNumber];
            callback(predecessorOperand);
          }
          if (condBranch.getTrueDest() == block) {
            Value predecessorOperand =
                condBranch.getTrueDestOperands()[argNumber];
            callback(predecessorOperand);
          }
        });
  }
}

/// Coalesce live ranges where it would prevent unnecessary tile moves.
SmallVector<LiveRange *>
coalesceTileLiveRanges(DenseMap<Value, LiveRange> &initialLiveRanges) {
  DenseMap<Value, LiveRange *> liveRanges;
  for (auto &[value, liveRange] : initialLiveRanges) {
    liveRanges.insert({value, &liveRange});
  }

  // Merge the live ranges of values `a` and `b` into one (if they do not
  // overlap). After this, the values `a` and `b` will both point to the same
  // live range (which will contain multiple values).
  auto mergeValuesIfNonOverlapping = [&](Value a, Value b) {
    LiveRange *aLiveRange = liveRanges.at(a);
    LiveRange *bLiveRange = liveRanges.at(b);
    if (aLiveRange != bLiveRange && !aLiveRange->overlaps(*bLiveRange)) {
      aLiveRange->unionWith(*bLiveRange);
      for (Value value : bLiveRange->values)
        liveRanges[value] = aLiveRange;
    }
  };

  // Merge the live ranges of new definitions with their tile operands.
  auto unifyDefinitionsWithOperands = [&](Value value) {
    auto armSMEOp = value.getDefiningOp<ArmSMETileOpInterface>();
    if (!armSMEOp)
      return;
    for (auto operand : armSMEOp->getOperands()) {
      if (isValidSMETileVectorType(operand.getType()))
        mergeValuesIfNonOverlapping(value, operand);
    }
  };

  // Merge the live ranges of block arguments with their predecessors.
  auto unifyBlockArgumentsWithPredecessors = [&](Value value) {
    auto blockArg = dyn_cast<BlockArgument>(value);
    if (!blockArg)
      return;
    forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) {
      mergeValuesIfNonOverlapping(blockArg, predecessorTile);
    });
  };

  auto applyRule = [&](auto rule) {
    llvm::for_each(llvm::make_first_range(initialLiveRanges), rule);
  };

  // Unify as many live ranges as we can. This prevents unnecessary moves.
  applyRule(unifyBlockArgumentsWithPredecessors);
  applyRule(unifyDefinitionsWithOperands);

  // Remove duplicate live range entries.
  SetVector<LiveRange *> uniqueLiveRanges;
  for (auto [_, liveRange] : liveRanges) {
    if (!liveRange->empty())
      uniqueLiveRanges.insert(liveRange);
  }

  // Sort the new live ranges by starting point (ready for tile allocation).
  auto coalescedLiveRanges = uniqueLiveRanges.takeVector();
  std::sort(coalescedLiveRanges.begin(), coalescedLiveRanges.end(),
            [](LiveRange *a, LiveRange *b) { return *a < *b; });
  return std::move(coalescedLiveRanges);
}

/// Choose a live range to spill (via some heuristics). This picks either a live
/// range from `overlappingRanges`, or the new live range `newRange`.
template <typename OverlappingRangesIterator>
LiveRange *
chooseSpillUsingHeuristics(OverlappingRangesIterator overlappingRanges,
                           LiveRange *newRange) {
  // Heuristic: Spill trivially copyable operations (usually free).
  auto isTrivialSpill = [&](LiveRange &allocatedRange) {
    return isTileTypeGreaterOrEqual(allocatedRange.getTileType(),
                                    newRange->getTileType()) &&
           allocatedRange.values.size() == 1 &&
           isTriviallyCloneableTileOp(
               allocatedRange.values[0].getDefiningOp<ArmSMETileOpInterface>());
  };
  if (isTrivialSpill(*newRange))
    return newRange;
  auto trivialSpill = llvm::find_if(overlappingRanges, isTrivialSpill);
  if (trivialSpill != overlappingRanges.end())
    return &*trivialSpill;

  // Heuristic: Spill the range that ends last (with a compatible tile type).
  auto isSmallerTileTypeOrEndsEarlier = [](LiveRange &a, LiveRange &b) {
    return !isTileTypeGreaterOrEqual(a.getTileType(), b.getTileType()) ||
           a.end() < b.end();
  };
  LiveRange &latestEndingLiveRange =
      *std::max_element(overlappingRanges.begin(), overlappingRanges.end(),
                        isSmallerTileTypeOrEndsEarlier);
  if (!isSmallerTileTypeOrEndsEarlier(latestEndingLiveRange, *newRange))
    return &latestEndingLiveRange;
  return newRange;
}

/// Greedily allocate tile IDs to live ranges. Spill using simple heuristics.
void allocateTilesToLiveRanges(
    ArrayRef<LiveRange *> liveRangesSortedByStartPoint) {
  TileAllocator tileAllocator;
  // `activeRanges` = Live ranges that need to be in a tile at the
  // `currentPoint` in the program.
  SetVector<LiveRange *> activeRanges;
  // `inactiveRanges` = Live ranges that _do not_ need to be in a tile
  // at the `currentPoint` in the program but could become active again later.
  // An inactive section of a live range can be seen as a 'hole' in the live
  // range, where it is possible to reuse the live range's tile ID _before_ it
  // has ended. By identifying 'holes', the allocator can reuse tiles more
  // often, which helps avoid costly tile spills.
  SetVector<LiveRange *> inactiveRanges;
  for (LiveRange *nextRange : liveRangesSortedByStartPoint) {
    auto currentPoint = nextRange->start();
    // 1. Update the `activeRanges` at `currentPoint`.
    activeRanges.remove_if([&](LiveRange *activeRange) {
      // Check for live ranges that have expired.
      if (activeRange->end() <= currentPoint) {
        tileAllocator.releaseTileId(activeRange->getTileType(),
                                    *activeRange->tileId);
        return true;
      }
      // Check for live ranges that have become inactive.
      if (!activeRange->overlaps(currentPoint)) {
        tileAllocator.releaseTileId(activeRange->getTileType(),
                                    *activeRange->tileId);
        inactiveRanges.insert(activeRange);
        return true;
      }
      return false;
    });
    // 2. Update the `inactiveRanges` at `currentPoint`.
    inactiveRanges.remove_if([&](LiveRange *inactiveRange) {
      // Check for live ranges that have expired.
      if (inactiveRange->end() <= currentPoint) {
        return true;
      }
      // Check for live ranges that have become active.
      if (inactiveRange->overlaps(currentPoint)) {
        tileAllocator.acquireTileId(inactiveRange->getTileType(),
                                    *inactiveRange->tileId);
        activeRanges.insert(inactiveRange);
        return true;
      }
      return false;
    });

    // 3. Collect inactive live ranges that overlap with the new live range.
    // Note: The overlap checks in steps 1 and 2 only look at the `currentPoint`
    // whereas this checks if there is an overlap at any future point too.
    SmallVector<LiveRange *> overlappingInactiveRanges;
    for (LiveRange *inactiveRange : inactiveRanges) {
      if (inactiveRange->overlaps(*nextRange)) {
        // We need to reserve the tile IDs of overlapping inactive ranges to
        // prevent two (overlapping) live ranges from getting the same tile ID.
        tileAllocator.acquireTileId(inactiveRange->getTileType(),
                                    *inactiveRange->tileId);
        overlappingInactiveRanges.push_back(inactiveRange);
      }
    }

    // 4. Allocate a tile ID to `nextRange`.
    auto rangeTileType = nextRange->getTileType();
    auto tileId = tileAllocator.allocateTileId(rangeTileType);
    if (succeeded(tileId)) {
      nextRange->tileId = *tileId;
    } else {
      // Create an iterator over all overlapping live ranges.
      auto allOverlappingRanges = llvm::concat<LiveRange>(
          llvm::make_pointee_range(activeRanges.getArrayRef()),
          llvm::make_pointee_range(overlappingInactiveRanges));
      // Choose an overlapping live range to spill.
      LiveRange *rangeToSpill =
          chooseSpillUsingHeuristics(allOverlappingRanges, nextRange);
      if (rangeToSpill != nextRange) {
        // Spill an (in)active live range (so release its tile ID first).
        tileAllocator.releaseTileId(rangeToSpill->getTileType(),
                                    *rangeToSpill->tileId);
        // This will always succeed after a spill (of an active live range).
        nextRange->tileId = *tileAllocator.allocateTileId(rangeTileType);
        // Remove the live range from the active/inactive sets.
        if (!activeRanges.remove(rangeToSpill)) {
          bool removed = inactiveRanges.remove(rangeToSpill);
          assert(removed && "expected a range to be removed!");
          (void)removed;
        }
      }
      rangeToSpill->tileId = tileAllocator.allocateInMemoryTileId();
    }

    // 5. Insert the live range into the active ranges.
    if (nextRange->tileId < kInMemoryTileIdBase)
      activeRanges.insert(nextRange);

    // 6. Release tiles reserved for inactive live ranges (in step 3).
    for (LiveRange *range : overlappingInactiveRanges) {
      if (*range->tileId < kInMemoryTileIdBase)
        tileAllocator.releaseTileId(range->getTileType(), *range->tileId);
    }
  }
}

/// Assigns a tile ID to an MLIR value.
void assignTileIdToValue(IRRewriter &rewriter, Value value,
                         IntegerAttr tileIdAttr) {
  if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>())
    rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); });
  for (Operation *user : value.getUsers()) {
    if (auto tileOp = dyn_cast<ArmSMETileOpInterface>(user)) {
      // Ensure ArmSME ops that don't produce a value still get a tile ID.
      if (!hasTileResult(tileOp))
        rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); });
    }
  }
}

/// Assign tile IDs back to IR and attempt to resolve trivial tile ID conflicts.
LogicalResult assignTileIdsAndResolveTrivialConflicts(
    IRRewriter &rewriter, FunctionOpInterface function,
    ArrayRef<LiveRange *> allocatedLiveRanges) {
  for (LiveRange const *liveRange : allocatedLiveRanges) {
    auto tileIdAttr = rewriter.getI32IntegerAttr(*liveRange->tileId);
    auto isAllocatedToSameTile = [&](Value value) {
      if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>();
          tileOp && tileOp.getTileId() == tileIdAttr)
        return true;
      return liveRange->values.contains(value);
    };

    /// Eliminates copies where the operand has the same tile ID.
    auto foldRedundantCopies = [&](Value value) -> LogicalResult {
      auto copyOp = value.getDefiningOp<CopyTileOp>();
      if (!copyOp || !isAllocatedToSameTile(copyOp.getTile()))
        return failure();
      rewriter.replaceAllUsesWith(copyOp, copyOp.getTile());
      return success();
    };

    /// Validates each predecessor to a tile block argument has been assigned
    /// the same tile ID.
    auto validateBlockArguments = [&](Value value) {
      auto blockArg = dyn_cast<BlockArgument>(value);
      if (!blockArg) {
        // Not a block argument (nothing to validate).
        return success();
      }
      bool tileMismatch = false;
      forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) {
        if (tileMismatch)
          return;
        if (!isAllocatedToSameTile(predecessorTile)) {
          blockArg.getOwner()->getParentOp()->emitOpError(
              "block argument not allocated to the same SME virtial tile as "
              "predecessors");
          tileMismatch = true;
        }
      });
      return success(/*isSuccess=*/!tileMismatch);
    };

    /// Attempts to resolve (trivial) tile ID conflicts.
    auto resolveTrivialTileConflicts = [&](Value value) -> LogicalResult {
      auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>();
      OpOperand *tileOperand = getTileOpOperand(tileOp);
      if (!tileOperand || isAllocatedToSameTile(tileOperand->get())) {
        // Operand already allocated to the correct tile.
        // No conflict to resolve.
        return success();
      }
      auto operandTileOp =
          tileOperand->get().getDefiningOp<ArmSMETileOpInterface>();
      if (!isTriviallyCloneableTileOp(operandTileOp)) {
        auto error =
            tileOp.emitOpError("tile operand allocated to different SME "
                               "virtial tile (move required)");
        error.attachNote(tileOperand->get().getLoc())
            << "tile operand is: " << tileOperand->get();
        return error;
      }
      // Cloning prevents a move/spill (though may require recomputation).
      rewriter.setInsertionPoint(tileOp);
      auto clonedOp = operandTileOp.clone();
      rewriter.modifyOpInPlace(clonedOp,
                               [&] { clonedOp.setTileId(tileOp.getTileId()); });
      rewriter.insert(clonedOp);
      if (isa<CopyTileOp>(tileOp)) {
        rewriter.replaceAllUsesWith(tileOp->getResult(0),
                                    clonedOp->getResult(0));
      } else {
        rewriter.modifyOpInPlace(
            tileOp, [&] { tileOperand->assign(clonedOp->getResult(0)); });
      }
      return success();
    };

    for (Value value : liveRange->values) {
      // 1. Assign the tile ID to the value.
      assignTileIdToValue(rewriter, value, tileIdAttr);

      // 2. Attempt to eliminate redundant tile copies.
      if (succeeded(foldRedundantCopies(value)))
        continue;

      // 3. Validate tile block arguments.
      if (failed(validateBlockArguments(value)))
        return failure();

      // 4. Attempt to resolve (trivial) tile ID conflicts.
      if (failed(resolveTrivialTileConflicts(value)))
        return failure();
    }
  }
  return success();
}

/// Prints live ranges alongside operation names for debugging.
void dumpLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
                    ArrayRef<LiveRange const *> liveRanges,
                    FunctionOpInterface function) {
  llvm::errs() << "SME Tile Liveness: @" << function.getName()
               << "\nKey:\nS - Start\nE - End\n| - Live\n";
  for (auto [blockIdx, block] : llvm::enumerate(function.getBlocks())) {
    llvm::errs() << "^bb" << blockIdx << ":\n";
    for (Operation &op : block.getOperations()) {
      unsigned operationIndex = operationToIndexMap.at(&op);
      for (LiveRange const *range : liveRanges) {
        char liveness = ' ';
        for (auto it = range->ranges->begin(); it != range->ranges->end();
             ++it) {
          if (it.start() == operationIndex)
            liveness = (liveness == 'E' ? '|' : 'S');
          else if (it.stop() == operationIndex)
            liveness = (liveness == 'S' ? '|' : 'E');
          else if (operationIndex >= it.start() && operationIndex < it.stop())
            liveness = '|';
        }
        llvm::errs() << liveness;
      }
      llvm::errs() << ' ' << op.getName() << '\n';
    }
  }
  llvm::errs() << "==========\n";
}

struct TestTileAllocationPass
    : public arm_sme::impl::TestTileAllocationBase<TestTileAllocationPass> {
  using TestTileAllocationBase::TestTileAllocationBase;
  void runOnOperation() override {
    FunctionOpInterface function = getOperation();
    if (preprocessOnly) {
      IRRewriter rewriter(function);
      return preprocessForTileAllocation(rewriter, function);
    }
    if (failed(arm_sme::allocateSMETiles(function, dumpTileLiveRanges)))
      signalPassFailure();
  }
};
} // namespace

LogicalResult mlir::arm_sme::allocateSMETiles(FunctionOpInterface function,
                                              bool dumpRanges) {
  if (function.empty()) {
    // TODO: Also return early if the function contains no ArmSME ops?
    return success();
  }

  LiveRange::Allocator liveRangeAllocator;
  IRRewriter rewriter(function.getContext());

  // 1. Preprocess the IR for tile allocation.
  preprocessForTileAllocation(rewriter, function);

  // 2. Gather live ranges for each ArmSME tile within the function.
  Liveness liveness(function);
  auto operationToIndexMap = generateOperationNumbering(function);
  auto initialLiveRanges = gatherTileLiveRanges(
      operationToIndexMap, liveRangeAllocator, liveness, function);
  if (initialLiveRanges.empty())
    return success();

  if (dumpRanges) {
    // Wrangle initial live ranges into a form suitable for printing.
    auto nonEmpty = llvm::make_filter_range(
        llvm::make_second_range(initialLiveRanges),
        [&](LiveRange const &liveRange) { return !liveRange.empty(); });
    auto initialRanges = llvm::to_vector(llvm::map_range(
        nonEmpty, [](LiveRange const &liveRange) { return &liveRange; }));
    std::sort(initialRanges.begin(), initialRanges.end(),
              [](LiveRange const *a, LiveRange const *b) { return *a < *b; });
    llvm::errs() << "\n========== Initial Live Ranges:\n";
    dumpLiveRanges(operationToIndexMap, initialRanges, function);
  }

  // 3. Coalesce (non-overlapping) live ranges where it would be beneficial
  // for tile allocation. E.g. Unify the result of an operation with its
  // operands.
  auto coalescedLiveRanges = coalesceTileLiveRanges(initialLiveRanges);

  if (dumpRanges) {
    llvm::errs() << "\n========== Coalesced Live Ranges:\n";
    dumpLiveRanges(operationToIndexMap, coalescedLiveRanges, function);
  }

  // 4. Allocate tile IDs to live ranges.
  allocateTilesToLiveRanges(coalescedLiveRanges);

  // 5. Assign the tile IDs back to the ArmSME operations.
  if (failed(assignTileIdsAndResolveTrivialConflicts(rewriter, function,
                                                     coalescedLiveRanges))) {
    return failure();
  }

  // 6. Erase trivially dead tile operations (e.g. a ZeroOp with no
  // users). This prevents the LLVM conversion needlessly inserting spills.
  eraseTriviallyDeadTileOps(rewriter, function);
  return success();
}