1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
|
/*
* Copyright (c) 2025 The Khronos Group Inc.
* Copyright (C) 2025 Arm Limited.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*/
#pragma once
#include "binding.h"
#include "descriptor_helper.h"
/**
* This file defines the structures needed to create a simple DataGraphPipelineARM.
* dg::ShaderModule creates a vkt::ShaderModule using SPIRV generated from a simple shader which uses a Tensor
*
* #extension GL_ARM_tensors : require
* #extension GL_EXT_shader_explicit_arithmetic_types : require
* layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
* layout(set=0, binding=0) uniform tensorARM<int32_t, 1> tens;
* void main()
* {
* const uint size_x = tensorSizeARM(tens, 0);
* }
*
* dg::PipelineLayout creates a vkt::PipelineLayout which correlates to the shader module provided above
*/
namespace vkt {
namespace dg {
class CreateDataGraphPipelineHelper {
public:
std::vector<VkDescriptorSetLayoutBinding> descriptor_set_layout_bindings_;
std::unique_ptr<OneOffDescriptorSet> descriptor_set_;
VkPipelineLayoutCreateInfo pipeline_layout_ci_ = {};
vkt::PipelineLayout pipeline_layout_;
VkDataGraphPipelineCreateInfoARM pipeline_ci_ = {};
vkt::ShaderModule shader_;
VkDataGraphPipelineShaderModuleCreateInfoARM shader_module_ci_;
std::vector<VkDataGraphPipelineResourceInfoARM> resources_;
vkt::Tensor in_tensor_;
vkt::Tensor out_tensor_;
vkt::TensorView in_tensor_view_;
vkt::TensorView out_tensor_view_;
VkLayerTest &layer_test_;
vkt::Device *device_;
explicit CreateDataGraphPipelineHelper(VkLayerTest &test, bool is_data_graph = false, bool protected_tensors = false, const char *inserted_line = "");
virtual ~CreateDataGraphPipelineHelper();
void Destroy();
static std::string GetGraphSpirvSource(const char *inserted_line = "");
void CreateComputeShaderModule();
void CreateGraphShaderModule(const char *spirv_text);
void InitPipelineResources(const std::vector<vkt::Tensor *> &tensors = {},
VkDescriptorType desc_type = VK_DESCRIPTOR_TYPE_TENSOR_ARM,
VkDescriptorSetLayoutCreateFlags layout_flags = 0);
void CreatePipelineLayout(const std::vector<VkPushConstantRange> &push_constant_ranges = {});
const VkPipeline &Handle() const { return pipeline_; }
VkResult CreateDataGraphPipeline();
// Helper function to create a simple test case (positive or negative)
//
// info_override can be any callable that takes a CreatePipelineHelper &
// flags, error can be any args accepted by "SetDesiredFailure".
template <typename Test, typename OverrideFunc, typename Error>
static void OneshotTest(Test &test, const OverrideFunc &info_override, const VkFlags flags, const std::vector<Error> &errors,
bool positive_test = false) {
CreateDataGraphPipelineHelper helper(test);
info_override(helper);
// Allow lambda to decide if to skip trying to compile pipeline to prevent crashing
for (const auto &error : errors) {
test.Monitor().SetDesiredFailureMsg(flags, error);
}
helper.CreateDataGraphPipeline();
if (!errors.empty()) {
test.Monitor().VerifyFound();
}
}
template <typename Test, typename OverrideFunc, typename Error>
static void OneshotTest(Test &test, const OverrideFunc &info_override, const VkFlags flags, Error error) {
OneshotTest(test, info_override, flags, std::vector<Error>(1, error));
}
template <typename Test, typename OverrideFunc>
static void OneshotTest(Test &test, const OverrideFunc &info_override, const VkFlags flags) {
OneshotTest(test, info_override, flags, std::vector<std::string>{});
}
private:
void CreateShaderModule(const char *spirv_source);
void InitTensor(vkt::Tensor &tensor, vkt::TensorView &tensor_view, const std::vector<int64_t> &tensor_dims, bool is_protected);
VkPipeline pipeline_ = VK_NULL_HANDLE;
};
} // namespace dg
} // namespace vkt
|