File: Bufferize.cpp

package info (click to toggle)
llvm-toolchain-13 1%3A13.0.1-6~deb10u4
  • links: PTS, VCS
  • area: main
  • in suites: buster
  • size: 1,418,792 kB
  • sloc: cpp: 5,290,827; ansic: 996,570; asm: 544,593; python: 188,212; objc: 72,027; lisp: 30,291; f90: 25,395; sh: 24,900; javascript: 9,780; pascal: 9,398; perl: 7,484; ml: 5,432; awk: 3,523; makefile: 2,892; xml: 953; cs: 573; fortran: 539
file content (189 lines) | stat: -rw-r--r-- 7,334 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
//===- Bufferize.cpp - Bufferization for `tensor` dialect ops -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements bufferization of `tensor` dialect ops
//
//===----------------------------------------------------------------------===//

#include "mlir/Transforms/Bufferize.h"
#include "PassDetail.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Passes.h"
#include "mlir/Transforms/DialectConversion.h"

using namespace mlir;

namespace {
class BufferizeCastOp : public OpConversionPattern<tensor::CastOp> {
public:
  using OpConversionPattern::OpConversionPattern;
  LogicalResult
  matchAndRewrite(tensor::CastOp op, ArrayRef<Value> operands,
                  ConversionPatternRewriter &rewriter) const override {
    auto resultType = getTypeConverter()->convertType(op.getType());
    rewriter.replaceOpWithNewOp<memref::CastOp>(op, resultType, operands[0]);
    return success();
  }
};
} // namespace

namespace {
class BufferizeDimOp : public OpConversionPattern<tensor::DimOp> {
public:
  using OpConversionPattern::OpConversionPattern;
  LogicalResult
  matchAndRewrite(tensor::DimOp op, ArrayRef<Value> operands,
                  ConversionPatternRewriter &rewriter) const override {
    tensor::DimOp::Adaptor adaptor(operands);
    rewriter.replaceOpWithNewOp<memref::DimOp>(op, adaptor.source(),
                                               adaptor.index());
    return success();
  }
};
} // namespace

namespace {
class BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> {
public:
  using OpConversionPattern::OpConversionPattern;
  LogicalResult
  matchAndRewrite(tensor::ExtractOp op, ArrayRef<Value> operands,
                  ConversionPatternRewriter &rewriter) const override {
    tensor::ExtractOp::Adaptor adaptor(operands);
    rewriter.replaceOpWithNewOp<memref::LoadOp>(op, adaptor.tensor(),
                                                adaptor.indices());
    return success();
  }
};
} // namespace

namespace {
class BufferizeFromElementsOp
    : public OpConversionPattern<tensor::FromElementsOp> {
public:
  using OpConversionPattern::OpConversionPattern;
  LogicalResult
  matchAndRewrite(tensor::FromElementsOp op, ArrayRef<Value> operands,
                  ConversionPatternRewriter &rewriter) const override {
    int numberOfElements = op.elements().size();
    auto resultType = MemRefType::get(
        {numberOfElements}, op.getType().cast<TensorType>().getElementType());
    Value result = rewriter.create<memref::AllocOp>(op.getLoc(), resultType);
    for (auto element : llvm::enumerate(op.elements())) {
      Value index =
          rewriter.create<ConstantIndexOp>(op.getLoc(), element.index());
      rewriter.create<memref::StoreOp>(op.getLoc(), element.value(), result,
                                       index);
    }
    rewriter.replaceOp(op, {result});
    return success();
  }
};
} // namespace

namespace {
class BufferizeGenerateOp : public OpConversionPattern<tensor::GenerateOp> {
public:
  using OpConversionPattern::OpConversionPattern;

  LogicalResult
  matchAndRewrite(tensor::GenerateOp op, ArrayRef<Value> operands,
                  ConversionPatternRewriter &rewriter) const final {
    // Allocate memory.
    Location loc = op.getLoc();
    tensor::GenerateOp::Adaptor transformed(operands);
    RankedTensorType tensorType = op.getType().cast<RankedTensorType>();
    MemRefType memrefType =
        MemRefType::get(tensorType.getShape(), tensorType.getElementType());
    Value result = rewriter.create<memref::AllocOp>(
        loc, memrefType, transformed.dynamicExtents());

    // Collect loop bounds.
    int64_t rank = tensorType.getRank();
    Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
    Value one = rewriter.create<ConstantIndexOp>(loc, 1);
    SmallVector<Value, 4> lowerBounds(rank, zero);
    SmallVector<Value, 4> steps(rank, one);
    SmallVector<Value, 4> upperBounds;
    int nextDynamicIndex = 0;
    for (int i = 0; i < rank; i++) {
      Value upperBound =
          tensorType.isDynamicDim(i)
              ? transformed.dynamicExtents()[nextDynamicIndex++]
              : rewriter.create<ConstantIndexOp>(loc, memrefType.getDimSize(i));
      upperBounds.push_back(upperBound);
    }

    // Generate tensor elements with a parallel loop that stores into
    // each element of the resulting memref.
    //
    // This is a bit tricky. We cannot simply clone the ops because when an op
    // is cloned, it must be legalized. However, we want to allow arbitrary ops
    // in the body that we don't necessarily have legalization patterns for as
    // part of this dialect conversion invocation.
    //
    // To accomplish this, we use mergeBlockBefore to "move" this op's body
    // into the scf.parallel's body.
    auto parallel =
        rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
    Block *parallelBody = parallel.getBody();
    rewriter.mergeBlockBefore(op.getBody(), parallelBody->getTerminator(),
                              parallelBody->getArguments());
    // Replace the inlined yield op with a store op. The scf.parallel's builder
    // already populated an scf.yield at the end, so we don't need to worry
    // about creating that.
    Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
    rewriter.setInsertionPointAfter(elementYield);
    rewriter.replaceOpWithNewOp<memref::StoreOp>(
        elementYield, elementYield->getOperands()[0], result,
        parallelBody->getArguments());

    rewriter.replaceOp(op, {result});
    return success();
  }
};
} // namespace

void mlir::populateTensorBufferizePatterns(
    BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
  patterns.add<BufferizeCastOp, BufferizeDimOp, BufferizeExtractOp,
               BufferizeFromElementsOp, BufferizeGenerateOp>(
      typeConverter, patterns.getContext());
}

namespace {
struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
  void runOnFunction() override {
    auto *context = &getContext();
    BufferizeTypeConverter typeConverter;
    RewritePatternSet patterns(context);
    ConversionTarget target(*context);

    populateBufferizeMaterializationLegality(target);

    populateTensorBufferizePatterns(typeConverter, patterns);
    target.addIllegalOp<tensor::CastOp, tensor::ExtractOp,
                        tensor::FromElementsOp, tensor::GenerateOp>();
    target.addLegalDialect<memref::MemRefDialect>();
    target.addDynamicallyLegalDialect<StandardOpsDialect>(
        [&](Operation *op) { return typeConverter.isLegal(op); });
    target.addLegalDialect<scf::SCFDialect>();

    if (failed(
            applyPartialConversion(getFunction(), target, std::move(patterns))))
      signalPassFailure();
  }
};
} // namespace

std::unique_ptr<Pass> mlir::createTensorBufferizePass() {
  return std::make_unique<TensorBufferizePass>();
}