File: BubbleUpExtractSlice.cpp

package info (click to toggle)
llvm-toolchain-17 1%3A17.0.6-22
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,799,624 kB
  • sloc: cpp: 6,428,607; ansic: 1,383,196; asm: 793,408; python: 223,504; objc: 75,364; f90: 60,502; lisp: 33,869; pascal: 15,282; sh: 9,684; perl: 7,453; ml: 4,937; awk: 3,523; makefile: 2,889; javascript: 2,149; xml: 888; fortran: 619; cs: 573
file content (138 lines) | stat: -rw-r--r-- 5,384 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
//===- BubbleUpExtractSlice.cpp - bubble up tensor.extract_slice ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements patterns that transforms linalg.<op> +
// tensor.extract_slice into tensor.extract_slice + linalg.<op> to reduce
// the computation for the linalg op.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;
using namespace mlir::linalg;

namespace {
/// Bubble up extract_slice above Linalg operation.
///
/// A sequence of operations
///
/// ```mlir
/// %0 = linalg.<op> ... arg0, arg1, ...
/// %1 = tensor.extract_slice %0 ...
/// ```
///
/// can be replaced with
///
/// ```mlir
/// %0 = tensor.extract_slice %arg0
/// %1 = tensor.extract_slice %arg1
/// %2 = linalg.<op> ... %0, %1, ...
/// ```
///
/// This results in the reduce computation of the linalg operation.
///
struct BubbleUpExtractSliceOpPattern
    : OpRewritePattern<tensor::ExtractSliceOp> {
  using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;

  LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
                                PatternRewriter &rewriter) const final {
    Value source = sliceOp.getSource();
    auto linalgOp = source.getDefiningOp<LinalgOp>();
    if (!linalgOp) {
      return rewriter.notifyMatchFailure(sliceOp,
                                         "expected source to be linalg op");
    }

    // TODO: we might relax this if we want heuristics to detect that all uses
    // are small portion of the output.
    if (!linalgOp->hasOneUse()) {
      return rewriter.notifyMatchFailure(sliceOp,
                                         "expected single use of linalg op");
    }

    if (linalgOp.getNumDpsInits() != 1) {
      return rewriter.notifyMatchFailure(sliceOp,
                                         "expected single output of linalg op");
    }

    if (!linalgOp.hasTensorSemantics()) {
      return rewriter.notifyMatchFailure(sliceOp,
                                         "expected tensor of linalg op");
    }

    if (!sliceOp.hasUnitStride())
      return rewriter.notifyMatchFailure(sliceOp, "expected unit stride");

    if (sliceOp.getType().getRank() != sliceOp.getSourceType().getRank()) {
      return rewriter.notifyMatchFailure(sliceOp, "expected no rank reduction");
    }

    OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
    AffineMap indexingMap = linalgOp.getMatchingIndexingMap(outOperand);
    if (!indexingMap.isProjectedPermutation()) {
      return rewriter.notifyMatchFailure(
          sliceOp, "expected a projected permutation for output");
    }

    auto linalgLoc = linalgOp.getLoc();
    SmallVector<OpFoldResult> allShapeSizes =
        linalgOp.createFlatListOfOperandDims(rewriter, linalgLoc);
    AffineMap shapeSizesToLoopsMap = linalgOp.getShapesToLoopsMap();
    if (!shapeSizesToLoopsMap) {
      return rewriter.notifyMatchFailure(
          linalgOp, "failed to get loops map from shape sizes");
    }
    SmallVector<OpFoldResult> sizeBounds =
        affine::makeComposedFoldedMultiResultAffineApply(
            rewriter, linalgLoc, shapeSizesToLoopsMap, allShapeSizes);

    // The offsets and sizes from the slice operation only give you the tile
    // size of the output. Use that compute the tile sizes and offsets of the
    // loops. For loops not used to access the output, set the tile sizes to
    // loop bounds and set the offset to 0.
    SmallVector<OpFoldResult> tileOffsets(sizeBounds.size(),
                                          rewriter.getIndexAttr(0));
    SmallVector<OpFoldResult> tileSizes = sizeBounds;
    for (auto const &result : enumerate(indexingMap.getResults())) {
      unsigned position = result.value().cast<AffineDimExpr>().getPosition();
      tileOffsets[position] = sliceOp.getMixedOffsets()[result.index()];
      tileSizes[position] = sliceOp.getMixedSizes()[result.index()];
    }

    SmallVector<Value> valuesToTile = linalgOp->getOperands();
    SmallVector<Value> tiledOperands =
        makeTiledShapes(rewriter, linalgLoc, linalgOp, valuesToTile,
                        tileOffsets, tileSizes, sizeBounds,
                        /*omitPartialTileCheck=*/true);

    SmallVector<Type, 4> resultTensorTypes;
    for (OpOperand *opOperand : linalgOp.getDpsInitOperands())
      resultTensorTypes.push_back(
          tiledOperands[opOperand->getOperandNumber()].getType());

    Operation *newOp =
        clone(rewriter, linalgOp, resultTensorTypes, tiledOperands);
    rewriter.replaceOp(sliceOp, newOp->getResults());
    return success();
  }
};
} // namespace

void mlir::linalg::populateBubbleUpExtractSliceOpPatterns(
    RewritePatternSet &patterns) {
  auto *context = patterns.getContext();
  patterns.add<BubbleUpExtractSliceOpPattern>(context);
}