1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
|
//===- vulkan-runtime-wrappers.cpp - MLIR Vulkan runner wrapper library ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Implements C runtime wrappers around the VulkanRuntime.
//
//===----------------------------------------------------------------------===//
#include <iostream>
#include <mutex>
#include <numeric>
#include "VulkanRuntime.h"
namespace {
class VulkanRuntimeManager {
public:
VulkanRuntimeManager() = default;
VulkanRuntimeManager(const VulkanRuntimeManager &) = delete;
VulkanRuntimeManager operator=(const VulkanRuntimeManager &) = delete;
~VulkanRuntimeManager() = default;
void setResourceData(DescriptorSetIndex setIndex, BindingIndex bindIndex,
const VulkanHostMemoryBuffer &memBuffer) {
std::lock_guard<std::mutex> lock(mutex);
vulkanRuntime.setResourceData(setIndex, bindIndex, memBuffer);
}
void setEntryPoint(const char *entryPoint) {
std::lock_guard<std::mutex> lock(mutex);
vulkanRuntime.setEntryPoint(entryPoint);
}
void setNumWorkGroups(NumWorkGroups numWorkGroups) {
std::lock_guard<std::mutex> lock(mutex);
vulkanRuntime.setNumWorkGroups(numWorkGroups);
}
void setShaderModule(uint8_t *shader, uint32_t size) {
std::lock_guard<std::mutex> lock(mutex);
vulkanRuntime.setShaderModule(shader, size);
}
void runOnVulkan() {
std::lock_guard<std::mutex> lock(mutex);
if (failed(vulkanRuntime.initRuntime()) || failed(vulkanRuntime.run()) ||
failed(vulkanRuntime.updateHostMemoryBuffers()) ||
failed(vulkanRuntime.destroy())) {
std::cerr << "runOnVulkan failed";
}
}
private:
VulkanRuntime vulkanRuntime;
std::mutex mutex;
};
} // namespace
template <typename T, int N>
struct MemRefDescriptor {
T *allocated;
T *aligned;
int64_t offset;
int64_t sizes[N];
int64_t strides[N];
};
template <typename T, uint32_t S>
void bindMemRef(void *vkRuntimeManager, DescriptorSetIndex setIndex,
BindingIndex bindIndex, MemRefDescriptor<T, S> *ptr) {
uint32_t size = sizeof(T);
for (unsigned i = 0; i < S; i++)
size *= ptr->sizes[i];
VulkanHostMemoryBuffer memBuffer{ptr->allocated, size};
reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
->setResourceData(setIndex, bindIndex, memBuffer);
}
extern "C" {
/// Initializes `VulkanRuntimeManager` and returns a pointer to it.
void *initVulkan() { return new VulkanRuntimeManager(); }
/// Deinitializes `VulkanRuntimeManager` by the given pointer.
void deinitVulkan(void *vkRuntimeManager) {
delete reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager);
}
void runOnVulkan(void *vkRuntimeManager) {
reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)->runOnVulkan();
}
void setEntryPoint(void *vkRuntimeManager, const char *entryPoint) {
reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
->setEntryPoint(entryPoint);
}
void setNumWorkGroups(void *vkRuntimeManager, uint32_t x, uint32_t y,
uint32_t z) {
reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
->setNumWorkGroups({x, y, z});
}
void setBinaryShader(void *vkRuntimeManager, uint8_t *shader, uint32_t size) {
reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
->setShaderModule(shader, size);
}
/// Binds the given memref to the given descriptor set and descriptor
/// index.
#define DECLARE_BIND_MEMREF(size, type, typeName) \
void bindMemRef##size##D##typeName( \
void *vkRuntimeManager, DescriptorSetIndex setIndex, \
BindingIndex bindIndex, MemRefDescriptor<type, size> *ptr) { \
bindMemRef<type, size>(vkRuntimeManager, setIndex, bindIndex, ptr); \
}
DECLARE_BIND_MEMREF(1, float, Float)
DECLARE_BIND_MEMREF(2, float, Float)
DECLARE_BIND_MEMREF(3, float, Float)
DECLARE_BIND_MEMREF(1, int32_t, Int32)
DECLARE_BIND_MEMREF(2, int32_t, Int32)
DECLARE_BIND_MEMREF(3, int32_t, Int32)
DECLARE_BIND_MEMREF(1, int16_t, Int16)
DECLARE_BIND_MEMREF(2, int16_t, Int16)
DECLARE_BIND_MEMREF(3, int16_t, Int16)
DECLARE_BIND_MEMREF(1, int8_t, Int8)
DECLARE_BIND_MEMREF(2, int8_t, Int8)
DECLARE_BIND_MEMREF(3, int8_t, Int8)
DECLARE_BIND_MEMREF(1, int16_t, Half)
DECLARE_BIND_MEMREF(2, int16_t, Half)
DECLARE_BIND_MEMREF(3, int16_t, Half)
/// Fills the given 1D float memref with the given float value.
void _mlir_ciface_fillResource1DFloat(MemRefDescriptor<float, 1> *ptr, // NOLINT
float value) {
std::fill_n(ptr->allocated, ptr->sizes[0], value);
}
/// Fills the given 2D float memref with the given float value.
void _mlir_ciface_fillResource2DFloat(MemRefDescriptor<float, 2> *ptr, // NOLINT
float value) {
std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value);
}
/// Fills the given 3D float memref with the given float value.
void _mlir_ciface_fillResource3DFloat(MemRefDescriptor<float, 3> *ptr, // NOLINT
float value) {
std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
value);
}
/// Fills the given 1D int memref with the given int value.
void _mlir_ciface_fillResource1DInt(MemRefDescriptor<int32_t, 1> *ptr, // NOLINT
int32_t value) {
std::fill_n(ptr->allocated, ptr->sizes[0], value);
}
/// Fills the given 2D int memref with the given int value.
void _mlir_ciface_fillResource2DInt(MemRefDescriptor<int32_t, 2> *ptr, // NOLINT
int32_t value) {
std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value);
}
/// Fills the given 3D int memref with the given int value.
void _mlir_ciface_fillResource3DInt(MemRefDescriptor<int32_t, 3> *ptr, // NOLINT
int32_t value) {
std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
value);
}
/// Fills the given 1D int memref with the given int8 value.
void _mlir_ciface_fillResource1DInt8(MemRefDescriptor<int8_t, 1> *ptr, // NOLINT
int8_t value) {
std::fill_n(ptr->allocated, ptr->sizes[0], value);
}
/// Fills the given 2D int memref with the given int8 value.
void _mlir_ciface_fillResource2DInt8(MemRefDescriptor<int8_t, 2> *ptr, // NOLINT
int8_t value) {
std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value);
}
/// Fills the given 3D int memref with the given int8 value.
void _mlir_ciface_fillResource3DInt8(MemRefDescriptor<int8_t, 3> *ptr, // NOLINT
int8_t value) {
std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
value);
}
}
|