File: MemProfUse.cpp

package info (click to toggle)
llvm-toolchain-21 1%3A21.1.6-3
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 2,245,028 kB
  • sloc: cpp: 7,619,726; ansic: 1,434,018; asm: 1,058,748; python: 252,740; f90: 94,671; objc: 70,685; lisp: 42,813; pascal: 18,401; sh: 8,601; ml: 5,111; perl: 4,720; makefile: 3,675; awk: 3,523; javascript: 2,409; xml: 892; fortran: 770
file content (739 lines) | stat: -rw-r--r-- 30,326 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
//===- MemProfUse.cpp - memory allocation profile use pass --*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the MemProfUsePass which reads memory profiling data
// and uses it to add metadata to instructions to guide optimization.
//
//===----------------------------------------------------------------------===//

#include "llvm/Transforms/Instrumentation/MemProfUse.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Analysis/MemoryProfileInfo.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Module.h"
#include "llvm/ProfileData/InstrProf.h"
#include "llvm/ProfileData/InstrProfReader.h"
#include "llvm/ProfileData/MemProfCommon.h"
#include "llvm/Support/BLAKE3.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/HashBuilder.h"
#include "llvm/Support/VirtualFileSystem.h"
#include "llvm/Transforms/Utils/LongestCommonSequence.h"
#include <map>
#include <set>

using namespace llvm;
using namespace llvm::memprof;

#define DEBUG_TYPE "memprof"

namespace llvm {
extern cl::opt<bool> PGOWarnMissing;
extern cl::opt<bool> NoPGOWarnMismatch;
extern cl::opt<bool> NoPGOWarnMismatchComdatWeak;
} // namespace llvm

// By default disable matching of allocation profiles onto operator new that
// already explicitly pass a hot/cold hint, since we don't currently
// override these hints anyway.
static cl::opt<bool> ClMemProfMatchHotColdNew(
    "memprof-match-hot-cold-new",
    cl::desc(
        "Match allocation profiles onto existing hot/cold operator new calls"),
    cl::Hidden, cl::init(false));

static cl::opt<bool>
    ClPrintMemProfMatchInfo("memprof-print-match-info",
                            cl::desc("Print matching stats for each allocation "
                                     "context in this module's profiles"),
                            cl::Hidden, cl::init(false));

static cl::opt<bool>
    SalvageStaleProfile("memprof-salvage-stale-profile",
                        cl::desc("Salvage stale MemProf profile"),
                        cl::init(false), cl::Hidden);

static cl::opt<bool> ClMemProfAttachCalleeGuids(
    "memprof-attach-calleeguids",
    cl::desc(
        "Attach calleeguids as value profile metadata for indirect calls."),
    cl::init(true), cl::Hidden);

static cl::opt<unsigned> MinMatchedColdBytePercent(
    "memprof-matching-cold-threshold", cl::init(100), cl::Hidden,
    cl::desc("Min percent of cold bytes matched to hint allocation cold"));

// Matching statistics
STATISTIC(NumOfMemProfMissing, "Number of functions without memory profile.");
STATISTIC(NumOfMemProfMismatch,
          "Number of functions having mismatched memory profile hash.");
STATISTIC(NumOfMemProfFunc, "Number of functions having valid memory profile.");
STATISTIC(NumOfMemProfAllocContextProfiles,
          "Number of alloc contexts in memory profile.");
STATISTIC(NumOfMemProfCallSiteProfiles,
          "Number of callsites in memory profile.");
STATISTIC(NumOfMemProfMatchedAllocContexts,
          "Number of matched memory profile alloc contexts.");
STATISTIC(NumOfMemProfMatchedAllocs,
          "Number of matched memory profile allocs.");
STATISTIC(NumOfMemProfMatchedCallSites,
          "Number of matched memory profile callsites.");

static void addCallsiteMetadata(Instruction &I,
                                ArrayRef<uint64_t> InlinedCallStack,
                                LLVMContext &Ctx) {
  I.setMetadata(LLVMContext::MD_callsite,
                buildCallstackMetadata(InlinedCallStack, Ctx));
}

