File: LowerVectorShapeCast.cpp

package info (click to toggle)
llvm-toolchain-18 1%3A18.1.8-18
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,908,340 kB
  • sloc: cpp: 6,667,937; ansic: 1,440,452; asm: 883,619; python: 230,549; objc: 76,880; f90: 74,238; lisp: 35,989; pascal: 16,571; sh: 10,229; perl: 7,459; ml: 5,047; awk: 3,523; makefile: 2,987; javascript: 2,149; xml: 892; fortran: 649; cs: 573
file content (360 lines) | stat: -rw-r--r-- 14,822 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
//===- LowerVectorShapeCast.cpp - Lower 'vector.shape_cast' operation -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements target-independent rewrites and utilities to lower the
// 'vector.shape_cast' operation.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/VectorInterfaces.h"
#include "mlir/Support/LogicalResult.h"

#define DEBUG_TYPE "vector-shape-cast-lowering"

using namespace mlir;
using namespace mlir::vector;

namespace {
/// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
/// vectors progressively on the way to target llvm.matrix intrinsics.
/// This iterates over the most major dimension of the 2-D vector and performs
/// rewrites into:
///   vector.extract from 2-D + vector.insert_strided_slice offset into 1-D
class ShapeCastOp2DDownCastRewritePattern
    : public OpRewritePattern<vector::ShapeCastOp> {
public:
  using OpRewritePattern::OpRewritePattern;

  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
                                PatternRewriter &rewriter) const override {
    auto sourceVectorType = op.getSourceVectorType();
    auto resultVectorType = op.getResultVectorType();

    if (sourceVectorType.isScalable() || resultVectorType.isScalable())
      return failure();

    if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
      return failure();

    auto loc = op.getLoc();
    Value desc = rewriter.create<arith::ConstantOp>(
        loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
    unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
    for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
      Value vec = rewriter.create<vector::ExtractOp>(loc, op.getSource(), i);
      desc = rewriter.create<vector::InsertStridedSliceOp>(
          loc, vec, desc,
          /*offsets=*/i * mostMinorVectorSize, /*strides=*/1);
    }
    rewriter.replaceOp(op, desc);
    return success();
  }
};

/// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
/// vectors progressively.
/// This iterates over the most major dimension of the 2-D vector and performs
/// rewrites into:
///   vector.extract_strided_slice from 1-D + vector.insert into 2-D
/// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
class ShapeCastOp2DUpCastRewritePattern
    : public OpRewritePattern<vector::ShapeCastOp> {
public:
  using OpRewritePattern::OpRewritePattern;

  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
                                PatternRewriter &rewriter) const override {
    auto sourceVectorType = op.getSourceVectorType();
    auto resultVectorType = op.getResultVectorType();

    if (sourceVectorType.isScalable() || resultVectorType.isScalable())
      return failure();

    if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
      return failure();

    auto loc = op.getLoc();
    Value desc = rewriter.create<arith::ConstantOp>(
        loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
    unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
    for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
      Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
          loc, op.getSource(), /*offsets=*/i * mostMinorVectorSize,
          /*sizes=*/mostMinorVectorSize,
          /*strides=*/1);
      desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
    }
    rewriter.replaceOp(op, desc);
    return success();
  }
};

static void incIdx(llvm::MutableArrayRef<int64_t> idx, VectorType tp,
                   int dimIdx, int initialStep = 1) {
  int step = initialStep;
  for (int d = dimIdx; d >= 0; d--) {
    idx[d] += step;
    if (idx[d] >= tp.getDimSize(d)) {
      idx[d] = 0;
      step = 1;
    } else {
      break;
    }
  }
}

