File: Var.h

package info (click to toggle)
llvm-toolchain-17 1%3A17.0.6-22
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,799,624 kB
  • sloc: cpp: 6,428,607; ansic: 1,383,196; asm: 793,408; python: 223,504; objc: 75,364; f90: 60,502; lisp: 33,869; pascal: 15,282; sh: 9,684; perl: 7,453; ml: 4,937; awk: 3,523; makefile: 2,889; javascript: 2,149; xml: 888; fortran: 619; cs: 573
file content (495 lines) | stat: -rw-r--r-- 20,970 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
//===- Var.h ----------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_VAR_H
#define MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_VAR_H

#include "TemplateExtras.h"

#include "mlir/IR/OpImplementation.h"
#include "llvm/ADT/EnumeratedArray.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/StringMap.h"

namespace mlir {
namespace sparse_tensor {
namespace ir_detail {

// Throughout this namespace we use the name `isWF` (is "well-formed")
// for predicates that detect intrinsic structural integrity criteria,
// and hence which should always be assertively true.  Whereas we reserve
// the name `isValid` for predicates that detect extrinsic semantic
// integrity criteria, and hence which may legitimately return false even
// in well-formed programs.  Moreover, "validity" is often a relational
// or contextual property, and therefore the same term may be considered
// valid in one context yet invalid in another.
//
// As an example of why we make this distinction, consider `Var`.
// A variable is well-formed if its kind and identifier are both well-formed;
// this can be checked locally, and the resulting truth-value holds globally.
// Whereas, a variable is valid with respect to a particular `Ranks` only if
// it is within bounds; and a variable is valid with respect to a particular
// `DimLvlMap` only if the variable is bound and all uses of the variable
// are within the scope of that binding.

// Throughout this namespace we use `enum class` types to form "newtypes".
// The enum-based implementation of newtypes only serves to block implicit
// conversions; it cannot enforce any wellformedness constraints, since
// `enum class` permits using direct-list-initialization to construct
// arbitrary values[1].  Consequently, we use the syntax "`E{u}`" whenever
// we intend that ctor to be a noop (i.e., `std::is_same_v<decltype(u),
// std::underlying_type_t<E>>`), since the compiler will ensure that that's
// the case.  Whereas we only use the "`static_cast<E>(u)`" syntax when we
// specifically intend to introduce conversions.
//
// [1]:
// <https://en.cppreference.com/w/cpp/language/enum#enum_relaxed_init_cpp17>

//===----------------------------------------------------------------------===//
/// The three kinds of variables that `Var` can be.
///
/// NOTE: The numerical values used to represent this enum should be
/// treated as an implementation detail, not as part of the API.  In the
/// API below we use the canonical ordering `{Symbol,Dimension,Level}` even
/// though that does not agree with the numerical ordering of the numerical
/// representation.
enum class VarKind { Symbol = 1, Dimension = 0, Level = 2 };

constexpr bool isWF(VarKind vk) {
  const auto vk_ = to_underlying(vk);
  return 0 <= vk_ && vk_ <= 2;
}

/// Swaps `Dimension` and `Level`, but leaves `Symbol` the same.
constexpr VarKind flipVarKind(VarKind vk) {
  return VarKind{2 - to_underlying(vk)};
}
static_assert(flipVarKind(VarKind::Symbol) == VarKind::Symbol &&
              flipVarKind(VarKind::Dimension) == VarKind::Level &&
              flipVarKind(VarKind::Level) == VarKind::Dimension);

/// Gets the ASCII character used as the prefix when printing `Var`.
constexpr char toChar(VarKind vk) {
  // If `isWF(vk)` then this computation's intermediate results are always
  // in the range [-44..126] (where that lower bound is under worst-case
  // rearranging of the expression); and `int_fast8_t` is the fastest type
  // which can support that range without over-/underflow.
  const auto vk_ = static_cast<int_fast8_t>(to_underlying(vk));
  return static_cast<char>(100 + vk_ * (26 - vk_ * 11));
}
static_assert(toChar(VarKind::Symbol) == 's' &&
              toChar(VarKind::Dimension) == 'd' &&
              toChar(VarKind::Level) == 'l');

//===----------------------------------------------------------------------===//
/// The type of arrays indexed by `VarKind`.
template <typename T>
using VarKindArray = llvm::EnumeratedArray<T, VarKind, VarKind::Level>;

//===----------------------------------------------------------------------===//
/// A concrete variable, to be used in our variant of `AffineExpr`.
class Var {
  // Design Note: This class makes several distinctions which may at first
  // seem unnecessary but are in fact needed for implementation reasons.
  // These distinctions are summarized as follows:
  //
  // * `Var`
  //   Client-facing class for `VarKind` + `Var::Num` pairs, with RTTI
  //   support for subclasses with a fixed `VarKind`.
  // * `Var::Num`
  //   Client-facing typedef for the type of variable numbers; defined
  //   so that client code can use it to disambiguate/document when things
  //   are intended to be variable numbers, as opposed to some other thing
  //   which happens to be represented as `unsigned`.
  // * `Var::Storage`
  //   Private typedef for the storage of `Var::Impl`; defined only because
  //   it's also needed for defining `kMaxNum`.  Note that this type must be
  //   kept distinct from `Var::Num`: not only can they be different C++ types
  //   (even though they currently happen to be the same), but also because
  //   they use different bitwise representations.
  // * `Var::Impl`
  //   The underlying implementation of `Var`; needed by RTTI to serve as
  //   an intermediary between `Var` and `Var::Storage`.  That is, we want
  //   the RTTI methods to select the `U(Var::Impl)` ctor, without any
  //   possibility of confusing that with the `U(Var::Num)` ctor nor with
  //   the copy-ctor.  (Although the `U(Var::Impl)` ctor is effectively
  //   identical to the copy-ctor, it doesn't have the type that C++ expects
  //   for a copy-ctor.)
  //
  // TODO: See if it'd be cleaner to use "llvm/ADT/Bitfields.h" in lieu
  // of doing our own bitbashing (though that seems to only be used by LLVM
  // for defining machine/assembly ops, and not anywhere else in LLVM/MLIR).
public:
  /// Typedef for the type of variable numbers.
  using Num = unsigned;

private:
  /// Typedef for the underlying storage of `Var::Impl`.
  using Storage = unsigned;

