File: OpenACCToLLVM.cpp

package info (click to toggle)
llvm-toolchain-16 1%3A16.0.6-15~deb12u1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 1,634,792 kB
  • sloc: cpp: 6,179,261; ansic: 1,216,205; asm: 741,319; python: 196,614; objc: 75,325; f90: 49,640; lisp: 32,396; pascal: 12,286; sh: 9,394; perl: 7,442; ml: 5,494; awk: 3,523; makefile: 2,723; javascript: 1,206; xml: 886; fortran: 581; cs: 573
file content (245 lines) | stat: -rw-r--r-- 10,501 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
//===- OpenACCToLLVM.cpp - Prepare OpenACC data for LLVM translation ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h"

#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/IR/Builders.h"
#include "mlir/Pass/Pass.h"

namespace mlir {
#define GEN_PASS_DEF_CONVERTOPENACCTOLLVM
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir

using namespace mlir;

//===----------------------------------------------------------------------===//
// DataDescriptor implementation
//===----------------------------------------------------------------------===//

constexpr StringRef getStructName() { return "openacc_data"; }

/// Construct a helper for the given descriptor value.
DataDescriptor::DataDescriptor(Value descriptor) : StructBuilder(descriptor) {
  assert(value != nullptr && "value cannot be null");
}

/// Builds IR creating an `undef` value of the data descriptor.
DataDescriptor DataDescriptor::undef(OpBuilder &builder, Location loc,
                                     Type basePtrTy, Type ptrTy) {
  Type descriptorType = LLVM::LLVMStructType::getNewIdentified(
      builder.getContext(), getStructName(),
      {basePtrTy, ptrTy, builder.getI64Type()});
  Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType);
  return DataDescriptor(descriptor);
}

/// Check whether the type is a valid data descriptor.
bool DataDescriptor::isValid(Value descriptor) {
  if (auto type = descriptor.getType().dyn_cast<LLVM::LLVMStructType>()) {
    if (type.isIdentified() && type.getName().startswith(getStructName()) &&
        type.getBody().size() == 3 &&
        (type.getBody()[kPtrBasePosInDataDescriptor]
             .isa<LLVM::LLVMPointerType>() ||
         type.getBody()[kPtrBasePosInDataDescriptor]
             .isa<LLVM::LLVMStructType>()) &&
        type.getBody()[kPtrPosInDataDescriptor].isa<LLVM::LLVMPointerType>() &&
        type.getBody()[kSizePosInDataDescriptor].isInteger(64))
      return true;
  }
  return false;
}

/// Builds IR inserting the base pointer value into the descriptor.
void DataDescriptor::setBasePointer(OpBuilder &builder, Location loc,
                                    Value basePtr) {
  setPtr(builder, loc, kPtrBasePosInDataDescriptor, basePtr);
}

/// Builds IR inserting the pointer value into the descriptor.
void DataDescriptor::setPointer(OpBuilder &builder, Location loc, Value ptr) {
  setPtr(builder, loc, kPtrPosInDataDescriptor, ptr);
}

/// Builds IR inserting the size value into the descriptor.
void DataDescriptor::setSize(OpBuilder &builder, Location loc, Value size) {
  setPtr(builder, loc, kSizePosInDataDescriptor, size);
}

//===----------------------------------------------------------------------===//
// Conversion patterns
//===----------------------------------------------------------------------===//

namespace {

template <typename Op>
class LegalizeDataOpForLLVMTranslation : public ConvertOpToLLVMPattern<Op> {
  using ConvertOpToLLVMPattern<Op>::ConvertOpToLLVMPattern;

  LogicalResult
  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
                  ConversionPatternRewriter &builder) const override {
    Location loc = op.getLoc();
    TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();

    unsigned numDataOperand = op.getNumDataOperands();

    // Keep the non data operands without modification.
    auto nonDataOperands = adaptor.getOperands().take_front(
        adaptor.getOperands().size() - numDataOperand);
    SmallVector<Value> convertedOperands;
    convertedOperands.append(nonDataOperands.begin(), nonDataOperands.end());

    // Go over the data operand and legalize them for translation.
    for (unsigned idx = 0; idx < numDataOperand; ++idx) {
      Value originalDataOperand = op.getDataOperand(idx);

      // Traverse operands that were converted to MemRefDescriptors.
      if (auto memRefType =
              originalDataOperand.getType().dyn_cast<MemRefType>()) {
        Type structType = converter->convertType(memRefType);
        Value memRefDescriptor = builder
                                     .create<UnrealizedConversionCastOp>(
                                         loc, structType, originalDataOperand)
                                     .getResult(0);

        // Calculate the size of the memref and get the pointer to the allocated
        // buffer.
        SmallVector<Value> sizes;
        SmallVector<Value> strides;
        Value sizeBytes;
        ConvertToLLVMPattern::getMemRefDescriptorSizes(
            loc, memRefType, {}, builder, sizes, strides, sizeBytes);
        MemRefDescriptor descriptor(memRefDescriptor);
        Value dataPtr = descriptor.alignedPtr(builder, loc);
        auto ptrType = descriptor.getElementPtrType();

        auto descr = DataDescriptor::undef(builder, loc, structType, ptrType);
        descr.setBasePointer(builder, loc, memRefDescriptor);
        descr.setPointer(builder, loc, dataPtr);
        descr.setSize(builder, loc, sizeBytes);
        convertedOperands.push_back(descr);
      } else if (originalDataOperand.getType().isa<LLVM::LLVMPointerType>()) {
        convertedOperands.push_back(originalDataOperand);
      } else {
        // Type not supported.
        return builder.notifyMatchFailure(op, "unsupported type");
      }
    }

    builder.replaceOpWithNewOp<Op>(op, TypeRange(), convertedOperands,
                                   op.getOperation()->getAttrs());

    return success();
  }
};
} // namespace