// We typically should not lower general shape cast operations into data
// movement instructions, since the assumption is that these casts are
// optimized away during progressive lowering. For completeness, however,
// we fall back to a reference implementation that moves all elements
// into the right place if we get here.
class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
public:
  using OpRewritePattern::OpRewritePattern;

  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
                                PatternRewriter &rewriter) const override {
    Location loc = op.getLoc();
    auto sourceVectorType = op.getSourceVectorType();
    auto resultVectorType = op.getResultVectorType();

    if (sourceVectorType.isScalable() || resultVectorType.isScalable())
      return failure();

    // Special case 2D / 1D lowerings with better implementations.
    // TODO: make is ND / 1D to allow generic ND -> 1D -> MD.
    int64_t srcRank = sourceVectorType.getRank();
    int64_t resRank = resultVectorType.getRank();
    if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
      return failure();

    // Generic ShapeCast lowering path goes all the way down to unrolled scalar
    // extract/insert chains.
    // TODO: consider evolving the semantics to only allow 1D source or dest and
    // drop this potentially very expensive lowering.
    // Compute number of elements involved in the reshape.
    int64_t numElts = 1;
    for (int64_t r = 0; r < srcRank; r++)
      numElts *= sourceVectorType.getDimSize(r);
    // Replace with data movement operations:
    //    x[0,0,0] = y[0,0]
    //    x[0,0,1] = y[0,1]
    //    x[0,1,0] = y[0,2]
    // etc., incrementing the two index vectors "row-major"
    // within the source and result shape.
    SmallVector<int64_t> srcIdx(srcRank);
    SmallVector<int64_t> resIdx(resRank);
    Value result = rewriter.create<arith::ConstantOp>(
        loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
    for (int64_t i = 0; i < numElts; i++) {
      if (i != 0) {
        incIdx(srcIdx, sourceVectorType, srcRank - 1);
        incIdx(resIdx, resultVectorType, resRank - 1);
      }

      Value extract;
      if (srcRank == 0) {
        // 0-D vector special case
        assert(srcIdx.empty() && "Unexpected indices for 0-D vector");
        extract = rewriter.create<vector::ExtractElementOp>(
            loc, op.getSourceVectorType().getElementType(), op.getSource());
      } else {
        extract =
            rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
      }

      if (resRank == 0) {
        // 0-D vector special case
        assert(resIdx.empty() && "Unexpected indices for 0-D vector");
        result = rewriter.create<vector::InsertElementOp>(loc, extract, result);
      } else {
        result =
            rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
      }
    }
    rewriter.replaceOp(op, result);
    return success();
  }
};

