File: RegisterBankEmitter.cpp

package info (click to toggle)
llvm-toolchain-21 1%3A21.1.7-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 2,245,064 kB
  • sloc: cpp: 7,619,731; ansic: 1,434,018; asm: 1,058,748; python: 252,740; f90: 94,671; objc: 70,685; lisp: 42,813; pascal: 18,401; sh: 8,601; ml: 5,111; perl: 4,720; makefile: 3,676; awk: 3,523; javascript: 2,409; xml: 892; fortran: 770
file content (443 lines) | stat: -rw-r--r-- 16,615 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
//===- RegisterBankEmitter.cpp - Generate a Register Bank Desc. -*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This tablegen backend is responsible for emitting a description of a target
// register bank for a code generator.
//
//===----------------------------------------------------------------------===//

#include "Common/CodeGenRegisters.h"
#include "Common/CodeGenTarget.h"
#include "Common/InfoByHwMode.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TGTimer.h"
#include "llvm/TableGen/TableGenBackend.h"

#define DEBUG_TYPE "register-bank-emitter"

using namespace llvm;

namespace {
class RegisterBank {

  /// A vector of register classes that are included in the register bank.
  typedef std::vector<const CodeGenRegisterClass *> RegisterClassesTy;

private:
  const Record &TheDef;

  /// The register classes that are covered by the register bank.
  RegisterClassesTy RCs;

  /// The register class with the largest register size.
  std::vector<const CodeGenRegisterClass *> RCsWithLargestRegSize;

public:
  RegisterBank(const Record &TheDef, unsigned NumModeIds)
      : TheDef(TheDef), RCsWithLargestRegSize(NumModeIds) {}

  /// Get the human-readable name for the bank.
  StringRef getName() const { return TheDef.getValueAsString("Name"); }
  /// Get the name of the enumerator in the ID enumeration.
  std::string getEnumeratorName() const {
    return (TheDef.getName() + "ID").str();
  }

  /// Get the name of the array holding the register class coverage data;
  std::string getCoverageArrayName() const {
    return (TheDef.getName() + "CoverageData").str();
  }

  /// Get the name of the global instance variable.
  StringRef getInstanceVarName() const { return TheDef.getName(); }

  const Record &getDef() const { return TheDef; }

  /// Get the register classes listed in the RegisterBank.RegisterClasses field.
  std::vector<const CodeGenRegisterClass *>
  getExplicitlySpecifiedRegisterClasses(
      const CodeGenRegBank &RegisterClassHierarchy) const {
    std::vector<const CodeGenRegisterClass *> RCs;
    for (const auto *RCDef : getDef().getValueAsListOfDefs("RegisterClasses"))
      RCs.push_back(RegisterClassHierarchy.getRegClass(RCDef));
    return RCs;
  }

  /// Add a register class to the bank without duplicates.
  void addRegisterClass(const CodeGenRegisterClass *RC) {
    if (llvm::is_contained(RCs, RC))
      return;

    // FIXME? We really want the register size rather than the spill size
    //        since the spill size may be bigger on some targets with
    //        limited load/store instructions. However, we don't store the
    //        register size anywhere (we could sum the sizes of the subregisters
    //        but there may be additional bits too) and we can't derive it from
    //        the VT's reliably due to Untyped.
    unsigned NumModeIds = RCsWithLargestRegSize.size();
    for (unsigned M = 0; M < NumModeIds; ++M) {
      if (RCsWithLargestRegSize[M] == nullptr)
        RCsWithLargestRegSize[M] = RC;
      else if (RCsWithLargestRegSize[M]->RSI.get(M).SpillSize <
               RC->RSI.get(M).SpillSize)
        RCsWithLargestRegSize[M] = RC;
      assert(RCsWithLargestRegSize[M] && "RC was nullptr?");
    }

    RCs.emplace_back(RC);
  }

  const CodeGenRegisterClass *getRCWithLargestRegSize(unsigned HwMode) const {
    return RCsWithLargestRegSize[HwMode];
  }

  iterator_range<typename RegisterClassesTy::const_iterator>
  register_classes() const {
    return llvm::make_range(RCs.begin(), RCs.end());
  }
};

class RegisterBankEmitter {
private:
  const CodeGenTarget Target;
  const RecordKeeper &Records;

  void emitHeader(raw_ostream &OS, const StringRef TargetName,
                  ArrayRef<RegisterBank> Banks);
  void emitBaseClassDefinition(raw_ostream &OS, const StringRef TargetName,
                               ArrayRef<RegisterBank> Banks);
  void emitBaseClassImplementation(raw_ostream &OS, const StringRef TargetName,
                                   ArrayRef<RegisterBank> Banks);

public:
  RegisterBankEmitter(const RecordKeeper &R) : Target(R), Records(R) {}

