1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280
|
//===- DimLvlMap.h ----------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H
#define MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H
#include "Var.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "llvm/ADT/STLForwardCompat.h"
namespace mlir {
namespace sparse_tensor {
namespace ir_detail {
//===----------------------------------------------------------------------===//
enum class ExprKind : bool { Dimension = false, Level = true };
constexpr VarKind getVarKindAllowedInExpr(ExprKind ek) {
using VK = std::underlying_type_t<VarKind>;
return VarKind{2 * static_cast<VK>(!llvm::to_underlying(ek))};
}
static_assert(getVarKindAllowedInExpr(ExprKind::Dimension) == VarKind::Level &&
getVarKindAllowedInExpr(ExprKind::Level) == VarKind::Dimension);
//===----------------------------------------------------------------------===//
class DimLvlExpr {
private:
ExprKind kind;
AffineExpr expr;
public:
constexpr DimLvlExpr(ExprKind ek, AffineExpr expr) : kind(ek), expr(expr) {}
//
// Boolean operators.
//
constexpr bool operator==(DimLvlExpr other) const {
return kind == other.kind && expr == other.expr;
}
constexpr bool operator!=(DimLvlExpr other) const {
return !(*this == other);
}
explicit operator bool() const { return static_cast<bool>(expr); }
//
// RTTI support (for the `DimLvlExpr` class itself).
//
template <typename U>
constexpr bool isa() const;
template <typename U>
constexpr U cast() const;
template <typename U>
constexpr U dyn_cast() const;
//
// Simple getters.
//
constexpr ExprKind getExprKind() const { return kind; }
constexpr VarKind getAllowedVarKind() const {
return getVarKindAllowedInExpr(kind);
}
constexpr AffineExpr getAffineExpr() const { return expr; }
AffineExprKind getAffineKind() const {
assert(expr);
return expr.getKind();
}
MLIRContext *tryGetContext() const {
return expr ? expr.getContext() : nullptr;
}
//
// Getters for handling `AffineExpr` subclasses.
//
SymVar castSymVar() const;
std::optional<SymVar> dyn_castSymVar() const;
Var castDimLvlVar() const;
std::optional<Var> dyn_castDimLvlVar() const;
std::tuple<DimLvlExpr, AffineExprKind, DimLvlExpr> unpackBinop() const;
/// Checks whether the variables bound/used by this spec are valid
/// with respect to the given ranks.
[[nodiscard]] bool isValid(Ranks const &ranks) const;
protected:
// Variant of `mlir::AsmPrinter::Impl::BindingStrength`
enum class BindingStrength : bool { Weak = false, Strong = true };
};
static_assert(IsZeroCostAbstraction<DimLvlExpr>);
class DimExpr final : public DimLvlExpr {
friend class DimLvlExpr;
constexpr explicit DimExpr(DimLvlExpr expr) : DimLvlExpr(expr) {}
public:
static constexpr ExprKind Kind = ExprKind::Dimension;
static constexpr bool classof(DimLvlExpr const *expr) {
return expr->getExprKind() == Kind;
}
constexpr explicit DimExpr(AffineExpr expr) : DimLvlExpr(Kind, expr) {}
LvlVar castLvlVar() const { return castDimLvlVar().cast<LvlVar>(); }
std::optional<LvlVar> dyn_castLvlVar() const {
const auto var = dyn_castDimLvlVar();
return var ? std::make_optional(var->cast<LvlVar>()) : std::nullopt;
}
};
static_assert(IsZeroCostAbstraction<DimExpr>);
class LvlExpr final : public DimLvlExpr {
friend class DimLvlExpr;
constexpr explicit LvlExpr(DimLvlExpr expr) : DimLvlExpr(expr) {}
public:
static constexpr ExprKind Kind = ExprKind::Level;
static constexpr bool classof(DimLvlExpr const *expr) {
return expr->getExprKind() == Kind;
}
constexpr explicit LvlExpr(AffineExpr expr) : DimLvlExpr(Kind, expr) {}
DimVar castDimVar() const { return castDimLvlVar().cast<DimVar>(); }
std::optional<DimVar> dyn_castDimVar() const {
const auto var = dyn_castDimLvlVar();
return var ? std::make_optional(var->cast<DimVar>()) : std::nullopt;
}
};
static_assert(IsZeroCostAbstraction<LvlExpr>);
template <typename U>
constexpr bool DimLvlExpr::isa() const {
if constexpr (std::is_same_v<U, DimExpr>)
return getExprKind() == ExprKind::Dimension;
if constexpr (std::is_same_v<U, LvlExpr>)
return getExprKind() == ExprKind::Level;
}
template <typename U>
constexpr U DimLvlExpr::cast() const {
assert(isa<U>());
return U(*this);
}
template <typename U>
constexpr U DimLvlExpr::dyn_cast() const {
return isa<U>() ? U(*this) : U();
}
//===----------------------------------------------------------------------===//
/// The full `dimVar = dimExpr : dimSlice` specification for a given dimension.
class DimSpec final {
/// The dimension-variable bound by this specification.
DimVar var;
/// The dimension-expression. The `DimSpec` ctor treats this field
/// as optional; whereas the `DimLvlMap` ctor will fill in (or verify)
/// the expression via function-inversion inference.
DimExpr expr;
/// Can the `expr` be elided when printing? The `DimSpec` ctor assumes
/// not (though if `expr` is null it will elide printing that); whereas
/// the `DimLvlMap` ctor will reset it as appropriate.
bool elideExpr = false;
/// The dimension-slice; optional, default is null.
SparseTensorDimSliceAttr slice;
public:
DimSpec(DimVar var, DimExpr expr, SparseTensorDimSliceAttr slice);
MLIRContext *tryGetContext() const { return expr.tryGetContext(); }
constexpr DimVar getBoundVar() const { return var; }
bool hasExpr() const { return static_cast<bool>(expr); }
constexpr DimExpr getExpr() const { return expr; }
void setExpr(DimExpr newExpr) {
assert(!hasExpr());
expr = newExpr;
}
constexpr bool canElideExpr() const { return elideExpr; }
void setElideExpr(bool b) { elideExpr = b; }
constexpr SparseTensorDimSliceAttr getSlice() const { return slice; }
/// Checks whether the variables bound/used by this spec are valid with
/// respect to the given ranks. Note that null `DimExpr` is considered
/// to be vacuously valid, and therefore calling `setExpr` invalidates
/// the result of this predicate.
[[nodiscard]] bool isValid(Ranks const &ranks) const;
};
static_assert(IsZeroCostAbstraction<DimSpec>);
//===----------------------------------------------------------------------===//
/// The full `lvlVar = lvlExpr : lvlType` specification for a given level.
class LvlSpec final {
/// The level-variable bound by this specification.
LvlVar var;
/// Can the `var` be elided when printing? The `LvlSpec` ctor assumes not;
/// whereas the `DimLvlMap` ctor will reset this as appropriate.
bool elideVar = false;
/// The level-expression.
LvlExpr expr;
/// The level-type (== level-format + lvl-properties).
LevelType type;
public:
LvlSpec(LvlVar var, LvlExpr expr, LevelType type);
MLIRContext *getContext() const {
MLIRContext *ctx = expr.tryGetContext();
assert(ctx);
return ctx;
}
constexpr LvlVar getBoundVar() const { return var; }
constexpr bool canElideVar() const { return elideVar; }
void setElideVar(bool b) { elideVar = b; }
constexpr LvlExpr getExpr() const { return expr; }
constexpr LevelType getType() const { return type; }
/// Checks whether the variables bound/used by this spec are valid
/// with respect to the given ranks.
[[nodiscard]] bool isValid(Ranks const &ranks) const;
};
static_assert(IsZeroCostAbstraction<LvlSpec>);
//===----------------------------------------------------------------------===//
class DimLvlMap final {
public:
DimLvlMap(unsigned symRank, ArrayRef<DimSpec> dimSpecs,
ArrayRef<LvlSpec> lvlSpecs);
unsigned getSymRank() const { return symRank; }
unsigned getDimRank() const { return dimSpecs.size(); }
unsigned getLvlRank() const { return lvlSpecs.size(); }
unsigned getRank(VarKind vk) const { return getRanks().getRank(vk); }
Ranks getRanks() const { return {getSymRank(), getDimRank(), getLvlRank()}; }
ArrayRef<DimSpec> getDims() const { return dimSpecs; }
const DimSpec &getDim(Dimension dim) const { return dimSpecs[dim]; }
SparseTensorDimSliceAttr getDimSlice(Dimension dim) const {
return getDim(dim).getSlice();
}
ArrayRef<LvlSpec> getLvls() const { return lvlSpecs; }
const LvlSpec &getLvl(Level lvl) const { return lvlSpecs[lvl]; }
LevelType getLvlType(Level lvl) const { return getLvl(lvl).getType(); }
AffineMap getDimToLvlMap(MLIRContext *context) const;
AffineMap getLvlToDimMap(MLIRContext *context) const;
private:
/// Checks for integrity of variable-binding structure.
/// This is already called by the ctor.
[[nodiscard]] bool isWF() const;
/// Helper function to call `DimSpec::setExpr` while asserting that
/// the invariant established by `DimLvlMap:isWF` is maintained.
/// This is used by the ctor.
void setDimExpr(Dimension dim, DimExpr expr) {
assert(expr && getRanks().isValid(expr));
dimSpecs[dim].setExpr(expr);
}
// All these fields are const-after-ctor.
unsigned symRank;
SmallVector<DimSpec> dimSpecs;
SmallVector<LvlSpec> lvlSpecs;
bool mustPrintLvlVars;
};
//===----------------------------------------------------------------------===//
} // namespace ir_detail
} // namespace sparse_tensor
} // namespace mlir
#endif // MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H
|