File: TestIntRangeInference.cpp

package info (click to toggle)
swiftlang 6.0.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,519,992 kB
  • sloc: cpp: 9,107,863; ansic: 2,040,022; asm: 1,135,751; python: 296,500; objc: 82,456; f90: 60,502; lisp: 34,951; pascal: 19,946; sh: 18,133; perl: 7,482; ml: 4,937; javascript: 4,117; makefile: 3,840; awk: 3,535; xml: 914; fortran: 619; cs: 573; ruby: 573
file content (124 lines) | stat: -rw-r--r-- 4,432 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
//===- TestIntRangeInference.cpp - Create consts from range inference ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// TODO: This pass is needed to test integer range inference until that
// functionality has been integrated into SCCP.
//===----------------------------------------------------------------------===//

#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Support/TypeID.h"
#include "mlir/Transforms/FoldUtils.h"
#include <optional>

using namespace mlir;
using namespace mlir::dataflow;

/// Patterned after SCCP
static LogicalResult replaceWithConstant(DataFlowSolver &solver, OpBuilder &b,
                                         OperationFolder &folder, Value value) {
  auto *maybeInferredRange =
      solver.lookupState<IntegerValueRangeLattice>(value);
  if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
    return failure();
  const ConstantIntRanges &inferredRange =
      maybeInferredRange->getValue().getValue();
  std::optional<APInt> maybeConstValue = inferredRange.getConstantValue();
  if (!maybeConstValue.has_value())
    return failure();

  Operation *maybeDefiningOp = value.getDefiningOp();
  Dialect *valueDialect =
      maybeDefiningOp ? maybeDefiningOp->getDialect()
                      : value.getParentRegion()->getParentOp()->getDialect();
  Attribute constAttr = b.getIntegerAttr(value.getType(), *maybeConstValue);
  Value constant =
      folder.getOrCreateConstant(b.getInsertionBlock(), valueDialect, constAttr,
                                 value.getType(), value.getLoc());
  if (!constant)
    return failure();

  value.replaceAllUsesWith(constant);
  return success();
}

static void rewrite(DataFlowSolver &solver, MLIRContext *context,
                    MutableArrayRef<Region> initialRegions) {
  SmallVector<Block *> worklist;
  auto addToWorklist = [&](MutableArrayRef<Region> regions) {
    for (Region &region : regions)
      for (Block &block : llvm::reverse(region))
        worklist.push_back(&block);
  };

  OpBuilder builder(context);
  OperationFolder folder(context);

  addToWorklist(initialRegions);
  while (!worklist.empty()) {
    Block *block = worklist.pop_back_val();

    for (Operation &op : llvm::make_early_inc_range(*block)) {
      builder.setInsertionPoint(&op);

      // Replace any result with constants.
      bool replacedAll = op.getNumResults() != 0;
      for (Value res : op.getResults())
        replacedAll &=
            succeeded(replaceWithConstant(solver, builder, folder, res));

      // If all of the results of the operation were replaced, try to erase
      // the operation completely.
      if (replacedAll && wouldOpBeTriviallyDead(&op)) {
        assert(op.use_empty() && "expected all uses to be replaced");
        op.erase();
        continue;
      }

      // Add any the regions of this operation to the worklist.
      addToWorklist(op.getRegions());
    }

    // Replace any block arguments with constants.
    builder.setInsertionPointToStart(block);
    for (BlockArgument arg : block->getArguments())
      (void)replaceWithConstant(solver, builder, folder, arg);
  }
}

namespace {
struct TestIntRangeInference
    : PassWrapper<TestIntRangeInference, OperationPass<>> {
  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestIntRangeInference)

  StringRef getArgument() const final { return "test-int-range-inference"; }
  StringRef getDescription() const final {
    return "Test integer range inference analysis";
  }

  void runOnOperation() override {
    Operation *op = getOperation();
    DataFlowSolver solver;
    solver.load<DeadCodeAnalysis>();
    solver.load<IntegerRangeAnalysis>();
    if (failed(solver.initializeAndRun(op)))
      return signalPassFailure();
    rewrite(solver, op->getContext(), op->getRegions());
  }
};
} // end anonymous namespace

namespace mlir {
namespace test {
void registerTestIntRangeInference() {
  PassRegistration<TestIntRangeInference>();
}
} // end namespace test
} // end namespace mlir