  void run(raw_ostream &OS);
};

} // end anonymous namespace

/// Emit code to declare the ID enumeration and external global instance
/// variables.
void RegisterBankEmitter::emitHeader(raw_ostream &OS,
                                     const StringRef TargetName,
                                     ArrayRef<RegisterBank> Banks) {
  // <Target>RegisterBankInfo.h
  OS << "namespace llvm {\n"
     << "namespace " << TargetName << " {\n"
     << "enum : unsigned {\n";

  OS << "  InvalidRegBankID = ~0u,\n";
  unsigned ID = 0;
  for (const auto &Bank : Banks)
    OS << "  " << Bank.getEnumeratorName() << " = " << ID++ << ",\n";
  OS << "  NumRegisterBanks,\n"
     << "};\n"
     << "} // end namespace " << TargetName << "\n"
     << "} // end namespace llvm\n";
}

/// Emit declarations of the <Target>GenRegisterBankInfo class.
void RegisterBankEmitter::emitBaseClassDefinition(
    raw_ostream &OS, const StringRef TargetName, ArrayRef<RegisterBank> Banks) {
  OS << "private:\n"
     << "  static const RegisterBank *RegBanks[];\n"
     << "  static const unsigned Sizes[];\n\n"
     << "public:\n"
     << "  const RegisterBank &getRegBankFromRegClass(const "
        "TargetRegisterClass &RC, LLT Ty) const override;\n"
     << "protected:\n"
     << "  " << TargetName << "GenRegisterBankInfo(unsigned HwMode = 0);\n"
     << "\n";
}

/// Visit each register class belonging to the given register bank.
///
/// A class belongs to the bank iff any of these apply:
/// * It is explicitly specified
/// * It is a subclass of a class that is a member.
/// * It is a class containing subregisters of the registers of a class that
///   is a member. This is known as a subreg-class.
///
/// This function must be called for each explicitly specified register class.
///
/// \param RC The register class to search.
/// \param Kind A debug string containing the path the visitor took to reach RC.
/// \param VisitFn The action to take for each class visited. It may be called
///                multiple times for a given class if there are multiple paths
///                to the class.
static void visitRegisterBankClasses(
    const CodeGenRegBank &RegisterClassHierarchy,
    const CodeGenRegisterClass *RC, const Twine &Kind,
    std::function<void(const CodeGenRegisterClass *, StringRef)> VisitFn,
    DenseSet<const CodeGenRegisterClass *> &VisitedRCs) {

  // Make sure we only visit each class once to avoid infinite loops.
  if (!VisitedRCs.insert(RC).second)
    return;

  // Visit each explicitly named class.
  VisitFn(RC, Kind.str());

  for (const auto &PossibleSubclass : RegisterClassHierarchy.getRegClasses()) {
    std::string TmpKind =
        (Kind + " (" + PossibleSubclass.getName() + ")").str();

    // Visit each subclass of an explicitly named class.
    if (RC != &PossibleSubclass && RC->hasSubClass(&PossibleSubclass))
      visitRegisterBankClasses(RegisterClassHierarchy, &PossibleSubclass,
                               TmpKind + " " + RC->getName() + " subclass",
                               VisitFn, VisitedRCs);

    // Visit each class that contains only subregisters of RC with a common
    // subregister-index.
    //
    // More precisely, PossibleSubclass is a subreg-class iff Reg:SubIdx is in
    // PossibleSubclass for all registers Reg from RC using any
    // subregister-index SubReg
    for (const auto &SubIdx : RegisterClassHierarchy.getSubRegIndices()) {
      BitVector BV(RegisterClassHierarchy.getRegClasses().size());
      PossibleSubclass.getSuperRegClasses(&SubIdx, BV);
      if (BV.test(RC->EnumValue)) {
        std::string TmpKind2 = (Twine(TmpKind) + " " + RC->getName() +
                                " class-with-subregs: " + RC->getName())
                                   .str();
        VisitFn(&PossibleSubclass, TmpKind2);
      }
    }
  }
}

