File: SCCP.cpp

package info (click to toggle)
llvm-toolchain-13 1%3A13.0.1-6~deb11u1
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 1,418,812 kB
  • sloc: cpp: 5,290,827; ansic: 996,570; asm: 544,593; python: 188,212; objc: 72,027; lisp: 30,291; f90: 25,395; sh: 24,900; javascript: 9,780; pascal: 9,398; perl: 7,484; ml: 5,432; awk: 3,523; makefile: 2,892; xml: 953; cs: 573; fortran: 539
file content (254 lines) | stat: -rw-r--r-- 9,831 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
//===- SCCP.cpp - Sparse Conditional Constant Propagation -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This transformation pass performs a sparse conditional constant propagation
// in MLIR. It identifies values known to be constant, propagates that
// information throughout the IR, and replaces them. This is done with an
// optimistic dataflow analysis that assumes that all values are constant until
// proven otherwise.
//
//===----------------------------------------------------------------------===//

#include "PassDetail.h"
#include "mlir/Analysis/DataFlowAnalysis.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/Passes.h"

using namespace mlir;

//===----------------------------------------------------------------------===//
// SCCP Analysis
//===----------------------------------------------------------------------===//

namespace {
struct SCCPLatticeValue {
  SCCPLatticeValue(Attribute constant = {}, Dialect *dialect = nullptr)
      : constant(constant), constantDialect(dialect) {}

  /// The pessimistic state of SCCP is non-constant.
  static SCCPLatticeValue getPessimisticValueState(MLIRContext *context) {
    return SCCPLatticeValue();
  }
  static SCCPLatticeValue getPessimisticValueState(Value value) {
    return SCCPLatticeValue();
  }

  /// Equivalence for SCCP only accounts for the constant, not the originating
  /// dialect.
  bool operator==(const SCCPLatticeValue &rhs) const {
    return constant == rhs.constant;
  }

  /// To join the state of two values, we simply check for equivalence.
  static SCCPLatticeValue join(const SCCPLatticeValue &lhs,
                               const SCCPLatticeValue &rhs) {
    return lhs == rhs ? lhs : SCCPLatticeValue();
  }

  /// The constant attribute value.
  Attribute constant;

  /// The dialect the constant originated from. This is not used as part of the
  /// key, and is only needed to materialize the held constant if necessary.
  Dialect *constantDialect;
};

struct SCCPAnalysis : public ForwardDataFlowAnalysis<SCCPLatticeValue> {
  using ForwardDataFlowAnalysis<SCCPLatticeValue>::ForwardDataFlowAnalysis;
  ~SCCPAnalysis() override = default;

  ChangeResult
  visitOperation(Operation *op,
                 ArrayRef<LatticeElement<SCCPLatticeValue> *> operands) final {
    // Don't try to simulate the results of a region operation as we can't
    // guarantee that folding will be out-of-place. We don't allow in-place
    // folds as the desire here is for simulated execution, and not general
    // folding.
    if (op->getNumRegions())
      return markAllPessimisticFixpoint(op->getResults());

    SmallVector<Attribute> constantOperands(
        llvm::map_range(operands, [](LatticeElement<SCCPLatticeValue> *value) {
          return value->getValue().constant;
        }));

    // Save the original operands and attributes just in case the operation
    // folds in-place. The constant passed in may not correspond to the real
    // runtime value, so in-place updates are not allowed.
    SmallVector<Value, 8> originalOperands(op->getOperands());
    DictionaryAttr originalAttrs = op->getAttrDictionary();

    // Simulate the result of folding this operation to a constant. If folding
    // fails or was an in-place fold, mark the results as overdefined.
    SmallVector<OpFoldResult, 8> foldResults;
    foldResults.reserve(op->getNumResults());
    if (failed(op->fold(constantOperands, foldResults)))
      return markAllPessimisticFixpoint(op->getResults());

    // If the folding was in-place, mark the results as overdefined and reset
    // the operation. We don't allow in-place folds as the desire here is for
    // simulated execution, and not general folding.
    if (foldResults.empty()) {
      op->setOperands(originalOperands);
      op->setAttrs(originalAttrs);
      return markAllPessimisticFixpoint(op->getResults());
    }

    // Merge the fold results into the lattice for this operation.
    assert(foldResults.size() == op->getNumResults() && "invalid result size");
    Dialect *dialect = op->getDialect();
    ChangeResult result = ChangeResult::NoChange;
    for (unsigned i = 0, e = foldResults.size(); i != e; ++i) {
      LatticeElement<SCCPLatticeValue> &lattice =
          getLatticeElement(op->getResult(i));

      // Merge in the result of the fold, either a constant or a value.
      OpFoldResult foldResult = foldResults[i];
      if (Attribute attr = foldResult.dyn_cast<Attribute>())
        result |= lattice.join(SCCPLatticeValue(attr, dialect));
      else
        result |= lattice.join(getLatticeElement(foldResult.get<Value>()));
    }
    return result;
  }

