File: DimLvlMap.h

package info (click to toggle)
llvm-toolchain-19 1%3A19.1.7-3
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 1,998,520 kB
  • sloc: cpp: 6,951,680; ansic: 1,486,157; asm: 913,598; python: 232,024; f90: 80,126; objc: 75,281; lisp: 37,276; pascal: 16,990; sh: 10,009; ml: 5,058; perl: 4,724; awk: 3,523; makefile: 3,167; javascript: 2,504; xml: 892; fortran: 664; cs: 573
file content (280 lines) | stat: -rw-r--r-- 9,591 bytes parent folder | download | duplicates (20)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
//===- DimLvlMap.h ----------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H
#define MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H

#include "Var.h"

#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "llvm/ADT/STLForwardCompat.h"

namespace mlir {
namespace sparse_tensor {
namespace ir_detail {

//===----------------------------------------------------------------------===//
enum class ExprKind : bool { Dimension = false, Level = true };

constexpr VarKind getVarKindAllowedInExpr(ExprKind ek) {
  using VK = std::underlying_type_t<VarKind>;
  return VarKind{2 * static_cast<VK>(!llvm::to_underlying(ek))};
}
static_assert(getVarKindAllowedInExpr(ExprKind::Dimension) == VarKind::Level &&
              getVarKindAllowedInExpr(ExprKind::Level) == VarKind::Dimension);

//===----------------------------------------------------------------------===//
class DimLvlExpr {
private:
  ExprKind kind;
  AffineExpr expr;

public:
  constexpr DimLvlExpr(ExprKind ek, AffineExpr expr) : kind(ek), expr(expr) {}

  //
  // Boolean operators.
  //
  constexpr bool operator==(DimLvlExpr other) const {
    return kind == other.kind && expr == other.expr;
  }
  constexpr bool operator!=(DimLvlExpr other) const {
    return !(*this == other);
  }
  explicit operator bool() const { return static_cast<bool>(expr); }

  //
  // RTTI support (for the `DimLvlExpr` class itself).
  //
  template <typename U>
  constexpr bool isa() const;
  template <typename U>
  constexpr U cast() const;
  template <typename U>
  constexpr U dyn_cast() const;

  //
  // Simple getters.
  //
  constexpr ExprKind getExprKind() const { return kind; }
  constexpr VarKind getAllowedVarKind() const {
    return getVarKindAllowedInExpr(kind);
  }
  constexpr AffineExpr getAffineExpr() const { return expr; }
  AffineExprKind getAffineKind() const {
    assert(expr);
    return expr.getKind();
  }
  MLIRContext *tryGetContext() const {
    return expr ? expr.getContext() : nullptr;
  }

  //
  // Getters for handling `AffineExpr` subclasses.
  //
  SymVar castSymVar() const;
  std::optional<SymVar> dyn_castSymVar() const;
  Var castDimLvlVar() const;
  std::optional<Var> dyn_castDimLvlVar() const;
  std::tuple<DimLvlExpr, AffineExprKind, DimLvlExpr> unpackBinop() const;

  /// Checks whether the variables bound/used by this spec are valid
  /// with respect to the given ranks.
  [[nodiscard]] bool isValid(Ranks const &ranks) const;

protected:
  // Variant of `mlir::AsmPrinter::Impl::BindingStrength`
  enum class BindingStrength : bool { Weak = false, Strong = true };
};
static_assert(IsZeroCostAbstraction<DimLvlExpr>);

class DimExpr final : public DimLvlExpr {
  friend class DimLvlExpr;
  constexpr explicit DimExpr(DimLvlExpr expr) : DimLvlExpr(expr) {}

public:
  static constexpr ExprKind Kind = ExprKind::Dimension;
  static constexpr bool classof(DimLvlExpr const *expr) {
    return expr->getExprKind() == Kind;
  }
  constexpr explicit DimExpr(AffineExpr expr) : DimLvlExpr(Kind, expr) {}

  LvlVar castLvlVar() const { return castDimLvlVar().cast<LvlVar>(); }
  std::optional<LvlVar> dyn_castLvlVar() const {
    const auto var = dyn_castDimLvlVar();
    return var ? std::make_optional(var->cast<LvlVar>()) : std::nullopt;
  }
};
static_assert(IsZeroCostAbstraction<DimExpr>);

class LvlExpr final : public DimLvlExpr {
  friend class DimLvlExpr;
  constexpr explicit LvlExpr(DimLvlExpr expr) : DimLvlExpr(expr) {}

public:
  static constexpr ExprKind Kind = ExprKind::Level;
  static constexpr bool classof(DimLvlExpr const *expr) {
    return expr->getExprKind() == Kind;
  }
  constexpr explicit LvlExpr(AffineExpr expr) : DimLvlExpr(Kind, expr) {}

