File: DimLvlMap.h

package info (click to toggle)
llvm-toolchain-17 1%3A17.0.6-22
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,799,624 kB
  • sloc: cpp: 6,428,607; ansic: 1,383,196; asm: 793,408; python: 223,504; objc: 75,364; f90: 60,502; lisp: 33,869; pascal: 15,282; sh: 9,684; perl: 7,453; ml: 4,937; awk: 3,523; makefile: 2,889; javascript: 2,149; xml: 888; fortran: 619; cs: 573
file content (323 lines) | stat: -rw-r--r-- 12,406 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
//===- DimLvlMap.h ----------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// FIXME(wrengr): The `DimLvlMap` class must be public so that it can
// be named as the storage representation of the parameter for the tblgen
// defn of STEA.  We may well need to make the other classes public too,
// so that the rest of the compiler can use them when necessary.
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H
#define MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H

#include "Var.h"

#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"

namespace mlir {
namespace sparse_tensor {
namespace ir_detail {

//===----------------------------------------------------------------------===//
// TODO(wrengr): Give this enum a better name, so that it fits together
// with the name of the `DimLvlExpr` class (which may also want a better
// name).  Perhaps make this a nested-type too.
//
// NOTE: In the future we will extend this enum to include "counting
// expressions" required for supporting ITPACK/ELL.  Therefore the current
// underlying-type and representation values should not be relied upon.
enum class ExprKind : bool { Dimension = false, Level = true };

// TODO(wrengr): still needs a better name....
constexpr VarKind getVarKindAllowedInExpr(ExprKind ek) {
  using VK = std::underlying_type_t<VarKind>;
  return VarKind{2 * static_cast<VK>(!to_underlying(ek))};
}
static_assert(getVarKindAllowedInExpr(ExprKind::Dimension) == VarKind::Level &&
              getVarKindAllowedInExpr(ExprKind::Level) == VarKind::Dimension);

//===----------------------------------------------------------------------===//
// TODO(wrengr): The goal of this class is to capture a proof that
// we've verified that the given `AffineExpr` only has variables of the
// appropriate kind(s).  So we need to actually prove/verify that in the
// ctor or all its callsites!
class DimLvlExpr {
private:
  // FIXME(wrengr): Per <https://llvm.org/docs/HowToSetUpLLVMStyleRTTI.html>,
  // the `kind` field should be private and const.  However, beware
  // that if we mark any field as `const` or if the fields have differing
  // `private`/`protected` privileges then the `IsZeroCostAbstraction`
  // assertion will fail!
  // (Also, iirc, if we end up moving the `expr` to the subclasses
  // instead, that'll also cause `IsZeroCostAbstraction` to fail.)
  ExprKind kind;
  AffineExpr expr;

public:
  constexpr DimLvlExpr(ExprKind ek, AffineExpr expr) : kind(ek), expr(expr) {}

  //
  // Boolean operators.
  //
  constexpr bool operator==(DimLvlExpr other) const {
    return kind == other.kind && expr == other.expr;
  }
  constexpr bool operator!=(DimLvlExpr other) const {
    return !(*this == other);
  }
  explicit operator bool() const { return static_cast<bool>(expr); }

  //
  // RTTI support (for the `DimLvlExpr` class itself).
  //
  template <typename U>
  constexpr bool isa() const;
  template <typename U>
  constexpr U cast() const;
  template <typename U>
  constexpr U dyn_cast() const;

  //
  // Simple getters.
  //
  constexpr ExprKind getExprKind() const { return kind; }
  constexpr VarKind getAllowedVarKind() const {
    return getVarKindAllowedInExpr(kind);
  }
  constexpr AffineExpr getAffineExpr() const { return expr; }
  AffineExprKind getAffineKind() const {
    assert(expr);
    return expr.getKind();
  }
  MLIRContext *getContext() const { return expr ? expr.getContext() : nullptr; }

  //
  // Getters for handling `AffineExpr` subclasses.
  //
  // TODO(wrengr): is there any way to make these typesafe without too much
  // templating?
  // TODO(wrengr): Most if not all of these don't actually need to be
  // methods, they could be free-functions instead.
  //
  SymVar castSymVar() const;
  Var castDimLvlVar() const;
  int64_t castConstantValue() const;
  std::optional<int64_t> tryGetConstantValue() const;
  bool hasConstantValue(int64_t val) const;
  DimLvlExpr getLHS() const;
  DimLvlExpr getRHS() const;
  std::tuple<DimLvlExpr, AffineExprKind, DimLvlExpr> unpackBinop() const;

