1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
|
//===- Bufferize.cpp - Bufferization for `tensor` dialect ops -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements bufferization of `tensor` dialect ops
//
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/Bufferize.h"
#include "PassDetail.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Passes.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
namespace {
class BufferizeCastOp : public OpConversionPattern<tensor::CastOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tensor::CastOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto resultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<memref::CastOp>(op, resultType, operands[0]);
return success();
}
};
} // namespace
namespace {
class BufferizeDimOp : public OpConversionPattern<tensor::DimOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tensor::DimOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
tensor::DimOp::Adaptor adaptor(operands);
rewriter.replaceOpWithNewOp<memref::DimOp>(op, adaptor.source(),
adaptor.index());
return success();
}
};
} // namespace
namespace {
class BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tensor::ExtractOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
tensor::ExtractOp::Adaptor adaptor(operands);
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, adaptor.tensor(),
adaptor.indices());
return success();
}
};
} // namespace
namespace {
class BufferizeFromElementsOp
: public OpConversionPattern<tensor::FromElementsOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tensor::FromElementsOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
int numberOfElements = op.elements().size();
auto resultType = MemRefType::get(
{numberOfElements}, op.getType().cast<TensorType>().getElementType());
Value result = rewriter.create<memref::AllocOp>(op.getLoc(), resultType);
for (auto element : llvm::enumerate(op.elements())) {
Value index =
rewriter.create<ConstantIndexOp>(op.getLoc(), element.index());
rewriter.create<memref::StoreOp>(op.getLoc(), element.value(), result,
index);
}
rewriter.replaceOp(op, {result});
return success();
}
};
} // namespace
namespace {
class BufferizeGenerateOp : public OpConversionPattern<tensor::GenerateOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tensor::GenerateOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// Allocate memory.
Location loc = op.getLoc();
tensor::GenerateOp::Adaptor transformed(operands);
RankedTensorType tensorType = op.getType().cast<RankedTensorType>();
MemRefType memrefType =
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
Value result = rewriter.create<memref::AllocOp>(
loc, memrefType, transformed.dynamicExtents());
// Collect loop bounds.
int64_t rank = tensorType.getRank();
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
SmallVector<Value, 4> lowerBounds(rank, zero);
SmallVector<Value, 4> steps(rank, one);
SmallVector<Value, 4> upperBounds;
int nextDynamicIndex = 0;
for (int i = 0; i < rank; i++) {
Value upperBound =
tensorType.isDynamicDim(i)
? transformed.dynamicExtents()[nextDynamicIndex++]
: rewriter.create<ConstantIndexOp>(loc, memrefType.getDimSize(i));
upperBounds.push_back(upperBound);
}
// Generate tensor elements with a parallel loop that stores into
// each element of the resulting memref.
//
// This is a bit tricky. We cannot simply clone the ops because when an op
// is cloned, it must be legalized. However, we want to allow arbitrary ops
// in the body that we don't necessarily have legalization patterns for as
// part of this dialect conversion invocation.
//
// To accomplish this, we use mergeBlockBefore to "move" this op's body
// into the scf.parallel's body.
auto parallel =
rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
Block *parallelBody = parallel.getBody();
rewriter.mergeBlockBefore(op.getBody(), parallelBody->getTerminator(),
parallelBody->getArguments());
// Replace the inlined yield op with a store op. The scf.parallel's builder
// already populated an scf.yield at the end, so we don't need to worry
// about creating that.
Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
rewriter.setInsertionPointAfter(elementYield);
rewriter.replaceOpWithNewOp<memref::StoreOp>(
elementYield, elementYield->getOperands()[0], result,
parallelBody->getArguments());
rewriter.replaceOp(op, {result});
return success();
}
};
} // namespace
void mlir::populateTensorBufferizePatterns(
BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
patterns.add<BufferizeCastOp, BufferizeDimOp, BufferizeExtractOp,
BufferizeFromElementsOp, BufferizeGenerateOp>(
typeConverter, patterns.getContext());
}
namespace {
struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
void runOnFunction() override {
auto *context = &getContext();
BufferizeTypeConverter typeConverter;
RewritePatternSet patterns(context);
ConversionTarget target(*context);
populateBufferizeMaterializationLegality(target);
populateTensorBufferizePatterns(typeConverter, patterns);
target.addIllegalOp<tensor::CastOp, tensor::ExtractOp,
tensor::FromElementsOp, tensor::GenerateOp>();
target.addLegalDialect<memref::MemRefDialect>();
target.addDynamicallyLegalDialect<StandardOpsDialect>(
[&](Operation *op) { return typeConverter.isLegal(op); });
target.addLegalDialect<scf::SCFDialect>();
if (failed(
applyPartialConversion(getFunction(), target, std::move(patterns))))
signalPassFailure();
}
};
} // namespace
std::unique_ptr<Pass> mlir::createTensorBufferizePass() {
return std::make_unique<TensorBufferizePass>();
}
|