File: data_graph_objects.h

package info (click to toggle)
vulkan-validationlayers 1.4.328.1-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 49,412 kB
  • sloc: cpp: 615,223; python: 12,115; sh: 24; makefile: 20; xml: 14
file content (107 lines) | stat: -rw-r--r-- 4,287 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
/*
 * Copyright (c) 2025 The Khronos Group Inc.
 * Copyright (C) 2025 Arm Limited.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 */
#pragma once

#include "binding.h"
#include "descriptor_helper.h"

/**
 * This file defines the structures needed to create a simple DataGraphPipelineARM.
 * dg::ShaderModule creates a vkt::ShaderModule using SPIRV generated from a simple shader which uses a Tensor
 *
 *  #extension GL_ARM_tensors : require
 *  #extension GL_EXT_shader_explicit_arithmetic_types : require
 *  layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
 *  layout(set=0, binding=0) uniform tensorARM<int32_t, 1> tens;
 *  void main()
 *  {
 *      const uint size_x = tensorSizeARM(tens, 0);
 *  }
 *
 *  dg::PipelineLayout creates a vkt::PipelineLayout which correlates to the shader module provided above
 */

namespace vkt {
namespace dg {

class CreateDataGraphPipelineHelper {
  public:
    std::vector<VkDescriptorSetLayoutBinding> descriptor_set_layout_bindings_;
    std::unique_ptr<OneOffDescriptorSet> descriptor_set_;
    VkPipelineLayoutCreateInfo pipeline_layout_ci_ = {};
    vkt::PipelineLayout pipeline_layout_;
    VkDataGraphPipelineCreateInfoARM pipeline_ci_ = {};
    vkt::ShaderModule shader_;
    VkDataGraphPipelineShaderModuleCreateInfoARM shader_module_ci_;
    std::vector<VkDataGraphPipelineResourceInfoARM> resources_;
    vkt::Tensor in_tensor_;
    vkt::Tensor out_tensor_;
    vkt::TensorView in_tensor_view_;
    vkt::TensorView out_tensor_view_;

    VkLayerTest &layer_test_;
    vkt::Device *device_;

    explicit CreateDataGraphPipelineHelper(VkLayerTest &test, bool is_data_graph = false, bool protected_tensors = false, const char *inserted_line = "");
    virtual ~CreateDataGraphPipelineHelper();
    void Destroy();

    static std::string GetGraphSpirvSource(const char *inserted_line = "");
    void CreateComputeShaderModule();
    void CreateGraphShaderModule(const char *spirv_text);
    void InitPipelineResources(const std::vector<vkt::Tensor *> &tensors = {},
                               VkDescriptorType desc_type = VK_DESCRIPTOR_TYPE_TENSOR_ARM,
                               VkDescriptorSetLayoutCreateFlags layout_flags = 0);
    void CreatePipelineLayout(const std::vector<VkPushConstantRange> &push_constant_ranges = {});

    const VkPipeline &Handle() const { return pipeline_; }

    VkResult CreateDataGraphPipeline();

    // Helper function to create a simple test case (positive or negative)
    //
    // info_override can be any callable that takes a CreatePipelineHelper &
    // flags, error can be any args accepted by "SetDesiredFailure".
    template <typename Test, typename OverrideFunc, typename Error>
    static void OneshotTest(Test &test, const OverrideFunc &info_override, const VkFlags flags, const std::vector<Error> &errors,
                            bool positive_test = false) {
        CreateDataGraphPipelineHelper helper(test);
        info_override(helper);
        // Allow lambda to decide if to skip trying to compile pipeline to prevent crashing
        for (const auto &error : errors) {
            test.Monitor().SetDesiredFailureMsg(flags, error);
        }
        helper.CreateDataGraphPipeline();

        if (!errors.empty()) {
            test.Monitor().VerifyFound();
        }
    }

    template <typename Test, typename OverrideFunc, typename Error>
    static void OneshotTest(Test &test, const OverrideFunc &info_override, const VkFlags flags, Error error) {
        OneshotTest(test, info_override, flags, std::vector<Error>(1, error));
    }

    template <typename Test, typename OverrideFunc>
    static void OneshotTest(Test &test, const OverrideFunc &info_override, const VkFlags flags) {
        OneshotTest(test, info_override, flags, std::vector<std::string>{});
    }

  private:
    void CreateShaderModule(const char *spirv_source);
    void InitTensor(vkt::Tensor &tensor, vkt::TensorView &tensor_view, const std::vector<int64_t> &tensor_dims, bool is_protected);

    VkPipeline pipeline_ = VK_NULL_HANDLE;
};

}  // namespace dg
}  // namespace vkt