static uint64_t computeStackId(GlobalValue::GUID Function, uint32_t LineOffset,
                               uint32_t Column) {
  llvm::HashBuilder<llvm::TruncatedBLAKE3<8>, llvm::endianness::little>
      HashBuilder;
  HashBuilder.add(Function, LineOffset, Column);
  llvm::BLAKE3Result<8> Hash = HashBuilder.final();
  uint64_t Id;
  std::memcpy(&Id, Hash.data(), sizeof(Hash));
  return Id;
}

static uint64_t computeStackId(const memprof::Frame &Frame) {
  return computeStackId(Frame.Function, Frame.LineOffset, Frame.Column);
}

static AllocationType addCallStack(CallStackTrie &AllocTrie,
                                   const AllocationInfo *AllocInfo,
                                   uint64_t FullStackId) {
  SmallVector<uint64_t> StackIds;
  for (const auto &StackFrame : AllocInfo->CallStack)
    StackIds.push_back(computeStackId(StackFrame));
  auto AllocType = getAllocType(AllocInfo->Info.getTotalLifetimeAccessDensity(),
                                AllocInfo->Info.getAllocCount(),
                                AllocInfo->Info.getTotalLifetime());
  std::vector<ContextTotalSize> ContextSizeInfo;
  if (recordContextSizeInfoForAnalysis()) {
    auto TotalSize = AllocInfo->Info.getTotalSize();
    assert(TotalSize);
    assert(FullStackId != 0);
    ContextSizeInfo.push_back({FullStackId, TotalSize});
  }
  AllocTrie.addCallStack(AllocType, StackIds, std::move(ContextSizeInfo));
  return AllocType;
}

// Return true if InlinedCallStack, computed from a call instruction's debug
// info, is a prefix of ProfileCallStack, a list of Frames from profile data
// (either the allocation data or a callsite).
static bool
stackFrameIncludesInlinedCallStack(ArrayRef<Frame> ProfileCallStack,
                                   ArrayRef<uint64_t> InlinedCallStack) {
  return ProfileCallStack.size() >= InlinedCallStack.size() &&
         llvm::equal(ProfileCallStack.take_front(InlinedCallStack.size()),
                     InlinedCallStack, [](const Frame &F, uint64_t StackId) {
                       return computeStackId(F) == StackId;
                     });
}

static bool isAllocationWithHotColdVariant(const Function *Callee,
                                           const TargetLibraryInfo &TLI) {
  if (!Callee)
    return false;
  LibFunc Func;
  if (!TLI.getLibFunc(*Callee, Func))
    return false;
  switch (Func) {
  case LibFunc_Znwm:
  case LibFunc_ZnwmRKSt9nothrow_t:
  case LibFunc_ZnwmSt11align_val_t:
  case LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t:
  case LibFunc_Znam:
  case LibFunc_ZnamRKSt9nothrow_t:
  case LibFunc_ZnamSt11align_val_t:
  case LibFunc_ZnamSt11align_val_tRKSt9nothrow_t:
  case LibFunc_size_returning_new:
  case LibFunc_size_returning_new_aligned:
    return true;
  case LibFunc_Znwm12__hot_cold_t:
  case LibFunc_ZnwmRKSt9nothrow_t12__hot_cold_t:
  case LibFunc_ZnwmSt11align_val_t12__hot_cold_t:
  case LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t12__hot_cold_t:
  case LibFunc_Znam12__hot_cold_t:
  case LibFunc_ZnamRKSt9nothrow_t12__hot_cold_t:
  case LibFunc_ZnamSt11align_val_t12__hot_cold_t:
  case LibFunc_ZnamSt11align_val_tRKSt9nothrow_t12__hot_cold_t:
  case LibFunc_size_returning_new_hot_cold:
  case LibFunc_size_returning_new_aligned_hot_cold:
    return ClMemProfMatchHotColdNew;
  default:
    return false;
  }
}