  /// The largest `Var::Num` supported by `Var`/`Var::Impl`/`Var::Storage`.
  /// Two low-order bits are reserved for storing the `VarKind`,
  /// and one high-order bit is reserved for future use (e.g., to support
  /// `DenseMapInfo<Var>` while maintaining the usual numeric values for
  /// "empty" and "tombstone").
  static constexpr Num kMaxNum =
      static_cast<Num>(std::numeric_limits<Storage>::max() >> 3);

public:
  /// Checks whether the number would be accepted by `Var(VarKind,Var::Num)`.
  //
  // This must be public for `VarInfo` to use it (whereas we don't want
  // to expose the `impl` field via friendship).
  static constexpr bool isWF_Num(Num n) { return n <= kMaxNum; }

protected:
  /// The underlying implementation of `Var`.  Note that this must be kept
  /// distinct from `Var` itself, since we want to ensure that the RTTI
  /// methods will select the `U(Var::Impl)` ctor rather than selecting
  /// the `U(Var::Num)` ctor.
  class Impl final {
    Storage data;

  public:
    constexpr Impl(VarKind vk, Num n)
        : data((static_cast<Storage>(n) << 2) |
               static_cast<Storage>(to_underlying(vk))) {
      assert(isWF(vk) && "unknown VarKind");
      assert(isWF_Num(n) && "Var::Num is too large");
    }
    constexpr bool operator==(Impl other) const { return data == other.data; }
    constexpr bool operator!=(Impl other) const { return !(*this == other); }
    constexpr VarKind getKind() const { return static_cast<VarKind>(data & 3); }
    constexpr Num getNum() const { return static_cast<Num>(data >> 2); }
  };
  static_assert(IsZeroCostAbstraction<Impl>);

private:
  Impl impl;

protected:
  /// Protected ctor for the RTTI methods to use.
  constexpr explicit Var(Impl impl) : impl(impl) {}

public:
  constexpr Var(VarKind vk, Num n) : impl(Impl(vk, n)) {}
  Var(AffineSymbolExpr sym) : Var(VarKind::Symbol, sym.getPosition()) {}
  Var(VarKind vk, AffineDimExpr var) : Var(vk, var.getPosition()) {}

  constexpr bool operator==(Var other) const { return impl == other.impl; }
  constexpr bool operator!=(Var other) const { return !(*this == other); }

