1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
|
/*========================== begin_copyright_notice ============================
Copyright (C) 2020-2021 Intel Corporation
SPDX-License-Identifier: MIT
============================= end_copyright_notice ===========================*/
#pragma once
#include "RTStackFormat.h"
#include "common/MDFrameWork.h"
#include "Compiler/CodeGenPublic.h"
#include "RTBuilder.h"
#include "common/LLVMWarningsPush.hpp"
#include <llvm/IR/Function.h>
#include <llvm/ADT/Optional.h>
#include "common/LLVMWarningsPop.hpp"
namespace IGC {
// Simple class to query the position of each arg in a raytracing shader.
class ArgQuery
{
public:
ArgQuery(const llvm::Function& F, const CodeGenContext& Ctx);
ArgQuery(const FunctionMetaData& FMD);
ArgQuery(CallableShaderTypeMD FuncType, const FunctionMetaData& FMD);
llvm::Argument* getPayloadArg(const llvm::Function *F) const;
llvm::Argument* getHitAttribArg(const llvm::Function *F) const;
private:
// Unlike DXR, Vulkan raytracing allows the optional specification of these
// values.
llvm::Optional<uint32_t> TraceRayPayloadIdx;
llvm::Optional<uint32_t> HitAttributeIdx;
llvm::Optional<uint32_t> CallableShaderPayloadIdx;
CallableShaderTypeMD ShaderTy = NumberOfCallableShaderTypes;
private:
llvm::Optional<uint32_t> getPayloadArgNo() const;
llvm::Optional<uint32_t> getHitAttribArgNo() const;
const llvm::Argument* getArg(
const llvm::Function *F,
llvm::Optional<uint32_t> ArgNo) const;
void init(CallableShaderTypeMD FuncType, const FunctionMetaData& FMD);
};
// When invoking a TraceRay() or CallShader(), these are the arguments that you
// must supply to the callee
class TraceRayRTArgs
{
public:
using ReturnIP = uint64_t;
using PointerSize = uint64_t;
llvm::Value* getReturnIPPtr(
llvm::IRBuilder<>& IRB,
llvm::Type* PayloadTy,
llvm::RTBuilder::SWStackPtrVal* FrameAddr,
const llvm::Twine &FrameName = "");
llvm::Value* getPayloadPtr(
llvm::IRBuilder<>& IRB,
llvm::Type* PayloadTy,
llvm::RTBuilder::SWStackPtrVal* FrameAddr,
const llvm::Twine &FrameName = "");
llvm::Value* getPayloadPaddingPtr(
llvm::IRBuilder<>& IRB,
llvm::Type* PayloadTy,
llvm::RTBuilder::SWStackPtrVal* FrameAddr,
const llvm::Twine &FrameName = "");
bool needPayloadPadding() const;
RayDispatchShaderContext &Ctx;
public:
static constexpr uint32_t ReturnIPSlot = 0;
static constexpr uint32_t PayloadSlot = 1;
static constexpr uint32_t PayloadPaddingSlot = 2;
protected:
TraceRayRTArgs(
RayDispatchShaderContext &Ctx,
RayTracingSWTypes& RTSWTypes,
const llvm::DataLayout &DL);
static uint32_t getReturnIPOffset();
static uint32_t getPayloadOffset();
RayTracingSWTypes& RTSWTypes;
uint32_t SWStackAddrSpace = 0;
const llvm::DataLayout& DL;
private:
using TypeCacheTy = llvm::DenseMap<llvm::PointerType*, llvm::StructType*>;
llvm::PointerType* getType(llvm::PointerType* PayloadTy);
TypeCacheTy ExistingStructs;
TypeCacheTy& getCache();
};
// Handle arguments in a raytracing shader.
class RTArgs : public TraceRayRTArgs
{
public:
RTArgs(
const llvm::Function *RootFunc,
CallableShaderTypeMD FuncType,
llvm::Optional<RTStackFormat::HIT_GROUP_TYPE> HitGroupTy,
RayDispatchShaderContext *Ctx,
const FunctionMetaData& FMD,
RayTracingSWTypes &RTSWTypes,
bool LogStackFrameEntries = false);
public:
llvm::Argument* getPayloadArg(const llvm::Function *F) const;
llvm::Argument* getHitAttribArg(const llvm::Function *F) const;
llvm::Value* getCustomHitAttribPtr(
llvm::IRBuilder<> &IRB,
llvm::RTBuilder::SWStackPtrVal* FrameAddr,
llvm::Type* CustomHitAttrTy);
llvm::Value* getHitKindPtr(
llvm::IRBuilder<> &IRB,
llvm::RTBuilder::SWStackPtrVal* FrameAddr);
bool isProcedural() const;
public:
CallableShaderTypeMD FuncType;
const llvm::Function *RootFunction = nullptr;
protected:
bool LogStackFrameEntries = false;
const FunctionMetaData& FMD;
protected:
uint32_t getCustomHitAttrOffset() const;
uint32_t getHitKindOffset() const;
uint32_t ArgumentSize = 0;
void addArguments();
llvm::SmallVector<llvm::Type*, 4> ArgTys;
std::vector<StackFrameEntry> ArgumentEntries;
void recordArgEntry(
const std::string &Name,
const std::string &TypeRepr,
uint32_t Size,
uint32_t Offset);
llvm::StructType* getArgumentType(llvm::Type* CustomHitAttrTy = nullptr);
private:
using TypeCacheTy =
llvm::DenseMap<
std::pair<llvm::PointerType*, llvm::Type*>,
llvm::StructType*>;
TypeCacheTy ExistingStructs;
TypeCacheTy& getCache();
llvm::Optional<RTStackFormat::HIT_GROUP_TYPE> HitGroupTy;
llvm::Optional<uint32_t> HitKindSlot;
llvm::Optional<uint32_t> CustomHitAttrSlot;
ArgQuery Args;
};
} // namespace IGC
|