struct AllocMatchInfo {
  uint64_t TotalSize = 0;
  AllocationType AllocType = AllocationType::None;
};

DenseMap<uint64_t, SmallVector<CallEdgeTy, 0>>
memprof::extractCallsFromIR(Module &M, const TargetLibraryInfo &TLI,
                            function_ref<bool(uint64_t)> IsPresentInProfile) {
  DenseMap<uint64_t, SmallVector<CallEdgeTy, 0>> Calls;

  auto GetOffset = [](const DILocation *DIL) {
    return (DIL->getLine() - DIL->getScope()->getSubprogram()->getLine()) &
           0xffff;
  };

  for (Function &F : M) {
    if (F.isDeclaration())
      continue;

    for (auto &BB : F) {
      for (auto &I : BB) {
        if (!isa<CallBase>(&I) || isa<IntrinsicInst>(&I))
          continue;

        auto *CB = dyn_cast<CallBase>(&I);
        auto *CalledFunction = CB->getCalledFunction();
        // Disregard indirect calls and intrinsics.
        if (!CalledFunction || CalledFunction->isIntrinsic())
          continue;

        StringRef CalleeName = CalledFunction->getName();
        // True if we are calling a heap allocation function that supports
        // hot/cold variants.
        bool IsAlloc = isAllocationWithHotColdVariant(CalledFunction, TLI);
        // True for the first iteration below, indicating that we are looking at
        // a leaf node.
        bool IsLeaf = true;
        for (const DILocation *DIL = I.getDebugLoc(); DIL;
             DIL = DIL->getInlinedAt()) {
          StringRef CallerName = DIL->getSubprogramLinkageName();
          assert(!CallerName.empty() &&
                 "Be sure to enable -fdebug-info-for-profiling");
          uint64_t CallerGUID = memprof::getGUID(CallerName);
          uint64_t CalleeGUID = memprof::getGUID(CalleeName);
          // Pretend that we are calling a function with GUID == 0 if we are
          // in the inline stack leading to a heap allocation function.
          if (IsAlloc) {
            if (IsLeaf) {
              // For leaf nodes, set CalleeGUID to 0 without consulting
              // IsPresentInProfile.
              CalleeGUID = 0;
            } else if (!IsPresentInProfile(CalleeGUID)) {
              // In addition to the leaf case above, continue to set CalleeGUID
              // to 0 as long as we don't see CalleeGUID in the profile.
              CalleeGUID = 0;
            } else {
              // Once we encounter a callee that exists in the profile, stop
              // setting CalleeGUID to 0.
              IsAlloc = false;
            }
          }

          LineLocation Loc = {GetOffset(DIL), DIL->getColumn()};
          Calls[CallerGUID].emplace_back(Loc, CalleeGUID);
          CalleeName = CallerName;
          IsLeaf = false;
        }
      }
    }
  }

  // Sort each call list by the source location.
  for (auto &[CallerGUID, CallList] : Calls) {
    llvm::sort(CallList);
    CallList.erase(llvm::unique(CallList), CallList.end());
  }

  return Calls;
}

DenseMap<uint64_t, LocToLocMap>
memprof::computeUndriftMap(Module &M, IndexedInstrProfReader *MemProfReader,
                           const TargetLibraryInfo &TLI) {
  DenseMap<uint64_t, LocToLocMap> UndriftMaps;

  DenseMap<uint64_t, SmallVector<memprof::CallEdgeTy, 0>> CallsFromProfile =
      MemProfReader->getMemProfCallerCalleePairs();
  DenseMap<uint64_t, SmallVector<memprof::CallEdgeTy, 0>> CallsFromIR =
      extractCallsFromIR(M, TLI, [&](uint64_t GUID) {
        return CallsFromProfile.contains(GUID);
      });

  // Compute an undrift map for each CallerGUID.
  for (const auto &[CallerGUID, IRAnchors] : CallsFromIR) {
    auto It = CallsFromProfile.find(CallerGUID);
    if (It == CallsFromProfile.end())
      continue;
    const auto &ProfileAnchors = It->second;

    LocToLocMap Matchings;
    longestCommonSequence<LineLocation, GlobalValue::GUID>(
        ProfileAnchors, IRAnchors, std::equal_to<GlobalValue::GUID>(),
        [&](LineLocation A, LineLocation B) { Matchings.try_emplace(A, B); });
    [[maybe_unused]] bool Inserted =
        UndriftMaps.try_emplace(CallerGUID, std::move(Matchings)).second;

    // The insertion must succeed because we visit each GUID exactly once.
    assert(Inserted);
  }

  return UndriftMaps;
}

