1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403
|
//===- Var.h ----------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_VAR_H
#define MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_VAR_H
#include "TemplateExtras.h"
#include "mlir/IR/OpImplementation.h"
#include "llvm/ADT/EnumeratedArray.h"
#include "llvm/ADT/STLForwardCompat.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/StringMap.h"
namespace mlir {
namespace sparse_tensor {
namespace ir_detail {
//===----------------------------------------------------------------------===//
/// The three kinds of variables that `Var` can be.
///
/// NOTE: The numerical values used to represent this enum should be
/// treated as an implementation detail, not as part of the API. In the
/// API below we use the canonical ordering `{Symbol,Dimension,Level}` even
/// though that does not agree with the numerical ordering of the numerical
/// representation.
enum class VarKind { Symbol = 1, Dimension = 0, Level = 2 };
[[nodiscard]] constexpr bool isWF(VarKind vk) {
const auto vk_ = llvm::to_underlying(vk);
return 0 <= vk_ && vk_ <= 2;
}
/// Gets the ASCII character used as the prefix when printing `Var`.
constexpr char toChar(VarKind vk) {
// If `isWF(vk)` then this computation's intermediate results are always
// in the range [-44..126] (where that lower bound is under worst-case
// rearranging of the expression); and `int_fast8_t` is the fastest type
// which can support that range without over-/underflow.
const auto vk_ = static_cast<int_fast8_t>(llvm::to_underlying(vk));
return static_cast<char>(100 + vk_ * (26 - vk_ * 11));
}
static_assert(toChar(VarKind::Symbol) == 's' &&
toChar(VarKind::Dimension) == 'd' &&
toChar(VarKind::Level) == 'l');
//===----------------------------------------------------------------------===//
/// The type of arrays indexed by `VarKind`.
template <typename T>
using VarKindArray = llvm::EnumeratedArray<T, VarKind, VarKind::Level>;
//===----------------------------------------------------------------------===//
/// A concrete variable, to be used in our variant of `AffineExpr`.
/// Client-facing class for `VarKind` + `Var::Num` pairs, with RTTI
/// support for subclasses with a fixed `VarKind`.
class Var {
public:
/// Typedef for the type of variable numbers.
using Num = unsigned;
private:
/// Typedef for the underlying storage of `Var::Impl`.
using Storage = unsigned;
/// The largest `Var::Num` supported by `Var`/`Var::Impl`/`Var::Storage`.
/// Two low-order bits are reserved for storing the `VarKind`,
/// and one high-order bit is reserved for future use (e.g., to support
/// `DenseMapInfo<Var>` while maintaining the usual numeric values for
/// "empty" and "tombstone").
static constexpr Num kMaxNum =
static_cast<Num>(std::numeric_limits<Storage>::max() >> 3);
public:
/// Checks whether the number would be accepted by `Var(VarKind,Var::Num)`.
//
// This must be public for `VarInfo` to use it (whereas we don't want
// to expose the `impl` field via friendship).
[[nodiscard]] static constexpr bool isWF_Num(Num n) { return n <= kMaxNum; }
protected:
/// The underlying implementation of `Var`. Note that this must be kept
/// distinct from `Var` itself, since we want to ensure that the RTTI
/// methods will select the `U(Var::Impl)` ctor rather than selecting
/// the `U(Var::Num)` ctor.
class Impl final {
Storage data;
public:
constexpr Impl(VarKind vk, Num n)
: data((static_cast<Storage>(n) << 2) |
static_cast<Storage>(llvm::to_underlying(vk))) {
assert(isWF(vk) && "unknown VarKind");
assert(isWF_Num(n) && "Var::Num is too large");
}
constexpr bool operator==(Impl other) const { return data == other.data; }
constexpr bool operator!=(Impl other) const { return !(*this == other); }
constexpr VarKind getKind() const { return static_cast<VarKind>(data & 3); }
constexpr Num getNum() const { return static_cast<Num>(data >> 2); }
};
static_assert(IsZeroCostAbstraction<Impl>);
private:
Impl impl;
protected:
/// Protected ctor for the RTTI methods to use.
constexpr explicit Var(Impl impl) : impl(impl) {}
public:
constexpr Var(VarKind vk, Num n) : impl(Impl(vk, n)) {}
Var(AffineSymbolExpr sym) : Var(VarKind::Symbol, sym.getPosition()) {}
Var(VarKind vk, AffineDimExpr var) : Var(vk, var.getPosition()) {
assert(vk != VarKind::Symbol);
}
constexpr bool operator==(Var other) const { return impl == other.impl; }
constexpr bool operator!=(Var other) const { return !(*this == other); }
constexpr VarKind getKind() const { return impl.getKind(); }
constexpr Num getNum() const { return impl.getNum(); }
template <typename U>
constexpr bool isa() const;
template <typename U>
constexpr U cast() const;
template <typename U>
constexpr std::optional<U> dyn_cast() const;
std::string str() const;
void print(llvm::raw_ostream &os) const;
void print(AsmPrinter &printer) const;
void dump() const;
};
static_assert(IsZeroCostAbstraction<Var>);
class SymVar final : public Var {
using Var::Var; // inherit `Var(Impl)` ctor for RTTI use.
public:
static constexpr VarKind Kind = VarKind::Symbol;
static constexpr bool classof(Var const *var) {
return var->getKind() == Kind;
}
constexpr SymVar(Num sym) : Var(Kind, sym) {}
SymVar(AffineSymbolExpr symExpr) : Var(symExpr) {}
};
static_assert(IsZeroCostAbstraction<SymVar>);
class DimVar final : public Var {
using Var::Var; // inherit `Var(Impl)` ctor for RTTI use.
public:
static constexpr VarKind Kind = VarKind::Dimension;
static constexpr bool classof(Var const *var) {
return var->getKind() == Kind;
}
constexpr DimVar(Num dim) : Var(Kind, dim) {}
DimVar(AffineDimExpr dimExpr) : Var(Kind, dimExpr) {}
};
static_assert(IsZeroCostAbstraction<DimVar>);
class LvlVar final : public Var {
using Var::Var; // inherit `Var(Impl)` ctor for RTTI use.
public:
static constexpr VarKind Kind = VarKind::Level;
static constexpr bool classof(Var const *var) {
return var->getKind() == Kind;
}
constexpr LvlVar(Num lvl) : Var(Kind, lvl) {}
LvlVar(AffineDimExpr lvlExpr) : Var(Kind, lvlExpr) {}
};
static_assert(IsZeroCostAbstraction<LvlVar>);
template <typename U>
constexpr bool Var::isa() const {
if constexpr (std::is_same_v<U, SymVar>)
return getKind() == VarKind::Symbol;
if constexpr (std::is_same_v<U, DimVar>)
return getKind() == VarKind::Dimension;
if constexpr (std::is_same_v<U, LvlVar>)
return getKind() == VarKind::Level;
}
template <typename U>
constexpr U Var::cast() const {
assert(isa<U>());
// NOTE: This should select the `U(Var::Impl)` ctor, *not* `U(Var::Num)`
return U(impl);
}
template <typename U>
constexpr std::optional<U> Var::dyn_cast() const {
// NOTE: This should select the `U(Var::Impl)` ctor, *not* `U(Var::Num)`
return isa<U>() ? std::make_optional(U(impl)) : std::nullopt;
}
//===----------------------------------------------------------------------===//
// Forward-decl so that we can declare methods of `Ranks` and `VarSet`.
class DimLvlExpr;
//===----------------------------------------------------------------------===//
class Ranks final {
// Not using `VarKindArray` since `EnumeratedArray` doesn't support constexpr.
unsigned impl[3];
static constexpr unsigned to_index(VarKind vk) {
assert(isWF(vk) && "unknown VarKind");
return static_cast<unsigned>(llvm::to_underlying(vk));
}
public:
constexpr Ranks(unsigned symRank, unsigned dimRank, unsigned lvlRank)
: impl() {
impl[to_index(VarKind::Symbol)] = symRank;
impl[to_index(VarKind::Dimension)] = dimRank;
impl[to_index(VarKind::Level)] = lvlRank;
}
Ranks(VarKindArray<unsigned> const &ranks)
: Ranks(ranks[VarKind::Symbol], ranks[VarKind::Dimension],
ranks[VarKind::Level]) {}
bool operator==(Ranks const &other) const;
bool operator!=(Ranks const &other) const { return !(*this == other); }
constexpr unsigned getRank(VarKind vk) const { return impl[to_index(vk)]; }
constexpr unsigned getSymRank() const { return getRank(VarKind::Symbol); }
constexpr unsigned getDimRank() const { return getRank(VarKind::Dimension); }
constexpr unsigned getLvlRank() const { return getRank(VarKind::Level); }
[[nodiscard]] constexpr bool isValid(Var var) const {
return var.getNum() < getRank(var.getKind());
}
[[nodiscard]] bool isValid(DimLvlExpr expr) const;
};
static_assert(IsZeroCostAbstraction<Ranks>);
//===----------------------------------------------------------------------===//
/// Efficient representation of a set of `Var`.
class VarSet final {
VarKindArray<llvm::SmallBitVector> impl;
public:
explicit VarSet(Ranks const &ranks);
unsigned getRank(VarKind vk) const { return impl[vk].size(); }
unsigned getSymRank() const { return getRank(VarKind::Symbol); }
unsigned getDimRank() const { return getRank(VarKind::Dimension); }
unsigned getLvlRank() const { return getRank(VarKind::Level); }
Ranks getRanks() const {
return Ranks(getSymRank(), getDimRank(), getLvlRank());
}
/// For the `contains` method: if variables occurring in
/// the method parameter are OOB for the `VarSet`, then these methods will
/// always return false.
bool contains(Var var) const;
/// For the `add` methods: OOB parameters cause undefined behavior.
/// Currently the `add` methods will raise an assertion error.
void add(Var var);
void add(VarSet const &vars);
void add(DimLvlExpr expr);
};
//===----------------------------------------------------------------------===//
/// A record of metadata for/about a variable, used by `VarEnv`.
/// The principal goal of this record is to enable `VarEnv` to be used for
/// incremental parsing; in particular, `VarInfo` allows the `Var::Num` to
/// remain unknown, since each record is instead identified by `VarInfo::ID`.
/// Therefore the `VarEnv` can freely allocate `VarInfo::ID` in whatever
/// order it likes, irrespective of the binding order (`Var::Num`) of the
/// associated variable.
class VarInfo final {
public:
/// Newtype for unique identifiers of `VarInfo` records, to ensure
/// they aren't confused with `Var::Num`.
enum class ID : unsigned {};
private:
StringRef name; // The bare-id used in the MLIR source.
llvm::SMLoc loc; // The location of the first occurence.
ID id; // The unique `VarInfo`-identifier.
std::optional<Var::Num> num; // The unique `Var`-identifier (if resolved).
VarKind kind; // The kind of variable.
public:
constexpr VarInfo(ID id, StringRef name, llvm::SMLoc loc, VarKind vk,
std::optional<Var::Num> n = {})
: name(name), loc(loc), id(id), num(n), kind(vk) {
assert(!name.empty() && "null StringRef");
assert(loc.isValid() && "null SMLoc");
assert(isWF(vk) && "unknown VarKind");
assert((!n || Var::isWF_Num(*n)) && "Var::Num is too large");
}
constexpr StringRef getName() const { return name; }
constexpr llvm::SMLoc getLoc() const { return loc; }
Location getLocation(AsmParser &parser) const {
return parser.getEncodedSourceLoc(loc);
}
constexpr ID getID() const { return id; }
constexpr VarKind getKind() const { return kind; }
constexpr std::optional<Var::Num> getNum() const { return num; }
constexpr bool hasNum() const { return num.has_value(); }
void setNum(Var::Num n);
constexpr Var getVar() const {
assert(hasNum());
return Var(kind, *num);
}
};
//===----------------------------------------------------------------------===//
enum class Policy { MustNot, May, Must };
//===----------------------------------------------------------------------===//
class VarEnv final {
/// Map from `VarKind` to the next free `Var::Num`; used by `bindVar`.
VarKindArray<Var::Num> nextNum;
/// Map from `VarInfo::ID` to shared storage for the actual `VarInfo` objects.
SmallVector<VarInfo> vars;
/// Map from variable names to their `VarInfo::ID`.
llvm::StringMap<VarInfo::ID> ids;
VarInfo::ID nextID() const { return static_cast<VarInfo::ID>(vars.size()); }
public:
VarEnv() : nextNum(0) {}
/// Gets the underlying storage for the `VarInfo` identified by
/// the `VarInfo::ID`.
///
/// NOTE: The returned reference can become dangling if the `VarEnv`
/// object is mutated during the lifetime of the pointer. Therefore,
/// client code should not store the reference nor otherwise allow it
/// to live too long.
VarInfo const &access(VarInfo::ID id) const {
// `SmallVector::operator[]` already asserts the index is in-bounds.
return vars[llvm::to_underlying(id)];
}
VarInfo const *access(std::optional<VarInfo::ID> oid) const {
return oid ? &access(*oid) : nullptr;
}
private:
VarInfo &access(VarInfo::ID id) {
return const_cast<VarInfo &>(std::as_const(*this).access(id));
}
VarInfo *access(std::optional<VarInfo::ID> oid) {
return const_cast<VarInfo *>(std::as_const(*this).access(oid));
}
public:
/// Looks up the variable with the given name.
std::optional<VarInfo::ID> lookup(StringRef name) const;
/// Creates a new currently-unbound variable. When a variable
/// of that name already exists: if `verifyUsage` is true, then will assert
/// that the variable has the same kind and a consistent location; otherwise,
/// when `verifyUsage` is false, this is a noop. Returns the identifier
/// for the variable with the given name, and a bool indicating whether
/// a new variable was created.
std::optional<std::pair<VarInfo::ID, bool>>
create(StringRef name, llvm::SMLoc loc, VarKind vk, bool verifyUsage = false);
/// Looks up or creates a variable according to the given
/// `Policy`. Returns nullopt in one of two circumstances:
/// (1) the policy says we `Must` create, yet the variable already exists;
/// (2) the policy says we `MustNot` create, yet no such variable exists.
/// Otherwise, if the variable already exists then it is validated against
/// the given kind and location to ensure consistency.
std::optional<std::pair<VarInfo::ID, bool>>
lookupOrCreate(Policy creationPolicy, StringRef name, llvm::SMLoc loc,
VarKind vk);
/// Binds the given variable to the next free `Var::Num` for its `VarKind`.
Var bindVar(VarInfo::ID id);
/// Creates a new variable of the given kind and immediately binds it.
/// This should only be used whenever the variable is known to be unused
/// and therefore does not have a name.
Var bindUnusedVar(VarKind vk);
InFlightDiagnostic emitErrorIfAnyUnbound(AsmParser &parser) const;
/// Returns the current ranks of bound variables. This method should
/// only be used after the environment is "finished", since binding new
/// variables will (semantically) invalidate any previously returned `Ranks`.
Ranks getRanks() const { return Ranks(nextNum); }
/// Gets the `Var` identified by the `VarInfo::ID`, raising an assertion
/// failure if the variable is not bound.
Var getVar(VarInfo::ID id) const { return access(id).getVar(); }
};
//===----------------------------------------------------------------------===//
} // namespace ir_detail
} // namespace sparse_tensor
} // namespace mlir
#endif // MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_VAR_H
|