1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
|
//===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
using namespace mlir;
namespace mlir {
namespace tensor {
namespace {
struct CastOpInterface
: public ValueBoundsOpInterface::ExternalModel<CastOpInterface, CastOp> {
void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
ValueBoundsConstraintSet &cstr) const {
auto castOp = cast<CastOp>(op);
assert(value == castOp.getResult() && "invalid value");
if (llvm::isa<RankedTensorType>(castOp.getResult().getType()) &&
llvm::isa<RankedTensorType>(castOp.getSource().getType())) {
cstr.bound(value)[dim] == cstr.getExpr(castOp.getSource(), dim);
}
}
};
struct DimOpInterface
: public ValueBoundsOpInterface::ExternalModel<DimOpInterface, DimOp> {
void populateBoundsForIndexValue(Operation *op, Value value,
ValueBoundsConstraintSet &cstr) const {
auto dimOp = cast<DimOp>(op);
assert(value == dimOp.getResult() && "invalid value");
auto constIndex = dimOp.getConstantIndex();
if (!constIndex.has_value())
return;
cstr.bound(value) == cstr.getExpr(dimOp.getSource(), *constIndex);
}
};
struct EmptyOpInterface
: public ValueBoundsOpInterface::ExternalModel<EmptyOpInterface, EmptyOp> {
void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
ValueBoundsConstraintSet &cstr) const {
auto emptyOp = cast<EmptyOp>(op);
assert(value == emptyOp.getResult() && "invalid value");
cstr.bound(value)[dim] == emptyOp.getMixedSizes()[dim];
}
};
struct ExtractSliceOpInterface
: public ValueBoundsOpInterface::ExternalModel<ExtractSliceOpInterface,
ExtractSliceOp> {
void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
ValueBoundsConstraintSet &cstr) const {
auto extractSliceOp = cast<ExtractSliceOp>(op);
assert(value == extractSliceOp.getResult() && "invalid value");
llvm::SmallBitVector dropped = extractSliceOp.getDroppedDims();
int64_t ctr = -1;
for (int64_t i = 0, e = extractSliceOp.getMixedSizes().size(); i < e; ++i) {
// Skip over rank-reduced dimensions.
if (!dropped.test(i))
++ctr;
if (ctr == dim) {
cstr.bound(value)[dim] == extractSliceOp.getMixedSizes()[i];
return;
}
}
llvm_unreachable("could not find non-rank-reduced dim");
}
};
struct PadOpInterface
: public ValueBoundsOpInterface::ExternalModel<PadOpInterface, PadOp> {
void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
ValueBoundsConstraintSet &cstr) const {
auto padOp = cast<PadOp>(op);
assert(value == padOp.getResult() && "invalid value");
AffineExpr srcSize = cstr.getExpr(padOp.getSource(), dim);
AffineExpr lowPad = cstr.getExpr(padOp.getMixedLowPad()[dim]);
AffineExpr highPad = cstr.getExpr(padOp.getMixedHighPad()[dim]);
cstr.bound(value)[dim] == srcSize + lowPad + highPad;
}
};
struct RankOpInterface
: public ValueBoundsOpInterface::ExternalModel<RankOpInterface, RankOp> {
void populateBoundsForIndexValue(Operation *op, Value value,
ValueBoundsConstraintSet &cstr) const {
auto rankOp = cast<RankOp>(op);
assert(value == rankOp.getResult() && "invalid value");
auto tensorType =
llvm::dyn_cast<RankedTensorType>(rankOp.getTensor().getType());
if (!tensorType)
return;
cstr.bound(value) == tensorType.getRank();
}
};
} // namespace
} // namespace tensor
} // namespace mlir
void mlir::tensor::registerValueBoundsOpInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
tensor::CastOp::attachInterface<tensor::CastOpInterface>(*ctx);
tensor::DimOp::attachInterface<tensor::DimOpInterface>(*ctx);
tensor::EmptyOp::attachInterface<tensor::EmptyOpInterface>(*ctx);
tensor::ExtractSliceOp::attachInterface<tensor::ExtractSliceOpInterface>(
*ctx);
tensor::PadOp::attachInterface<tensor::PadOpInterface>(*ctx);
tensor::RankOp::attachInterface<tensor::RankOpInterface>(*ctx);
// Note: ValueBoundsOpInterface implementation is not required for ops that
// implement `DestinationStyleOpInterface` (for querying shaped OpResults).
});
}
|