File: EnableArmStreaming.cpp

package info (click to toggle)
swiftlang 6.1.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 2,791,604 kB
  • sloc: cpp: 9,901,740; ansic: 2,201,431; asm: 1,091,827; python: 308,252; objc: 82,166; f90: 80,126; lisp: 38,358; pascal: 25,559; sh: 20,429; ml: 5,058; perl: 4,745; makefile: 4,484; awk: 3,535; javascript: 3,018; xml: 918; fortran: 664; cs: 573; ruby: 396
file content (148 lines) | stat: -rw-r--r-- 5,592 bytes parent folder | download | duplicates (8)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
//===- EnableArmStreaming.cpp - Enable Armv9 Streaming SVE mode -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This pass enables the Armv9 Scalable Matrix Extension (SME) Streaming SVE
// (SSVE) mode [1][2] by adding either of the following attributes to
// 'func.func' ops:
//
//   * 'arm_streaming' (default)
//   * 'arm_locally_streaming'
//
// It can also optionally enable the ZA storage array.
//
// Streaming-mode is part of the interface (ABI) for functions with the
// first attribute and it's the responsibility of the caller to manage
// PSTATE.SM on entry/exit to functions with this attribute [3]. The LLVM
// backend will emit 'smstart sm' / 'smstop sm' [4] around calls to
// streaming functions.
//
// In locally streaming functions PSTATE.SM is kept internal and managed by
// the callee on entry/exit. The LLVM backend will emit 'smstart sm' /
// 'smstop sm' in the prologue / epilogue for functions with this
// attribute.
//
// [1] https://developer.arm.com/documentation/ddi0616/aa
// [2] https://llvm.org/docs/AArch64SME.html
// [3] https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#671pstatesm-interfaces
// [4] https://developer.arm.com/documentation/ddi0602/2023-03/Base-Instructions/SMSTART--Enables-access-to-Streaming-SVE-mode-and-SME-architectural-state--an-alias-of-MSR--immediate--
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/ArmSME/Transforms/PassesEnums.cpp.inc"

#include "mlir/Dialect/Func/IR/FuncOps.h"

#define DEBUG_TYPE "enable-arm-streaming"

namespace mlir {
namespace arm_sme {
#define GEN_PASS_DEF_ENABLEARMSTREAMING
#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
} // namespace arm_sme
} // namespace mlir

using namespace mlir;
using namespace mlir::arm_sme;
namespace {

constexpr StringLiteral
    kEnableArmStreamingIgnoreAttr("enable_arm_streaming_ignore");

template <typename... Ops>
constexpr auto opList() {
  return std::array{TypeID::get<Ops>()...};
}

bool isScalableVector(Type type) {
  if (auto vectorType = dyn_cast<VectorType>(type))
    return vectorType.isScalable();
  return false;
}

struct EnableArmStreamingPass
    : public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
  EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode,
                         bool ifRequiredByOps, bool ifScalableAndSupported) {
    this->streamingMode = streamingMode;
    this->zaMode = zaMode;
    this->ifRequiredByOps = ifRequiredByOps;
    this->ifScalableAndSupported = ifScalableAndSupported;
  }
  void runOnOperation() override {
    auto function = getOperation();

    if (ifRequiredByOps && ifScalableAndSupported) {
      function->emitOpError(
          "enable-arm-streaming: `if-required-by-ops` and "
          "`if-scalable-and-supported` are mutually exclusive");
      return signalPassFailure();
    }

    if (ifRequiredByOps) {
      bool foundTileOp = false;
      function.walk([&](Operation *op) {
        if (llvm::isa<ArmSMETileOpInterface>(op)) {
          foundTileOp = true;
          return WalkResult::interrupt();
        }
        return WalkResult::advance();
      });
      if (!foundTileOp)
        return;
    }

    if (ifScalableAndSupported) {
      // FIXME: This should be based on target information (i.e., the presence
      // of FEAT_SME_FA64). This currently errs on the side of caution. If
      // possible gathers/scatters should be lowered regular vector loads/stores
      // before invoking this pass.
      auto disallowedOperations = opList<vector::GatherOp, vector::ScatterOp>();
      bool isCompatibleScalableFunction = false;
      function.walk([&](Operation *op) {
        if (llvm::is_contained(disallowedOperations,
                               op->getName().getTypeID())) {
          isCompatibleScalableFunction = false;
          return WalkResult::interrupt();
        }
        if (!isCompatibleScalableFunction &&
            (llvm::any_of(op->getOperandTypes(), isScalableVector) ||
             llvm::any_of(op->getResultTypes(), isScalableVector))) {
          isCompatibleScalableFunction = true;
        }
        return WalkResult::advance();
      });
      if (!isCompatibleScalableFunction)
        return;
    }

    if (function->getAttr(kEnableArmStreamingIgnoreAttr) ||
        streamingMode == ArmStreamingMode::Disabled)
      return;

    auto unitAttr = UnitAttr::get(&getContext());

    function->setAttr(stringifyArmStreamingMode(streamingMode), unitAttr);

    // The pass currently only supports enabling ZA when in streaming-mode, but
    // ZA can be accessed by the SME LDR, STR and ZERO instructions when not in
    // streaming-mode (see section B1.1.1, IDGNQM of spec [1]). It may be worth
    // supporting this later.
    if (zaMode != ArmZaMode::Disabled)
      function->setAttr(stringifyArmZaMode(zaMode), unitAttr);
  }
};
} // namespace

std::unique_ptr<Pass> mlir::arm_sme::createEnableArmStreamingPass(
    const ArmStreamingMode streamingMode, const ArmZaMode zaMode,
    bool ifRequiredByOps, bool ifScalableAndSupported) {
  return std::make_unique<EnableArmStreamingPass>(
      streamingMode, zaMode, ifRequiredByOps, ifScalableAndSupported);
}