void RegisterBankEmitter::emitBaseClassImplementation(
    raw_ostream &OS, StringRef TargetName, ArrayRef<RegisterBank> Banks) {
  const CodeGenRegBank &RegisterClassHierarchy = Target.getRegBank();
  const CodeGenHwModes &CGH = Target.getHwModes();

  OS << "namespace llvm {\n"
     << "namespace " << TargetName << " {\n";
  for (const auto &Bank : Banks) {
    std::vector<std::vector<const CodeGenRegisterClass *>> RCsGroupedByWord(
        (RegisterClassHierarchy.getRegClasses().size() + 31) / 32);

    for (const auto &RC : Bank.register_classes())
      RCsGroupedByWord[RC->EnumValue / 32].push_back(RC);

    OS << "const uint32_t " << Bank.getCoverageArrayName() << "[] = {\n";
    unsigned LowestIdxInWord = 0;
    for (const auto &RCs : RCsGroupedByWord) {
      OS << "    // " << LowestIdxInWord << "-" << (LowestIdxInWord + 31)
         << "\n";
      for (const auto &RC : RCs) {
        OS << "    (1u << (" << RC->getQualifiedIdName() << " - "
           << LowestIdxInWord << ")) |\n";
      }
      OS << "    0,\n";
      LowestIdxInWord += 32;
    }
    OS << "};\n";
  }
  OS << "\n";

  for (const auto &Bank : Banks) {
    std::string QualifiedBankID =
        (TargetName + "::" + Bank.getEnumeratorName()).str();
    OS << "constexpr RegisterBank " << Bank.getInstanceVarName() << "(/* ID */ "
       << QualifiedBankID << ", /* Name */ \"" << Bank.getName() << "\", "
       << "/* CoveredRegClasses */ " << Bank.getCoverageArrayName()
       << ", /* NumRegClasses */ "
       << RegisterClassHierarchy.getRegClasses().size() << ");\n";
  }
  OS << "} // end namespace " << TargetName << "\n"
     << "\n";

  OS << "const RegisterBank *" << TargetName
     << "GenRegisterBankInfo::RegBanks[] = {\n";
  for (const auto &Bank : Banks)
    OS << "    &" << TargetName << "::" << Bank.getInstanceVarName() << ",\n";
  OS << "};\n\n";

  unsigned NumModeIds = CGH.getNumModeIds();
  OS << "const unsigned " << TargetName << "GenRegisterBankInfo::Sizes[] = {\n";
  for (unsigned M = 0; M < NumModeIds; ++M) {
    OS << "    // Mode = " << M << " (";
    if (M == DefaultMode)
      OS << "Default";
    else
      OS << CGH.getMode(M).Name;
    OS << ")\n";
    for (const auto &Bank : Banks) {
      const CodeGenRegisterClass &RC = *Bank.getRCWithLargestRegSize(M);
      unsigned Size = RC.RSI.get(M).SpillSize;
      OS << "    " << Size << ",\n";
    }
  }
  OS << "};\n\n";

  OS << TargetName << "GenRegisterBankInfo::" << TargetName
     << "GenRegisterBankInfo(unsigned HwMode)\n"
     << "    : RegisterBankInfo(RegBanks, " << TargetName
     << "::NumRegisterBanks, Sizes, HwMode) {\n"
     << "  // Assert that RegBank indices match their ID's\n"
     << "#ifndef NDEBUG\n"
     << "  for (auto RB : enumerate(RegBanks))\n"
     << "    assert(RB.index() == RB.value()->getID() && \"Index != ID\");\n"
     << "#endif // NDEBUG\n"
     << "}\n";

  uint32_t NumRegBanks = Banks.size();
  uint32_t BitSize = NextPowerOf2(Log2_32(NumRegBanks));
  uint32_t ElemsPerWord = 32 / BitSize;
  uint32_t BitMask = (1 << BitSize) - 1;
  bool HasAmbigousOrMissingEntry = false;
  struct Entry {
    std::string RCIdName;
    std::string RBIdName;
  };
  SmallVector<Entry, 0> Entries;
  for (const auto &Bank : Banks) {
    for (const auto *RC : Bank.register_classes()) {
      if (RC->EnumValue >= Entries.size())
        Entries.resize(RC->EnumValue + 1);
      Entry &E = Entries[RC->EnumValue];
      E.RCIdName = RC->getIdName();
      if (!E.RBIdName.empty()) {
        HasAmbigousOrMissingEntry = true;
        E.RBIdName = "InvalidRegBankID";
      } else {
        E.RBIdName = (TargetName + "::" + Bank.getEnumeratorName()).str();
      }
    }
  }
  for (auto &E : Entries) {
    if (E.RBIdName.empty()) {
      HasAmbigousOrMissingEntry = true;
      E.RBIdName = "InvalidRegBankID";
    }
  }
  OS << "const RegisterBank &\n"
     << TargetName
     << "GenRegisterBankInfo::getRegBankFromRegClass"
        "(const TargetRegisterClass &RC, LLT) const {\n";
  if (HasAmbigousOrMissingEntry) {
    OS << "  constexpr uint32_t InvalidRegBankID = uint32_t("
       << TargetName + "::InvalidRegBankID) & " << BitMask << ";\n";
  }
  unsigned TableSize =
      Entries.size() / ElemsPerWord + ((Entries.size() % ElemsPerWord) > 0);
  OS << "  static const uint32_t RegClass2RegBank[" << TableSize << "] = {\n";
  uint32_t Shift = 32 - BitSize;
  bool First = true;
  std::string TrailingComment;
  for (auto &E : Entries) {
    Shift += BitSize;
    if (Shift == 32) {
      Shift = 0;
      if (First)
        First = false;
      else
        OS << ',' << TrailingComment << '\n';
    } else {
      OS << " |" << TrailingComment << '\n';
    }
    OS << "    ("
       << (E.RBIdName.empty()
               ? "InvalidRegBankID"
               : Twine("uint32_t(").concat(E.RBIdName).concat(")").str())
       << " << " << Shift << ')';
    if (!E.RCIdName.empty())
      TrailingComment = " // " + E.RCIdName;
    else
      TrailingComment = "";
  }
  OS << TrailingComment
     << "\n  };\n"
        "  const unsigned RegClassID = RC.getID();\n"
        "  if (LLVM_LIKELY(RegClassID < "
     << Entries.size()
     << ")) {\n"
        "    unsigned RegBankID = (RegClass2RegBank[RegClassID / "
     << ElemsPerWord << "] >> ((RegClassID % " << ElemsPerWord << ") * "
     << BitSize << ")) & " << BitMask << ";\n";
  if (HasAmbigousOrMissingEntry) {
    OS << "    if (RegBankID != InvalidRegBankID)\n"
          "      return getRegBank(RegBankID);\n";
  } else {
    OS << "    return getRegBank(RegBankID);\n";
  }
  OS << "  }\n"
        "  llvm_unreachable(llvm::Twine(\"Target needs to handle register "
        "class ID "
        "0x\").concat(llvm::Twine::utohexstr(RegClassID)).str().c_str());\n"
        "}\n";

  OS << "} // end namespace llvm\n";
}

