File: LdShrink.cpp

package info (click to toggle)
intel-graphics-compiler 1.0.12504.6-1%2Bdeb12u1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 83,912 kB
  • sloc: cpp: 910,147; lisp: 202,655; ansic: 15,197; python: 4,025; yacc: 2,241; lex: 1,570; pascal: 244; sh: 104; makefile: 25
file content (156 lines) | stat: -rw-r--r-- 5,217 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
/*========================== begin_copyright_notice ============================

Copyright (C) 2017-2021 Intel Corporation

SPDX-License-Identifier: MIT

============================= end_copyright_notice ===========================*/

#include "common/LLVMWarningsPush.hpp"
#include <llvm/Pass.h>
#include <llvm/IR/DataLayout.h>
#include <llvm/IR/IRBuilder.h>
#include <llvmWrapper/Support/Alignment.h>
#include <llvm/Support/Debug.h>
#include <llvm/Support/MathExtras.h>
#include <llvm/Support/raw_ostream.h>
#include <llvmWrapper/IR/DerivedTypes.h>
#include "common/LLVMWarningsPop.hpp"
#include "Compiler/CISACodeGen/ShaderCodeGen.hpp"
#include "Compiler/IGCPassSupport.h"
#include "Compiler/CISACodeGen/LdShrink.h"
#include "Probe/Assertion.h"

using namespace llvm;
using namespace IGC;

namespace {

    // A simple pass to shrink vector load into scalar or narrow vector load
    // when only partial elements are used.
    class LdShrink : public FunctionPass {
        const DataLayout* DL;

    public:
        static char ID;

        LdShrink() : FunctionPass(ID) {
            initializeLdShrinkPass(*PassRegistry::getPassRegistry());
        }

        bool runOnFunction(Function& F) override;

    private:
        void getAnalysisUsage(AnalysisUsage& AU) const override {
            AU.setPreservesCFG();
        }

        unsigned getExtractIndexMask(LoadInst* LI) const;
    };

    char LdShrink::ID = 0;

} // End anonymous namespace

FunctionPass* createLdShrinkPass() {
    return new LdShrink();
}

#define PASS_FLAG     "igc-ldshrink"
#define PASS_DESC     "IGC Load Shrink"
#define PASS_CFG_ONLY false
#define PASS_ANALYSIS false
IGC_INITIALIZE_PASS_BEGIN(LdShrink, PASS_FLAG, PASS_DESC, PASS_CFG_ONLY, PASS_ANALYSIS)
IGC_INITIALIZE_PASS_END(LdShrink, PASS_FLAG, PASS_DESC, PASS_CFG_ONLY, PASS_ANALYSIS)

unsigned LdShrink::getExtractIndexMask(LoadInst* LI) const {
    IGCLLVM::FixedVectorType* VTy = dyn_cast<IGCLLVM::FixedVectorType>(LI->getType());
    // Skip non-vector loads.
    if (!VTy)
        return 0;
    // Skip if there are more than 32 elements.
    if (VTy->getNumElements() > 32)
        return 0;
    // Check whether all users are ExtractElement with constant index.
    // Collect index mask at the same time.
    Type* Ty = VTy->getScalarType();
    // Skip non-BYTE addressable data types. So far, check integer types
    // only.
    if (IntegerType * ITy = dyn_cast<IntegerType>(Ty)) {
        // Unroll isPowerOf2ByteWidth, it was removed in LLVM 12.
        unsigned BitWidth = ITy->getBitWidth();
        if (!((BitWidth > 7) && isPowerOf2_32(BitWidth)))
            return 0;
    }

    unsigned Mask = 0; // Maxmimally 32 elements.

    for (auto UI = LI->user_begin(), UE = LI->user_end(); UI != UE; ++UI) {
        ExtractElementInst* EEI = dyn_cast<ExtractElementInst>(*UI);
        if (!EEI)
            return 0;
        // Skip non-constant index.
        auto Idx = dyn_cast<ConstantInt>(EEI->getIndexOperand());
        if (!Idx)
            return 0;
        IGC_ASSERT_MESSAGE(Idx->getZExtValue() < 32, "Index is out of range!");
        Mask |= (1 << Idx->getZExtValue());
    }

    return Mask;
}

bool LdShrink::runOnFunction(Function& F) {
    DL = &F.getParent()->getDataLayout();
    if (!DL)
        return false;

    bool Changed = false;
    for (auto& BB : F) {
        for (auto BI = BB.begin(), BE = BB.end(); BI != BE; /*EMPTY*/) {
            LoadInst* LI = dyn_cast<LoadInst>(BI++);
            // Skip non-load instructions.
            if (!LI)
                continue;
            // Skip non-simple load.
            if (!LI->isSimple())
                continue;
            // Replace it with scalar load or narrow vector load.
            unsigned Mask = getExtractIndexMask(LI);
            if (!Mask)
                continue;
            if (!isShiftedMask_32(Mask))
                continue;
            unsigned Offset = llvm::countTrailingZeros(Mask);
            unsigned Length = llvm::countTrailingZeros((Mask >> Offset) + 1);
            // TODO: So far skip narrow vector.
            if (Length != 1)
                continue;

            IGCLLVM::IRBuilder<> Builder(LI);

            // Shrink it to scalar load.
            auto Ptr = LI->getPointerOperand();
            Type* Ty = LI->getType();
            Type* ScalarTy = Ty->getScalarType();
            PointerType* PtrTy = cast<PointerType>(Ptr->getType());
            PointerType* ScalarPtrTy
                = PointerType::get(ScalarTy, PtrTy->getAddressSpace());
            Value* ScalarPtr = Builder.CreatePointerCast(Ptr, ScalarPtrTy);
            if (Offset)
                ScalarPtr = Builder.CreateInBoundsGEP(ScalarPtr, Builder.getInt32(Offset));

            unsigned alignment
                = int_cast<unsigned int>(MinAlign(LI->getAlignment(),
                    DL->getTypeStoreSize(ScalarTy) * Offset));

            LoadInst* NewLoad = Builder.CreateAlignedLoad(ScalarPtr, IGCLLVM::getAlign(alignment));
            NewLoad->setDebugLoc(LI->getDebugLoc());

            ExtractElementInst* EEI = cast<ExtractElementInst>(*LI->user_begin());
            EEI->replaceAllUsesWith(NewLoad);
        }
    }

    return Changed;
}