// Given a MemProfRecord, undrift all the source locations present in the
// record in place.
static void
undriftMemProfRecord(const DenseMap<uint64_t, LocToLocMap> &UndriftMaps,
                     memprof::MemProfRecord &MemProfRec) {
  // Undrift a call stack in place.
  auto UndriftCallStack = [&](std::vector<Frame> &CallStack) {
    for (auto &F : CallStack) {
      auto I = UndriftMaps.find(F.Function);
      if (I == UndriftMaps.end())
        continue;
      auto J = I->second.find(LineLocation(F.LineOffset, F.Column));
      if (J == I->second.end())
        continue;
      auto &NewLoc = J->second;
      F.LineOffset = NewLoc.LineOffset;
      F.Column = NewLoc.Column;
    }
  };

  for (auto &AS : MemProfRec.AllocSites)
    UndriftCallStack(AS.CallStack);

  for (auto &CS : MemProfRec.CallSites)
    UndriftCallStack(CS.Frames);
}

// Helper function to process CalleeGuids and create value profile metadata
static void addVPMetadata(Module &M, Instruction &I,
                          ArrayRef<GlobalValue::GUID> CalleeGuids) {
  if (!ClMemProfAttachCalleeGuids || CalleeGuids.empty())
    return;

  if (I.getMetadata(LLVMContext::MD_prof)) {
    uint64_t Unused;
    // TODO: When merging is implemented, increase this to a typical ICP value
    // (e.g., 3-6) For now, we only need to check if existing data exists, so 1
    // is sufficient
    auto ExistingVD = getValueProfDataFromInst(I, IPVK_IndirectCallTarget,
                                               /*MaxNumValueData=*/1, Unused);
    // We don't know how to merge value profile data yet.
    if (!ExistingVD.empty()) {
      return;
    }
  }

  SmallVector<InstrProfValueData, 4> VDs;
  uint64_t TotalCount = 0;

  for (const GlobalValue::GUID CalleeGUID : CalleeGuids) {
    InstrProfValueData VD;
    VD.Value = CalleeGUID;
    // For MemProf, we don't have actual call counts, so we assign
    // a weight of 1 to each potential target.
    // TODO: Consider making this weight configurable or increasing it to
    // improve effectiveness for ICP.
    VD.Count = 1;
    VDs.push_back(VD);
    TotalCount += VD.Count;
  }

  if (!VDs.empty()) {
    annotateValueSite(M, I, VDs, TotalCount, IPVK_IndirectCallTarget,
                      VDs.size());
  }
}

