File: FoldTensorSubsetOps.cpp

package info (click to toggle)
llvm-toolchain-17 1%3A17.0.6-22
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,799,624 kB
  • sloc: cpp: 6,428,607; ansic: 1,383,196; asm: 793,408; python: 223,504; objc: 75,364; f90: 60,502; lisp: 33,869; pascal: 15,282; sh: 9,684; perl: 7,453; ml: 4,937; awk: 3,523; makefile: 2,889; javascript: 2,149; xml: 888; fortran: 619; cs: 573
file content (258 lines) | stat: -rw-r--r-- 10,369 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
//===- FoldTensorSubsetOps.cpp - Fold tensor subset ops -------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Fold tensor subset ops with producer / consumers.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"
#include <type_traits>

namespace mlir {
namespace tensor {
#define GEN_PASS_DEF_FOLDTENSORSUBSETOPS
#include "mlir/Dialect/Tensor/Transforms/Passes.h.inc"
} // namespace tensor
} // namespace mlir

using namespace mlir;

static Value getTensorOperand(vector::TransferReadOp op) {
  return op.getSource();
}

static Value getTensorOperand(tensor::InsertSliceOp op) {
  return op.getSource();
}

//===----------------------------------------------------------------------===//
// Patterns
//===----------------------------------------------------------------------===//

namespace {
/// Merge extract_slice operation with load/transferRead operation.
class TransferReadOfExtractSliceOpFolder final
    : public OpRewritePattern<vector::TransferReadOp> {
public:
  using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;

  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
                                PatternRewriter &rewriter) const override;
};

/// Merge insert_slice operation with store/transferWriteOp operation.
class InsertSliceOfTransferWriteOpFolder final
    : public OpRewritePattern<tensor::InsertSliceOp> {
public:
  using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;

  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
                                PatternRewriter &rewriter) const override;
};
} // namespace

template <typename XferOp, typename ExtractOrInsertOp>
static LogicalResult preconditionsFoldExtractOrInsertWithTransferOp(
    RewriterBase &rewriter, XferOp xferOp,
    ExtractOrInsertOp extractOrInsertSliceOp) {
  if (xferOp.hasOutOfBoundsDim())
    return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
  if (xferOp.getMask())
    return rewriter.notifyMatchFailure(xferOp, "masked transfer");
  if (!extractOrInsertSliceOp.hasUnitStride()) {
    return rewriter.notifyMatchFailure(
        xferOp, "non-1 stride insert/extract, requires keeping track of "
                "strides, this may result in needing to insert "
                "vector.insert_strided_slice/extract_strided_slice ops");
  }
  return success();
}

LogicalResult TransferReadOfExtractSliceOpFolder::matchAndRewrite(
    vector::TransferReadOp readOp, PatternRewriter &rewriter) const {
  auto extractSliceOp =
      getTensorOperand(readOp).getDefiningOp<tensor::ExtractSliceOp>();
  if (!extractSliceOp)
    return rewriter.notifyMatchFailure(readOp, "not an extract_slice");

  LogicalResult preconditionResult =
      preconditionsFoldExtractOrInsertWithTransferOp(rewriter, readOp,
                                                     extractSliceOp);
  if (failed(preconditionResult))
    return preconditionResult;

  SmallVector<Value> indices(readOp.getIndices().begin(),
                             readOp.getIndices().end());
  SmallVector<Value> sourceIndices;
  affine::resolveIndicesIntoOpWithOffsetsAndStrides(
      rewriter, readOp.getLoc(), extractSliceOp.getMixedOffsets(),
      extractSliceOp.getMixedStrides(), extractSliceOp.getDroppedDims(),
      indices, sourceIndices);

  rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
      readOp, readOp.getVectorType(), extractSliceOp.getSource(), sourceIndices,
      AffineMapAttr::get(expandDimsToRank(
          readOp.getPermutationMap(), extractSliceOp.getSourceType().getRank(),
          extractSliceOp.getDroppedDims())),
      readOp.getPadding(),
      /*mask=*/Value(), readOp.getInBoundsAttr());

  return success();
}

LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
    tensor::InsertSliceOp insertSliceOp, PatternRewriter &rewriter) const {
  auto writeOp = getTensorOperand(insertSliceOp)
                     .template getDefiningOp<vector::TransferWriteOp>();
  if (!writeOp)
    return rewriter.notifyMatchFailure(insertSliceOp, "not a transfer_write");

  LogicalResult preconditionResult =
      preconditionsFoldExtractOrInsertWithTransferOp(rewriter, writeOp,
                                                     insertSliceOp);
  if (failed(preconditionResult))
    return preconditionResult;

  SmallVector<Value> indices(writeOp.getIndices().begin(),
                             writeOp.getIndices().end());
  SmallVector<Value> sourceIndices;
  affine::resolveIndicesIntoOpWithOffsetsAndStrides(
      rewriter, writeOp.getLoc(), insertSliceOp.getMixedOffsets(),
      insertSliceOp.getMixedStrides(), insertSliceOp.getDroppedDims(), indices,
      sourceIndices);

  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
      insertSliceOp, writeOp.getValue(), insertSliceOp.getDest(), sourceIndices,
      AffineMapAttr::get(expandDimsToRank(writeOp.getPermutationMap(),
                                          insertSliceOp.getDestType().getRank(),
                                          insertSliceOp.getDroppedDims())),
      writeOp.getInBoundsAttr());

  return success();
}