  constexpr VarKind getKind() const { return impl.getKind(); }
  constexpr Num getNum() const { return impl.getNum(); }

  template <typename U>
  constexpr bool isa() const;
  template <typename U>
  constexpr U cast() const;
  template <typename U>
  constexpr std::optional<U> dyn_cast() const;

  void print(llvm::raw_ostream &os) const;
  void print(AsmPrinter &printer) const;
  void dump() const;
};
static_assert(IsZeroCostAbstraction<Var>);

class SymVar final : public Var {
  using Var::Var; // inherit `Var(Impl)` ctor for RTTI use.
public:
  static constexpr VarKind Kind = VarKind::Symbol;
  static constexpr bool classof(Var const *var) {
    return var->getKind() == Kind;
  }
  constexpr SymVar(Num sym) : Var(Kind, sym) {}
  SymVar(AffineSymbolExpr symExpr) : Var(symExpr) {}
};
static_assert(IsZeroCostAbstraction<SymVar>);

class DimVar final : public Var {
  using Var::Var; // inherit `Var(Impl)` ctor for RTTI use.
public:
  static constexpr VarKind Kind = VarKind::Dimension;
  static constexpr bool classof(Var const *var) {
    return var->getKind() == Kind;
  }
  constexpr DimVar(Num dim) : Var(Kind, dim) {}
  DimVar(AffineDimExpr dimExpr) : Var(Kind, dimExpr) {}
};
static_assert(IsZeroCostAbstraction<DimVar>);

class LvlVar final : public Var {
  using Var::Var; // inherit `Var(Impl)` ctor for RTTI use.
public:
  static constexpr VarKind Kind = VarKind::Level;
  static constexpr bool classof(Var const *var) {
    return var->getKind() == Kind;
  }
  constexpr LvlVar(Num lvl) : Var(Kind, lvl) {}
  LvlVar(AffineDimExpr lvlExpr) : Var(Kind, lvlExpr) {}
};
static_assert(IsZeroCostAbstraction<LvlVar>);

template <typename U>
constexpr bool Var::isa() const {
  if constexpr (std::is_same_v<U, SymVar>)
    return getKind() == VarKind::Symbol;
  if constexpr (std::is_same_v<U, DimVar>)
    return getKind() == VarKind::Dimension;
  if constexpr (std::is_same_v<U, LvlVar>)
    return getKind() == VarKind::Level;
  // NOTE: The `AffineExpr::isa` implementation doesn't have a fallthrough
  // case returning `false`; wrengr guesses that's so things will fail
  // to compile whenever `!std::is_base_of<Var, U>`.  Though it's unclear
  // why they implemented it that way rather than using SFINAE for that,
  // especially since it would give better error messages.
}

template <typename U>
constexpr U Var::cast() const {
  assert(isa<U>());
  // NOTE: This should select the `U(Var::Impl)` ctor, *not* `U(Var::Num)`
  return U(impl);
}

template <typename U>
constexpr std::optional<U> Var::dyn_cast() const {
  // NOTE: This should select the `U(Var::Impl)` ctor, *not* `U(Var::Num)`
  return isa<U>() ? std::make_optional(U(impl)) : std::nullopt;
}

//===----------------------------------------------------------------------===//
// Forward-decl so that we can declare methods of `Ranks` and `VarSet`.
class DimLvlExpr;

//===----------------------------------------------------------------------===//
class Ranks final {
  // Not using `VarKindArray` since `EnumeratedArray` doesn't support constexpr.
  unsigned impl[3];

  static constexpr unsigned to_index(VarKind vk) {
    assert(isWF(vk) && "unknown VarKind");
    return static_cast<unsigned>(to_underlying(vk));
  }

public:
  constexpr Ranks(unsigned symRank, unsigned dimRank, unsigned lvlRank)
      : impl() {
    impl[to_index(VarKind::Symbol)] = symRank;
    impl[to_index(VarKind::Dimension)] = dimRank;
    impl[to_index(VarKind::Level)] = lvlRank;
  }
  Ranks(VarKindArray<unsigned> const &ranks)
      : Ranks(ranks[VarKind::Symbol], ranks[VarKind::Dimension],
              ranks[VarKind::Level]) {}

  constexpr unsigned getRank(VarKind vk) const { return impl[to_index(vk)]; }
  constexpr unsigned getSymRank() const { return getRank(VarKind::Symbol); }
  constexpr unsigned getDimRank() const { return getRank(VarKind::Dimension); }
  constexpr unsigned getLvlRank() const { return getRank(VarKind::Level); }

