1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
|
//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
using namespace mlir;
using namespace linalg;
using namespace mlir::bufferization;
namespace {
/// Generic conversion for any DestinationStyleOpInterface on tensors.
static LogicalResult
bufferizeDestinationStyleOpInterface(RewriterBase &rewriter,
DestinationStyleOpInterface op,
const BufferizationOptions &options) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(op);
// Nothing to do. This op is already bufferized.
if (op.hasPureBufferSemantics())
return success();
// Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need
// basis.
if (!op.hasPureTensorSemantics())
return op->emitError() << "op does not have pure tensor semantics";
// New input operands for the cloned op.
SmallVector<Value> newInputBuffers;
newInputBuffers.reserve(op.getNumDpsInputs());
for (OpOperand *opOperand : op.getDpsInputOperands()) {
if (op.isScalar(opOperand)) {
newInputBuffers.push_back(opOperand->get());
continue;
}
FailureOr<Value> buffer = getBuffer(rewriter, opOperand->get(), options);
if (failed(buffer))
return failure();
newInputBuffers.push_back(*buffer);
}
// New output operands for the cloned op.
SmallVector<Value> newOutputBuffers;
for (OpResult opResult : op->getOpResults()) {
OpOperand *opOperand = op.getDpsInitOperand(opResult.getResultNumber());
FailureOr<Value> resultBuffer =
getBuffer(rewriter, opOperand->get(), options);
if (failed(resultBuffer))
return failure();
newOutputBuffers.push_back(*resultBuffer);
}
// Merge input/output operands.
SmallVector<Value> newOperands = newInputBuffers;
newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end());
// Set insertion point now that potential alloc/dealloc are introduced.
rewriter.setInsertionPoint(op);
// Clone the op, but use the new operands. Move the existing block into the
// new op. Since the new op does not have any tensor results, it does not
// return anything.
assert(op->getNumRegions() == 1 && "expected that op has 1 region");
OperationState state(op->getLoc(), op->getName(), newOperands, TypeRange{},
op->getAttrs());
state.addRegion();
Operation *newOp = Operation::create(state);
newOp->getRegion(0).getBlocks().splice(newOp->getRegion(0).begin(),
op->getRegion(0).getBlocks());
// We don't want the rewriter tracks an incomplete operation, so insert new
// operation after op was fully constructed.
rewriter.insert(newOp);
// Replace the results of the old op with the new output buffers.
replaceOpWithBufferizedValues(rewriter, op, newOutputBuffers);
return success();
}
/// Bufferization of linalg.generic. Replace with a new linalg.generic that
/// operates entirely on memrefs.
template <typename OpTy>
struct LinalgOpInterface
: public DstBufferizableOpInterfaceExternalModel<LinalgOpInterface<OpTy>,
OpTy> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
// Operand is read if it is used in the computation.
auto linalgOp = cast<linalg::LinalgOp>(op);
return linalgOp.payloadUsesValueFromOperand(&opOperand);
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
// Operand is written to if it is not an input/init.
auto dpsOp = cast<DestinationStyleOpInterface>(op);
return dpsOp.isDpsInit(&opOperand);
}
bool bufferizesToElementwiseAccess(Operation *op, const AnalysisState &state,
ArrayRef<OpOperand *> opOperands) const {
auto linalgOp = cast<linalg::LinalgOp>(op);
// Accesses into sparse data structures are not necessarily elementwise.
if (sparse_tensor::hasAnySparseOperand(linalgOp))
return false;
// All loops must be parallel.
if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
return false;
// All index maps of tensors must be identity maps.
SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
assert(linalgOp->getNumOperands() == indexingMaps.size() &&
"unexpected number of indexing maps");
for (auto [operand, map] :
llvm::zip(linalgOp->getOpOperands(), indexingMaps)) {
// Non-tensors do not participate in bufferization, so they can be
// ignored.
if (!isa<RankedTensorType, MemRefType>(operand.get().getType()))
continue;
// Only consider operands in `opOperands`.
if (!llvm::is_contained(opOperands, &operand))
continue;
// TODO: This could be generalized to other indexing maps. (All indexing
// must be the same.)
if (!map.isIdentity())
return false;
}
return true;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
return bufferizeDestinationStyleOpInterface(
rewriter, cast<DestinationStyleOpInterface>(op), options);
}
};
/// Helper structure that iterates over all LinalgOps in `OpTys` and registers
/// the `BufferizableOpInterface` with each of them.
template <typename... Ops>
struct LinalgOpInterfaceHelper {
static void registerOpInterface(MLIRContext *ctx) {
(Ops::template attachInterface<LinalgOpInterface<Ops>>(*ctx), ...);
}
};
struct SoftmaxOpInterface
: public DstBufferizableOpInterfaceExternalModel<SoftmaxOpInterface,
linalg::SoftmaxOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
// Output operand is not read.
auto softmaxOp = cast<linalg::SoftmaxOp>(op);
return &opOperand == &softmaxOp.getInputMutable();
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto softmaxOp = cast<linalg::SoftmaxOp>(op);
FailureOr<Value> inputBuffer =
getBuffer(rewriter, softmaxOp.getInput(), options);
if (failed(inputBuffer))
return failure();
FailureOr<Value> outputBuffer =
getBuffer(rewriter, softmaxOp.getOutput(), options);
if (failed(outputBuffer))
return failure();
rewriter.create<linalg::SoftmaxOp>(softmaxOp.getLoc(),
/*result=*/TypeRange(), *inputBuffer,
*outputBuffer, softmaxOp.getDimension());
replaceOpWithBufferizedValues(rewriter, op, *outputBuffer);
return success();
}
};
} // namespace
void mlir::linalg::registerBufferizableOpInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
// Register all Linalg structured ops. `LinalgOp` is an interface and it is
// not possible to attach an external interface to an existing interface.
// Therefore, attach the `BufferizableOpInterface` to all ops one-by-one.
LinalgOpInterfaceHelper<
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
>::registerOpInterface(ctx);
SoftmaxOp::attachInterface<SoftmaxOpInterface>(*ctx);
});
}
|