  /// Implementation of `getSuccessorsForOperands` that uses constant operands
  /// to potentially remove dead successors.
  LogicalResult getSuccessorsForOperands(
      BranchOpInterface branch,
      ArrayRef<LatticeElement<SCCPLatticeValue> *> operands,
      SmallVectorImpl<Block *> &successors) final {
    SmallVector<Attribute> constantOperands(
        llvm::map_range(operands, [](LatticeElement<SCCPLatticeValue> *value) {
          return value->getValue().constant;
        }));
    if (Block *singleSucc = branch.getSuccessorForOperands(constantOperands)) {
      successors.push_back(singleSucc);
      return success();
    }
    return failure();
  }

  /// Implementation of `getSuccessorsForOperands` that uses constant operands
  /// to potentially remove dead region successors.
  void getSuccessorsForOperands(
      RegionBranchOpInterface branch, Optional<unsigned> sourceIndex,
      ArrayRef<LatticeElement<SCCPLatticeValue> *> operands,
      SmallVectorImpl<RegionSuccessor> &successors) final {
    SmallVector<Attribute> constantOperands(
        llvm::map_range(operands, [](LatticeElement<SCCPLatticeValue> *value) {
          return value->getValue().constant;
        }));
    branch.getSuccessorRegions(sourceIndex, constantOperands, successors);
  }
};
} // namespace

//===----------------------------------------------------------------------===//
// SCCP Rewrites
//===----------------------------------------------------------------------===//

/// Replace the given value with a constant if the corresponding lattice
/// represents a constant. Returns success if the value was replaced, failure
/// otherwise.
static LogicalResult replaceWithConstant(SCCPAnalysis &analysis,
                                         OpBuilder &builder,
                                         OperationFolder &folder, Value value) {
  LatticeElement<SCCPLatticeValue> *lattice =
      analysis.lookupLatticeElement(value);
  if (!lattice)
    return failure();
  SCCPLatticeValue &latticeValue = lattice->getValue();
  if (!latticeValue.constant)
    return failure();

  // Attempt to materialize a constant for the given value.
  Dialect *dialect = latticeValue.constantDialect;
  Value constant = folder.getOrCreateConstant(
      builder, dialect, latticeValue.constant, value.getType(), value.getLoc());
  if (!constant)
    return failure();

  value.replaceAllUsesWith(constant);
  return success();
}

/// Rewrite the given regions using the computing analysis. This replaces the
/// uses of all values that have been computed to be constant, and erases as
/// many newly dead operations.
static void rewrite(SCCPAnalysis &analysis, MLIRContext *context,
                    MutableArrayRef<Region> initialRegions) {
  SmallVector<Block *> worklist;
  auto addToWorklist = [&](MutableArrayRef<Region> regions) {
    for (Region &region : regions)
      for (Block &block : llvm::reverse(region))
        worklist.push_back(&block);
  };

  // An operation folder used to create and unique constants.
  OperationFolder folder(context);
  OpBuilder builder(context);

  addToWorklist(initialRegions);
  while (!worklist.empty()) {
    Block *block = worklist.pop_back_val();

    for (Operation &op : llvm::make_early_inc_range(*block)) {
      builder.setInsertionPoint(&op);

      // Replace any result with constants.
      bool replacedAll = op.getNumResults() != 0;
      for (Value res : op.getResults())
        replacedAll &=
            succeeded(replaceWithConstant(analysis, builder, folder, res));

      // If all of the results of the operation were replaced, try to erase
      // the operation completely.
      if (replacedAll && wouldOpBeTriviallyDead(&op)) {
        assert(op.use_empty() && "expected all uses to be replaced");
        op.erase();
        continue;
      }

      // Add any the regions of this operation to the worklist.
      addToWorklist(op.getRegions());
    }

    // Replace any block arguments with constants.
    builder.setInsertionPointToStart(block);
    for (BlockArgument arg : block->getArguments())
      (void)replaceWithConstant(analysis, builder, folder, arg);
  }
}

//===----------------------------------------------------------------------===//
// SCCP Pass
//===----------------------------------------------------------------------===//

namespace {
struct SCCP : public SCCPBase<SCCP> {
  void runOnOperation() override;
};
} // end anonymous namespace

void SCCP::runOnOperation() {
  Operation *op = getOperation();

  SCCPAnalysis analysis(op->getContext());
  analysis.run(op);
  rewrite(analysis, op->getContext(), op->getRegions());
}

std::unique_ptr<Pass> mlir::createSCCPPass() {
  return std::make_unique<SCCP>();
}