| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 
 | //===- Set.cpp - MLIR PresburgerSet Class ---------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/PresburgerSet.h"
#include "mlir/Analysis/Presburger/Simplex.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
using namespace mlir;
PresburgerSet::PresburgerSet(const FlatAffineConstraints &fac)
    : nDim(fac.getNumDimIds()), nSym(fac.getNumSymbolIds()) {
  unionFACInPlace(fac);
}
unsigned PresburgerSet::getNumFACs() const {
  return flatAffineConstraints.size();
}
unsigned PresburgerSet::getNumDims() const { return nDim; }
unsigned PresburgerSet::getNumSyms() const { return nSym; }
ArrayRef<FlatAffineConstraints>
PresburgerSet::getAllFlatAffineConstraints() const {
  return flatAffineConstraints;
}
const FlatAffineConstraints &
PresburgerSet::getFlatAffineConstraints(unsigned index) const {
  assert(index < flatAffineConstraints.size() && "index out of bounds!");
  return flatAffineConstraints[index];
}
/// Assert that the FlatAffineConstraints and PresburgerSet live in
/// compatible spaces.
static void assertDimensionsCompatible(const FlatAffineConstraints &fac,
                                       const PresburgerSet &set) {
  assert(fac.getNumDimIds() == set.getNumDims() &&
         "Number of dimensions of the FlatAffineConstraints and PresburgerSet"
         "do not match!");
  assert(fac.getNumSymbolIds() == set.getNumSyms() &&
         "Number of symbols of the FlatAffineConstraints and PresburgerSet"
         "do not match!");
}
/// Assert that the two PresburgerSets live in compatible spaces.
static void assertDimensionsCompatible(const PresburgerSet &setA,
                                       const PresburgerSet &setB) {
  assert(setA.getNumDims() == setB.getNumDims() &&
         "Number of dimensions of the PresburgerSets do not match!");
  assert(setA.getNumSyms() == setB.getNumSyms() &&
         "Number of symbols of the PresburgerSets do not match!");
}
/// Mutate this set, turning it into the union of this set and the given
/// FlatAffineConstraints.
void PresburgerSet::unionFACInPlace(const FlatAffineConstraints &fac) {
  assertDimensionsCompatible(fac, *this);
  flatAffineConstraints.push_back(fac);
}
/// Mutate this set, turning it into the union of this set and the given set.
///
/// This is accomplished by simply adding all the FACs of the given set to this
/// set.
void PresburgerSet::unionSetInPlace(const PresburgerSet &set) {
  assertDimensionsCompatible(set, *this);
  for (const FlatAffineConstraints &fac : set.flatAffineConstraints)
    unionFACInPlace(fac);
}
/// Return the union of this set and the given set.
PresburgerSet PresburgerSet::unionSet(const PresburgerSet &set) const {
  assertDimensionsCompatible(set, *this);
  PresburgerSet result = *this;
  result.unionSetInPlace(set);
  return result;
}
/// A point is contained in the union iff any of the parts contain the point.
bool PresburgerSet::containsPoint(ArrayRef<int64_t> point) const {
  for (const FlatAffineConstraints &fac : flatAffineConstraints) {
    if (fac.containsPoint(point))
      return true;
  }
  return false;
}
PresburgerSet PresburgerSet::getUniverse(unsigned nDim, unsigned nSym) {
  PresburgerSet result(nDim, nSym);
  result.unionFACInPlace(FlatAffineConstraints::getUniverse(nDim, nSym));
  return result;
}
PresburgerSet PresburgerSet::getEmptySet(unsigned nDim, unsigned nSym) {
  return PresburgerSet(nDim, nSym);
}
// Return the intersection of this set with the given set.
//
// We directly compute (S_1 or S_2 ...) and (T_1 or T_2 ...)
// as (S_1 and T_1) or (S_1 and T_2) or ...
PresburgerSet PresburgerSet::intersect(const PresburgerSet &set) const {
  assertDimensionsCompatible(set, *this);
  PresburgerSet result(nDim, nSym);
  for (const FlatAffineConstraints &csA : flatAffineConstraints) {
    for (const FlatAffineConstraints &csB : set.flatAffineConstraints) {
      FlatAffineConstraints intersection(csA);
      intersection.append(csB);
      if (!intersection.isEmpty())
        result.unionFACInPlace(std::move(intersection));
    }
  }
  return result;
}
/// Return `coeffs` with all the elements negated.
static SmallVector<int64_t, 8> getNegatedCoeffs(ArrayRef<int64_t> coeffs) {
  SmallVector<int64_t, 8> negatedCoeffs;
  negatedCoeffs.reserve(coeffs.size());
  for (int64_t coeff : coeffs)
    negatedCoeffs.emplace_back(-coeff);
  return negatedCoeffs;
}
/// Return the complement of the given inequality.
///
/// The complement of a_1 x_1 + ... + a_n x_ + c >= 0 is
/// a_1 x_1 + ... + a_n x_ + c < 0, i.e., -a_1 x_1 - ... - a_n x_ - c - 1 >= 0.
static SmallVector<int64_t, 8> getComplementIneq(ArrayRef<int64_t> ineq) {
  SmallVector<int64_t, 8> coeffs;
  coeffs.reserve(ineq.size());
  for (int64_t coeff : ineq)
    coeffs.emplace_back(-coeff);
  --coeffs.back();
  return coeffs;
}
/// Return the set difference b \ s and accumulate the result into `result`.
/// `simplex` must correspond to b.
///
/// In the following, V denotes union, ^ denotes intersection, \ denotes set
/// difference and ~ denotes complement.
/// Let b be the FlatAffineConstraints and s = (V_i s_i) be the set. We want
/// b \ (V_i s_i).
///
/// Let s_i = ^_j s_ij, where each s_ij is a single inequality. To compute
/// b \ s_i = b ^ ~s_i, we partition s_i based on the first violated inequality:
/// ~s_i = (~s_i1) V (s_i1 ^ ~s_i2) V (s_i1 ^ s_i2 ^ ~s_i3) V ...
/// And the required result is (b ^ ~s_i1) V (b ^ s_i1 ^ ~s_i2) V ...
/// We recurse by subtracting V_{j > i} S_j from each of these parts and
/// returning the union of the results. Each equality is handled as a
/// conjunction of two inequalities.
///
/// As a heuristic, we try adding all the constraints and check if simplex
/// says that the intersection is empty. Also, in the process we find out that
/// some constraints are redundant. These redundant constraints are ignored.
static void subtractRecursively(FlatAffineConstraints &b, Simplex &simplex,
                                const PresburgerSet &s, unsigned i,
                                PresburgerSet &result) {
  if (i == s.getNumFACs()) {
    result.unionFACInPlace(b);
    return;
  }
  const FlatAffineConstraints &sI = s.getFlatAffineConstraints(i);
  assert(sI.getNumLocalIds() == 0 &&
         "Subtracting sets with divisions is not yet supported!");
  unsigned initialSnapshot = simplex.getSnapshot();
  unsigned offset = simplex.numConstraints();
  simplex.intersectFlatAffineConstraints(sI);
  if (simplex.isEmpty()) {
    /// b ^ s_i is empty, so b \ s_i = b. We move directly to i + 1.
    simplex.rollback(initialSnapshot);
    subtractRecursively(b, simplex, s, i + 1, result);
    return;
  }
  simplex.detectRedundant();
  llvm::SmallBitVector isMarkedRedundant;
  for (unsigned j = 0; j < 2 * sI.getNumEqualities() + sI.getNumInequalities();
       j++)
    isMarkedRedundant.push_back(simplex.isMarkedRedundant(offset + j));
  simplex.rollback(initialSnapshot);
  // Recurse with the part b ^ ~ineq. Note that b is modified throughout
  // subtractRecursively. At the time this function is called, the current b is
  // actually equal to b ^ s_i1 ^ s_i2 ^ ... ^ s_ij, and ineq is the next
  // inequality, s_{i,j+1}. This function recurses into the next level i + 1
  // with the part b ^ s_i1 ^ s_i2 ^ ... ^ s_ij ^ ~s_{i,j+1}.
  auto recurseWithInequality = [&, i](ArrayRef<int64_t> ineq) {
    size_t snapshot = simplex.getSnapshot();
    b.addInequality(ineq);
    simplex.addInequality(ineq);
    subtractRecursively(b, simplex, s, i + 1, result);
    b.removeInequality(b.getNumInequalities() - 1);
    simplex.rollback(snapshot);
  };
  // For each inequality ineq, we first recurse with the part where ineq
  // is not satisfied, and then add the ineq to b and simplex because
  // ineq must be satisfied by all later parts.
  auto processInequality = [&](ArrayRef<int64_t> ineq) {
    recurseWithInequality(getComplementIneq(ineq));
    b.addInequality(ineq);
    simplex.addInequality(ineq);
  };
  // processInequality appends some additional constraints to b. We want to
  // rollback b to its initial state before returning, which we will do by
  // removing all constraints beyond the original number of inequalities
  // and equalities, so we store these counts first.
  unsigned originalNumIneqs = b.getNumInequalities();
  unsigned originalNumEqs = b.getNumEqualities();
  for (unsigned j = 0, e = sI.getNumInequalities(); j < e; j++) {
    if (isMarkedRedundant[j])
      continue;
    processInequality(sI.getInequality(j));
  }
  offset = sI.getNumInequalities();
  for (unsigned j = 0, e = sI.getNumEqualities(); j < e; ++j) {
    const ArrayRef<int64_t> &coeffs = sI.getEquality(j);
    // Same as the above loop for inequalities, done once each for the positive
    // and negative inequalities that make up this equality.
    if (!isMarkedRedundant[offset + 2 * j])
      processInequality(coeffs);
    if (!isMarkedRedundant[offset + 2 * j + 1])
      processInequality(getNegatedCoeffs(coeffs));
  }
  // Rollback b and simplex to their initial states.
  for (unsigned i = b.getNumInequalities(); i > originalNumIneqs; --i)
    b.removeInequality(i - 1);
  for (unsigned i = b.getNumEqualities(); i > originalNumEqs; --i)
    b.removeEquality(i - 1);
  simplex.rollback(initialSnapshot);
}
/// Return the set difference fac \ set.
///
/// The FAC here is modified in subtractRecursively, so it cannot be a const
/// reference even though it is restored to its original state before returning
/// from that function.
PresburgerSet PresburgerSet::getSetDifference(FlatAffineConstraints fac,
                                              const PresburgerSet &set) {
  assertDimensionsCompatible(fac, set);
  assert(fac.getNumLocalIds() == 0 &&
         "Subtracting sets with divisions is not yet supported!");
  if (fac.isEmptyByGCDTest())
    return PresburgerSet::getEmptySet(fac.getNumDimIds(),
                                      fac.getNumSymbolIds());
  PresburgerSet result(fac.getNumDimIds(), fac.getNumSymbolIds());
  Simplex simplex(fac);
  subtractRecursively(fac, simplex, set, 0, result);
  return result;
}
/// Return the complement of this set.
PresburgerSet PresburgerSet::complement() const {
  return getSetDifference(
      FlatAffineConstraints::getUniverse(getNumDims(), getNumSyms()), *this);
}
/// Return the result of subtract the given set from this set, i.e.,
/// return `this \ set`.
PresburgerSet PresburgerSet::subtract(const PresburgerSet &set) const {
  assertDimensionsCompatible(set, *this);
  PresburgerSet result(nDim, nSym);
  // We compute (V_i t_i) \ (V_i set_i) as V_i (t_i \ V_i set_i).
  for (const FlatAffineConstraints &fac : flatAffineConstraints)
    result.unionSetInPlace(getSetDifference(fac, set));
  return result;
}
/// Two sets S and T are equal iff S contains T and T contains S.
/// By "S contains T", we mean that S is a superset of or equal to T.
///
/// S contains T iff T \ S is empty, since if T \ S contains a
/// point then this is a point that is contained in T but not S.
///
/// Therefore, S is equal to T iff S \ T and T \ S are both empty.
bool PresburgerSet::isEqual(const PresburgerSet &set) const {
  assertDimensionsCompatible(set, *this);
  return this->subtract(set).isIntegerEmpty() &&
         set.subtract(*this).isIntegerEmpty();
}
/// Return true if all the sets in the union are known to be integer empty,
/// false otherwise.
bool PresburgerSet::isIntegerEmpty() const {
  // The set is empty iff all of the disjuncts are empty.
  for (const FlatAffineConstraints &fac : flatAffineConstraints) {
    if (!fac.isIntegerEmpty())
      return false;
  }
  return true;
}
bool PresburgerSet::findIntegerSample(SmallVectorImpl<int64_t> &sample) {
  // A sample exists iff any of the disjuncts contains a sample.
  for (const FlatAffineConstraints &fac : flatAffineConstraints) {
    if (Optional<SmallVector<int64_t, 8>> opt = fac.findIntegerSample()) {
      sample = std::move(*opt);
      return true;
    }
  }
  return false;
}
void PresburgerSet::print(raw_ostream &os) const {
  os << getNumFACs() << " FlatAffineConstraints:\n";
  for (const FlatAffineConstraints &fac : flatAffineConstraints) {
    fac.print(os);
    os << '\n';
  }
}
void PresburgerSet::dump() const { print(llvm::errs()); }
 |