1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
|
//===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
using namespace mlir;
namespace mlir {
namespace arith {
namespace {
struct AddIOpInterface
: public ValueBoundsOpInterface::ExternalModel<AddIOpInterface, AddIOp> {
void populateBoundsForIndexValue(Operation *op, Value value,
ValueBoundsConstraintSet &cstr) const {
auto addIOp = cast<AddIOp>(op);
assert(value == addIOp.getResult() && "invalid value");
// Note: `getExpr` has a side effect: it may add a new column to the
// constraint system. The evaluation order of addition operands is
// unspecified in C++. To make sure that all compilers produce the exact
// same results (that can be FileCheck'd), it is important that `getExpr`
// is called first and assigned to temporary variables, and the addition
// is performed afterwards.
AffineExpr lhs = cstr.getExpr(addIOp.getLhs());
AffineExpr rhs = cstr.getExpr(addIOp.getRhs());
cstr.bound(value) == lhs + rhs;
}
};
struct ConstantOpInterface
: public ValueBoundsOpInterface::ExternalModel<ConstantOpInterface,
ConstantOp> {
void populateBoundsForIndexValue(Operation *op, Value value,
ValueBoundsConstraintSet &cstr) const {
auto constantOp = cast<ConstantOp>(op);
assert(value == constantOp.getResult() && "invalid value");
if (auto attr = llvm::dyn_cast<IntegerAttr>(constantOp.getValue()))
cstr.bound(value) == attr.getInt();
}
};
struct SubIOpInterface
: public ValueBoundsOpInterface::ExternalModel<SubIOpInterface, SubIOp> {
void populateBoundsForIndexValue(Operation *op, Value value,
ValueBoundsConstraintSet &cstr) const {
auto subIOp = cast<SubIOp>(op);
assert(value == subIOp.getResult() && "invalid value");
AffineExpr lhs = cstr.getExpr(subIOp.getLhs());
AffineExpr rhs = cstr.getExpr(subIOp.getRhs());
cstr.bound(value) == lhs - rhs;
}
};
struct MulIOpInterface
: public ValueBoundsOpInterface::ExternalModel<MulIOpInterface, MulIOp> {
void populateBoundsForIndexValue(Operation *op, Value value,
ValueBoundsConstraintSet &cstr) const {
auto mulIOp = cast<MulIOp>(op);
assert(value == mulIOp.getResult() && "invalid value");
AffineExpr lhs = cstr.getExpr(mulIOp.getLhs());
AffineExpr rhs = cstr.getExpr(mulIOp.getRhs());
cstr.bound(value) == lhs *rhs;
}
};
struct SelectOpInterface
: public ValueBoundsOpInterface::ExternalModel<SelectOpInterface,
SelectOp> {
static void populateBounds(SelectOp selectOp, std::optional<int64_t> dim,
ValueBoundsConstraintSet &cstr) {
Value value = selectOp.getResult();
Value condition = selectOp.getCondition();
Value trueValue = selectOp.getTrueValue();
Value falseValue = selectOp.getFalseValue();
if (isa<ShapedType>(condition.getType())) {
// If the condition is a shaped type, the condition is applied
// element-wise. All three operands must have the same shape.
cstr.bound(value)[*dim] == cstr.getExpr(trueValue, dim);
cstr.bound(value)[*dim] == cstr.getExpr(falseValue, dim);
cstr.bound(value)[*dim] == cstr.getExpr(condition, dim);
return;
}
// Populate constraints for the true/false values (and all values on the
// backward slice, as long as the current stop condition is not satisfied).
cstr.populateConstraints(trueValue, dim);
cstr.populateConstraints(falseValue, dim);
auto boundsBuilder = cstr.bound(value);
if (dim)
boundsBuilder[*dim];
// Compare yielded values.
// If trueValue <= falseValue:
// * result <= falseValue
// * result >= trueValue
if (cstr.compare(/*lhs=*/{trueValue, dim},
ValueBoundsConstraintSet::ComparisonOperator::LE,
/*rhs=*/{falseValue, dim})) {
if (dim) {
cstr.bound(value)[*dim] >= cstr.getExpr(trueValue, dim);
cstr.bound(value)[*dim] <= cstr.getExpr(falseValue, dim);
} else {
cstr.bound(value) >= trueValue;
cstr.bound(value) <= falseValue;
}
}
// If falseValue <= trueValue:
// * result <= trueValue
// * result >= falseValue
if (cstr.compare(/*lhs=*/{falseValue, dim},
ValueBoundsConstraintSet::ComparisonOperator::LE,
/*rhs=*/{trueValue, dim})) {
if (dim) {
cstr.bound(value)[*dim] >= cstr.getExpr(falseValue, dim);
cstr.bound(value)[*dim] <= cstr.getExpr(trueValue, dim);
} else {
cstr.bound(value) >= falseValue;
cstr.bound(value) <= trueValue;
}
}
}
void populateBoundsForIndexValue(Operation *op, Value value,
ValueBoundsConstraintSet &cstr) const {
populateBounds(cast<SelectOp>(op), /*dim=*/std::nullopt, cstr);
}
void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
ValueBoundsConstraintSet &cstr) const {
populateBounds(cast<SelectOp>(op), dim, cstr);
}
};
} // namespace
} // namespace arith
} // namespace mlir
void mlir::arith::registerValueBoundsOpInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) {
arith::AddIOp::attachInterface<arith::AddIOpInterface>(*ctx);
arith::ConstantOp::attachInterface<arith::ConstantOpInterface>(*ctx);
arith::SubIOp::attachInterface<arith::SubIOpInterface>(*ctx);
arith::MulIOp::attachInterface<arith::MulIOpInterface>(*ctx);
arith::SelectOp::attachInterface<arith::SelectOpInterface>(*ctx);
});
}
|