File: BasicPtxBuilderInterface.cpp

package info (click to toggle)
llvm-toolchain-18 1%3A18.1.8-18
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,908,340 kB
  • sloc: cpp: 6,667,937; ansic: 1,440,452; asm: 883,619; python: 230,549; objc: 76,880; f90: 74,238; lisp: 35,989; pascal: 16,571; sh: 10,229; perl: 7,459; ml: 5,047; awk: 3,523; makefile: 2,987; javascript: 2,149; xml: 892; fortran: 649; cs: 573
file content (158 lines) | stat: -rw-r--r-- 5,118 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
//===- BasicPtxBuilderInterface.td - PTX builder interface -*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Defines the interface to build PTX (Parallel Thread Execution) from NVVM Ops
// automatically. It is used by NVVM to LLVM pass.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
#include "mlir/Support/LogicalResult.h"

#define DEBUG_TYPE "ptx-builder"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define DBGSNL() (llvm::dbgs() << "\n")

//===----------------------------------------------------------------------===//
// BasicPtxBuilderInterface
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.cpp.inc"

using namespace mlir;
using namespace NVVM;

static constexpr int64_t kSharedMemorySpace = 3;

static char getRegisterType(Type type) {
  if (type.isInteger(1))
    return 'b';
  if (type.isInteger(16))
    return 'h';
  if (type.isInteger(32))
    return 'r';
  if (type.isInteger(64))
    return 'l';
  if (type.isF32())
    return 'f';
  if (type.isF64())
    return 'd';
  if (auto ptr = type.dyn_cast<LLVM::LLVMPointerType>()) {
    // Shared address spaces is addressed with 32-bit pointers.
    if (ptr.getAddressSpace() == kSharedMemorySpace) {
      return 'r';
    }
    return 'l';
  }
  // register type for struct is not supported.
  llvm_unreachable("The register type could not deduced from MLIR type");
  return '?';
}

static char getRegisterType(Value v) {
  if (v.getDefiningOp<LLVM::ConstantOp>())
    return 'n';
  return getRegisterType(v.getType());
}

void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
  LLVM_DEBUG(DBGS() << v << "\t Modifier : " << &itype << "\n");
  auto getModifier = [&]() -> const char * {
    if (itype == PTXRegisterMod::ReadWrite) {
      assert(false && "Read-Write modifier is not supported. Try setting the "
                      "same value as Write and Read seperately.");
      return "+";
    }
    if (itype == PTXRegisterMod::Write) {
      return "=";
    }
    return "";
  };
  auto addValue = [&](Value v) {
    if (itype == PTXRegisterMod::Read) {
      ptxOperands.push_back(v);
      return;
    }
    if (itype == PTXRegisterMod::ReadWrite)
      ptxOperands.push_back(v);
    hasResult = true;
  };

  llvm::raw_string_ostream ss(registerConstraints);
  // Handle Structs
  if (auto stype = dyn_cast<LLVM::LLVMStructType>(v.getType())) {
    if (itype == PTXRegisterMod::Write) {
      addValue(v);
    }
    for (auto [idx, t] : llvm::enumerate(stype.getBody())) {
      if (itype != PTXRegisterMod::Write) {
        Value extractValue = rewriter.create<LLVM::ExtractValueOp>(
            interfaceOp->getLoc(), v, idx);
        addValue(extractValue);
      }
      if (itype == PTXRegisterMod::ReadWrite) {
        ss << idx << ",";
      } else {
        ss << getModifier() << getRegisterType(t) << ",";
      }
      ss.flush();
    }
    return;
  }
  // Handle Scalars
  addValue(v);
  ss << getModifier() << getRegisterType(v) << ",";
  ss.flush();
}

LLVM::InlineAsmOp PtxBuilder::build() {
  auto asmDialectAttr = LLVM::AsmDialectAttr::get(interfaceOp->getContext(),
                                                  LLVM::AsmDialect::AD_ATT);

  auto resultTypes = interfaceOp->getResultTypes();

  // Remove the last comma from the constraints string.
  if (!registerConstraints.empty() &&
      registerConstraints[registerConstraints.size() - 1] == ',')
    registerConstraints.pop_back();

  std::string ptxInstruction = interfaceOp.getPtx();

  // Add the predicate to the asm string.
  if (interfaceOp.getPredicate().has_value() &&
      interfaceOp.getPredicate().value()) {
    std::string predicateStr = "@%";
    predicateStr += std::to_string((ptxOperands.size() - 1));
    ptxInstruction = predicateStr + " " + ptxInstruction;
  }

  // Tablegen doesn't accept $, so we use %, but inline assembly uses $.
  // Replace all % with $
  std::replace(ptxInstruction.begin(), ptxInstruction.end(), '%', '$');

  return rewriter.create<LLVM::InlineAsmOp>(
      interfaceOp->getLoc(),
      /*result types=*/resultTypes,
      /*operands=*/ptxOperands,
      /*asm_string=*/llvm::StringRef(ptxInstruction),
      /*constraints=*/registerConstraints.data(),
      /*has_side_effects=*/interfaceOp.hasSideEffect(),
      /*is_align_stack=*/false,
      /*asm_dialect=*/asmDialectAttr,
      /*operand_attrs=*/ArrayAttr());
}

void PtxBuilder::buildAndReplaceOp() {
  LLVM::InlineAsmOp inlineAsmOp = build();
  LLVM_DEBUG(DBGS() << "\n Generated PTX \n\t" << inlineAsmOp << "\n");
  if (inlineAsmOp->getNumResults() == interfaceOp->getNumResults()) {
    rewriter.replaceOp(interfaceOp, inlineAsmOp);
  } else {
    rewriter.eraseOp(interfaceOp);
  }
}