File: Var.h

package info (click to toggle)
llvm-toolchain-19 1%3A19.1.7-3
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 1,998,520 kB
  • sloc: cpp: 6,951,680; ansic: 1,486,157; asm: 913,598; python: 232,024; f90: 80,126; objc: 75,281; lisp: 37,276; pascal: 16,990; sh: 10,009; ml: 5,058; perl: 4,724; awk: 3,523; makefile: 3,167; javascript: 2,504; xml: 892; fortran: 664; cs: 573
file content (403 lines) | stat: -rw-r--r-- 15,671 bytes parent folder | download | duplicates (13)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
//===- Var.h ----------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_VAR_H
#define MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_VAR_H

#include "TemplateExtras.h"

#include "mlir/IR/OpImplementation.h"
#include "llvm/ADT/EnumeratedArray.h"
#include "llvm/ADT/STLForwardCompat.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/StringMap.h"

namespace mlir {
namespace sparse_tensor {
namespace ir_detail {

//===----------------------------------------------------------------------===//
/// The three kinds of variables that `Var` can be.
///
/// NOTE: The numerical values used to represent this enum should be
/// treated as an implementation detail, not as part of the API.  In the
/// API below we use the canonical ordering `{Symbol,Dimension,Level}` even
/// though that does not agree with the numerical ordering of the numerical
/// representation.
enum class VarKind { Symbol = 1, Dimension = 0, Level = 2 };

[[nodiscard]] constexpr bool isWF(VarKind vk) {
  const auto vk_ = llvm::to_underlying(vk);
  return 0 <= vk_ && vk_ <= 2;
}

/// Gets the ASCII character used as the prefix when printing `Var`.
constexpr char toChar(VarKind vk) {
  // If `isWF(vk)` then this computation's intermediate results are always
  // in the range [-44..126] (where that lower bound is under worst-case
  // rearranging of the expression); and `int_fast8_t` is the fastest type
  // which can support that range without over-/underflow.
  const auto vk_ = static_cast<int_fast8_t>(llvm::to_underlying(vk));
  return static_cast<char>(100 + vk_ * (26 - vk_ * 11));
}
static_assert(toChar(VarKind::Symbol) == 's' &&
              toChar(VarKind::Dimension) == 'd' &&
              toChar(VarKind::Level) == 'l');

//===----------------------------------------------------------------------===//
/// The type of arrays indexed by `VarKind`.
template <typename T>
using VarKindArray = llvm::EnumeratedArray<T, VarKind, VarKind::Level>;

//===----------------------------------------------------------------------===//
/// A concrete variable, to be used in our variant of `AffineExpr`.
/// Client-facing class for `VarKind` + `Var::Num` pairs, with RTTI
/// support for subclasses with a fixed `VarKind`.
class Var {
public:
  /// Typedef for the type of variable numbers.
  using Num = unsigned;

private:
  /// Typedef for the underlying storage of `Var::Impl`.
  using Storage = unsigned;

  /// The largest `Var::Num` supported by `Var`/`Var::Impl`/`Var::Storage`.
  /// Two low-order bits are reserved for storing the `VarKind`,
  /// and one high-order bit is reserved for future use (e.g., to support
  /// `DenseMapInfo<Var>` while maintaining the usual numeric values for
  /// "empty" and "tombstone").
  static constexpr Num kMaxNum =
      static_cast<Num>(std::numeric_limits<Storage>::max() >> 3);

public:
  /// Checks whether the number would be accepted by `Var(VarKind,Var::Num)`.
  //
  // This must be public for `VarInfo` to use it (whereas we don't want
  // to expose the `impl` field via friendship).
  [[nodiscard]] static constexpr bool isWF_Num(Num n) { return n <= kMaxNum; }

protected:
  /// The underlying implementation of `Var`.  Note that this must be kept
  /// distinct from `Var` itself, since we want to ensure that the RTTI
  /// methods will select the `U(Var::Impl)` ctor rather than selecting
  /// the `U(Var::Num)` ctor.
  class Impl final {
    Storage data;

