File: SparseTensorDescriptor.h

package info (click to toggle)
llvm-toolchain-17 1%3A17.0.6-22
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,799,624 kB
  • sloc: cpp: 6,428,607; ansic: 1,383,196; asm: 793,408; python: 223,504; objc: 75,364; f90: 60,502; lisp: 33,869; pascal: 15,282; sh: 9,684; perl: 7,453; ml: 4,937; awk: 3,523; makefile: 2,889; javascript: 2,149; xml: 888; fortran: 619; cs: 573
file content (278 lines) | stat: -rw-r--r-- 10,443 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
//===- SparseTensorDescriptor.h ------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This header file defines utilities for the sparse memory layout.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORDESCRIPTOR_H_
#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSORDESCRIPTOR_H_

#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"

namespace mlir {
namespace sparse_tensor {

//===----------------------------------------------------------------------===//
// SparseTensorDescriptor and helpers that manage the sparse tensor memory
// layout scheme during "direct code generation" (i.e. when sparsification
// generates the buffers as part of actual IR, in constrast with the library
// approach where data structures are hidden behind opaque pointers).
//===----------------------------------------------------------------------===//

class SparseTensorSpecifier {
public:
  explicit SparseTensorSpecifier(Value specifier)
      : specifier(cast<TypedValue<StorageSpecifierType>>(specifier)) {}

  // Undef value for level-sizes, all zero values for memory-sizes.
  static Value getInitValue(OpBuilder &builder, Location loc,
                            SparseTensorType stt);

  /*implicit*/ operator Value() { return specifier; }

  Value getSpecifierField(OpBuilder &builder, Location loc,
                          StorageSpecifierKind kind, std::optional<Level> lvl);

  void setSpecifierField(OpBuilder &builder, Location loc, Value v,
                         StorageSpecifierKind kind, std::optional<Level> lvl);

private:
  TypedValue<StorageSpecifierType> specifier;
};

/// A helper class around an array of values that corresponds to a sparse
/// tensor. This class provides a set of meaningful APIs to query and update
/// a particular field in a consistent way. Users should not make assumptions
/// on how a sparse tensor is laid out but instead rely on this class to access
/// the right value for the right field.
template <typename ValueArrayRef>
class SparseTensorDescriptorImpl {
protected:
  // TODO: Functions/methods marked with [NUMFIELDS] might should use
  // `FieldIndex` for their return type, via the same reasoning for why
  // `Dimension`/`Level` are used both for identifiers and ranks.
  SparseTensorDescriptorImpl(SparseTensorType stt, ValueArrayRef fields)
      : rType(stt), fields(fields), layout(stt) {
    assert(layout.getNumFields() == getNumFields());
    // We should make sure the class is trivially copyable (and should be small
    // enough) such that we can pass it by value.
    static_assert(std::is_trivially_copyable_v<
                  SparseTensorDescriptorImpl<ValueArrayRef>>);
  }

public:
  FieldIndex getMemRefFieldIndex(SparseTensorFieldKind kind,
                                 std::optional<Level> lvl) const {
    // Delegates to storage layout.
    return layout.getMemRefFieldIndex(kind, lvl);
  }

  // TODO: See note [NUMFIELDS].
  unsigned getNumFields() const { return fields.size(); }

  ///
  /// Getters: get the value for required field.
  ///

  Value getSpecifier() const { return fields.back(); }

  Value getSpecifierField(OpBuilder &builder, Location loc,
                          StorageSpecifierKind kind,
                          std::optional<Level> lvl) const {
    SparseTensorSpecifier md(fields.back());
    return md.getSpecifierField(builder, loc, kind, lvl);
  }

  Value getLvlSize(OpBuilder &builder, Location loc, Level lvl) const {
    return getSpecifierField(builder, loc, StorageSpecifierKind::LvlSize, lvl);
  }

  Value getPosMemRef(Level lvl) const {
    return getMemRefField(SparseTensorFieldKind::PosMemRef, lvl);
  }

  Value getValMemRef() const {
    return getMemRefField(SparseTensorFieldKind::ValMemRef, std::nullopt);
  }

  Value getMemRefField(SparseTensorFieldKind kind,
                       std::optional<Level> lvl) const {
    return getField(getMemRefFieldIndex(kind, lvl));
  }

  Value getMemRefField(FieldIndex fidx) const {
    assert(fidx < fields.size() - 1);
    return getField(fidx);
  }

  Value getPosMemSize(OpBuilder &builder, Location loc, Level lvl) const {
    return getSpecifierField(builder, loc, StorageSpecifierKind::PosMemSize,
                             lvl);
  }

  Value getCrdMemSize(OpBuilder &builder, Location loc, Level lvl) const {
    return getSpecifierField(builder, loc, StorageSpecifierKind::CrdMemSize,
                             lvl);
  }

  Value getValMemSize(OpBuilder &builder, Location loc) const {
    return getSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize,
                             std::nullopt);
  }

  Type getMemRefElementType(SparseTensorFieldKind kind,
                            std::optional<Level> lvl) const {
    return getMemRefType(getMemRefField(kind, lvl)).getElementType();
  }

  Value getField(FieldIndex fidx) const {
    assert(fidx < fields.size());
    return fields[fidx];
  }