template <typename OpTy>
struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
  using OpRewritePattern<OpTy>::OpRewritePattern;

  LogicalResult matchAndRewrite(OpTy insertSliceOp,
                                PatternRewriter &rewriter) const override {
    auto sourceInsertSliceOp =
        insertSliceOp.getSource()
            .template getDefiningOp<tensor::InsertSliceOp>();
    if (!sourceInsertSliceOp)
      return failure();

    // TODO: relax unit stride assumption where possible.
    if (!insertSliceOp.hasUnitStride()) {
      return rewriter.notifyMatchFailure(insertSliceOp,
                                         "requires unit strides");
    }
    if (!sourceInsertSliceOp.hasUnitStride()) {
      return rewriter.notifyMatchFailure(sourceInsertSliceOp,
                                         "requires unit strides");
    }

    int64_t srcDim = 0;
    llvm::SmallBitVector droppedDims = insertSliceOp.getDroppedDims();
    for (int64_t d = 0, e = insertSliceOp.getDestType().getRank(); d < e; ++d) {
      if (droppedDims[d])
        continue;
      if (insertSliceOp.getMixedSizes()[d] !=
          sourceInsertSliceOp.getMixedSizes()[srcDim++]) {
        return rewriter.notifyMatchFailure(
            sourceInsertSliceOp,
            "requires matching sizes to fold, otherwise a copy is needed");
      }
    }

    // Resolve sizes according to dropped dims.
    SmallVector<OpFoldResult> resolvedSizes;
    // Note: the "insertSlice" case is symmetrical to the extract/subview case:
    // `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is
    // passed as the destination to the helper function.
    affine::resolveSizesIntoOpWithSizes(insertSliceOp.getMixedSizes(),
                                        sourceInsertSliceOp.getMixedSizes(),
                                        droppedDims, resolvedSizes);

    // If we are inside an InParallel region, temporarily set the insertion
    // point outside: only tensor.parallel_insert_slice ops are allowed in
    // there.
    if (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
      rewriter.setInsertionPoint(
          insertSliceOp->template getParentOfType<scf::InParallelOp>());
    }

    // Resolve offsets according to source offsets and strides.
    SmallVector<Value> resolvedOffsets;
    // Note: the "insertSlice" case is symmetrical to the extract/subview case:
    // `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is
    // passed as the destination to the helper function.
    affine::resolveIndicesIntoOpWithOffsetsAndStrides(
        rewriter, insertSliceOp.getLoc(), insertSliceOp.getMixedOffsets(),
        insertSliceOp.getMixedStrides(), droppedDims,
        sourceInsertSliceOp.getMixedOffsets(), resolvedOffsets);

    // Reset the insertion point.
    rewriter.setInsertionPoint(insertSliceOp);
    // Replace original op.
    rewriter.replaceOpWithNewOp<OpTy>(
        insertSliceOp, sourceInsertSliceOp.getSource(), insertSliceOp.getDest(),
        getAsOpFoldResult(resolvedOffsets), resolvedSizes,
        insertSliceOp.getMixedStrides());

    return success();
  }
};

void tensor::populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns) {
  populateFoldTensorSubsetIntoVectorTransferPatterns(patterns);
  patterns.add<InsertSliceOfInsertSliceFolder<tensor::InsertSliceOp>,
               InsertSliceOfInsertSliceFolder<tensor::ParallelInsertSliceOp>>(
      patterns.getContext());
}

void tensor::populateFoldTensorSubsetIntoVectorTransferPatterns(
    RewritePatternSet &patterns) {
  patterns.add<TransferReadOfExtractSliceOpFolder,
               InsertSliceOfTransferWriteOpFolder>(patterns.getContext());
}

//===----------------------------------------------------------------------===//
// Pass registration
//===----------------------------------------------------------------------===//

namespace {

struct FoldTensorSubsetOpsPass final
    : public tensor::impl::FoldTensorSubsetOpsBase<FoldTensorSubsetOpsPass> {
  void runOnOperation() override;
};

} // namespace

void FoldTensorSubsetOpsPass::runOnOperation() {
  RewritePatternSet patterns(&getContext());
  tensor::populateFoldTensorSubsetOpPatterns(patterns);
  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}

std::unique_ptr<Pass> tensor::createFoldTensorSubsetOpsPass() {
  return std::make_unique<FoldTensorSubsetOpsPass>();
}