/// A shape_cast lowering for scalable vectors with a single trailing scalable
/// dimension. This is similar to the general shape_cast lowering but makes use
/// of vector.scalable.insert and vector.scalable.extract to move elements a
/// subvector at a time.
///
/// E.g.:
/// ```
/// // Flatten scalable vector
/// %0 = vector.shape_cast %arg0 : vector<2x1x[4]xi32> to vector<[8]xi32>
/// ```
/// is rewritten to:
/// ```
/// // Flatten scalable vector
/// %c = arith.constant dense<0> : vector<[8]xi32>
/// %0 = vector.extract %arg0[0, 0] : vector<[4]xi32> from vector<2x1x[4]xi32>
/// %1 = vector.scalable.insert %0, %c[0] : vector<[4]xi32> into vector<[8]xi32>
/// %2 = vector.extract %arg0[1, 0] : vector<[4]xi32> from vector<2x1x[4]xi32>
/// %3 = vector.scalable.insert %2, %1[4] : vector<[4]xi32> into vector<[8]xi32>
/// ```
/// or:
/// ```
/// // Un-flatten scalable vector
/// %0 = vector.shape_cast %arg0 : vector<[8]xi32> to vector<2x1x[4]xi32>
/// ```
/// is rewritten to:
/// ```
/// // Un-flatten scalable vector
/// %c = arith.constant dense<0> : vector<2x1x[4]xi32>
/// %0 = vector.scalable.extract %arg0[0] : vector<[4]xi32> from vector<[8]xi32>
/// %1 = vector.insert %0, %c [0, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
/// %2 = vector.scalable.extract %arg0[4] : vector<[4]xi32> from vector<[8]xi32>
/// %3 = vector.insert %2, %1 [1, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
/// ```
class ScalableShapeCastOpRewritePattern
    : public OpRewritePattern<vector::ShapeCastOp> {
public:
  using OpRewritePattern::OpRewritePattern;

  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
                                PatternRewriter &rewriter) const override {

    Location loc = op.getLoc();
    auto sourceVectorType = op.getSourceVectorType();
    auto resultVectorType = op.getResultVectorType();
    auto srcRank = sourceVectorType.getRank();
    auto resRank = resultVectorType.getRank();

    // This can only lower shape_casts where both the source and result types
    // have a single trailing scalable dimension. This is because there are no
    // legal representation of other scalable types in LLVM (and likely won't be
    // soon). There are also (currently) no operations that can index or extract
    // from >= 2D scalable vectors or scalable vectors of fixed vectors.
    if (!isTrailingDimScalable(sourceVectorType) ||
        !isTrailingDimScalable(resultVectorType)) {
      return failure();
    }

    // The sizes of the trailing dimension of the source and result vectors, the
    // size of subvector to move, and the number of elements in the vectors.
    // These are "min" sizes as they are the size when vscale == 1.
    auto minSourceTrailingSize = sourceVectorType.getShape().back();
    auto minResultTrailingSize = resultVectorType.getShape().back();
    auto minExtractionSize =
        std::min(minSourceTrailingSize, minResultTrailingSize);
    int64_t minNumElts = 1;
    for (auto size : sourceVectorType.getShape())
      minNumElts *= size;

    // The subvector type to move from the source to the result. Note that this
    // is a scalable vector. This rewrite will generate code in terms of the
    // "min" size (vscale == 1 case), that scales to any vscale.
    auto extractionVectorType = VectorType::get(
        {minExtractionSize}, sourceVectorType.getElementType(), {true});

    Value result = rewriter.create<arith::ConstantOp>(
        loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));

    SmallVector<int64_t> srcIdx(srcRank);
    SmallVector<int64_t> resIdx(resRank);

    // TODO: Try rewriting this with StaticTileOffsetRange (from IndexingUtils)
    // once D150000 lands.
    Value currentResultScalableVector;
    Value currentSourceScalableVector;
    for (int64_t i = 0; i < minNumElts; i += minExtractionSize) {
      // 1. Extract a scalable subvector from the source vector.
      if (!currentSourceScalableVector) {
        if (srcRank != 1) {
          currentSourceScalableVector = rewriter.create<vector::ExtractOp>(
              loc, op.getSource(), llvm::ArrayRef(srcIdx).drop_back());
        } else {
          currentSourceScalableVector = op.getSource();
        }
      }
      Value sourceSubVector = currentSourceScalableVector;
      if (minExtractionSize < minSourceTrailingSize) {
        sourceSubVector = rewriter.create<vector::ScalableExtractOp>(
            loc, extractionVectorType, sourceSubVector, srcIdx.back());
      }

      // 2. Insert the scalable subvector into the result vector.
      if (!currentResultScalableVector) {
        if (minExtractionSize == minResultTrailingSize) {
          currentResultScalableVector = sourceSubVector;
        } else if (resRank != 1) {
          currentResultScalableVector = rewriter.create<vector::ExtractOp>(
              loc, result, llvm::ArrayRef(resIdx).drop_back());
        } else {
          currentResultScalableVector = result;
        }
      }
      if (minExtractionSize < minResultTrailingSize) {
        currentResultScalableVector = rewriter.create<vector::ScalableInsertOp>(
            loc, sourceSubVector, currentResultScalableVector, resIdx.back());
      }

      // 3. Update the source and result scalable vectors if needed.
      if (resIdx.back() + minExtractionSize >= minResultTrailingSize &&
          currentResultScalableVector != result) {
        // Finished row of result. Insert complete scalable vector into result
        // (n-D) vector.
        result = rewriter.create<vector::InsertOp>(
            loc, currentResultScalableVector, result,
            llvm::ArrayRef(resIdx).drop_back());
        currentResultScalableVector = {};
      }
      if (srcIdx.back() + minExtractionSize >= minSourceTrailingSize) {
        // Finished row of source.
        currentSourceScalableVector = {};
      }

      // 4. Increment the insert/extract indices, stepping by minExtractionSize
      // for the trailing dimensions.
      incIdx(srcIdx, sourceVectorType, srcRank - 1, minExtractionSize);
      incIdx(resIdx, resultVectorType, resRank - 1, minExtractionSize);
    }

    rewriter.replaceOp(op, result);
    return success();
  }

  static bool isTrailingDimScalable(VectorType type) {
    return type.getRank() >= 1 && type.getScalableDims().back() &&
           !llvm::is_contained(type.getScalableDims().drop_back(), true);
  }
};

} // namespace

void mlir::vector::populateVectorShapeCastLoweringPatterns(
    RewritePatternSet &patterns, PatternBenefit benefit) {
  patterns.add<ShapeCastOp2DDownCastRewritePattern,
               ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern,
               ScalableShapeCastOpRewritePattern>(patterns.getContext(),
                                                  benefit);
}