File: VectorToArmSME.cpp

package info (click to toggle)
swiftlang 6.0.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,519,992 kB
  • sloc: cpp: 9,107,863; ansic: 2,040,022; asm: 1,135,751; python: 296,500; objc: 82,456; f90: 60,502; lisp: 34,951; pascal: 19,946; sh: 18,133; perl: 7,482; ml: 4,937; javascript: 4,117; makefile: 3,840; awk: 3,535; xml: 914; fortran: 619; cs: 573; ruby: 573
file content (118 lines) | stat: -rw-r--r-- 4,144 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
//===- VectorToArmSME.cpp - Conversion from Vector to the ArmSME dialect --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"

#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/Support/Casting.h"

using namespace mlir;

static constexpr unsigned kMinNumElts = 16;

/// Returns true if 'val' is a splat of zero, false otherwise.
static bool isSplatZero(Type elemType, DenseElementsAttr val) {
  if (llvm::isa<FloatType>(elemType))
    return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
  if (llvm::isa<IntegerType>(elemType))
    return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
  return false;
}

namespace {

/// Look at `vector.transfer_write` operations and convert suitable candidates
/// to ArmSME operations, e.g.:
///
///    %cst = arith.constant dense<0> : vector<[16]x[16]xi8>
///    vector.transfer_write %cst, %arg0 : vector<[16]x[16]xi8>, memref<?x?xi8>
///
/// is converted to:
///
///    %0 = arm_sme.zero : vector<[16]x[16]xi8>
///    arm_sme.tile_store %arg0[%c0, %c0], %0 : memref<?x?xi8>,
///    vector<[16]x[16]xi8>
///
struct TransferWriteToArmSMELowering
    : public OpRewritePattern<vector::TransferWriteOp> {
  using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;

  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
                                PatternRewriter &rewriter) const final {
    auto vType = writeOp.getVectorType();
    if (vType.getRank() != 2)
      return failure();
    if (vType.getShape() != ArrayRef<int64_t>({kMinNumElts, kMinNumElts}))
      return failure();
    if (vType.getElementType() != rewriter.getI8Type())
      return failure();
    if (vType.getScalableDims().size() != 2)
      return failure();

    auto loc = writeOp.getLoc();

    if (!llvm::isa<MemRefType>(writeOp.getSource().getType()))
      return failure();

    auto constant = writeOp.getVector().getDefiningOp<arith::ConstantOp>();
    if (!constant)
      return failure();

    auto denseAttr = dyn_cast<DenseElementsAttr>(constant.getValueAttr());
    if (!denseAttr || !isSplatZero(vType.getElementType(), denseAttr))
      return failure();

    auto zero = rewriter.create<arm_sme::ZeroOp>(loc, vType);

    rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
        writeOp, zero, writeOp.getSource(), writeOp.getIndices());
    return success();
  }
};

/// Conversion pattern for vector.load.
struct VectorLoadToArmSMELowering : public OpRewritePattern<vector::LoadOp> {
  using OpRewritePattern<vector::LoadOp>::OpRewritePattern;

  LogicalResult matchAndRewrite(vector::LoadOp load,
                                PatternRewriter &rewriter) const override {
    if (!arm_sme::isValidSMETileVectorType(load.getVectorType()))
      return failure();

    rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
        load, load.getVectorType(), load.getBase(), load.getIndices());

    return success();
  }
};

/// Conversion pattern for vector.store.
struct VectorStoreToArmSMELowering : public OpRewritePattern<vector::StoreOp> {
  using OpRewritePattern<vector::StoreOp>::OpRewritePattern;

  LogicalResult matchAndRewrite(vector::StoreOp store,
                                PatternRewriter &rewriter) const override {
    if (!arm_sme::isValidSMETileVectorType(store.getVectorType()))
      return failure();

    rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
        store, store.getValueToStore(), store.getBase(), store.getIndices());

    return success();
  }
};

} // namespace

void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
                                          MLIRContext &ctx) {
  patterns.add<TransferWriteToArmSMELowering, VectorLoadToArmSMELowering,
               VectorStoreToArmSMELowering>(&ctx);
}