  /// Checks whether the variables bound/used by this spec are valid
  /// with respect to the given ranks.
  bool isValid(Ranks const &ranks) const;

  void print(llvm::raw_ostream &os) const;
  void print(AsmPrinter &printer) const;
  void dump() const;

protected:
  // Variant of `mlir::AsmPrinter::Impl::BindingStrength`
  enum class BindingStrength : bool { Weak = false, Strong = true };

  // TODO(wrengr): Does our version of `printAffineExprInternal` really
  // need to be a method, or could it be a free-function instead? (assuming
  // `BindingStrength` goes with it).
  void printAffineExprInternal(llvm::raw_ostream &os,
                               BindingStrength enclosingTightness) const;
  void printStrong(llvm::raw_ostream &os) const {
    printAffineExprInternal(os, BindingStrength::Strong);
  }
  void printWeak(llvm::raw_ostream &os) const {
    printAffineExprInternal(os, BindingStrength::Weak);
  }
};
static_assert(IsZeroCostAbstraction<DimLvlExpr>);

// FUTURE_CL(wrengr): It would be nice to have the subclasses override
// `getRHS`, `getLHS`, `unpackBinop`, and `castDimLvlVar` to give them
// the proper covariant return types.
//
class DimExpr final : public DimLvlExpr {
  // FIXME(wrengr): These two are needed for the current RTTI implementation.
  friend class DimLvlExpr;
  constexpr explicit DimExpr(DimLvlExpr expr) : DimLvlExpr(expr) {}

public:
  static constexpr ExprKind Kind = ExprKind::Dimension;
  static constexpr bool classof(DimLvlExpr const *expr) {
    return expr->getExprKind() == Kind;
  }
  constexpr explicit DimExpr(AffineExpr expr) : DimLvlExpr(Kind, expr) {}
};
static_assert(IsZeroCostAbstraction<DimExpr>);

class LvlExpr final : public DimLvlExpr {
  // FIXME(wrengr): These two are needed for the current RTTI implementation.
  friend class DimLvlExpr;
  constexpr explicit LvlExpr(DimLvlExpr expr) : DimLvlExpr(expr) {}

public:
  static constexpr ExprKind Kind = ExprKind::Level;
  static constexpr bool classof(DimLvlExpr const *expr) {
    return expr->getExprKind() == Kind;
  }
  constexpr explicit LvlExpr(AffineExpr expr) : DimLvlExpr(Kind, expr) {}
};
static_assert(IsZeroCostAbstraction<LvlExpr>);

// FIXME(wrengr): See comments elsewhere re RTTI implementation issues/questions
template <typename U>
constexpr bool DimLvlExpr::isa() const {
  if constexpr (std::is_same_v<U, DimExpr>)
    return getExprKind() == ExprKind::Dimension;
  if constexpr (std::is_same_v<U, LvlExpr>)
    return getExprKind() == ExprKind::Level;
}

template <typename U>
constexpr U DimLvlExpr::cast() const {
  assert(isa<U>());
  return U(*this);
}

template <typename U>
constexpr U DimLvlExpr::dyn_cast() const {
  return isa<U>() ? U(*this) : U();
}

//===----------------------------------------------------------------------===//
/// The full `dimVar = dimExpr : dimSlice` specification for a given dimension.
class DimSpec final {
  /// The dimension-variable bound by this specification.
  DimVar var;
  /// The dimension-expression.  The `DimSpec` ctor treats this field
  /// as optional; whereas the `DimLvlMap` ctor will fill in (or verify)
  /// the expression via function-inversion inference.
  DimExpr expr;
  /// Can the `expr` be elided when printing? The `DimSpec` ctor assumes
  /// not (though if `expr` is null it will elide printing that); whereas
  /// the `DimLvlMap` ctor will reset it as appropriate.
  bool elideExpr = false;
  /// The dimension-slice; optional, default is null.
  SparseTensorDimSliceAttr slice;

public:
  DimSpec(DimVar var, DimExpr expr, SparseTensorDimSliceAttr slice);