  constexpr bool isValid(Var var) const {
    return var.getNum() < getRank(var.getKind());
  }
  bool isValid(DimLvlExpr expr) const;
};
static_assert(IsZeroCostAbstraction<Ranks>);

//===----------------------------------------------------------------------===//
/// Efficient representation of a set of `Var`.
///
/// NOTE: For the `contains`/`occursIn` methods: if variables occurring in
/// the method parameter are OOB for the `VarSet`, then these methods will
/// always return false.  However, for the `add` methods: OOB parameters
/// cause undefined behavior.  Currently the `add` methods will raise an
/// assertion error; though we may change that behavior in the future
/// (e.g., to resize the underlying bitvectors).
class VarSet final {
  // If we're willing to give up the possibility of resizing the
  // individual bitvectors, then we could flatten this into a single
  // bitvector (akin to how `mlir::presburger::PresburgerSpace` does it);
  // however, doing so would greatly complicate the implementation of the
  // `occursIn(VarSet)` method.
  VarKindArray<llvm::SmallBitVector> impl;

public:
  explicit VarSet(Ranks const &ranks);

  bool contains(Var var) const;
  bool occursIn(VarSet const &vars) const;
  bool occursIn(DimLvlExpr expr) const;

  void add(Var var);
  void add(VarSet const &vars);
  void add(DimLvlExpr expr);
};

//===----------------------------------------------------------------------===//
/// A record of metadata for/about a variable, used by `VarEnv`.
/// The principal goal of this record is to enable `VarEnv` to be used for
/// incremental parsing; in particular, `VarInfo` allows the `Var::Num` to
/// remain unknown, since each record is instead identified by `VarInfo::ID`.
/// Therefore the `VarEnv` can freely allocate `VarInfo::ID` in whatever
/// order it likes, irrespective of the binding order (`Var::Num`) of the
/// associated variable.
class VarInfo final {
public:
  /// Newtype for unique identifiers of `VarInfo` records, to ensure
  /// they aren't confused with `Var::Num`.
  enum class ID : unsigned {};

private:
  // FUTURE_CL(wrengr): We could use the high-bit of `Var::Impl` to
  // store the `std::optional` bit, therefore allowing us to bitbash the
  // `num` and `kind` fields together.
  //
  StringRef name;              // The bare-id used in the MLIR source.
  llvm::SMLoc loc;             // The location of the first occurence.
                               // TODO(wrengr): See the above `LocatedVar` note.
  ID id;                       // The unique `VarInfo`-identifier.
  std::optional<Var::Num> num; // The unique `Var`-identifier (if resolved).
  VarKind kind;                // The kind of variable.

public:
  constexpr VarInfo(ID id, StringRef name, llvm::SMLoc loc, VarKind vk,
                    std::optional<Var::Num> n = {})
      : name(name), loc(loc), id(id), num(n), kind(vk) {
    assert(!name.empty() && "null StringRef");
    assert(loc.isValid() && "null SMLoc");
    assert(isWF(vk) && "unknown VarKind");
    assert((!n || Var::isWF_Num(*n)) && "Var::Num is too large");
  }

  constexpr StringRef getName() const { return name; }
  constexpr llvm::SMLoc getLoc() const { return loc; }
  Location getLocation(AsmParser &parser) const {
    return parser.getEncodedSourceLoc(loc);
  }
  constexpr ID getID() const { return id; }
  constexpr VarKind getKind() const { return kind; }
  constexpr std::optional<Var::Num> getNum() const { return num; }
  constexpr bool hasNum() const { return num.has_value(); }
  void setNum(Var::Num n);
  constexpr Var getVar() const {
    assert(hasNum());
    return Var(kind, *num);
  }
  constexpr std::optional<Var> tryGetVar() const {
    return num ? std::make_optional(Var(kind, *num)) : std::nullopt;
  }
};
// We don't actually require this, since `VarInfo` is a proper struct
// rather than a newtype.  But it passes, so for now we'll keep it around.
// TODO: Uncomment the static assert, it fails the build with gcc7 right now.
// static_assert(IsZeroCostAbstraction<VarInfo>);

//===----------------------------------------------------------------------===//
enum class Policy { MustNot, May, Must };

//===----------------------------------------------------------------------===//
class VarEnv final {
  /// Map from `VarKind` to the next free `Var::Num`; used by `bindVar`.
  VarKindArray<Var::Num> nextNum;
  /// Map from `VarInfo::ID` to shared storage for the actual `VarInfo` objects.
  SmallVector<VarInfo> vars;
  /// Map from variable names to their `VarInfo::ID`.
  llvm::StringMap<VarInfo::ID> ids;