void mlir::populateOpenACCToLLVMConversionPatterns(
    LLVMTypeConverter &converter, RewritePatternSet &patterns) {
  patterns.add<LegalizeDataOpForLLVMTranslation<acc::DataOp>>(converter);
  patterns.add<LegalizeDataOpForLLVMTranslation<acc::EnterDataOp>>(converter);
  patterns.add<LegalizeDataOpForLLVMTranslation<acc::ExitDataOp>>(converter);
  patterns.add<LegalizeDataOpForLLVMTranslation<acc::ParallelOp>>(converter);
  patterns.add<LegalizeDataOpForLLVMTranslation<acc::UpdateOp>>(converter);
}

namespace {
struct ConvertOpenACCToLLVMPass
    : public impl::ConvertOpenACCToLLVMBase<ConvertOpenACCToLLVMPass> {
  void runOnOperation() override;
};
} // namespace

void ConvertOpenACCToLLVMPass::runOnOperation() {
  auto op = getOperation();
  auto *context = op.getContext();

  // Convert to OpenACC operations with LLVM IR dialect
  RewritePatternSet patterns(context);
  LLVMTypeConverter converter(context);
  populateOpenACCToLLVMConversionPatterns(converter, patterns);

  ConversionTarget target(*context);
  target.addLegalDialect<LLVM::LLVMDialect>();
  target.addLegalOp<UnrealizedConversionCastOp>();

  auto allDataOperandsAreConverted = [](ValueRange operands) {
    for (Value operand : operands) {
      if (!DataDescriptor::isValid(operand) &&
          !operand.getType().isa<LLVM::LLVMPointerType>())
        return false;
    }
    return true;
  };

  target.addDynamicallyLegalOp<acc::DataOp>(
      [allDataOperandsAreConverted](acc::DataOp op) {
        return allDataOperandsAreConverted(op.getCopyOperands()) &&
               allDataOperandsAreConverted(op.getCopyinOperands()) &&
               allDataOperandsAreConverted(op.getCopyinReadonlyOperands()) &&
               allDataOperandsAreConverted(op.getCopyoutOperands()) &&
               allDataOperandsAreConverted(op.getCopyoutZeroOperands()) &&
               allDataOperandsAreConverted(op.getCreateOperands()) &&
               allDataOperandsAreConverted(op.getCreateZeroOperands()) &&
               allDataOperandsAreConverted(op.getNoCreateOperands()) &&
               allDataOperandsAreConverted(op.getPresentOperands()) &&
               allDataOperandsAreConverted(op.getDeviceptrOperands()) &&
               allDataOperandsAreConverted(op.getAttachOperands());
      });

  target.addDynamicallyLegalOp<acc::EnterDataOp>(
      [allDataOperandsAreConverted](acc::EnterDataOp op) {
        return allDataOperandsAreConverted(op.getCopyinOperands()) &&
               allDataOperandsAreConverted(op.getCreateOperands()) &&
               allDataOperandsAreConverted(op.getCreateZeroOperands()) &&
               allDataOperandsAreConverted(op.getAttachOperands());
      });

  target.addDynamicallyLegalOp<acc::ExitDataOp>(
      [allDataOperandsAreConverted](acc::ExitDataOp op) {
        return allDataOperandsAreConverted(op.getCopyoutOperands()) &&
               allDataOperandsAreConverted(op.getDeleteOperands()) &&
               allDataOperandsAreConverted(op.getDetachOperands());
      });

  target.addDynamicallyLegalOp<acc::ParallelOp>(
      [allDataOperandsAreConverted](acc::ParallelOp op) {
        return allDataOperandsAreConverted(op.getReductionOperands()) &&
               allDataOperandsAreConverted(op.getCopyOperands()) &&
               allDataOperandsAreConverted(op.getCopyinOperands()) &&
               allDataOperandsAreConverted(op.getCopyinReadonlyOperands()) &&
               allDataOperandsAreConverted(op.getCopyoutOperands()) &&
               allDataOperandsAreConverted(op.getCopyoutZeroOperands()) &&
               allDataOperandsAreConverted(op.getCreateOperands()) &&
               allDataOperandsAreConverted(op.getCreateZeroOperands()) &&
               allDataOperandsAreConverted(op.getNoCreateOperands()) &&
               allDataOperandsAreConverted(op.getPresentOperands()) &&
               allDataOperandsAreConverted(op.getDevicePtrOperands()) &&
               allDataOperandsAreConverted(op.getAttachOperands()) &&
               allDataOperandsAreConverted(op.getGangPrivateOperands()) &&
               allDataOperandsAreConverted(op.getGangFirstPrivateOperands());
      });

  target.addDynamicallyLegalOp<acc::UpdateOp>(
      [allDataOperandsAreConverted](acc::UpdateOp op) {
        return allDataOperandsAreConverted(op.getHostOperands()) &&
               allDataOperandsAreConverted(op.getDeviceOperands());
      });

  if (failed(applyPartialConversion(op, target, std::move(patterns))))
    signalPassFailure();
}

std::unique_ptr<OperationPass<ModuleOp>>
mlir::createConvertOpenACCToLLVMPass() {
  return std::make_unique<ConvertOpenACCToLLVMPass>();
}