1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
|
//===- DimLvlMap.cpp ------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "DimLvlMap.h"
using namespace mlir;
using namespace mlir::sparse_tensor;
using namespace mlir::sparse_tensor::ir_detail;
//===----------------------------------------------------------------------===//
// `DimLvlExpr` implementation.
//===----------------------------------------------------------------------===//
SymVar DimLvlExpr::castSymVar() const {
return SymVar(llvm::cast<AffineSymbolExpr>(expr));
}
std::optional<SymVar> DimLvlExpr::dyn_castSymVar() const {
if (const auto s = dyn_cast_or_null<AffineSymbolExpr>(expr))
return SymVar(s);
return std::nullopt;
}
Var DimLvlExpr::castDimLvlVar() const {
return Var(getAllowedVarKind(), llvm::cast<AffineDimExpr>(expr));
}
std::optional<Var> DimLvlExpr::dyn_castDimLvlVar() const {
if (const auto x = dyn_cast_or_null<AffineDimExpr>(expr))
return Var(getAllowedVarKind(), x);
return std::nullopt;
}
std::tuple<DimLvlExpr, AffineExprKind, DimLvlExpr>
DimLvlExpr::unpackBinop() const {
const auto ak = getAffineKind();
const auto binop = llvm::dyn_cast<AffineBinaryOpExpr>(expr);
const DimLvlExpr lhs(kind, binop ? binop.getLHS() : nullptr);
const DimLvlExpr rhs(kind, binop ? binop.getRHS() : nullptr);
return {lhs, ak, rhs};
}
//===----------------------------------------------------------------------===//
// `DimSpec` implementation.
//===----------------------------------------------------------------------===//
DimSpec::DimSpec(DimVar var, DimExpr expr, SparseTensorDimSliceAttr slice)
: var(var), expr(expr), slice(slice) {}
bool DimSpec::isValid(Ranks const &ranks) const {
// Nothing in `slice` needs additional validation.
// We explicitly consider null-expr to be vacuously valid.
return ranks.isValid(var) && (!expr || ranks.isValid(expr));
}
//===----------------------------------------------------------------------===//
// `LvlSpec` implementation.
//===----------------------------------------------------------------------===//
LvlSpec::LvlSpec(LvlVar var, LvlExpr expr, LevelType type)
: var(var), expr(expr), type(type) {
assert(expr);
assert(isValidLT(type) && !isUndefLT(type));
}
bool LvlSpec::isValid(Ranks const &ranks) const {
// Nothing in `type` needs additional validation.
return ranks.isValid(var) && ranks.isValid(expr);
}
//===----------------------------------------------------------------------===//
// `DimLvlMap` implementation.
//===----------------------------------------------------------------------===//
DimLvlMap::DimLvlMap(unsigned symRank, ArrayRef<DimSpec> dimSpecs,
ArrayRef<LvlSpec> lvlSpecs)
: symRank(symRank), dimSpecs(dimSpecs), lvlSpecs(lvlSpecs),
mustPrintLvlVars(false) {
// First, check integrity of the variable-binding structure.
// NOTE: This establishes the invariant that calls to `VarSet::add`
// below cannot cause OOB errors.
assert(isWF());
VarSet usedVars(getRanks());
for (const auto &dimSpec : dimSpecs)
if (!dimSpec.canElideExpr())
usedVars.add(dimSpec.getExpr());
for (auto &lvlSpec : this->lvlSpecs) {
// Is this LvlVar used in any overt expression?
const bool isUsed = usedVars.contains(lvlSpec.getBoundVar());
// This LvlVar can be elided iff it isn't overtly used.
lvlSpec.setElideVar(!isUsed);
// If any LvlVar cannot be elided, then must forward-declare all LvlVars.
mustPrintLvlVars = mustPrintLvlVars || isUsed;
}
}
bool DimLvlMap::isWF() const {
const auto ranks = getRanks();
unsigned dimNum = 0;
for (const auto &dimSpec : dimSpecs)
if (dimSpec.getBoundVar().getNum() != dimNum++ || !dimSpec.isValid(ranks))
return false;
assert(dimNum == ranks.getDimRank());
unsigned lvlNum = 0;
for (const auto &lvlSpec : lvlSpecs)
if (lvlSpec.getBoundVar().getNum() != lvlNum++ || !lvlSpec.isValid(ranks))
return false;
assert(lvlNum == ranks.getLvlRank());
return true;
}
AffineMap DimLvlMap::getDimToLvlMap(MLIRContext *context) const {
SmallVector<AffineExpr> lvlAffines;
lvlAffines.reserve(getLvlRank());
for (const auto &lvlSpec : lvlSpecs)
lvlAffines.push_back(lvlSpec.getExpr().getAffineExpr());
auto map = AffineMap::get(getDimRank(), getSymRank(), lvlAffines, context);
return map;
}
AffineMap DimLvlMap::getLvlToDimMap(MLIRContext *context) const {
SmallVector<AffineExpr> dimAffines;
dimAffines.reserve(getDimRank());
for (const auto &dimSpec : dimSpecs) {
auto expr = dimSpec.getExpr().getAffineExpr();
if (expr) {
dimAffines.push_back(expr);
}
}
auto map = AffineMap::get(getLvlRank(), getSymRank(), dimAffines, context);
// If no lvlToDim map was passed in, returns a null AffineMap and infers it
// in SparseTensorEncodingAttr::parse.
if (dimAffines.empty())
return AffineMap();
return map;
}
//===----------------------------------------------------------------------===//
|