1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
|
//===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
using namespace mlir;
namespace mlir {
namespace scf {
namespace {
struct ForOpInterface
: public ValueBoundsOpInterface::ExternalModel<ForOpInterface, ForOp> {
/// Populate bounds of values/dimensions for iter_args/OpResults. If the
/// value/dimension size does not change in an iteration, we can deduce that
/// it the same as the initial value/dimension.
///
/// Example 1:
/// %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> {
/// ...
/// %1 = tensor.insert %f into %arg0[...] : tensor<?xf32>
/// scf.yield %1 : tensor<?xf32>
/// }
/// --> bound(%0)[0] == bound(%t)[0]
/// --> bound(%arg0)[0] == bound(%t)[0]
///
/// Example 2:
/// %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> {
/// %sz = tensor.dim %arg0 : tensor<?xf32>
/// %incr = arith.addi %sz, %c1 : index
/// %1 = tensor.empty(%incr) : tensor<?xf32>
/// scf.yield %1 : tensor<?xf32>
/// }
/// --> The yielded tensor dimension size changes with each iteration. Such
/// loops are not supported and no constraints are added.
static void populateIterArgBounds(scf::ForOp forOp, Value value,
std::optional<int64_t> dim,
ValueBoundsConstraintSet &cstr) {
// `value` is an iter_arg or an OpResult.
int64_t iterArgIdx;
if (auto iterArg = llvm::dyn_cast<BlockArgument>(value)) {
iterArgIdx = iterArg.getArgNumber() - forOp.getNumInductionVars();
} else {
iterArgIdx = llvm::cast<OpResult>(value).getResultNumber();
}
Value yieldedValue = cast<scf::YieldOp>(forOp.getBody()->getTerminator())
.getOperand(iterArgIdx);
Value iterArg = forOp.getRegionIterArg(iterArgIdx);
Value initArg = forOp.getInitArgs()[iterArgIdx];
// An EQ constraint can be added if the yielded value (dimension size)
// equals the corresponding block argument (dimension size).
if (cstr.populateAndCompare(
/*lhs=*/{yieldedValue, dim},
ValueBoundsConstraintSet::ComparisonOperator::EQ,
/*rhs=*/{iterArg, dim})) {
if (dim.has_value()) {
cstr.bound(value)[*dim] == cstr.getExpr(initArg, dim);
} else {
cstr.bound(value) == cstr.getExpr(initArg);
}
}
}
void populateBoundsForIndexValue(Operation *op, Value value,
ValueBoundsConstraintSet &cstr) const {
auto forOp = cast<ForOp>(op);
if (value == forOp.getInductionVar()) {
// TODO: Take into account step size.
cstr.bound(value) >= forOp.getLowerBound();
cstr.bound(value) < forOp.getUpperBound();
return;
}
// Handle iter_args and OpResults.
populateIterArgBounds(forOp, value, std::nullopt, cstr);
}
void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
ValueBoundsConstraintSet &cstr) const {
auto forOp = cast<ForOp>(op);
// Handle iter_args and OpResults.
populateIterArgBounds(forOp, value, dim, cstr);
}
};
struct IfOpInterface
: public ValueBoundsOpInterface::ExternalModel<IfOpInterface, IfOp> {
static void populateBounds(scf::IfOp ifOp, Value value,
std::optional<int64_t> dim,
ValueBoundsConstraintSet &cstr) {
unsigned int resultNum = cast<OpResult>(value).getResultNumber();
Value thenValue = ifOp.thenYield().getResults()[resultNum];
Value elseValue = ifOp.elseYield().getResults()[resultNum];
auto boundsBuilder = cstr.bound(value);
if (dim)
boundsBuilder[*dim];
// Compare yielded values.
// If thenValue <= elseValue:
// * result <= elseValue
// * result >= thenValue
if (cstr.populateAndCompare(
/*lhs=*/{thenValue, dim},
ValueBoundsConstraintSet::ComparisonOperator::LE,
/*rhs=*/{elseValue, dim})) {
if (dim) {
cstr.bound(value)[*dim] >= cstr.getExpr(thenValue, dim);
cstr.bound(value)[*dim] <= cstr.getExpr(elseValue, dim);
} else {
cstr.bound(value) >= thenValue;
cstr.bound(value) <= elseValue;
}
}
// If elseValue <= thenValue:
// * result <= thenValue
// * result >= elseValue
if (cstr.populateAndCompare(
/*lhs=*/{elseValue, dim},
ValueBoundsConstraintSet::ComparisonOperator::LE,
/*rhs=*/{thenValue, dim})) {
if (dim) {
cstr.bound(value)[*dim] >= cstr.getExpr(elseValue, dim);
cstr.bound(value)[*dim] <= cstr.getExpr(thenValue, dim);
} else {
cstr.bound(value) >= elseValue;
cstr.bound(value) <= thenValue;
}
}
}
void populateBoundsForIndexValue(Operation *op, Value value,
ValueBoundsConstraintSet &cstr) const {
populateBounds(cast<IfOp>(op), value, /*dim=*/std::nullopt, cstr);
}
void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
ValueBoundsConstraintSet &cstr) const {
populateBounds(cast<IfOp>(op), value, dim, cstr);
}
};
} // namespace
} // namespace scf
} // namespace mlir
void mlir::scf::registerValueBoundsOpInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
scf::ForOp::attachInterface<scf::ForOpInterface>(*ctx);
scf::IfOp::attachInterface<scf::IfOpInterface>(*ctx);
});
}
|