1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
|
//===- EnableArmStreaming.cpp - Enable Armv9 Streaming SVE mode -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This pass enables the Armv9 Scalable Matrix Extension (SME) Streaming SVE
// (SSVE) mode [1][2] by adding either of the following attributes to
// 'func.func' ops:
//
// * 'arm_streaming' (default)
// * 'arm_locally_streaming'
//
// It can also optionally enable the ZA storage array.
//
// Streaming-mode is part of the interface (ABI) for functions with the
// first attribute and it's the responsibility of the caller to manage
// PSTATE.SM on entry/exit to functions with this attribute [3]. The LLVM
// backend will emit 'smstart sm' / 'smstop sm' [4] around calls to
// streaming functions.
//
// In locally streaming functions PSTATE.SM is kept internal and managed by
// the callee on entry/exit. The LLVM backend will emit 'smstart sm' /
// 'smstop sm' in the prologue / epilogue for functions with this
// attribute.
//
// [1] https://developer.arm.com/documentation/ddi0616/aa
// [2] https://llvm.org/docs/AArch64SME.html
// [3] https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#671pstatesm-interfaces
// [4] https://developer.arm.com/documentation/ddi0602/2023-03/Base-Instructions/SMSTART--Enables-access-to-Streaming-SVE-mode-and-SME-architectural-state--an-alias-of-MSR--immediate--
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/ArmSME/Transforms/PassesEnums.cpp.inc"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#define DEBUG_TYPE "enable-arm-streaming"
namespace mlir {
namespace arm_sme {
#define GEN_PASS_DEF_ENABLEARMSTREAMING
#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
} // namespace arm_sme
} // namespace mlir
using namespace mlir;
using namespace mlir::arm_sme;
namespace {
constexpr StringLiteral
kEnableArmStreamingIgnoreAttr("enable_arm_streaming_ignore");
template <typename... Ops>
constexpr auto opList() {
return std::array{TypeID::get<Ops>()...};
}
bool isScalableVector(Type type) {
if (auto vectorType = dyn_cast<VectorType>(type))
return vectorType.isScalable();
return false;
}
struct EnableArmStreamingPass
: public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode,
bool ifRequiredByOps, bool ifScalableAndSupported) {
this->streamingMode = streamingMode;
this->zaMode = zaMode;
this->ifRequiredByOps = ifRequiredByOps;
this->ifScalableAndSupported = ifScalableAndSupported;
}
void runOnOperation() override {
auto function = getOperation();
if (ifRequiredByOps && ifScalableAndSupported) {
function->emitOpError(
"enable-arm-streaming: `if-required-by-ops` and "
"`if-scalable-and-supported` are mutually exclusive");
return signalPassFailure();
}
if (ifRequiredByOps) {
bool foundTileOp = false;
function.walk([&](Operation *op) {
if (llvm::isa<ArmSMETileOpInterface>(op)) {
foundTileOp = true;
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (!foundTileOp)
return;
}
if (ifScalableAndSupported) {
// FIXME: This should be based on target information (i.e., the presence
// of FEAT_SME_FA64). This currently errs on the side of caution. If
// possible gathers/scatters should be lowered regular vector loads/stores
// before invoking this pass.
auto disallowedOperations = opList<vector::GatherOp, vector::ScatterOp>();
bool isCompatibleScalableFunction = false;
function.walk([&](Operation *op) {
if (llvm::is_contained(disallowedOperations,
op->getName().getTypeID())) {
isCompatibleScalableFunction = false;
return WalkResult::interrupt();
}
if (!isCompatibleScalableFunction &&
(llvm::any_of(op->getOperandTypes(), isScalableVector) ||
llvm::any_of(op->getResultTypes(), isScalableVector))) {
isCompatibleScalableFunction = true;
}
return WalkResult::advance();
});
if (!isCompatibleScalableFunction)
return;
}
if (function->getAttr(kEnableArmStreamingIgnoreAttr) ||
streamingMode == ArmStreamingMode::Disabled)
return;
auto unitAttr = UnitAttr::get(&getContext());
function->setAttr(stringifyArmStreamingMode(streamingMode), unitAttr);
// The pass currently only supports enabling ZA when in streaming-mode, but
// ZA can be accessed by the SME LDR, STR and ZERO instructions when not in
// streaming-mode (see section B1.1.1, IDGNQM of spec [1]). It may be worth
// supporting this later.
if (zaMode != ArmZaMode::Disabled)
function->setAttr(stringifyArmZaMode(zaMode), unitAttr);
}
};
} // namespace
std::unique_ptr<Pass> mlir::arm_sme::createEnableArmStreamingPass(
const ArmStreamingMode streamingMode, const ArmZaMode zaMode,
bool ifRequiredByOps, bool ifScalableAndSupported) {
return std::make_unique<EnableArmStreamingPass>(
streamingMode, zaMode, ifRequiredByOps, ifScalableAndSupported);
}
|