  public:
    constexpr Impl(VarKind vk, Num n)
        : data((static_cast<Storage>(n) << 2) |
               static_cast<Storage>(llvm::to_underlying(vk))) {
      assert(isWF(vk) && "unknown VarKind");
      assert(isWF_Num(n) && "Var::Num is too large");
    }
    constexpr bool operator==(Impl other) const { return data == other.data; }
    constexpr bool operator!=(Impl other) const { return !(*this == other); }
    constexpr VarKind getKind() const { return static_cast<VarKind>(data & 3); }
    constexpr Num getNum() const { return static_cast<Num>(data >> 2); }
  };
  static_assert(IsZeroCostAbstraction<Impl>);

private:
  Impl impl;

protected:
  /// Protected ctor for the RTTI methods to use.
  constexpr explicit Var(Impl impl) : impl(impl) {}

public:
  constexpr Var(VarKind vk, Num n) : impl(Impl(vk, n)) {}
  Var(AffineSymbolExpr sym) : Var(VarKind::Symbol, sym.getPosition()) {}
  Var(VarKind vk, AffineDimExpr var) : Var(vk, var.getPosition()) {
    assert(vk != VarKind::Symbol);
  }

  constexpr bool operator==(Var other) const { return impl == other.impl; }
  constexpr bool operator!=(Var other) const { return !(*this == other); }

  constexpr VarKind getKind() const { return impl.getKind(); }
  constexpr Num getNum() const { return impl.getNum(); }

  template <typename U>
  constexpr bool isa() const;
  template <typename U>
  constexpr U cast() const;
  template <typename U>
  constexpr std::optional<U> dyn_cast() const;

  std::string str() const;
  void print(llvm::raw_ostream &os) const;
  void print(AsmPrinter &printer) const;
  void dump() const;
};
static_assert(IsZeroCostAbstraction<Var>);

class SymVar final : public Var {
  using Var::Var; // inherit `Var(Impl)` ctor for RTTI use.
public:
  static constexpr VarKind Kind = VarKind::Symbol;
  static constexpr bool classof(Var const *var) {
    return var->getKind() == Kind;
  }
  constexpr SymVar(Num sym) : Var(Kind, sym) {}
  SymVar(AffineSymbolExpr symExpr) : Var(symExpr) {}
};
static_assert(IsZeroCostAbstraction<SymVar>);

class DimVar final : public Var {
  using Var::Var; // inherit `Var(Impl)` ctor for RTTI use.
public:
  static constexpr VarKind Kind = VarKind::Dimension;
  static constexpr bool classof(Var const *var) {
    return var->getKind() == Kind;
  }
  constexpr DimVar(Num dim) : Var(Kind, dim) {}
  DimVar(AffineDimExpr dimExpr) : Var(Kind, dimExpr) {}
};
static_assert(IsZeroCostAbstraction<DimVar>);

class LvlVar final : public Var {
  using Var::Var; // inherit `Var(Impl)` ctor for RTTI use.
public:
  static constexpr VarKind Kind = VarKind::Level;
  static constexpr bool classof(Var const *var) {
    return var->getKind() == Kind;
  }
  constexpr LvlVar(Num lvl) : Var(Kind, lvl) {}
  LvlVar(AffineDimExpr lvlExpr) : Var(Kind, lvlExpr) {}
};
static_assert(IsZeroCostAbstraction<LvlVar>);

template <typename U>
constexpr bool Var::isa() const {
  if constexpr (std::is_same_v<U, SymVar>)
    return getKind() == VarKind::Symbol;
  if constexpr (std::is_same_v<U, DimVar>)
    return getKind() == VarKind::Dimension;
  if constexpr (std::is_same_v<U, LvlVar>)
    return getKind() == VarKind::Level;
}

template <typename U>
constexpr U Var::cast() const {
  assert(isa<U>());
  // NOTE: This should select the `U(Var::Impl)` ctor, *not* `U(Var::Num)`
  return U(impl);
}

template <typename U>
constexpr std::optional<U> Var::dyn_cast() const {
  // NOTE: This should select the `U(Var::Impl)` ctor, *not* `U(Var::Num)`
  return isa<U>() ? std::make_optional(U(impl)) : std::nullopt;
}

//===----------------------------------------------------------------------===//
// Forward-decl so that we can declare methods of `Ranks` and `VarSet`.
class DimLvlExpr;

//===----------------------------------------------------------------------===//
class Ranks final {
  // Not using `VarKindArray` since `EnumeratedArray` doesn't support constexpr.
  unsigned impl[3];