static void readMemprof(Module &M, Function &F,
                        IndexedInstrProfReader *MemProfReader,
                        const TargetLibraryInfo &TLI,
                        std::map<std::pair<uint64_t, unsigned>, AllocMatchInfo>
                            &FullStackIdToAllocMatchInfo,
                        std::set<std::vector<uint64_t>> &MatchedCallSites,
                        DenseMap<uint64_t, LocToLocMap> &UndriftMaps,
                        OptimizationRemarkEmitter &ORE, uint64_t MaxColdSize) {
  auto &Ctx = M.getContext();
  // Previously we used getIRPGOFuncName() here. If F is local linkage,
  // getIRPGOFuncName() returns FuncName with prefix 'FileName;'. But
  // llvm-profdata uses FuncName in dwarf to create GUID which doesn't
  // contain FileName's prefix. It caused local linkage function can't
  // find MemProfRecord. So we use getName() now.
  // 'unique-internal-linkage-names' can make MemProf work better for local
  // linkage function.
  auto FuncName = F.getName();
  auto FuncGUID = Function::getGUIDAssumingExternalLinkage(FuncName);
  std::optional<memprof::MemProfRecord> MemProfRec;
  auto Err = MemProfReader->getMemProfRecord(FuncGUID).moveInto(MemProfRec);
  if (Err) {
    handleAllErrors(std::move(Err), [&](const InstrProfError &IPE) {
      auto Err = IPE.get();
      bool SkipWarning = false;
      LLVM_DEBUG(dbgs() << "Error in reading profile for Func " << FuncName
                        << ": ");
      if (Err == instrprof_error::unknown_function) {
        NumOfMemProfMissing++;
        SkipWarning = !PGOWarnMissing;
        LLVM_DEBUG(dbgs() << "unknown function");
      } else if (Err == instrprof_error::hash_mismatch) {
        NumOfMemProfMismatch++;
        SkipWarning =
            NoPGOWarnMismatch ||
            (NoPGOWarnMismatchComdatWeak &&
             (F.hasComdat() ||
              F.getLinkage() == GlobalValue::AvailableExternallyLinkage));
        LLVM_DEBUG(dbgs() << "hash mismatch (skip=" << SkipWarning << ")");
      }

      if (SkipWarning)
        return;

      std::string Msg = (IPE.message() + Twine(" ") + F.getName().str() +
                         Twine(" Hash = ") + std::to_string(FuncGUID))
                            .str();

      Ctx.diagnose(
          DiagnosticInfoPGOProfile(M.getName().data(), Msg, DS_Warning));
    });
    return;
  }

  NumOfMemProfFunc++;

  // If requested, undrfit MemProfRecord so that the source locations in it
  // match those in the IR.
  if (SalvageStaleProfile)
    undriftMemProfRecord(UndriftMaps, *MemProfRec);

  // Detect if there are non-zero column numbers in the profile. If not,
  // treat all column numbers as 0 when matching (i.e. ignore any non-zero
  // columns in the IR). The profiled binary might have been built with
  // column numbers disabled, for example.
  bool ProfileHasColumns = false;

  // Build maps of the location hash to all profile data with that leaf location
  // (allocation info and the callsites).
  std::map<uint64_t, std::set<const AllocationInfo *>> LocHashToAllocInfo;

  // Helper struct for maintaining refs to callsite data. As an alternative we
  // could store a pointer to the CallSiteInfo struct but we also need the frame
  // index. Using ArrayRefs instead makes it a little easier to read.
  struct CallSiteEntry {
    // Subset of frames for the corresponding CallSiteInfo.
    ArrayRef<Frame> Frames;
    // Potential targets for indirect calls.
    ArrayRef<GlobalValue::GUID> CalleeGuids;

    // Only compare Frame contents.
    // Use pointer-based equality instead of ArrayRef's operator== which does
    // element-wise comparison. We want to check if it's the same slice of the
    // underlying array, not just equivalent content.
    bool operator==(const CallSiteEntry &Other) const {
      return Frames.data() == Other.Frames.data() &&
             Frames.size() == Other.Frames.size();
    }
  };

  struct CallSiteEntryHash {
    size_t operator()(const CallSiteEntry &Entry) const {
      return computeFullStackId(Entry.Frames);
    }
  };

  // For the callsites we need to record slices of the frame array (see comments
  // below where the map entries are added) along with their CalleeGuids.
  std::map<uint64_t, std::unordered_set<CallSiteEntry, CallSiteEntryHash>>
      LocHashToCallSites;
  for (auto &AI : MemProfRec->AllocSites) {
    NumOfMemProfAllocContextProfiles++;
    // Associate the allocation info with the leaf frame. The later matching
    // code will match any inlined call sequences in the IR with a longer prefix
    // of call stack frames.
    uint64_t StackId = computeStackId(AI.CallStack[0]);
    LocHashToAllocInfo[StackId].insert(&AI);
    ProfileHasColumns |= AI.CallStack[0].Column;
  }
  for (auto &CS : MemProfRec->CallSites) {
    NumOfMemProfCallSiteProfiles++;
    // Need to record all frames from leaf up to and including this function,
    // as any of these may or may not have been inlined at this point.
    unsigned Idx = 0;
    for (auto &StackFrame : CS.Frames) {
      uint64_t StackId = computeStackId(StackFrame);
      ArrayRef<Frame> FrameSlice = ArrayRef<Frame>(CS.Frames).drop_front(Idx++);
      ArrayRef<GlobalValue::GUID> CalleeGuids(CS.CalleeGuids);
      LocHashToCallSites[StackId].insert({FrameSlice, CalleeGuids});

      ProfileHasColumns |= StackFrame.Column;
      // Once we find this function, we can stop recording.
      if (StackFrame.Function == FuncGUID)
        break;
    }
    assert(Idx <= CS.Frames.size() && CS.Frames[Idx - 1].Function == FuncGUID);
  }

  auto GetOffset = [](const DILocation *DIL) {
    return (DIL->getLine() - DIL->getScope()->getSubprogram()->getLine()) &
           0xffff;
  };

  // Now walk the instructions, looking up the associated profile data using
  // debug locations.
  for (auto &BB : F) {
    for (auto &I : BB) {
      if (I.isDebugOrPseudoInst())
        continue;
      // We are only interested in calls (allocation or interior call stack
      // context calls).
      auto *CI = dyn_cast<CallBase>(&I);
      if (!CI)
        continue;
      auto *CalledFunction = CI->getCalledFunction();
      if (CalledFunction && CalledFunction->isIntrinsic())
        continue;
      // List of call stack ids computed from the location hashes on debug
      // locations (leaf to inlined at root).
      SmallVector<uint64_t, 8> InlinedCallStack;
      // Was the leaf location found in one of the profile maps?
      bool LeafFound = false;
      // If leaf was found in a map, iterators pointing to its location in both
      // of the maps. It might exist in neither, one, or both (the latter case
      // can happen because we don't currently have discriminators to
      // distinguish the case when a single line/col maps to both an allocation
      // and another callsite).
      auto AllocInfoIter = LocHashToAllocInfo.end();
      auto CallSitesIter = LocHashToCallSites.end();
      for (const DILocation *DIL = I.getDebugLoc(); DIL != nullptr;
           DIL = DIL->getInlinedAt()) {
        // Use C++ linkage name if possible. Need to compile with
        // -fdebug-info-for-profiling to get linkage name.
        StringRef Name = DIL->getScope()->getSubprogram()->getLinkageName();
        if (Name.empty())
          Name = DIL->getScope()->getSubprogram()->getName();
        auto CalleeGUID = Function::getGUIDAssumingExternalLinkage(Name);
        auto StackId = computeStackId(CalleeGUID, GetOffset(DIL),
                                      ProfileHasColumns ? DIL->getColumn() : 0);
        // Check if we have found the profile's leaf frame. If yes, collect
        // the rest of the call's inlined context starting here. If not, see if
        // we find a match further up the inlined context (in case the profile
        // was missing debug frames at the leaf).
        if (!LeafFound) {
          AllocInfoIter = LocHashToAllocInfo.find(StackId);
          CallSitesIter = LocHashToCallSites.find(StackId);
          if (AllocInfoIter != LocHashToAllocInfo.end() ||
              CallSitesIter != LocHashToCallSites.end())
            LeafFound = true;
        }
        if (LeafFound)
          InlinedCallStack.push_back(StackId);
      }
      // If leaf not in either of the maps, skip inst.
      if (!LeafFound)
        continue;

      // First add !memprof metadata from allocation info, if we found the
      // instruction's leaf location in that map, and if the rest of the
      // instruction's locations match the prefix Frame locations on an
      // allocation context with the same leaf.
      if (AllocInfoIter != LocHashToAllocInfo.end() &&
          // Only consider allocations which support hinting.
          isAllocationWithHotColdVariant(CI->getCalledFunction(), TLI)) {
        // We may match this instruction's location list to multiple MIB
        // contexts. Add them to a Trie specialized for trimming the contexts to
        // the minimal needed to disambiguate contexts with unique behavior.
        CallStackTrie AllocTrie(&ORE, MaxColdSize);
        uint64_t TotalSize = 0;
        uint64_t TotalColdSize = 0;
        for (auto *AllocInfo : AllocInfoIter->second) {
          // Check the full inlined call stack against this one.
          // If we found and thus matched all frames on the call, include
          // this MIB.
          if (stackFrameIncludesInlinedCallStack(AllocInfo->CallStack,
                                                 InlinedCallStack)) {
            NumOfMemProfMatchedAllocContexts++;
            uint64_t FullStackId = 0;
            if (ClPrintMemProfMatchInfo || recordContextSizeInfoForAnalysis())
              FullStackId = computeFullStackId(AllocInfo->CallStack);
            auto AllocType = addCallStack(AllocTrie, AllocInfo, FullStackId);
            TotalSize += AllocInfo->Info.getTotalSize();
            if (AllocType == AllocationType::Cold)
              TotalColdSize += AllocInfo->Info.getTotalSize();
            // Record information about the allocation if match info printing
            // was requested.
            if (ClPrintMemProfMatchInfo) {
              assert(FullStackId != 0);
              FullStackIdToAllocMatchInfo[std::make_pair(
                  FullStackId, InlinedCallStack.size())] = {
                  AllocInfo->Info.getTotalSize(), AllocType};
            }
          }
        }
        // If the threshold for the percent of cold bytes is less than 100%,
        // and not all bytes are cold, see if we should still hint this
        // allocation as cold without context sensitivity.
        if (TotalColdSize < TotalSize && MinMatchedColdBytePercent < 100 &&
            TotalColdSize * 100 >= MinMatchedColdBytePercent * TotalSize) {
          AllocTrie.addSingleAllocTypeAttribute(CI, AllocationType::Cold,
                                                "dominant");
          continue;
        }

        // We might not have matched any to the full inlined call stack.
        // But if we did, create and attach metadata, or a function attribute if
        // all contexts have identical profiled behavior.
        if (!AllocTrie.empty()) {
          NumOfMemProfMatchedAllocs++;
          // MemprofMDAttached will be false if a function attribute was
          // attached.
          bool MemprofMDAttached = AllocTrie.buildAndAttachMIBMetadata(CI);
          assert(MemprofMDAttached == I.hasMetadata(LLVMContext::MD_memprof));
          if (MemprofMDAttached) {
            // Add callsite metadata for the instruction's location list so that
            // it simpler later on to identify which part of the MIB contexts
            // are from this particular instruction (including during inlining,
            // when the callsite metadata will be updated appropriately).
            // FIXME: can this be changed to strip out the matching stack
            // context ids from the MIB contexts and not add any callsite
            // metadata here to save space?
            addCallsiteMetadata(I, InlinedCallStack, Ctx);
          }
        }
        continue;
      }

      if (CallSitesIter == LocHashToCallSites.end())
        continue;

      // Otherwise, add callsite metadata. If we reach here then we found the
      // instruction's leaf location in the callsites map and not the allocation
      // map.
      for (const auto &CallSiteEntry : CallSitesIter->second) {
        // If we found and thus matched all frames on the call, create and
        // attach call stack metadata.
        if (stackFrameIncludesInlinedCallStack(CallSiteEntry.Frames,
                                               InlinedCallStack)) {
          NumOfMemProfMatchedCallSites++;
          addCallsiteMetadata(I, InlinedCallStack, Ctx);

          // Try to attach indirect call metadata if possible.
          if (!CalledFunction)
            addVPMetadata(M, I, CallSiteEntry.CalleeGuids);

          // Only need to find one with a matching call stack and add a single
          // callsite metadata.

          // Accumulate call site matching information upon request.
          if (ClPrintMemProfMatchInfo) {
            std::vector<uint64_t> CallStack;
            append_range(CallStack, InlinedCallStack);
            MatchedCallSites.insert(std::move(CallStack));
          }
          break;
        }
      }
    }
  }
}