  DimVar castDimVar() const { return castDimLvlVar().cast<DimVar>(); }
  std::optional<DimVar> dyn_castDimVar() const {
    const auto var = dyn_castDimLvlVar();
    return var ? std::make_optional(var->cast<DimVar>()) : std::nullopt;
  }
};
static_assert(IsZeroCostAbstraction<LvlExpr>);

template <typename U>
constexpr bool DimLvlExpr::isa() const {
  if constexpr (std::is_same_v<U, DimExpr>)
    return getExprKind() == ExprKind::Dimension;
  if constexpr (std::is_same_v<U, LvlExpr>)
    return getExprKind() == ExprKind::Level;
}

template <typename U>
constexpr U DimLvlExpr::cast() const {
  assert(isa<U>());
  return U(*this);
}

template <typename U>
constexpr U DimLvlExpr::dyn_cast() const {
  return isa<U>() ? U(*this) : U();
}

//===----------------------------------------------------------------------===//
/// The full `dimVar = dimExpr : dimSlice` specification for a given dimension.
class DimSpec final {
  /// The dimension-variable bound by this specification.
  DimVar var;
  /// The dimension-expression.  The `DimSpec` ctor treats this field
  /// as optional; whereas the `DimLvlMap` ctor will fill in (or verify)
  /// the expression via function-inversion inference.
  DimExpr expr;
  /// Can the `expr` be elided when printing? The `DimSpec` ctor assumes
  /// not (though if `expr` is null it will elide printing that); whereas
  /// the `DimLvlMap` ctor will reset it as appropriate.
  bool elideExpr = false;
  /// The dimension-slice; optional, default is null.
  SparseTensorDimSliceAttr slice;

public:
  DimSpec(DimVar var, DimExpr expr, SparseTensorDimSliceAttr slice);

  MLIRContext *tryGetContext() const { return expr.tryGetContext(); }

  constexpr DimVar getBoundVar() const { return var; }
  bool hasExpr() const { return static_cast<bool>(expr); }
  constexpr DimExpr getExpr() const { return expr; }
  void setExpr(DimExpr newExpr) {
    assert(!hasExpr());
    expr = newExpr;
  }
  constexpr bool canElideExpr() const { return elideExpr; }
  void setElideExpr(bool b) { elideExpr = b; }
  constexpr SparseTensorDimSliceAttr getSlice() const { return slice; }

  /// Checks whether the variables bound/used by this spec are valid with
  /// respect to the given ranks.  Note that null `DimExpr` is considered
  /// to be vacuously valid, and therefore calling `setExpr` invalidates
  /// the result of this predicate.
  [[nodiscard]] bool isValid(Ranks const &ranks) const;
};

static_assert(IsZeroCostAbstraction<DimSpec>);

//===----------------------------------------------------------------------===//
/// The full `lvlVar = lvlExpr : lvlType` specification for a given level.
class LvlSpec final {
  /// The level-variable bound by this specification.
  LvlVar var;
  /// Can the `var` be elided when printing?  The `LvlSpec` ctor assumes not;
  /// whereas the `DimLvlMap` ctor will reset this as appropriate.
  bool elideVar = false;
  /// The level-expression.
  LvlExpr expr;
  /// The level-type (== level-format + lvl-properties).
  LevelType type;

public:
  LvlSpec(LvlVar var, LvlExpr expr, LevelType type);

  MLIRContext *getContext() const {
    MLIRContext *ctx = expr.tryGetContext();
    assert(ctx);
    return ctx;
  }

  constexpr LvlVar getBoundVar() const { return var; }
  constexpr bool canElideVar() const { return elideVar; }
  void setElideVar(bool b) { elideVar = b; }
  constexpr LvlExpr getExpr() const { return expr; }
  constexpr LevelType getType() const { return type; }

  /// Checks whether the variables bound/used by this spec are valid
  /// with respect to the given ranks.
  [[nodiscard]] bool isValid(Ranks const &ranks) const;
};

static_assert(IsZeroCostAbstraction<LvlSpec>);

//===----------------------------------------------------------------------===//
class DimLvlMap final {
public:
  DimLvlMap(unsigned symRank, ArrayRef<DimSpec> dimSpecs,
            ArrayRef<LvlSpec> lvlSpecs);

  unsigned getSymRank() const { return symRank; }
  unsigned getDimRank() const { return dimSpecs.size(); }
  unsigned getLvlRank() const { return lvlSpecs.size(); }
  unsigned getRank(VarKind vk) const { return getRanks().getRank(vk); }
  Ranks getRanks() const { return {getSymRank(), getDimRank(), getLvlRank()}; }

  ArrayRef<DimSpec> getDims() const { return dimSpecs; }
  const DimSpec &getDim(Dimension dim) const { return dimSpecs[dim]; }
  SparseTensorDimSliceAttr getDimSlice(Dimension dim) const {
    return getDim(dim).getSlice();
  }

  ArrayRef<LvlSpec> getLvls() const { return lvlSpecs; }
  const LvlSpec &getLvl(Level lvl) const { return lvlSpecs[lvl]; }
  LevelType getLvlType(Level lvl) const { return getLvl(lvl).getType(); }

  AffineMap getDimToLvlMap(MLIRContext *context) const;
  AffineMap getLvlToDimMap(MLIRContext *context) const;

private:
  /// Checks for integrity of variable-binding structure.
  /// This is already called by the ctor.
  [[nodiscard]] bool isWF() const;

  /// Helper function to call `DimSpec::setExpr` while asserting that
  /// the invariant established by `DimLvlMap:isWF` is maintained.
  /// This is used by the ctor.
  void setDimExpr(Dimension dim, DimExpr expr) {
    assert(expr && getRanks().isValid(expr));
    dimSpecs[dim].setExpr(expr);
  }

  // All these fields are const-after-ctor.
  unsigned symRank;
  SmallVector<DimSpec> dimSpecs;
  SmallVector<LvlSpec> lvlSpecs;
  bool mustPrintLvlVars;
};

//===----------------------------------------------------------------------===//

} // namespace ir_detail
} // namespace sparse_tensor
} // namespace mlir

#endif // MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H