1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323
|
//===- DimLvlMap.h ----------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// FIXME(wrengr): The `DimLvlMap` class must be public so that it can
// be named as the storage representation of the parameter for the tblgen
// defn of STEA. We may well need to make the other classes public too,
// so that the rest of the compiler can use them when necessary.
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H
#define MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H
#include "Var.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
namespace mlir {
namespace sparse_tensor {
namespace ir_detail {
//===----------------------------------------------------------------------===//
// TODO(wrengr): Give this enum a better name, so that it fits together
// with the name of the `DimLvlExpr` class (which may also want a better
// name). Perhaps make this a nested-type too.
//
// NOTE: In the future we will extend this enum to include "counting
// expressions" required for supporting ITPACK/ELL. Therefore the current
// underlying-type and representation values should not be relied upon.
enum class ExprKind : bool { Dimension = false, Level = true };
// TODO(wrengr): still needs a better name....
constexpr VarKind getVarKindAllowedInExpr(ExprKind ek) {
using VK = std::underlying_type_t<VarKind>;
return VarKind{2 * static_cast<VK>(!to_underlying(ek))};
}
static_assert(getVarKindAllowedInExpr(ExprKind::Dimension) == VarKind::Level &&
getVarKindAllowedInExpr(ExprKind::Level) == VarKind::Dimension);
//===----------------------------------------------------------------------===//
// TODO(wrengr): The goal of this class is to capture a proof that
// we've verified that the given `AffineExpr` only has variables of the
// appropriate kind(s). So we need to actually prove/verify that in the
// ctor or all its callsites!
class DimLvlExpr {
private:
// FIXME(wrengr): Per <https://llvm.org/docs/HowToSetUpLLVMStyleRTTI.html>,
// the `kind` field should be private and const. However, beware
// that if we mark any field as `const` or if the fields have differing
// `private`/`protected` privileges then the `IsZeroCostAbstraction`
// assertion will fail!
// (Also, iirc, if we end up moving the `expr` to the subclasses
// instead, that'll also cause `IsZeroCostAbstraction` to fail.)
ExprKind kind;
AffineExpr expr;
public:
constexpr DimLvlExpr(ExprKind ek, AffineExpr expr) : kind(ek), expr(expr) {}
//
// Boolean operators.
//
constexpr bool operator==(DimLvlExpr other) const {
return kind == other.kind && expr == other.expr;
}
constexpr bool operator!=(DimLvlExpr other) const {
return !(*this == other);
}
explicit operator bool() const { return static_cast<bool>(expr); }
//
// RTTI support (for the `DimLvlExpr` class itself).
//
template <typename U>
constexpr bool isa() const;
template <typename U>
constexpr U cast() const;
template <typename U>
constexpr U dyn_cast() const;
//
// Simple getters.
//
constexpr ExprKind getExprKind() const { return kind; }
constexpr VarKind getAllowedVarKind() const {
return getVarKindAllowedInExpr(kind);
}
constexpr AffineExpr getAffineExpr() const { return expr; }
AffineExprKind getAffineKind() const {
assert(expr);
return expr.getKind();
}
MLIRContext *getContext() const { return expr ? expr.getContext() : nullptr; }
//
// Getters for handling `AffineExpr` subclasses.
//
// TODO(wrengr): is there any way to make these typesafe without too much
// templating?
// TODO(wrengr): Most if not all of these don't actually need to be
// methods, they could be free-functions instead.
//
SymVar castSymVar() const;
Var castDimLvlVar() const;
int64_t castConstantValue() const;
std::optional<int64_t> tryGetConstantValue() const;
bool hasConstantValue(int64_t val) const;
DimLvlExpr getLHS() const;
DimLvlExpr getRHS() const;
std::tuple<DimLvlExpr, AffineExprKind, DimLvlExpr> unpackBinop() const;
/// Checks whether the variables bound/used by this spec are valid
/// with respect to the given ranks.
bool isValid(Ranks const &ranks) const;
void print(llvm::raw_ostream &os) const;
void print(AsmPrinter &printer) const;
void dump() const;
protected:
// Variant of `mlir::AsmPrinter::Impl::BindingStrength`
enum class BindingStrength : bool { Weak = false, Strong = true };
// TODO(wrengr): Does our version of `printAffineExprInternal` really
// need to be a method, or could it be a free-function instead? (assuming
// `BindingStrength` goes with it).
void printAffineExprInternal(llvm::raw_ostream &os,
BindingStrength enclosingTightness) const;
void printStrong(llvm::raw_ostream &os) const {
printAffineExprInternal(os, BindingStrength::Strong);
}
void printWeak(llvm::raw_ostream &os) const {
printAffineExprInternal(os, BindingStrength::Weak);
}
};
static_assert(IsZeroCostAbstraction<DimLvlExpr>);
// FUTURE_CL(wrengr): It would be nice to have the subclasses override
// `getRHS`, `getLHS`, `unpackBinop`, and `castDimLvlVar` to give them
// the proper covariant return types.
//
class DimExpr final : public DimLvlExpr {
// FIXME(wrengr): These two are needed for the current RTTI implementation.
friend class DimLvlExpr;
constexpr explicit DimExpr(DimLvlExpr expr) : DimLvlExpr(expr) {}
public:
static constexpr ExprKind Kind = ExprKind::Dimension;
static constexpr bool classof(DimLvlExpr const *expr) {
return expr->getExprKind() == Kind;
}
constexpr explicit DimExpr(AffineExpr expr) : DimLvlExpr(Kind, expr) {}
};
static_assert(IsZeroCostAbstraction<DimExpr>);
class LvlExpr final : public DimLvlExpr {
// FIXME(wrengr): These two are needed for the current RTTI implementation.
friend class DimLvlExpr;
constexpr explicit LvlExpr(DimLvlExpr expr) : DimLvlExpr(expr) {}
public:
static constexpr ExprKind Kind = ExprKind::Level;
static constexpr bool classof(DimLvlExpr const *expr) {
return expr->getExprKind() == Kind;
}
constexpr explicit LvlExpr(AffineExpr expr) : DimLvlExpr(Kind, expr) {}
};
static_assert(IsZeroCostAbstraction<LvlExpr>);
// FIXME(wrengr): See comments elsewhere re RTTI implementation issues/questions
template <typename U>
constexpr bool DimLvlExpr::isa() const {
if constexpr (std::is_same_v<U, DimExpr>)
return getExprKind() == ExprKind::Dimension;
if constexpr (std::is_same_v<U, LvlExpr>)
return getExprKind() == ExprKind::Level;
}
template <typename U>
constexpr U DimLvlExpr::cast() const {
assert(isa<U>());
return U(*this);
}
template <typename U>
constexpr U DimLvlExpr::dyn_cast() const {
return isa<U>() ? U(*this) : U();
}
//===----------------------------------------------------------------------===//
/// The full `dimVar = dimExpr : dimSlice` specification for a given dimension.
class DimSpec final {
/// The dimension-variable bound by this specification.
DimVar var;
/// The dimension-expression. The `DimSpec` ctor treats this field
/// as optional; whereas the `DimLvlMap` ctor will fill in (or verify)
/// the expression via function-inversion inference.
DimExpr expr;
/// Can the `expr` be elided when printing? The `DimSpec` ctor assumes
/// not (though if `expr` is null it will elide printing that); whereas
/// the `DimLvlMap` ctor will reset it as appropriate.
bool elideExpr = false;
/// The dimension-slice; optional, default is null.
SparseTensorDimSliceAttr slice;
public:
DimSpec(DimVar var, DimExpr expr, SparseTensorDimSliceAttr slice);
constexpr DimVar getBoundVar() const { return var; }
bool hasExpr() const { return static_cast<bool>(expr); }
constexpr DimExpr getExpr() const { return expr; }
void setExpr(DimExpr newExpr) {
assert(!hasExpr());
expr = newExpr;
}
constexpr bool canElideExpr() const { return elideExpr; }
void setElideExpr(bool b) { elideExpr = b; }
constexpr SparseTensorDimSliceAttr getSlice() const { return slice; }
/// Checks whether the variables bound/used by this spec are valid
/// with respect to the given ranks.
bool isValid(Ranks const &ranks) const;
// TODO(wrengr): Use it or loose it.
bool isFunctionOf(Var var) const;
bool isFunctionOf(VarSet const &vars) const;
void getFreeVars(VarSet &vars) const;
void print(llvm::raw_ostream &os, bool wantElision = true) const;
void print(AsmPrinter &printer, bool wantElision = true) const;
void dump() const;
};
// Although this class is more than just a newtype/wrapper, we do want
// to ensure that storing them into `SmallVector` is efficient.
static_assert(IsZeroCostAbstraction<DimSpec>);
//===----------------------------------------------------------------------===//
/// The full `lvlVar = lvlExpr : lvlType` specification for a given level.
class LvlSpec final {
/// The level-variable bound by this specification.
LvlVar var;
/// Can the `var` be elided when printing? The `LvlSpec` ctor assumes not;
/// whereas the `DimLvlMap` ctor will reset this as appropriate.
bool elideVar = false;
/// The level-expression.
//
// NOTE: For now we use `LvlExpr` because all level-expressions must be
// `AffineExpr`; however, in the future we will also want to allow "counting
// expressions", and potentially other kinds of non-affine level-expressions.
// Which kinds of `DimLvlExpr` are allowed will depend on the `DimLevelType`,
// so we may consider defining another class for pairing those two together
// to ensure that the pair is well-formed.
LvlExpr expr;
/// The level-type (== level-format + lvl-properties).
DimLevelType type;
public:
LvlSpec(LvlVar var, LvlExpr expr, DimLevelType type);
constexpr LvlVar getBoundVar() const { return var; }
constexpr bool canElideVar() const { return elideVar; }
void setElideVar(bool b) { elideVar = b; }
constexpr LvlExpr getExpr() const { return expr; }
constexpr DimLevelType getType() const { return type; }
/// Checks whether the variables bound/used by this spec are valid
/// with respect to the given ranks.
//
// NOTE: Once we introduce "counting expressions" this will need
// a more sophisticated implementation than `DimSpec::isValid` does.
bool isValid(Ranks const &ranks) const;
// TODO(wrengr): Use it or loose it.
bool isFunctionOf(Var var) const;
bool isFunctionOf(VarSet const &vars) const;
void getFreeVars(VarSet &vars) const;
void print(llvm::raw_ostream &os, bool wantElision = true) const;
void print(AsmPrinter &printer, bool wantElision = true) const;
void dump() const;
};
// Although this class is more than just a newtype/wrapper, we do want
// to ensure that storing them into `SmallVector` is efficient.
static_assert(IsZeroCostAbstraction<LvlSpec>);
//===----------------------------------------------------------------------===//
class DimLvlMap final {
// TODO(wrengr): Need to define getters
unsigned symRank;
SmallVector<DimSpec> dimSpecs;
SmallVector<LvlSpec> lvlSpecs;
// Checks for integrity of variable-binding structure.
// This is already called by the ctor.
bool isWF() const;
public:
DimLvlMap(unsigned symRank, ArrayRef<DimSpec> dimSpecs,
ArrayRef<LvlSpec> lvlSpecs);
unsigned getSymRank() const { return symRank; }
unsigned getDimRank() const { return dimSpecs.size(); }
unsigned getLvlRank() const { return lvlSpecs.size(); }
unsigned getRank(VarKind vk) const { return getRanks().getRank(vk); }
Ranks getRanks() const { return {getSymRank(), getDimRank(), getLvlRank()}; }
DimLevelType getDimLevelType(unsigned i) { return lvlSpecs[i].getType(); }
void print(llvm::raw_ostream &os, bool wantElision = true) const;
void print(AsmPrinter &printer, bool wantElision = true) const;
void dump() const;
};
//===----------------------------------------------------------------------===//
} // namespace ir_detail
} // namespace sparse_tensor
} // namespace mlir
#endif // MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H
|