1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
|
//===- UnsignedWhenEquivalent.cpp - Pass to replace signed operations with
// unsigned
// ones when all their arguments and results are statically non-negative --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace arith {
#define GEN_PASS_DEF_ARITHUNSIGNEDWHENEQUIVALENT
#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
} // namespace arith
} // namespace mlir
using namespace mlir;
using namespace mlir::arith;
using namespace mlir::dataflow;
/// Succeeds when a value is statically non-negative in that it has a lower
/// bound on its value (if it is treated as signed) and that bound is
/// non-negative.
static LogicalResult staticallyNonNegative(DataFlowSolver &solver, Value v) {
auto *result = solver.lookupState<IntegerValueRangeLattice>(v);
if (!result || result->getValue().isUninitialized())
return failure();
const ConstantIntRanges &range = result->getValue().getValue();
return success(range.smin().isNonNegative());
}
/// Succeeds if an op can be converted to its unsigned equivalent without
/// changing its semantics. This is the case when none of its openands or
/// results can be below 0 when analyzed from a signed perspective.
static LogicalResult staticallyNonNegative(DataFlowSolver &solver,
Operation *op) {
auto nonNegativePred = [&solver](Value v) -> bool {
return succeeded(staticallyNonNegative(solver, v));
};
return success(llvm::all_of(op->getOperands(), nonNegativePred) &&
llvm::all_of(op->getResults(), nonNegativePred));
}
/// Succeeds when the comparison predicate is a signed operation and all the
/// operands are non-negative, indicating that the cmpi operation `op` can have
/// its predicate changed to an unsigned equivalent.
static LogicalResult isCmpIConvertable(DataFlowSolver &solver, CmpIOp op) {
CmpIPredicate pred = op.getPredicate();
switch (pred) {
case CmpIPredicate::sle:
case CmpIPredicate::slt:
case CmpIPredicate::sge:
case CmpIPredicate::sgt:
return success(llvm::all_of(op.getOperands(), [&solver](Value v) -> bool {
return succeeded(staticallyNonNegative(solver, v));
}));
default:
return failure();
}
}
/// Return the unsigned equivalent of a signed comparison predicate,
/// or the predicate itself if there is none.
static CmpIPredicate toUnsignedPred(CmpIPredicate pred) {
switch (pred) {
case CmpIPredicate::sle:
return CmpIPredicate::ule;
case CmpIPredicate::slt:
return CmpIPredicate::ult;
case CmpIPredicate::sge:
return CmpIPredicate::uge;
case CmpIPredicate::sgt:
return CmpIPredicate::ugt;
default:
return pred;
}
}
namespace {
template <typename Signed, typename Unsigned>
struct ConvertOpToUnsigned : OpConversionPattern<Signed> {
using OpConversionPattern<Signed>::OpConversionPattern;
LogicalResult matchAndRewrite(Signed op, typename Signed::Adaptor adaptor,
ConversionPatternRewriter &rw) const override {
rw.replaceOpWithNewOp<Unsigned>(op, op->getResultTypes(),
adaptor.getOperands(), op->getAttrs());
return success();
}
};
struct ConvertCmpIToUnsigned : OpConversionPattern<CmpIOp> {
using OpConversionPattern<CmpIOp>::OpConversionPattern;
LogicalResult matchAndRewrite(CmpIOp op, CmpIOpAdaptor adaptor,
ConversionPatternRewriter &rw) const override {
rw.replaceOpWithNewOp<CmpIOp>(op, toUnsignedPred(op.getPredicate()),
op.getLhs(), op.getRhs());
return success();
}
};
struct ArithUnsignedWhenEquivalentPass
: public arith::impl::ArithUnsignedWhenEquivalentBase<
ArithUnsignedWhenEquivalentPass> {
/// Implementation structure: first find all equivalent ops and collect them,
/// then perform all the rewrites in a second pass over the target op. This
/// ensures that analysis results are not invalidated during rewriting.
void runOnOperation() override {
Operation *op = getOperation();
MLIRContext *ctx = op->getContext();
DataFlowSolver solver;
solver.load<DeadCodeAnalysis>();
solver.load<IntegerRangeAnalysis>();
if (failed(solver.initializeAndRun(op)))
return signalPassFailure();
ConversionTarget target(*ctx);
target.addLegalDialect<ArithDialect>();
target
.addDynamicallyLegalOp<DivSIOp, CeilDivSIOp, CeilDivUIOp, FloorDivSIOp,
RemSIOp, MinSIOp, MaxSIOp, ExtSIOp>(
[&solver](Operation *op) -> std::optional<bool> {
return failed(staticallyNonNegative(solver, op));
});
target.addDynamicallyLegalOp<CmpIOp>(
[&solver](CmpIOp op) -> std::optional<bool> {
return failed(isCmpIConvertable(solver, op));
});
RewritePatternSet patterns(ctx);
patterns.add<ConvertOpToUnsigned<DivSIOp, DivUIOp>,
ConvertOpToUnsigned<CeilDivSIOp, CeilDivUIOp>,
ConvertOpToUnsigned<FloorDivSIOp, DivUIOp>,
ConvertOpToUnsigned<RemSIOp, RemUIOp>,
ConvertOpToUnsigned<MinSIOp, MinUIOp>,
ConvertOpToUnsigned<MaxSIOp, MaxUIOp>,
ConvertOpToUnsigned<ExtSIOp, ExtUIOp>, ConvertCmpIToUnsigned>(
ctx);
if (failed(applyPartialConversion(op, target, std::move(patterns)))) {
signalPassFailure();
}
}
};
} // end anonymous namespace
std::unique_ptr<Pass> mlir::arith::createArithUnsignedWhenEquivalentPass() {
return std::make_unique<ArithUnsignedWhenEquivalentPass>();
}
|