  ValueRange getMemRefFields() const {
    // Drop the last metadata fields.
    return fields.drop_back();
  }

  std::pair<FieldIndex, unsigned> getCrdMemRefIndexAndStride(Level lvl) const {
    return layout.getFieldIndexAndStride(SparseTensorFieldKind::CrdMemRef, lvl);
  }

  Value getAOSMemRef() const {
    const Level cooStart = getCOOStart(rType.getEncoding());
    assert(cooStart < rType.getLvlRank());
    return getMemRefField(SparseTensorFieldKind::CrdMemRef, cooStart);
  }

  RankedTensorType getRankedTensorType() const { return rType; }
  ValueArrayRef getFields() const { return fields; }
  StorageLayout getLayout() const { return layout; }

protected:
  SparseTensorType rType;
  ValueArrayRef fields;
  StorageLayout layout;
};

/// Uses ValueRange for immutable descriptors.
class SparseTensorDescriptor : public SparseTensorDescriptorImpl<ValueRange> {
public:
  SparseTensorDescriptor(SparseTensorType stt, ValueRange buffers)
      : SparseTensorDescriptorImpl<ValueRange>(stt, buffers) {}

  Value getCrdMemRefOrView(OpBuilder &builder, Location loc, Level lvl) const;
};

/// Uses SmallVectorImpl<Value> & for mutable descriptors.
/// Using SmallVector for mutable descriptor allows users to reuse it as a
/// tmp buffers to append value for some special cases, though users should
/// be responsible to restore the buffer to legal states after their use. It
/// is probably not a clean way, but it is the most efficient way to avoid
/// copying the fields into another SmallVector. If a more clear way is
/// wanted, we should change it to MutableArrayRef instead.
class MutSparseTensorDescriptor
    : public SparseTensorDescriptorImpl<SmallVectorImpl<Value> &> {
public:
  MutSparseTensorDescriptor(SparseTensorType stt,
                            SmallVectorImpl<Value> &buffers)
      : SparseTensorDescriptorImpl<SmallVectorImpl<Value> &>(stt, buffers) {}

  // Allow implicit type conversion from mutable descriptors to immutable ones
  // (but not vice versa).
  /*implicit*/ operator SparseTensorDescriptor() const {
    return SparseTensorDescriptor(rType, fields);
  }

  ///
  /// Adds additional setters for mutable descriptor, update the value for
  /// required field.
  ///

  void setMemRefField(SparseTensorFieldKind kind, std::optional<Level> lvl,
                      Value v) {
    fields[getMemRefFieldIndex(kind, lvl)] = v;
  }

  void setMemRefField(FieldIndex fidx, Value v) {
    assert(fidx < fields.size() - 1);
    fields[fidx] = v;
  }

  void setField(FieldIndex fidx, Value v) {
    assert(fidx < fields.size());
    fields[fidx] = v;
  }

  void setSpecifier(Value newSpec) { fields.back() = newSpec; }

  void setSpecifierField(OpBuilder &builder, Location loc,
                         StorageSpecifierKind kind, std::optional<Level> lvl,
                         Value v) {
    SparseTensorSpecifier md(fields.back());
    md.setSpecifierField(builder, loc, v, kind, lvl);
    fields.back() = md;
  }

  void setValMemSize(OpBuilder &builder, Location loc, Value v) {
    setSpecifierField(builder, loc, StorageSpecifierKind::ValMemSize,
                      std::nullopt, v);
  }

  void setCrdMemSize(OpBuilder &builder, Location loc, Level lvl, Value v) {
    setSpecifierField(builder, loc, StorageSpecifierKind::CrdMemSize, lvl, v);
  }

  void setPosMemSize(OpBuilder &builder, Location loc, Level lvl, Value v) {
    setSpecifierField(builder, loc, StorageSpecifierKind::PosMemSize, lvl, v);
  }

  void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value v) {
    setSpecifierField(builder, loc, StorageSpecifierKind::LvlSize, lvl, v);
  }
};

/// Returns the "tuple" value of the adapted tensor.
inline UnrealizedConversionCastOp getTuple(Value tensor) {
  return llvm::cast<UnrealizedConversionCastOp>(tensor.getDefiningOp());
}

/// Packs the given values as a "tuple" value.
inline Value genTuple(OpBuilder &builder, Location loc, Type tp,
                      ValueRange values) {
  return builder.create<UnrealizedConversionCastOp>(loc, TypeRange(tp), values)
      .getResult(0);
}

inline Value genTuple(OpBuilder &builder, Location loc,
                      SparseTensorDescriptor desc) {
  return genTuple(builder, loc, desc.getRankedTensorType(), desc.getFields());
}

inline SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor) {
  auto tuple = getTuple(tensor);
  SparseTensorType stt(cast<RankedTensorType>(tuple.getResultTypes()[0]));
  return SparseTensorDescriptor(stt, tuple.getInputs());
}

inline MutSparseTensorDescriptor
getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl<Value> &fields) {
  auto tuple = getTuple(tensor);
  fields.assign(tuple.getInputs().begin(), tuple.getInputs().end());
  SparseTensorType stt(cast<RankedTensorType>(tuple.getResultTypes()[0]));
  return MutSparseTensorDescriptor(stt, fields);
}

} // namespace sparse_tensor
} // namespace mlir

#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_SPARSETENSODESCRIPTOR_H_