void RegisterBankEmitter::run(raw_ostream &OS) {
  StringRef TargetName = Target.getName();
  const CodeGenRegBank &RegisterClassHierarchy = Target.getRegBank();
  const CodeGenHwModes &CGH = Target.getHwModes();

  TGTimer &Timer = Records.getTimer();
  Timer.startTimer("Analyze records");
  std::vector<RegisterBank> Banks;
  for (const auto &V : Records.getAllDerivedDefinitions("RegisterBank")) {
    DenseSet<const CodeGenRegisterClass *> VisitedRCs;
    RegisterBank Bank(*V, CGH.getNumModeIds());

    for (const CodeGenRegisterClass *RC :
         Bank.getExplicitlySpecifiedRegisterClasses(RegisterClassHierarchy)) {
      visitRegisterBankClasses(
          RegisterClassHierarchy, RC, "explicit",
          [&Bank](const CodeGenRegisterClass *RC, StringRef Kind) {
            LLVM_DEBUG(dbgs()
                       << "Added " << RC->getName() << "(" << Kind << ")\n");
            Bank.addRegisterClass(RC);
          },
          VisitedRCs);
    }

    Banks.push_back(Bank);
  }

  // Warn about ambiguous MIR caused by register bank/class name clashes.
  Timer.startTimer("Warn ambiguous");
  for (const auto &Class : RegisterClassHierarchy.getRegClasses()) {
    for (const auto &Bank : Banks) {
      if (Bank.getName().lower() == StringRef(Class.getName()).lower()) {
        PrintWarning(Bank.getDef().getLoc(), "Register bank names should be "
                                             "distinct from register classes "
                                             "to avoid ambiguous MIR");
        PrintNote(Bank.getDef().getLoc(), "RegisterBank was declared here");
        PrintNote(Class.getDef()->getLoc(), "RegisterClass was declared here");
      }
    }
  }

  Timer.startTimer("Emit output");
  emitSourceFileHeader("Register Bank Source Fragments", OS);
  OS << "#ifdef GET_REGBANK_DECLARATIONS\n"
     << "#undef GET_REGBANK_DECLARATIONS\n";
  emitHeader(OS, TargetName, Banks);
  OS << "#endif // GET_REGBANK_DECLARATIONS\n\n"
     << "#ifdef GET_TARGET_REGBANK_CLASS\n"
     << "#undef GET_TARGET_REGBANK_CLASS\n";
  emitBaseClassDefinition(OS, TargetName, Banks);
  OS << "#endif // GET_TARGET_REGBANK_CLASS\n\n"
     << "#ifdef GET_TARGET_REGBANK_IMPL\n"
     << "#undef GET_TARGET_REGBANK_IMPL\n";
  emitBaseClassImplementation(OS, TargetName, Banks);
  OS << "#endif // GET_TARGET_REGBANK_IMPL\n";
}

static TableGen::Emitter::OptClass<RegisterBankEmitter>
    X("gen-register-bank", "Generate registers bank descriptions");