File: RISCVInsertReadWriteCSR.cpp

package info (click to toggle)
swiftlang 6.0.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,519,992 kB
  • sloc: cpp: 9,107,863; ansic: 2,040,022; asm: 1,135,751; python: 296,500; objc: 82,456; f90: 60,502; lisp: 34,951; pascal: 19,946; sh: 18,133; perl: 7,482; ml: 4,937; javascript: 4,117; makefile: 3,840; awk: 3,535; xml: 914; fortran: 619; cs: 573; ruby: 573
file content (135 lines) | stat: -rw-r--r-- 4,417 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
//===-- RISCVInsertReadWriteCSR.cpp - Insert Read/Write of RISC-V CSR -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// This file implements the machine function pass to insert read/write of CSR-s
// of the RISC-V instructions.
//
// Currently the pass implements naive insertion of a write to vxrm before an
// RVV fixed-point instruction.
//
//===----------------------------------------------------------------------===//

#include "MCTargetDesc/RISCVBaseInfo.h"
#include "RISCV.h"
#include "RISCVSubtarget.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
using namespace llvm;

#define DEBUG_TYPE "riscv-insert-read-write-csr"
#define RISCV_INSERT_READ_WRITE_CSR_NAME "RISC-V Insert Read/Write CSR Pass"

namespace {

class RISCVInsertReadWriteCSR : public MachineFunctionPass {
  const TargetInstrInfo *TII;

public:
  static char ID;

  RISCVInsertReadWriteCSR() : MachineFunctionPass(ID) {
    initializeRISCVInsertReadWriteCSRPass(*PassRegistry::getPassRegistry());
  }

  bool runOnMachineFunction(MachineFunction &MF) override;

  void getAnalysisUsage(AnalysisUsage &AU) const override {
    AU.setPreservesCFG();
    MachineFunctionPass::getAnalysisUsage(AU);
  }

  StringRef getPassName() const override {
    return RISCV_INSERT_READ_WRITE_CSR_NAME;
  }

private:
  bool emitWriteRoundingMode(MachineBasicBlock &MBB);
};

} // end anonymous namespace

char RISCVInsertReadWriteCSR::ID = 0;

INITIALIZE_PASS(RISCVInsertReadWriteCSR, DEBUG_TYPE,
                RISCV_INSERT_READ_WRITE_CSR_NAME, false, false)

// Returns the index to the rounding mode immediate value if any, otherwise the
// function will return None.
static std::optional<unsigned> getRoundModeIdx(const MachineInstr &MI) {
  uint64_t TSFlags = MI.getDesc().TSFlags;
  if (!RISCVII::hasRoundModeOp(TSFlags))
    return std::nullopt;

  // The operand order
  // -------------------------------------
  // | n-1 (if any)   | n-2  | n-3 | n-4 |
  // | policy         | sew  | vl  | rm  |
  // -------------------------------------
  return MI.getNumExplicitOperands() - RISCVII::hasVecPolicyOp(TSFlags) - 3;
}

// This function inserts a write to vxrm when encountering an RVV fixed-point
// instruction.
bool RISCVInsertReadWriteCSR::emitWriteRoundingMode(MachineBasicBlock &MBB) {
  bool Changed = false;
  for (MachineInstr &MI : MBB) {
    if (auto RoundModeIdx = getRoundModeIdx(MI)) {
      if (RISCVII::usesVXRM(MI.getDesc().TSFlags)) {
        unsigned VXRMImm = MI.getOperand(*RoundModeIdx).getImm();

        Changed = true;

        BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::WriteVXRMImm))
            .addImm(VXRMImm);
        MI.addOperand(MachineOperand::CreateReg(RISCV::VXRM, /*IsDef*/ false,
                                                /*IsImp*/ true));
      } else { // FRM
        unsigned FRMImm = MI.getOperand(*RoundModeIdx).getImm();

        // The value is a hint to this pass to not alter the frm value.
        if (FRMImm == RISCVFPRndMode::DYN)
          continue;

        Changed = true;

        // Save
        MachineRegisterInfo *MRI = &MBB.getParent()->getRegInfo();
        Register SavedFRM = MRI->createVirtualRegister(&RISCV::GPRRegClass);
        BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::SwapFRMImm),
                SavedFRM)
            .addImm(FRMImm);
        MI.addOperand(MachineOperand::CreateReg(RISCV::FRM, /*IsDef*/ false,
                                                /*IsImp*/ true));
        // Restore
        MachineInstrBuilder MIB =
            BuildMI(*MBB.getParent(), {}, TII->get(RISCV::WriteFRM))
                .addReg(SavedFRM);
        MBB.insertAfter(MI, MIB);
      }
    }
  }
  return Changed;
}

bool RISCVInsertReadWriteCSR::runOnMachineFunction(MachineFunction &MF) {
  // Skip if the vector extension is not enabled.
  const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
  if (!ST.hasVInstructions())
    return false;

  TII = ST.getInstrInfo();

  bool Changed = false;

  for (MachineBasicBlock &MBB : MF)
    Changed |= emitWriteRoundingMode(MBB);

  return Changed;
}

FunctionPass *llvm::createRISCVInsertReadWriteCSRPass() {
  return new RISCVInsertReadWriteCSR();
}