1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
|
/*========================== begin_copyright_notice ============================
Copyright (C) 2018-2024 Intel Corporation
SPDX-License-Identifier: MIT
============================= end_copyright_notice ===========================*/
#include "llvmWrapper/IR/DerivedTypes.h"
#include "PacketBuilder.h"
#include "Probe/Assertion.h"
using namespace llvm;
namespace pktz {
//////////////////////////////////////////////////////////////////////////
/// @brief Contructor for Builder.
/// @param pJitMgr - JitManager which contains modules, function passes, etc.
PacketBuilder::PacketBuilder(Module *MIn, uint32_t Width) {
M = static_cast<IGCLLVM::Module *>(MIn);
// Built in types: scalar
LLVMContext &Ctx = getContext();
IRB = new IGCLLVM::IRBuilder<>(Ctx);
FP32Ty = Type::getFloatTy(Ctx);
Int1Ty = Type::getInt1Ty(Ctx);
Int8Ty = Type::getInt8Ty(Ctx);
Int16Ty = Type::getInt16Ty(Ctx);
Int32Ty = Type::getInt32Ty(Ctx);
Int64Ty = Type::getInt64Ty(Ctx);
// Built in types: target simd
setTargetWidth(Width);
}
void PacketBuilder::setTargetWidth(uint32_t Width) {
VWidth = Width;
SimdInt32Ty = IGCLLVM::FixedVectorType::get(Int32Ty, VWidth);
SimdFP32Ty = IGCLLVM::FixedVectorType::get(FP32Ty, VWidth);
}
//////////////////////////////////////////////////////////////////////////
/// @brief Packetizes the type. Assumes SOA conversion.
Type *PacketBuilder::getVectorType(Type *Ty) {
if (Ty->isVoidTy())
return Ty;
if (auto *VecTy = dyn_cast<IGCLLVM::FixedVectorType>(Ty)) {
uint32_t VecSize = VecTy->getNumElements();
auto *ElemTy = VecTy->getElementType();
return IGCLLVM::FixedVectorType::get(ElemTy, VecSize * VWidth);
}
// [N x float] should packetize to [N x <8 x float>]
if (Ty->isArrayTy()) {
uint32_t ArrSize = Ty->getArrayNumElements();
auto *ArrTy = Ty->getArrayElementType();
auto *VecArrTy = getVectorType(ArrTy);
return ArrayType::get(VecArrTy, ArrSize);
}
// {float,int} should packetize to {<8 x float>, <8 x int>}
if (Ty->isAggregateType()) {
uint32_t NumElems = Ty->getStructNumElements();
SmallVector<Type *, 8> VecTys;
for (uint32_t Idx = 0; Idx < NumElems; ++Idx) {
auto *ElemTy = Ty->getStructElementType(Idx);
auto *VecElemTy = getVectorType(ElemTy);
VecTys.push_back(VecElemTy);
}
return StructType::get(getContext(), VecTys);
}
// <ty> should packetize to <8 x <ty>>
return IGCLLVM::FixedVectorType::get(Ty, VWidth);
}
} // end of namespace pktz
|