  static constexpr unsigned to_index(VarKind vk) {
    assert(isWF(vk) && "unknown VarKind");
    return static_cast<unsigned>(llvm::to_underlying(vk));
  }

public:
  constexpr Ranks(unsigned symRank, unsigned dimRank, unsigned lvlRank)
      : impl() {
    impl[to_index(VarKind::Symbol)] = symRank;
    impl[to_index(VarKind::Dimension)] = dimRank;
    impl[to_index(VarKind::Level)] = lvlRank;
  }
  Ranks(VarKindArray<unsigned> const &ranks)
      : Ranks(ranks[VarKind::Symbol], ranks[VarKind::Dimension],
              ranks[VarKind::Level]) {}

  bool operator==(Ranks const &other) const;
  bool operator!=(Ranks const &other) const { return !(*this == other); }

  constexpr unsigned getRank(VarKind vk) const { return impl[to_index(vk)]; }
  constexpr unsigned getSymRank() const { return getRank(VarKind::Symbol); }
  constexpr unsigned getDimRank() const { return getRank(VarKind::Dimension); }
  constexpr unsigned getLvlRank() const { return getRank(VarKind::Level); }

  [[nodiscard]] constexpr bool isValid(Var var) const {
    return var.getNum() < getRank(var.getKind());
  }
  [[nodiscard]] bool isValid(DimLvlExpr expr) const;
};
static_assert(IsZeroCostAbstraction<Ranks>);

//===----------------------------------------------------------------------===//
/// Efficient representation of a set of `Var`.
class VarSet final {
  VarKindArray<llvm::SmallBitVector> impl;

public:
  explicit VarSet(Ranks const &ranks);

  unsigned getRank(VarKind vk) const { return impl[vk].size(); }
  unsigned getSymRank() const { return getRank(VarKind::Symbol); }
  unsigned getDimRank() const { return getRank(VarKind::Dimension); }
  unsigned getLvlRank() const { return getRank(VarKind::Level); }
  Ranks getRanks() const {
    return Ranks(getSymRank(), getDimRank(), getLvlRank());
  }
  /// For the `contains` method: if variables occurring in
  /// the method parameter are OOB for the `VarSet`, then these methods will
  /// always return false.
  bool contains(Var var) const;

  /// For the `add` methods: OOB parameters cause undefined behavior.
  /// Currently the `add` methods will raise an assertion error.
  void add(Var var);
  void add(VarSet const &vars);
  void add(DimLvlExpr expr);
};

//===----------------------------------------------------------------------===//
/// A record of metadata for/about a variable, used by `VarEnv`.
/// The principal goal of this record is to enable `VarEnv` to be used for
/// incremental parsing; in particular, `VarInfo` allows the `Var::Num` to
/// remain unknown, since each record is instead identified by `VarInfo::ID`.
/// Therefore the `VarEnv` can freely allocate `VarInfo::ID` in whatever
/// order it likes, irrespective of the binding order (`Var::Num`) of the
/// associated variable.
class VarInfo final {
public:
  /// Newtype for unique identifiers of `VarInfo` records, to ensure
  /// they aren't confused with `Var::Num`.
  enum class ID : unsigned {};

private:
  StringRef name;              // The bare-id used in the MLIR source.
  llvm::SMLoc loc;             // The location of the first occurence.
  ID id;                       // The unique `VarInfo`-identifier.
  std::optional<Var::Num> num; // The unique `Var`-identifier (if resolved).
  VarKind kind;                // The kind of variable.

public:
  constexpr VarInfo(ID id, StringRef name, llvm::SMLoc loc, VarKind vk,
                    std::optional<Var::Num> n = {})
      : name(name), loc(loc), id(id), num(n), kind(vk) {
    assert(!name.empty() && "null StringRef");
    assert(loc.isValid() && "null SMLoc");
    assert(isWF(vk) && "unknown VarKind");
    assert((!n || Var::isWF_Num(*n)) && "Var::Num is too large");
  }