MemProfUsePass::MemProfUsePass(std::string MemoryProfileFile,
                               IntrusiveRefCntPtr<vfs::FileSystem> FS)
    : MemoryProfileFileName(MemoryProfileFile), FS(FS) {
  if (!FS)
    this->FS = vfs::getRealFileSystem();
}

PreservedAnalyses MemProfUsePass::run(Module &M, ModuleAnalysisManager &AM) {
  // Return immediately if the module doesn't contain any function.
  if (M.empty())
    return PreservedAnalyses::all();

  LLVM_DEBUG(dbgs() << "Read in memory profile:");
  auto &Ctx = M.getContext();
  auto ReaderOrErr = IndexedInstrProfReader::create(MemoryProfileFileName, *FS);
  if (Error E = ReaderOrErr.takeError()) {
    handleAllErrors(std::move(E), [&](const ErrorInfoBase &EI) {
      Ctx.diagnose(
          DiagnosticInfoPGOProfile(MemoryProfileFileName.data(), EI.message()));
    });
    return PreservedAnalyses::all();
  }

  std::unique_ptr<IndexedInstrProfReader> MemProfReader =
      std::move(ReaderOrErr.get());
  if (!MemProfReader) {
    Ctx.diagnose(DiagnosticInfoPGOProfile(
        MemoryProfileFileName.data(), StringRef("Cannot get MemProfReader")));
    return PreservedAnalyses::all();
  }

  if (!MemProfReader->hasMemoryProfile()) {
    Ctx.diagnose(DiagnosticInfoPGOProfile(MemoryProfileFileName.data(),
                                          "Not a memory profile"));
    return PreservedAnalyses::all();
  }

  auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();

  TargetLibraryInfo &TLI = FAM.getResult<TargetLibraryAnalysis>(*M.begin());
  DenseMap<uint64_t, LocToLocMap> UndriftMaps;
  if (SalvageStaleProfile)
    UndriftMaps = computeUndriftMap(M, MemProfReader.get(), TLI);

  // Map from the stack hash and matched frame count of each allocation context
  // in the function profiles to the total profiled size (bytes) and allocation
  // type.
  std::map<std::pair<uint64_t, unsigned>, AllocMatchInfo>
      FullStackIdToAllocMatchInfo;

  // Set of the matched call sites, each expressed as a sequence of an inline
  // call stack.
  std::set<std::vector<uint64_t>> MatchedCallSites;

  uint64_t MaxColdSize = 0;
  if (auto *MemProfSum = MemProfReader->getMemProfSummary())
    MaxColdSize = MemProfSum->getMaxColdTotalSize();

  for (auto &F : M) {
    if (F.isDeclaration())
      continue;

    const TargetLibraryInfo &TLI = FAM.getResult<TargetLibraryAnalysis>(F);
    auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
    readMemprof(M, F, MemProfReader.get(), TLI, FullStackIdToAllocMatchInfo,
                MatchedCallSites, UndriftMaps, ORE, MaxColdSize);
  }

  if (ClPrintMemProfMatchInfo) {
    for (const auto &[IdLengthPair, Info] : FullStackIdToAllocMatchInfo) {
      auto [Id, Length] = IdLengthPair;
      errs() << "MemProf " << getAllocTypeAttributeString(Info.AllocType)
             << " context with id " << Id << " has total profiled size "
             << Info.TotalSize << " is matched with " << Length << " frames\n";
    }

    for (const auto &CallStack : MatchedCallSites) {
      errs() << "MemProf callsite match for inline call stack";
      for (uint64_t StackId : CallStack)
        errs() << " " << StackId;
      errs() << "\n";
    }
  }

  return PreservedAnalyses::none();
}