  constexpr DimVar getBoundVar() const { return var; }
  bool hasExpr() const { return static_cast<bool>(expr); }
  constexpr DimExpr getExpr() const { return expr; }
  void setExpr(DimExpr newExpr) {
    assert(!hasExpr());
    expr = newExpr;
  }
  constexpr bool canElideExpr() const { return elideExpr; }
  void setElideExpr(bool b) { elideExpr = b; }
  constexpr SparseTensorDimSliceAttr getSlice() const { return slice; }

  /// Checks whether the variables bound/used by this spec are valid
  /// with respect to the given ranks.
  bool isValid(Ranks const &ranks) const;

  // TODO(wrengr): Use it or loose it.
  bool isFunctionOf(Var var) const;
  bool isFunctionOf(VarSet const &vars) const;
  void getFreeVars(VarSet &vars) const;

  void print(llvm::raw_ostream &os, bool wantElision = true) const;
  void print(AsmPrinter &printer, bool wantElision = true) const;
  void dump() const;
};
// Although this class is more than just a newtype/wrapper, we do want
// to ensure that storing them into `SmallVector` is efficient.
static_assert(IsZeroCostAbstraction<DimSpec>);

//===----------------------------------------------------------------------===//
/// The full `lvlVar = lvlExpr : lvlType` specification for a given level.
class LvlSpec final {
  /// The level-variable bound by this specification.
  LvlVar var;
  /// Can the `var` be elided when printing?  The `LvlSpec` ctor assumes not;
  /// whereas the `DimLvlMap` ctor will reset this as appropriate.
  bool elideVar = false;
  /// The level-expression.
  //
  // NOTE: For now we use `LvlExpr` because all level-expressions must be
  // `AffineExpr`; however, in the future we will also want to allow "counting
  // expressions", and potentially other kinds of non-affine level-expressions.
  // Which kinds of `DimLvlExpr` are allowed will depend on the `DimLevelType`,
  // so we may consider defining another class for pairing those two together
  // to ensure that the pair is well-formed.
  LvlExpr expr;
  /// The level-type (== level-format + lvl-properties).
  DimLevelType type;

public:
  LvlSpec(LvlVar var, LvlExpr expr, DimLevelType type);

  constexpr LvlVar getBoundVar() const { return var; }
  constexpr bool canElideVar() const { return elideVar; }
  void setElideVar(bool b) { elideVar = b; }
  constexpr LvlExpr getExpr() const { return expr; }
  constexpr DimLevelType getType() const { return type; }

  /// Checks whether the variables bound/used by this spec are valid
  /// with respect to the given ranks.
  //
  // NOTE: Once we introduce "counting expressions" this will need
  // a more sophisticated implementation than `DimSpec::isValid` does.
  bool isValid(Ranks const &ranks) const;

  // TODO(wrengr): Use it or loose it.
  bool isFunctionOf(Var var) const;
  bool isFunctionOf(VarSet const &vars) const;
  void getFreeVars(VarSet &vars) const;

  void print(llvm::raw_ostream &os, bool wantElision = true) const;
  void print(AsmPrinter &printer, bool wantElision = true) const;
  void dump() const;
};
// Although this class is more than just a newtype/wrapper, we do want
// to ensure that storing them into `SmallVector` is efficient.
static_assert(IsZeroCostAbstraction<LvlSpec>);

//===----------------------------------------------------------------------===//
class DimLvlMap final {
  // TODO(wrengr): Need to define getters
  unsigned symRank;
  SmallVector<DimSpec> dimSpecs;
  SmallVector<LvlSpec> lvlSpecs;

  // Checks for integrity of variable-binding structure.
  // This is already called by the ctor.
  bool isWF() const;

public:
  DimLvlMap(unsigned symRank, ArrayRef<DimSpec> dimSpecs,
            ArrayRef<LvlSpec> lvlSpecs);

  unsigned getSymRank() const { return symRank; }
  unsigned getDimRank() const { return dimSpecs.size(); }
  unsigned getLvlRank() const { return lvlSpecs.size(); }
  unsigned getRank(VarKind vk) const { return getRanks().getRank(vk); }
  Ranks getRanks() const { return {getSymRank(), getDimRank(), getLvlRank()}; }

  DimLevelType getDimLevelType(unsigned i) { return lvlSpecs[i].getType(); }

  void print(llvm::raw_ostream &os, bool wantElision = true) const;
  void print(AsmPrinter &printer, bool wantElision = true) const;
  void dump() const;
};

//===----------------------------------------------------------------------===//

} // namespace ir_detail
} // namespace sparse_tensor
} // namespace mlir

#endif // MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H