  constexpr StringRef getName() const { return name; }
  constexpr llvm::SMLoc getLoc() const { return loc; }
  Location getLocation(AsmParser &parser) const {
    return parser.getEncodedSourceLoc(loc);
  }
  constexpr ID getID() const { return id; }
  constexpr VarKind getKind() const { return kind; }
  constexpr std::optional<Var::Num> getNum() const { return num; }
  constexpr bool hasNum() const { return num.has_value(); }
  void setNum(Var::Num n);
  constexpr Var getVar() const {
    assert(hasNum());
    return Var(kind, *num);
  }
};

//===----------------------------------------------------------------------===//
enum class Policy { MustNot, May, Must };

//===----------------------------------------------------------------------===//
class VarEnv final {
  /// Map from `VarKind` to the next free `Var::Num`; used by `bindVar`.
  VarKindArray<Var::Num> nextNum;
  /// Map from `VarInfo::ID` to shared storage for the actual `VarInfo` objects.
  SmallVector<VarInfo> vars;
  /// Map from variable names to their `VarInfo::ID`.
  llvm::StringMap<VarInfo::ID> ids;

  VarInfo::ID nextID() const { return static_cast<VarInfo::ID>(vars.size()); }

public:
  VarEnv() : nextNum(0) {}

  /// Gets the underlying storage for the `VarInfo` identified by
  /// the `VarInfo::ID`.
  ///
  /// NOTE: The returned reference can become dangling if the `VarEnv`
  /// object is mutated during the lifetime of the pointer.  Therefore,
  /// client code should not store the reference nor otherwise allow it
  /// to live too long.
  VarInfo const &access(VarInfo::ID id) const {
    // `SmallVector::operator[]` already asserts the index is in-bounds.
    return vars[llvm::to_underlying(id)];
  }
  VarInfo const *access(std::optional<VarInfo::ID> oid) const {
    return oid ? &access(*oid) : nullptr;
  }

private:
  VarInfo &access(VarInfo::ID id) {
    return const_cast<VarInfo &>(std::as_const(*this).access(id));
  }
  VarInfo *access(std::optional<VarInfo::ID> oid) {
    return const_cast<VarInfo *>(std::as_const(*this).access(oid));
  }

public:
  /// Looks up the variable with the given name.
  std::optional<VarInfo::ID> lookup(StringRef name) const;

  /// Creates a new currently-unbound variable.  When a variable
  /// of that name already exists: if `verifyUsage` is true, then will assert
  /// that the variable has the same kind and a consistent location; otherwise,
  /// when `verifyUsage` is false, this is a noop.  Returns the identifier
  /// for the variable with the given name, and a bool indicating whether
  /// a new variable was created.
  std::optional<std::pair<VarInfo::ID, bool>>
  create(StringRef name, llvm::SMLoc loc, VarKind vk, bool verifyUsage = false);

  /// Looks up or creates a variable according to the given
  /// `Policy`.  Returns nullopt in one of two circumstances:
  /// (1) the policy says we `Must` create, yet the variable already exists;
  /// (2) the policy says we `MustNot` create, yet no such variable exists.
  /// Otherwise, if the variable already exists then it is validated against
  /// the given kind and location to ensure consistency.
  std::optional<std::pair<VarInfo::ID, bool>>
  lookupOrCreate(Policy creationPolicy, StringRef name, llvm::SMLoc loc,
                 VarKind vk);

  /// Binds the given variable to the next free `Var::Num` for its `VarKind`.
  Var bindVar(VarInfo::ID id);

  /// Creates a new variable of the given kind and immediately binds it.
  /// This should only be used whenever the variable is known to be unused
  /// and therefore does not have a name.
  Var bindUnusedVar(VarKind vk);

  InFlightDiagnostic emitErrorIfAnyUnbound(AsmParser &parser) const;

  /// Returns the current ranks of bound variables.  This method should
  /// only be used after the environment is "finished", since binding new
  /// variables will (semantically) invalidate any previously returned `Ranks`.
  Ranks getRanks() const { return Ranks(nextNum); }

  /// Gets the `Var` identified by the `VarInfo::ID`, raising an assertion
  /// failure if the variable is not bound.
  Var getVar(VarInfo::ID id) const { return access(id).getVar(); }
};

//===----------------------------------------------------------------------===//

} // namespace ir_detail
} // namespace sparse_tensor
} // namespace mlir

#endif // MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_VAR_H