1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
|
//===- Interfaces.cpp - C Interface for MLIR Interfaces -------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir-c/Interfaces.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Interfaces.h"
#include "mlir/CAPI/Support.h"
#include "mlir/CAPI/Wrap.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "llvm/ADT/ScopeExit.h"
#include <optional>
using namespace mlir;
namespace {
std::optional<RegisteredOperationName>
getRegisteredOperationName(MlirContext context, MlirStringRef opName) {
StringRef name(opName.data, opName.length);
std::optional<RegisteredOperationName> info =
RegisteredOperationName::lookup(name, unwrap(context));
return info;
}
std::optional<Location> maybeGetLocation(MlirLocation location) {
std::optional<Location> maybeLocation;
if (!mlirLocationIsNull(location))
maybeLocation = unwrap(location);
return maybeLocation;
}
SmallVector<Value> unwrapOperands(intptr_t nOperands, MlirValue *operands) {
SmallVector<Value> unwrappedOperands;
(void)unwrapList(nOperands, operands, unwrappedOperands);
return unwrappedOperands;
}
DictionaryAttr unwrapAttributes(MlirAttribute attributes) {
DictionaryAttr attributeDict;
if (!mlirAttributeIsNull(attributes))
attributeDict = llvm::cast<DictionaryAttr>(unwrap(attributes));
return attributeDict;
}
SmallVector<std::unique_ptr<Region>> unwrapRegions(intptr_t nRegions,
MlirRegion *regions) {
// Create a vector of unique pointers to regions and make sure they are not
// deleted when exiting the scope. This is a hack caused by C++ API expecting
// an list of unique pointers to regions (without ownership transfer
// semantics) and C API making ownership transfer explicit.
SmallVector<std::unique_ptr<Region>> unwrappedRegions;
unwrappedRegions.reserve(nRegions);
for (intptr_t i = 0; i < nRegions; ++i)
unwrappedRegions.emplace_back(unwrap(*(regions + i)));
auto cleaner = llvm::make_scope_exit([&]() {
for (auto ®ion : unwrappedRegions)
region.release();
});
return unwrappedRegions;
}
} // namespace
bool mlirOperationImplementsInterface(MlirOperation operation,
MlirTypeID interfaceTypeID) {
std::optional<RegisteredOperationName> info =
unwrap(operation)->getRegisteredInfo();
return info && info->hasInterface(unwrap(interfaceTypeID));
}
bool mlirOperationImplementsInterfaceStatic(MlirStringRef operationName,
MlirContext context,
MlirTypeID interfaceTypeID) {
std::optional<RegisteredOperationName> info = RegisteredOperationName::lookup(
StringRef(operationName.data, operationName.length), unwrap(context));
return info && info->hasInterface(unwrap(interfaceTypeID));
}
MlirTypeID mlirInferTypeOpInterfaceTypeID() {
return wrap(InferTypeOpInterface::getInterfaceID());
}
MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(
MlirStringRef opName, MlirContext context, MlirLocation location,
intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
void *properties, intptr_t nRegions, MlirRegion *regions,
MlirTypesCallback callback, void *userData) {
StringRef name(opName.data, opName.length);
std::optional<RegisteredOperationName> info =
getRegisteredOperationName(context, opName);
if (!info)
return mlirLogicalResultFailure();
std::optional<Location> maybeLocation = maybeGetLocation(location);
SmallVector<Value> unwrappedOperands = unwrapOperands(nOperands, operands);
DictionaryAttr attributeDict = unwrapAttributes(attributes);
SmallVector<std::unique_ptr<Region>> unwrappedRegions =
unwrapRegions(nRegions, regions);
SmallVector<Type> inferredTypes;
if (failed(info->getInterface<InferTypeOpInterface>()->inferReturnTypes(
unwrap(context), maybeLocation, unwrappedOperands, attributeDict,
properties, unwrappedRegions, inferredTypes)))
return mlirLogicalResultFailure();
SmallVector<MlirType> wrappedInferredTypes;
wrappedInferredTypes.reserve(inferredTypes.size());
for (Type t : inferredTypes)
wrappedInferredTypes.push_back(wrap(t));
callback(wrappedInferredTypes.size(), wrappedInferredTypes.data(), userData);
return mlirLogicalResultSuccess();
}
MlirTypeID mlirInferShapedTypeOpInterfaceTypeID() {
return wrap(InferShapedTypeOpInterface::getInterfaceID());
}
MlirLogicalResult mlirInferShapedTypeOpInterfaceInferReturnTypes(
MlirStringRef opName, MlirContext context, MlirLocation location,
intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
void *properties, intptr_t nRegions, MlirRegion *regions,
MlirShapedTypeComponentsCallback callback, void *userData) {
std::optional<RegisteredOperationName> info =
getRegisteredOperationName(context, opName);
if (!info)
return mlirLogicalResultFailure();
std::optional<Location> maybeLocation = maybeGetLocation(location);
SmallVector<Value> unwrappedOperands = unwrapOperands(nOperands, operands);
DictionaryAttr attributeDict = unwrapAttributes(attributes);
SmallVector<std::unique_ptr<Region>> unwrappedRegions =
unwrapRegions(nRegions, regions);
SmallVector<ShapedTypeComponents> inferredTypeComponents;
if (failed(info->getInterface<InferShapedTypeOpInterface>()
->inferReturnTypeComponents(
unwrap(context), maybeLocation,
mlir::ValueRange(llvm::ArrayRef(unwrappedOperands)),
attributeDict, properties, unwrappedRegions,
inferredTypeComponents)))
return mlirLogicalResultFailure();
bool hasRank;
intptr_t rank;
const int64_t *shapeData;
for (const ShapedTypeComponents &t : inferredTypeComponents) {
if (t.hasRank()) {
hasRank = true;
rank = t.getDims().size();
shapeData = t.getDims().data();
} else {
hasRank = false;
rank = 0;
shapeData = nullptr;
}
callback(hasRank, rank, shapeData, wrap(t.getElementType()),
wrap(t.getAttribute()), userData);
}
return mlirLogicalResultSuccess();
}
|