  VarInfo::ID nextID() const { return static_cast<VarInfo::ID>(vars.size()); }

public:
  VarEnv() : nextNum(0) {}

  /// Gets the underlying storage for the `VarInfo` identified by
  /// the `VarInfo::ID`.
  ///
  /// NOTE: The returned reference can become dangling if the `VarEnv`
  /// object is mutated during the lifetime of the pointer.  Therefore,
  /// client code should not store the reference nor otherwise allow it
  /// to live too long.
  //
  // FUTURE_CL(wrengr): Consider trying to define/use a nested class
  // `struct{VarEnv*; VarInfo::ID}` akin to `BitVector::reference`.
  VarInfo const &access(VarInfo::ID id) const {
    // `SmallVector::operator[]` already asserts the index is in-bounds.
    return vars[to_underlying(id)];
  }
  VarInfo const *access(std::optional<VarInfo::ID> oid) const {
    return oid ? &access(*oid) : nullptr;
  }

private:
  VarInfo &access(VarInfo::ID id) {
    return const_cast<VarInfo &>(std::as_const(*this).access(id));
  }
  VarInfo *access(std::optional<VarInfo::ID> oid) {
    return const_cast<VarInfo *>(std::as_const(*this).access(oid));
  }

public:
  /// Attempts to look up the variable with the given name.
  std::optional<VarInfo::ID> lookup(StringRef name) const;

  /// Attempts to create a new currently-unbound variable.  When a variable
  /// of that name already exists: if `verifyUsage` is true, then will assert
  /// that the variable has the same kind and a consistent location; otherwise,
  /// when `verifyUsage` is false, this is a noop.  Returns the identifier
  /// for the variable with the given name (i.e., either the newly created
  /// variable, or the pre-existing variable), and a bool indicating whether
  /// a new variable was created.
  std::pair<VarInfo::ID, bool> create(StringRef name, llvm::SMLoc loc,
                                      VarKind vk, bool verifyUsage = false);

  /// Attempts to lookup or create a variable according to the given
  /// `Policy`.  Returns nullopt in one of two circumstances:
  /// (1) the policy says we `Must` create, yet the variable already exists;
  /// (2) the policy says we `MustNot` create, yet no such variable exists.
  /// Otherwise, if the variable already exists then it is validated against
  /// the given kind and location to ensure consistency.
  //
  // TODO(wrengr): Define an enum of error codes, to avoid `nullopt`-blindness
  // TODO(wrengr): Prolly want to rename this to `create` and move the
  // current method of that name to being a private `createImpl`.
  std::optional<std::pair<VarInfo::ID, bool>>
  lookupOrCreate(Policy creationPolicy, StringRef name, llvm::SMLoc loc,
                 VarKind vk);

  /// Binds the given variable to the next free `Var::Num` for its `VarKind`.
  Var bindVar(VarInfo::ID id);

  /// Creates a new variable of the given kind and immediately binds it.
  /// This should only be used whenever the variable is known to be unused
  /// and therefore does not have a name.
  Var bindUnusedVar(VarKind vk);

  InFlightDiagnostic emitErrorIfAnyUnbound(AsmParser &parser) const;

  /// Returns the current ranks of bound variables.  This method should
  /// only be used after the environment is "finished", since binding new
  /// variables will (semantically) invalidate any previously returned `Ranks`.
  Ranks getRanks() const { return Ranks(nextNum); }

  /// Gets the `Var` identified by the `VarInfo::ID`, raising an assertion
  /// failure if the variable is not bound.
  Var getVar(VarInfo::ID id) const { return access(id).getVar(); }

  /// Gets the `Var` identified by the `VarInfo::ID`, returning nullopt
  /// if the variable is not bound.
  std::optional<Var> tryGetVar(VarInfo::ID id) const {
    return access(id).tryGetVar();
  }
};

//===----------------------------------------------------------------------===//

} // namespace ir_detail
} // namespace sparse_tensor
} // namespace mlir

#endif // MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_VAR_H