1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278
|
//===- SparseTensorDescriptor.h ------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This header file defines utilities for the sparse memory layout.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORDESCRIPTOR_H_
#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORDESCRIPTOR_H_
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
namespace mlir {
namespace sparse_tensor {
//===----------------------------------------------------------------------===//
// SparseTensorDescriptor and helpers that manage the sparse tensor memory
// layout scheme during "direct code generation" (i.e. when sparsification
// generates the buffers as part of actual IR, in constrast with the library
// approach where data structures are hidden behind opaque pointers).
//===----------------------------------------------------------------------===//
class SparseTensorSpecifier {
public:
explicit SparseTensorSpecifier(Value specifier)
: specifier(cast<TypedValue<StorageSpecifierType>>(specifier)) {}
// Undef value for level-sizes, all zero values for memory-sizes.
static Value getInitValue(OpBuilder &builder, Location loc,
SparseTensorType stt);
/*implicit*/ operator Value() { return specifier; }
Value getSpecifierField(OpBuilder &builder, Location loc,
StorageSpecifierKind kind, std::optional<Level> lvl);
void setSpecifierField(OpBuilder &builder, Location loc, Value v,
StorageSpecifierKind kind, std::optional<Level> lvl);
private:
TypedValue<StorageSpecifierType> specifier;
};
/// A helper class around an array of values that corresponds to a sparse
/// tensor. This class provides a set of meaningful APIs to query and update
/// a particular field in a consistent way. Users should not make assumptions
/// on how a sparse tensor is laid out but instead rely on this class to access
/// the right value for the right field.
template <typename ValueArrayRef>
class SparseTensorDescriptorImpl {
protected:
// TODO: Functions/methods marked with [NUMFIELDS] might should use
// `FieldIndex` for their return type, via the same reasoning for why
// `Dimension`/`Level` are used both for identifiers and ranks.
SparseTensorDescriptorImpl(SparseTensorType stt, ValueArrayRef fields)
: rType(stt), fields(fields), layout(stt) {
assert(layout.getNumFields() == getNumFields());
// We should make sure the class is trivially copyable (and should be small
// enough) such that we can pass it by value.
static_assert(std::is_trivially_copyable_v<
SparseTensorDescriptorImpl<ValueArrayRef>>);
}
public:
FieldIndex getMemRefFieldIndex(SparseTensorFieldKind kind,
std::optional<Level> lvl) const {
// Delegates to storage layout.
return layout.getMemRefFieldIndex(kind, lvl);
}
// TODO: See note [NUMFIELDS].
unsigned getNumFields() const { return fields.size(); }
///
/// Getters: get the value for required field.
///
Value getSpecifier() const { return fields.back(); }
Value getSpecifierField(OpBuilder &builder, Location loc,
StorageSpecifierKind kind,
std::optional<Level> lvl) const {
SparseTensorSpecifier md(fields.back());
return md.getSpecifierField(builder, loc, kind, lvl);
}
Value getLvlSize(OpBuilder &builder, Location loc, Level lvl) const {
return getSpecifierField(builder, loc, StorageSpecifierKind::LvlSize, lvl);
}
Value getPosMemRef(Level lvl) const {
return getMemRefField(SparseTensorFieldKind::PosMemRef, lvl);
}
Value getValMemRef() const {
return getMemRefField(SparseTensorFieldKind::ValMemRef, std::nullopt);
}
Value getMemRefField(SparseTensorFieldKind kind,
std::optional<Level> lvl) const {
return getField(getMemRefFieldIndex(kind, lvl));
}
Value getMemRefField(FieldIndex fidx) const {
assert(fidx < fields.size() - 1);
return getField(fidx);
}
Value getPosMemSize(OpBuilder &builder, Location loc, Level lvl) const {
return getSpecifierField(builder, loc, StorageSpecifierKind::PosMemSize,
lvl);
}
Value getCrdMemSize(OpBuilder &builder, Location loc, Level lvl) const {
return getSpecifierField(builder, loc, StorageSpecifierKind::CrdMemSize,
lvl);
}
Value getValMemSize(OpBuilder &builder, Location loc) const {
return getSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize,
std::nullopt);
}
Type getMemRefElementType(SparseTensorFieldKind kind,
std::optional<Level> lvl) const {
return getMemRefType(getMemRefField(kind, lvl)).getElementType();
}
Value getField(FieldIndex fidx) const {
assert(fidx < fields.size());
return fields[fidx];
}
ValueRange getMemRefFields() const {
// Drop the last metadata fields.
return fields.drop_back();
}
std::pair<FieldIndex, unsigned> getCrdMemRefIndexAndStride(Level lvl) const {
return layout.getFieldIndexAndStride(SparseTensorFieldKind::CrdMemRef, lvl);
}
Value getAOSMemRef() const {
const Level cooStart = getCOOStart(rType.getEncoding());
assert(cooStart < rType.getLvlRank());
return getMemRefField(SparseTensorFieldKind::CrdMemRef, cooStart);
}
RankedTensorType getRankedTensorType() const { return rType; }
ValueArrayRef getFields() const { return fields; }
StorageLayout getLayout() const { return layout; }
protected:
SparseTensorType rType;
ValueArrayRef fields;
StorageLayout layout;
};
/// Uses ValueRange for immutable descriptors.
class SparseTensorDescriptor : public SparseTensorDescriptorImpl<ValueRange> {
public:
SparseTensorDescriptor(SparseTensorType stt, ValueRange buffers)
: SparseTensorDescriptorImpl<ValueRange>(stt, buffers) {}
Value getCrdMemRefOrView(OpBuilder &builder, Location loc, Level lvl) const;
};
/// Uses SmallVectorImpl<Value> & for mutable descriptors.
/// Using SmallVector for mutable descriptor allows users to reuse it as a
/// tmp buffers to append value for some special cases, though users should
/// be responsible to restore the buffer to legal states after their use. It
/// is probably not a clean way, but it is the most efficient way to avoid
/// copying the fields into another SmallVector. If a more clear way is
/// wanted, we should change it to MutableArrayRef instead.
class MutSparseTensorDescriptor
: public SparseTensorDescriptorImpl<SmallVectorImpl<Value> &> {
public:
MutSparseTensorDescriptor(SparseTensorType stt,
SmallVectorImpl<Value> &buffers)
: SparseTensorDescriptorImpl<SmallVectorImpl<Value> &>(stt, buffers) {}
// Allow implicit type conversion from mutable descriptors to immutable ones
// (but not vice versa).
/*implicit*/ operator SparseTensorDescriptor() const {
return SparseTensorDescriptor(rType, fields);
}
///
/// Adds additional setters for mutable descriptor, update the value for
/// required field.
///
void setMemRefField(SparseTensorFieldKind kind, std::optional<Level> lvl,
Value v) {
fields[getMemRefFieldIndex(kind, lvl)] = v;
}
void setMemRefField(FieldIndex fidx, Value v) {
assert(fidx < fields.size() - 1);
fields[fidx] = v;
}
void setField(FieldIndex fidx, Value v) {
assert(fidx < fields.size());
fields[fidx] = v;
}
void setSpecifier(Value newSpec) { fields.back() = newSpec; }
void setSpecifierField(OpBuilder &builder, Location loc,
StorageSpecifierKind kind, std::optional<Level> lvl,
Value v) {
SparseTensorSpecifier md(fields.back());
md.setSpecifierField(builder, loc, v, kind, lvl);
fields.back() = md;
}
void setValMemSize(OpBuilder &builder, Location loc, Value v) {
setSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize,
std::nullopt, v);
}
void setCrdMemSize(OpBuilder &builder, Location loc, Level lvl, Value v) {
setSpecifierField(builder, loc, StorageSpecifierKind::CrdMemSize, lvl, v);
}
void setPosMemSize(OpBuilder &builder, Location loc, Level lvl, Value v) {
setSpecifierField(builder, loc, StorageSpecifierKind::PosMemSize, lvl, v);
}
void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value v) {
setSpecifierField(builder, loc, StorageSpecifierKind::LvlSize, lvl, v);
}
};
/// Returns the "tuple" value of the adapted tensor.
inline UnrealizedConversionCastOp getTuple(Value tensor) {
return llvm::cast<UnrealizedConversionCastOp>(tensor.getDefiningOp());
}
/// Packs the given values as a "tuple" value.
inline Value genTuple(OpBuilder &builder, Location loc, Type tp,
ValueRange values) {
return builder.create<UnrealizedConversionCastOp>(loc, TypeRange(tp), values)
.getResult(0);
}
inline Value genTuple(OpBuilder &builder, Location loc,
SparseTensorDescriptor desc) {
return genTuple(builder, loc, desc.getRankedTensorType(), desc.getFields());
}
inline SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor) {
auto tuple = getTuple(tensor);
SparseTensorType stt(cast<RankedTensorType>(tuple.getResultTypes()[0]));
return SparseTensorDescriptor(stt, tuple.getInputs());
}
inline MutSparseTensorDescriptor
getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl<Value> &fields) {
auto tuple = getTuple(tensor);
fields.assign(tuple.getInputs().begin(), tuple.getInputs().end());
SparseTensorType stt(cast<RankedTensorType>(tuple.getResultTypes()[0]));
return MutSparseTensorDescriptor(stt, fields);
}
} // namespace sparse_tensor
} // namespace mlir
#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSODESCRIPTOR_H_
|